1420 lines
54 KiB
Python
Executable file
1420 lines
54 KiB
Python
Executable file
from collections import defaultdict
|
|
from io import open
|
|
import os
|
|
import sys
|
|
import traceback
|
|
import logging
|
|
import threading
|
|
import re
|
|
import stat
|
|
import fileinput
|
|
from collections import namedtuple
|
|
try:
|
|
from pwd import getpwuid
|
|
except ImportError:
|
|
pass
|
|
from time import time
|
|
from datetime import datetime
|
|
from random import choice
|
|
|
|
from pymysql import OperationalError
|
|
from cli_helpers.tabular_output import TabularOutputFormatter
|
|
from cli_helpers.tabular_output import preprocessors
|
|
from cli_helpers.utils import strip_ansi
|
|
import click
|
|
import sqlparse
|
|
from mycli.packages.parseutils import is_dropping_database, is_destructive
|
|
from prompt_toolkit.completion import DynamicCompleter
|
|
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
|
|
from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register
|
|
from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
|
|
from prompt_toolkit.document import Document
|
|
from prompt_toolkit.filters import HasFocus, IsDone
|
|
from prompt_toolkit.formatted_text import ANSI
|
|
from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor,
|
|
ConditionalProcessor)
|
|
from prompt_toolkit.lexers import PygmentsLexer
|
|
from prompt_toolkit.history import FileHistory
|
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
|
|
|
from .packages.special.main import NO_QUERY
|
|
from .packages.prompt_utils import confirm, confirm_destructive_query
|
|
from .packages.tabular_output import sql_format
|
|
from .packages import special
|
|
from .packages.special.favoritequeries import FavoriteQueries
|
|
from .sqlcompleter import SQLCompleter
|
|
from .clitoolbar import create_toolbar_tokens_func
|
|
from .clistyle import style_factory, style_factory_output
|
|
from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
|
|
from .clibuffer import cli_is_multiline
|
|
from .completion_refresher import CompletionRefresher
|
|
from .config import (write_default_config, get_mylogin_cnf_path,
|
|
open_mylogin_cnf, read_config_files, str_to_bool,
|
|
strip_matching_quotes)
|
|
from .key_bindings import mycli_bindings
|
|
from .lexer import MyCliLexer
|
|
from . import __version__
|
|
from .compat import WIN
|
|
from .packages.filepaths import dir_path_exists, guess_socket_location
|
|
|
|
import itertools
|
|
|
|
click.disable_unicode_literals_warning = True
|
|
|
|
try:
|
|
from urlparse import urlparse
|
|
from urlparse import unquote
|
|
except ImportError:
|
|
from urllib.parse import urlparse
|
|
from urllib.parse import unquote
|
|
|
|
try:
|
|
import importlib.resources as resources
|
|
except ImportError:
|
|
# Python < 3.7
|
|
import importlib_resources as resources
|
|
|
|
try:
|
|
import paramiko
|
|
except ImportError:
|
|
from mycli.packages.paramiko_stub import paramiko
|
|
|
|
# Query tuples are used for maintaining history
|
|
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
|
|
|
|
SUPPORT_INFO = (
|
|
'Home: http://mycli.net\n'
|
|
'Bug tracker: https://github.com/dbcli/mycli/issues'
|
|
)
|
|
|
|
|
|
class MyCli(object):
|
|
|
|
default_prompt = '\\t \\u@\\h:\\d> '
|
|
max_len_prompt = 45
|
|
defaults_suffix = None
|
|
|
|
# In order of being loaded. Files lower in list override earlier ones.
|
|
cnf_files = [
|
|
'/etc/my.cnf',
|
|
'/etc/mysql/my.cnf',
|
|
'/usr/local/etc/my.cnf',
|
|
os.path.expanduser('~/.my.cnf'),
|
|
]
|
|
|
|
# check XDG_CONFIG_HOME exists and not an empty string
|
|
if os.environ.get("XDG_CONFIG_HOME"):
|
|
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
|
|
else:
|
|
xdg_config_home = "~/.config"
|
|
system_config_files = [
|
|
'/etc/myclirc',
|
|
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
|
|
]
|
|
|
|
pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
|
|
|
|
def __init__(self, sqlexecute=None, prompt=None,
|
|
logfile=None, defaults_suffix=None, defaults_file=None,
|
|
login_path=None, auto_vertical_output=False, warn=None,
|
|
myclirc="~/.myclirc"):
|
|
self.sqlexecute = sqlexecute
|
|
self.logfile = logfile
|
|
self.defaults_suffix = defaults_suffix
|
|
self.login_path = login_path
|
|
|
|
# self.cnf_files is a class variable that stores the list of mysql
|
|
# config files to read in at launch.
|
|
# If defaults_file is specified then override the class variable with
|
|
# defaults_file.
|
|
if defaults_file:
|
|
self.cnf_files = [defaults_file]
|
|
|
|
# Load config.
|
|
config_files = (self.system_config_files +
|
|
[myclirc] + [self.pwd_config_file])
|
|
c = self.config = read_config_files(config_files)
|
|
self.multi_line = c['main'].as_bool('multi_line')
|
|
self.key_bindings = c['main']['key_bindings']
|
|
special.set_timing_enabled(c['main'].as_bool('timing'))
|
|
|
|
FavoriteQueries.instance = FavoriteQueries.from_config(self.config)
|
|
|
|
self.dsn_alias = None
|
|
self.formatter = TabularOutputFormatter(
|
|
format_name=c['main']['table_format'])
|
|
sql_format.register_new_formatter(self.formatter)
|
|
self.formatter.mycli = self
|
|
self.syntax_style = c['main']['syntax_style']
|
|
self.less_chatty = c['main'].as_bool('less_chatty')
|
|
self.cli_style = c['colors']
|
|
self.output_style = style_factory_output(
|
|
self.syntax_style,
|
|
self.cli_style
|
|
)
|
|
self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
|
|
c_dest_warning = c['main'].as_bool('destructive_warning')
|
|
self.destructive_warning = c_dest_warning if warn is None else warn
|
|
self.login_path_as_host = c['main'].as_bool('login_path_as_host')
|
|
|
|
# read from cli argument or user config file
|
|
self.auto_vertical_output = auto_vertical_output or \
|
|
c['main'].as_bool('auto_vertical_output')
|
|
|
|
# Write user config if system config wasn't the last config loaded.
|
|
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
|
|
write_default_config(myclirc)
|
|
|
|
# audit log
|
|
if self.logfile is None and 'audit_log' in c['main']:
|
|
try:
|
|
self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a')
|
|
except (IOError, OSError) as e:
|
|
self.echo('Error: Unable to open the audit log file. Your queries will not be logged.',
|
|
err=True, fg='red')
|
|
self.logfile = False
|
|
|
|
self.completion_refresher = CompletionRefresher()
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
self.initialize_logging()
|
|
|
|
prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt']
|
|
self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \
|
|
self.default_prompt
|
|
self.multiline_continuation_char = c['main']['prompt_continuation']
|
|
keyword_casing = c['main'].get('keyword_casing', 'auto')
|
|
|
|
self.query_history = []
|
|
|
|
# Initialize completer.
|
|
self.smart_completion = c['main'].as_bool('smart_completion')
|
|
self.completer = SQLCompleter(
|
|
self.smart_completion,
|
|
supported_formats=self.formatter.supported_formats,
|
|
keyword_casing=keyword_casing)
|
|
self._completer_lock = threading.Lock()
|
|
|
|
# Register custom special commands.
|
|
self.register_special_commands()
|
|
|
|
# Load .mylogin.cnf if it exists.
|
|
mylogin_cnf_path = get_mylogin_cnf_path()
|
|
if mylogin_cnf_path:
|
|
mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
|
|
if mylogin_cnf_path and mylogin_cnf:
|
|
# .mylogin.cnf gets read last, even if defaults_file is specified.
|
|
self.cnf_files.append(mylogin_cnf)
|
|
elif mylogin_cnf_path and not mylogin_cnf:
|
|
# There was an error reading the login path file.
|
|
print('Error: Unable to read login path file.')
|
|
|
|
self.prompt_app = None
|
|
|
|
def register_special_commands(self):
|
|
special.register_special_command(self.change_db, 'use',
|
|
'\\u', 'Change to a new database.', aliases=('\\u',))
|
|
special.register_special_command(self.change_db, 'connect',
|
|
'\\r', 'Reconnect to the database. Optional database argument.',
|
|
aliases=('\\r', ), case_sensitive=True)
|
|
special.register_special_command(self.refresh_completions, 'rehash',
|
|
'\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',))
|
|
special.register_special_command(
|
|
self.change_table_format, 'tableformat', '\\T',
|
|
'Change the table format used to output results.',
|
|
aliases=('\\T',), case_sensitive=True)
|
|
special.register_special_command(self.execute_from_file, 'source', '\\. filename',
|
|
'Execute commands from file.', aliases=('\\.',))
|
|
special.register_special_command(self.change_prompt_format, 'prompt',
|
|
'\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True)
|
|
|
|
def change_table_format(self, arg, **_):
|
|
try:
|
|
self.formatter.format_name = arg
|
|
yield (None, None, None,
|
|
'Changed table format to {}'.format(arg))
|
|
except ValueError:
|
|
msg = 'Table format {} not recognized. Allowed formats:'.format(
|
|
arg)
|
|
for table_type in self.formatter.supported_formats:
|
|
msg += "\n\t{}".format(table_type)
|
|
yield (None, None, None, msg)
|
|
|
|
def change_db(self, arg, **_):
|
|
if not arg:
|
|
click.secho(
|
|
"No database selected",
|
|
err=True, fg="red"
|
|
)
|
|
return
|
|
|
|
if arg.startswith('`') and arg.endswith('`'):
|
|
arg = re.sub(r'^`(.*)`$', r'\1', arg)
|
|
arg = re.sub(r'``', r'`', arg)
|
|
self.sqlexecute.change_db(arg)
|
|
|
|
yield (None, None, None, 'You are now connected to database "%s" as '
|
|
'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user))
|
|
|
|
def execute_from_file(self, arg, **_):
|
|
if not arg:
|
|
message = 'Missing required argument, filename.'
|
|
return [(None, None, None, message)]
|
|
try:
|
|
with open(os.path.expanduser(arg)) as f:
|
|
query = f.read()
|
|
except IOError as e:
|
|
return [(None, None, None, str(e))]
|
|
|
|
if (self.destructive_warning and
|
|
confirm_destructive_query(query) is False):
|
|
message = 'Wise choice. Command execution stopped.'
|
|
return [(None, None, None, message)]
|
|
|
|
return self.sqlexecute.run(query)
|
|
|
|
def change_prompt_format(self, arg, **_):
|
|
"""
|
|
Change the prompt format.
|
|
"""
|
|
if not arg:
|
|
message = 'Missing required argument, format.'
|
|
return [(None, None, None, message)]
|
|
|
|
self.prompt_format = self.get_prompt(arg)
|
|
return [(None, None, None, "Changed prompt format to %s" % arg)]
|
|
|
|
def initialize_logging(self):
|
|
|
|
log_file = os.path.expanduser(self.config['main']['log_file'])
|
|
log_level = self.config['main']['log_level']
|
|
|
|
level_map = {'CRITICAL': logging.CRITICAL,
|
|
'ERROR': logging.ERROR,
|
|
'WARNING': logging.WARNING,
|
|
'INFO': logging.INFO,
|
|
'DEBUG': logging.DEBUG
|
|
}
|
|
|
|
# Disable logging if value is NONE by switching to a no-op handler
|
|
# Set log level to a high value so it doesn't even waste cycles getting called.
|
|
if log_level.upper() == "NONE":
|
|
handler = logging.NullHandler()
|
|
log_level = "CRITICAL"
|
|
elif dir_path_exists(log_file):
|
|
handler = logging.FileHandler(log_file)
|
|
else:
|
|
self.echo(
|
|
'Error: Unable to open the log file "{}".'.format(log_file),
|
|
err=True, fg='red')
|
|
return
|
|
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s (%(process)d/%(threadName)s) '
|
|
'%(name)s %(levelname)s - %(message)s')
|
|
|
|
handler.setFormatter(formatter)
|
|
|
|
root_logger = logging.getLogger('mycli')
|
|
root_logger.addHandler(handler)
|
|
root_logger.setLevel(level_map[log_level.upper()])
|
|
|
|
logging.captureWarnings(True)
|
|
|
|
root_logger.debug('Initializing mycli logging.')
|
|
root_logger.debug('Log file %r.', log_file)
|
|
|
|
|
|
def read_my_cnf_files(self, files, keys):
|
|
"""
|
|
Reads a list of config files and merges them. The last one will win.
|
|
:param files: list of files to read
|
|
:param keys: list of keys to retrieve
|
|
:returns: tuple, with None for missing keys.
|
|
"""
|
|
cnf = read_config_files(files, list_values=False)
|
|
|
|
sections = ['client', 'mysqld']
|
|
key_transformations = {
|
|
'mysqld': {
|
|
'socket': 'default_socket',
|
|
'port': 'default_port',
|
|
},
|
|
}
|
|
|
|
if self.login_path and self.login_path != 'client':
|
|
sections.append(self.login_path)
|
|
|
|
if self.defaults_suffix:
|
|
sections.extend([sect + self.defaults_suffix for sect in sections])
|
|
|
|
configuration = defaultdict(lambda: None)
|
|
for key in keys:
|
|
for section in cnf:
|
|
if (
|
|
section not in sections or
|
|
key not in cnf[section]
|
|
):
|
|
continue
|
|
new_key = key_transformations.get(section, {}).get(key) or key
|
|
configuration[new_key] = strip_matching_quotes(
|
|
cnf[section][key])
|
|
|
|
return configuration
|
|
|
|
|
|
def merge_ssl_with_cnf(self, ssl, cnf):
|
|
"""Merge SSL configuration dict with cnf dict"""
|
|
|
|
merged = {}
|
|
merged.update(ssl)
|
|
prefix = 'ssl-'
|
|
for k, v in cnf.items():
|
|
# skip unrelated options
|
|
if not k.startswith(prefix):
|
|
continue
|
|
if v is None:
|
|
continue
|
|
# special case because PyMySQL argument is significantly different
|
|
# from commandline
|
|
if k == 'ssl-verify-server-cert':
|
|
merged['check_hostname'] = v
|
|
else:
|
|
# use argument name just strip "ssl-" prefix
|
|
arg = k[len(prefix):]
|
|
merged[arg] = v
|
|
|
|
return merged
|
|
|
|
def connect(self, database='', user='', passwd='', host='', port='',
|
|
socket='', charset='', local_infile='', ssl='',
|
|
ssh_user='', ssh_host='', ssh_port='',
|
|
ssh_password='', ssh_key_filename='', init_command='', password_file=''):
|
|
|
|
cnf = {'database': None,
|
|
'user': None,
|
|
'password': None,
|
|
'host': None,
|
|
'port': None,
|
|
'socket': None,
|
|
'default_socket': None,
|
|
'default-character-set': None,
|
|
'local-infile': None,
|
|
'loose-local-infile': None,
|
|
'ssl-ca': None,
|
|
'ssl-cert': None,
|
|
'ssl-key': None,
|
|
'ssl-cipher': None,
|
|
'ssl-verify-serer-cert': None,
|
|
}
|
|
|
|
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
|
|
|
|
# Fall back to config values only if user did not specify a value.
|
|
database = database or cnf['database']
|
|
user = user or cnf['user'] or os.getenv('USER')
|
|
host = host or cnf['host']
|
|
port = port or cnf['port']
|
|
ssl = ssl or {}
|
|
|
|
port = port and int(port)
|
|
if not port:
|
|
port = 3306
|
|
if not host or host == 'localhost':
|
|
socket = (
|
|
cnf['socket'] or
|
|
cnf['default_socket'] or
|
|
guess_socket_location()
|
|
)
|
|
|
|
|
|
passwd = passwd if isinstance(passwd, str) else cnf['password']
|
|
charset = charset or cnf['default-character-set'] or 'utf8'
|
|
|
|
# Favor whichever local_infile option is set.
|
|
for local_infile_option in (local_infile, cnf['local-infile'],
|
|
cnf['loose-local-infile'], False):
|
|
try:
|
|
local_infile = str_to_bool(local_infile_option)
|
|
break
|
|
except (TypeError, ValueError):
|
|
pass
|
|
|
|
ssl = self.merge_ssl_with_cnf(ssl, cnf)
|
|
# prune lone check_hostname=False
|
|
if not any(v for v in ssl.values()):
|
|
ssl = None
|
|
|
|
# if the passwd is not specfied try to set it using the password_file option
|
|
password_from_file = self.get_password_from_file(password_file)
|
|
passwd = passwd or password_from_file
|
|
|
|
# Connect to the database.
|
|
|
|
def _connect():
|
|
try:
|
|
self.sqlexecute = SQLExecute(
|
|
database, user, passwd, host, port, socket, charset,
|
|
local_infile, ssl, ssh_user, ssh_host, ssh_port,
|
|
ssh_password, ssh_key_filename, init_command
|
|
)
|
|
except OperationalError as e:
|
|
if e.args[0] == ERROR_CODE_ACCESS_DENIED:
|
|
if password_from_file:
|
|
new_passwd = password_from_file
|
|
else:
|
|
new_passwd = click.prompt('Password', hide_input=True,
|
|
show_default=False, type=str, err=True)
|
|
self.sqlexecute = SQLExecute(
|
|
database, user, new_passwd, host, port, socket,
|
|
charset, local_infile, ssl, ssh_user, ssh_host,
|
|
ssh_port, ssh_password, ssh_key_filename, init_command
|
|
)
|
|
else:
|
|
raise e
|
|
|
|
try:
|
|
if not WIN and socket:
|
|
socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
|
|
self.echo(
|
|
f"Connecting to socket {socket}, owned by user {socket_owner}", err=True)
|
|
try:
|
|
_connect()
|
|
except OperationalError as e:
|
|
# These are "Can't open socket" and 2x "Can't connect"
|
|
if [code for code in (2001, 2002, 2003) if code == e.args[0]]:
|
|
self.logger.debug('Database connection failed: %r.', e)
|
|
self.logger.error(
|
|
"traceback: %r", traceback.format_exc())
|
|
self.logger.debug('Retrying over TCP/IP')
|
|
self.echo(
|
|
"Failed to connect to local MySQL server through socket '{}':".format(socket))
|
|
self.echo(str(e), err=True)
|
|
self.echo(
|
|
'Retrying over TCP/IP', err=True)
|
|
|
|
# Else fall back to TCP/IP localhost
|
|
socket = ""
|
|
host = 'localhost'
|
|
port = 3306
|
|
_connect()
|
|
else:
|
|
raise e
|
|
else:
|
|
host = host or 'localhost'
|
|
port = port or 3306
|
|
|
|
# Bad ports give particularly daft error messages
|
|
try:
|
|
port = int(port)
|
|
except ValueError as e:
|
|
self.echo("Error: Invalid port number: '{0}'.".format(port),
|
|
err=True, fg='red')
|
|
exit(1)
|
|
|
|
_connect()
|
|
except Exception as e: # Connecting to a database could fail.
|
|
self.logger.debug('Database connection failed: %r.', e)
|
|
self.logger.error("traceback: %r", traceback.format_exc())
|
|
self.echo(str(e), err=True, fg='red')
|
|
exit(1)
|
|
|
|
def get_password_from_file(self, password_file):
|
|
password_from_file = None
|
|
if password_file:
|
|
if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
|
|
and os.access(password_file, os.R_OK):
|
|
with open(password_file) as fp:
|
|
password_from_file = fp.readline()
|
|
password_from_file = password_from_file.rstrip().lstrip()
|
|
|
|
return password_from_file
|
|
|
|
def handle_editor_command(self, text):
|
|
r"""Editor command is any query that is prefixed or suffixed by a '\e'.
|
|
The reason for a while loop is because a user might edit a query
|
|
multiple times. For eg:
|
|
|
|
"select * from \e"<enter> to edit it in vim, then come
|
|
back to the prompt with the edited query "select * from
|
|
blah where q = 'abc'\e" to edit it again.
|
|
:param text: Document
|
|
:return: Document
|
|
|
|
"""
|
|
|
|
while special.editor_command(text):
|
|
filename = special.get_filename(text)
|
|
query = (special.get_editor_query(text) or
|
|
self.get_last_query())
|
|
sql, message = special.open_external_editor(filename, sql=query)
|
|
if message:
|
|
# Something went wrong. Raise an exception and bail.
|
|
raise RuntimeError(message)
|
|
while True:
|
|
try:
|
|
text = self.prompt_app.prompt(default=sql)
|
|
break
|
|
except KeyboardInterrupt:
|
|
sql = ""
|
|
|
|
continue
|
|
return text
|
|
|
|
def handle_clip_command(self, text):
|
|
r"""A clip command is any query that is prefixed or suffixed by a
|
|
'\clip'.
|
|
|
|
:param text: Document
|
|
:return: Boolean
|
|
|
|
"""
|
|
|
|
if special.clip_command(text):
|
|
query = (special.get_clip_query(text) or
|
|
self.get_last_query())
|
|
message = special.copy_query_to_clipboard(sql=query)
|
|
if message:
|
|
raise RuntimeError(message)
|
|
return True
|
|
return False
|
|
|
|
def run_cli(self):
|
|
iterations = 0
|
|
sqlexecute = self.sqlexecute
|
|
logger = self.logger
|
|
self.configure_pager()
|
|
|
|
if self.smart_completion:
|
|
self.refresh_completions()
|
|
|
|
history_file = os.path.expanduser(
|
|
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
|
|
if dir_path_exists(history_file):
|
|
history = FileHistory(history_file)
|
|
else:
|
|
history = None
|
|
self.echo(
|
|
'Error: Unable to open the history file "{}". '
|
|
'Your query history will not be saved.'.format(history_file),
|
|
err=True, fg='red')
|
|
|
|
key_bindings = mycli_bindings(self)
|
|
|
|
if not self.less_chatty:
|
|
print(sqlexecute.server_info)
|
|
print('mycli', __version__)
|
|
print(SUPPORT_INFO)
|
|
print('Thanks to the contributor -', thanks_picker())
|
|
|
|
def get_message():
|
|
prompt = self.get_prompt(self.prompt_format)
|
|
if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
|
|
prompt = self.get_prompt('\\d> ')
|
|
prompt = prompt.replace("\\x1b", "\x1b")
|
|
return ANSI(prompt)
|
|
|
|
def get_continuation(width, *_):
|
|
if self.multiline_continuation_char == '':
|
|
continuation = ''
|
|
elif self.multiline_continuation_char:
|
|
left_padding = width - len(self.multiline_continuation_char)
|
|
continuation = " " * \
|
|
max((left_padding - 1), 0) + \
|
|
self.multiline_continuation_char + " "
|
|
else:
|
|
continuation = " "
|
|
return [('class:continuation', continuation)]
|
|
|
|
def show_suggestion_tip():
|
|
return iterations < 2
|
|
|
|
def one_iteration(text=None):
|
|
if text is None:
|
|
try:
|
|
text = self.prompt_app.prompt()
|
|
except KeyboardInterrupt:
|
|
return
|
|
|
|
special.set_expanded_output(False)
|
|
|
|
try:
|
|
text = self.handle_editor_command(text)
|
|
except RuntimeError as e:
|
|
logger.error("sql: %r, error: %r", text, e)
|
|
logger.error("traceback: %r", traceback.format_exc())
|
|
self.echo(str(e), err=True, fg='red')
|
|
return
|
|
|
|
try:
|
|
if self.handle_clip_command(text):
|
|
return
|
|
except RuntimeError as e:
|
|
logger.error("sql: %r, error: %r", text, e)
|
|
logger.error("traceback: %r", traceback.format_exc())
|
|
self.echo(str(e), err=True, fg='red')
|
|
return
|
|
|
|
if not text.strip():
|
|
return
|
|
|
|
if self.destructive_warning:
|
|
destroy = confirm_destructive_query(text)
|
|
if destroy is None:
|
|
pass # Query was not destructive. Nothing to do here.
|
|
elif destroy is True:
|
|
self.echo('Your call!')
|
|
else:
|
|
self.echo('Wise choice!')
|
|
return
|
|
else:
|
|
destroy = True
|
|
|
|
# Keep track of whether or not the query is mutating. In case
|
|
# of a multi-statement query, the overall query is considered
|
|
# mutating if any one of the component statements is mutating
|
|
mutating = False
|
|
|
|
try:
|
|
logger.debug('sql: %r', text)
|
|
|
|
special.write_tee(self.get_prompt(self.prompt_format) + text)
|
|
if self.logfile:
|
|
self.logfile.write('\n# %s\n' % datetime.now())
|
|
self.logfile.write(text)
|
|
self.logfile.write('\n')
|
|
|
|
successful = False
|
|
start = time()
|
|
res = sqlexecute.run(text)
|
|
self.formatter.query = text
|
|
successful = True
|
|
result_count = 0
|
|
for title, cur, headers, status in res:
|
|
logger.debug("headers: %r", headers)
|
|
logger.debug("rows: %r", cur)
|
|
logger.debug("status: %r", status)
|
|
threshold = 1000
|
|
if (is_select(status) and
|
|
cur and cur.rowcount > threshold):
|
|
self.echo('The result set has more than {} rows.'.format(
|
|
threshold), fg='red')
|
|
if not confirm('Do you want to continue?'):
|
|
self.echo("Aborted!", err=True, fg='red')
|
|
break
|
|
|
|
if self.auto_vertical_output:
|
|
max_width = self.prompt_app.output.get_size().columns
|
|
else:
|
|
max_width = None
|
|
|
|
formatted = self.format_output(
|
|
title, cur, headers, special.is_expanded_output(),
|
|
max_width)
|
|
|
|
t = time() - start
|
|
try:
|
|
if result_count > 0:
|
|
self.echo('')
|
|
try:
|
|
self.output(formatted, status)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
if special.is_timing_enabled():
|
|
self.echo('Time: %0.03fs' % t)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
start = time()
|
|
result_count += 1
|
|
mutating = mutating or destroy or is_mutating(status)
|
|
special.unset_once_if_written()
|
|
special.unset_pipe_once_if_written()
|
|
except EOFError as e:
|
|
raise e
|
|
except KeyboardInterrupt:
|
|
# get last connection id
|
|
connection_id_to_kill = sqlexecute.connection_id
|
|
logger.debug("connection id to kill: %r", connection_id_to_kill)
|
|
# Restart connection to the database
|
|
sqlexecute.connect()
|
|
try:
|
|
for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill):
|
|
status_str = str(status).lower()
|
|
if status_str.find('ok') > -1:
|
|
logger.debug("cancelled query, connection id: %r, sql: %r",
|
|
connection_id_to_kill, text)
|
|
self.echo("cancelled query", err=True, fg='red')
|
|
except Exception as e:
|
|
self.echo('Encountered error while cancelling query: {}'.format(e),
|
|
err=True, fg='red')
|
|
except NotImplementedError:
|
|
self.echo('Not Yet Implemented.', fg="yellow")
|
|
except OperationalError as e:
|
|
logger.debug("Exception: %r", e)
|
|
if (e.args[0] in (2003, 2006, 2013)):
|
|
logger.debug('Attempting to reconnect.')
|
|
self.echo('Reconnecting...', fg='yellow')
|
|
try:
|
|
sqlexecute.connect()
|
|
logger.debug('Reconnected successfully.')
|
|
one_iteration(text)
|
|
return # OK to just return, cuz the recursion call runs to the end.
|
|
except OperationalError as e:
|
|
logger.debug('Reconnect failed. e: %r', e)
|
|
self.echo(str(e), err=True, fg='red')
|
|
# If reconnection failed, don't proceed further.
|
|
return
|
|
else:
|
|
logger.error("sql: %r, error: %r", text, e)
|
|
logger.error("traceback: %r", traceback.format_exc())
|
|
self.echo(str(e), err=True, fg='red')
|
|
except Exception as e:
|
|
logger.error("sql: %r, error: %r", text, e)
|
|
logger.error("traceback: %r", traceback.format_exc())
|
|
self.echo(str(e), err=True, fg='red')
|
|
else:
|
|
if is_dropping_database(text, self.sqlexecute.dbname):
|
|
self.sqlexecute.dbname = None
|
|
self.sqlexecute.connect()
|
|
|
|
# Refresh the table names and column names if necessary.
|
|
if need_completion_refresh(text):
|
|
self.refresh_completions(
|
|
reset=need_completion_reset(text))
|
|
finally:
|
|
if self.logfile is False:
|
|
self.echo("Warning: This query was not logged.",
|
|
err=True, fg='red')
|
|
query = Query(text, successful, mutating)
|
|
self.query_history.append(query)
|
|
|
|
get_toolbar_tokens = create_toolbar_tokens_func(
|
|
self, show_suggestion_tip)
|
|
if self.wider_completion_menu:
|
|
complete_style = CompleteStyle.MULTI_COLUMN
|
|
else:
|
|
complete_style = CompleteStyle.COLUMN
|
|
|
|
with self._completer_lock:
|
|
|
|
if self.key_bindings == 'vi':
|
|
editing_mode = EditingMode.VI
|
|
else:
|
|
editing_mode = EditingMode.EMACS
|
|
|
|
self.prompt_app = PromptSession(
|
|
lexer=PygmentsLexer(MyCliLexer),
|
|
reserve_space_for_menu=self.get_reserved_space(),
|
|
message=get_message,
|
|
prompt_continuation=get_continuation,
|
|
bottom_toolbar=get_toolbar_tokens,
|
|
complete_style=complete_style,
|
|
input_processors=[ConditionalProcessor(
|
|
processor=HighlightMatchingBracketProcessor(
|
|
chars='[](){}'),
|
|
filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()
|
|
)],
|
|
tempfile_suffix='.sql',
|
|
completer=DynamicCompleter(lambda: self.completer),
|
|
history=history,
|
|
auto_suggest=AutoSuggestFromHistory(),
|
|
complete_while_typing=True,
|
|
multiline=cli_is_multiline(self),
|
|
style=style_factory(self.syntax_style, self.cli_style),
|
|
include_default_pygments_style=False,
|
|
key_bindings=key_bindings,
|
|
enable_open_in_editor=True,
|
|
enable_system_prompt=True,
|
|
enable_suspend=True,
|
|
editing_mode=editing_mode,
|
|
search_ignore_case=True
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
one_iteration()
|
|
iterations += 1
|
|
except EOFError:
|
|
special.close_tee()
|
|
if not self.less_chatty:
|
|
self.echo('Goodbye!')
|
|
|
|
def log_output(self, output):
|
|
"""Log the output in the audit log, if it's enabled."""
|
|
if self.logfile:
|
|
click.echo(output, file=self.logfile)
|
|
|
|
def echo(self, s, **kwargs):
|
|
"""Print a message to stdout.
|
|
|
|
The message will be logged in the audit log, if enabled.
|
|
|
|
All keyword arguments are passed to click.echo().
|
|
|
|
"""
|
|
self.log_output(s)
|
|
click.secho(s, **kwargs)
|
|
|
|
def get_output_margin(self, status=None):
|
|
"""Get the output margin (number of rows for the prompt, footer and
|
|
timing message."""
|
|
margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1
|
|
if special.is_timing_enabled():
|
|
margin += 1
|
|
if status:
|
|
margin += 1 + status.count('\n')
|
|
|
|
return margin
|
|
|
|
|
|
def output(self, output, status=None):
|
|
"""Output text to stdout or a pager command.
|
|
|
|
The status text is not outputted to pager or files.
|
|
|
|
The message will be logged in the audit log, if enabled. The
|
|
message will be written to the tee file, if enabled. The
|
|
message will be written to the output file, if enabled.
|
|
|
|
"""
|
|
if output:
|
|
size = self.prompt_app.output.get_size()
|
|
|
|
margin = self.get_output_margin(status)
|
|
|
|
fits = True
|
|
buf = []
|
|
output_via_pager = self.explicit_pager and special.is_pager_enabled()
|
|
for i, line in enumerate(output, 1):
|
|
self.log_output(line)
|
|
special.write_tee(line)
|
|
special.write_once(line)
|
|
special.write_pipe_once(line)
|
|
|
|
if fits or output_via_pager:
|
|
# buffering
|
|
buf.append(line)
|
|
if len(line) > size.columns or i > (size.rows - margin):
|
|
fits = False
|
|
if not self.explicit_pager and special.is_pager_enabled():
|
|
# doesn't fit, use pager
|
|
output_via_pager = True
|
|
|
|
if not output_via_pager:
|
|
# doesn't fit, flush buffer
|
|
for buf_line in buf:
|
|
click.secho(buf_line)
|
|
buf = []
|
|
else:
|
|
click.secho(line)
|
|
|
|
if buf:
|
|
if output_via_pager:
|
|
def newlinewrapper(text):
|
|
for line in text:
|
|
yield line + "\n"
|
|
click.echo_via_pager(newlinewrapper(buf))
|
|
else:
|
|
for line in buf:
|
|
click.secho(line)
|
|
|
|
if status:
|
|
self.log_output(status)
|
|
click.secho(status)
|
|
|
|
def configure_pager(self):
|
|
# Provide sane defaults for less if they are empty.
|
|
if not os.environ.get('LESS'):
|
|
os.environ['LESS'] = '-RXF'
|
|
|
|
cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager'])
|
|
if cnf['pager']:
|
|
special.set_pager(cnf['pager'])
|
|
self.explicit_pager = True
|
|
else:
|
|
self.explicit_pager = False
|
|
|
|
if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'):
|
|
special.disable_pager()
|
|
|
|
def refresh_completions(self, reset=False):
|
|
if reset:
|
|
with self._completer_lock:
|
|
self.completer.reset_completions()
|
|
self.completion_refresher.refresh(
|
|
self.sqlexecute, self._on_completions_refreshed,
|
|
{'smart_completion': self.smart_completion,
|
|
'supported_formats': self.formatter.supported_formats,
|
|
'keyword_casing': self.completer.keyword_casing})
|
|
|
|
return [(None, None, None,
|
|
'Auto-completion refresh started in the background.')]
|
|
|
|
def _on_completions_refreshed(self, new_completer):
|
|
"""Swap the completer object in cli with the newly created completer.
|
|
"""
|
|
with self._completer_lock:
|
|
self.completer = new_completer
|
|
|
|
if self.prompt_app:
|
|
# After refreshing, redraw the CLI to clear the statusbar
|
|
# "Refreshing completions..." indicator
|
|
self.prompt_app.app.invalidate()
|
|
|
|
def get_completions(self, text, cursor_positition):
|
|
with self._completer_lock:
|
|
return self.completer.get_completions(
|
|
Document(text=text, cursor_position=cursor_positition), None)
|
|
|
|
def get_prompt(self, string):
|
|
sqlexecute = self.sqlexecute
|
|
host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host
|
|
now = datetime.now()
|
|
string = string.replace('\\u', sqlexecute.user or '(none)')
|
|
string = string.replace('\\h', host or '(none)')
|
|
string = string.replace('\\d', sqlexecute.dbname or '(none)')
|
|
string = string.replace('\\t', sqlexecute.server_info.species.name)
|
|
string = string.replace('\\n', "\n")
|
|
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
|
|
string = string.replace('\\m', now.strftime('%M'))
|
|
string = string.replace('\\P', now.strftime('%p'))
|
|
string = string.replace('\\R', now.strftime('%H'))
|
|
string = string.replace('\\r', now.strftime('%I'))
|
|
string = string.replace('\\s', now.strftime('%S'))
|
|
string = string.replace('\\p', str(sqlexecute.port))
|
|
string = string.replace('\\A', self.dsn_alias or '(none)')
|
|
string = string.replace('\\_', ' ')
|
|
return string
|
|
|
|
def run_query(self, query, new_line=True):
|
|
"""Runs *query*."""
|
|
results = self.sqlexecute.run(query)
|
|
for result in results:
|
|
title, cur, headers, status = result
|
|
self.formatter.query = query
|
|
output = self.format_output(title, cur, headers)
|
|
for line in output:
|
|
click.echo(line, nl=new_line)
|
|
|
|
def format_output(self, title, cur, headers, expanded=False,
|
|
max_width=None):
|
|
expanded = expanded or self.formatter.format_name == 'vertical'
|
|
output = []
|
|
|
|
output_kwargs = {
|
|
'dialect': 'unix',
|
|
'disable_numparse': True,
|
|
'preserve_whitespace': True,
|
|
'style': self.output_style
|
|
}
|
|
|
|
if not self.formatter.format_name in sql_format.supported_formats:
|
|
output_kwargs["preprocessors"] = (preprocessors.align_decimals, )
|
|
|
|
if title: # Only print the title if it's not None.
|
|
output = itertools.chain(output, [title])
|
|
|
|
if cur:
|
|
column_types = None
|
|
if hasattr(cur, 'description'):
|
|
def get_col_type(col):
|
|
col_type = FIELD_TYPES.get(col[1], str)
|
|
return col_type if type(col_type) is type else str
|
|
column_types = [get_col_type(col) for col in cur.description]
|
|
|
|
if max_width is not None:
|
|
cur = list(cur)
|
|
|
|
formatted = self.formatter.format_output(
|
|
cur, headers, format_name='vertical' if expanded else None,
|
|
column_types=column_types,
|
|
**output_kwargs)
|
|
|
|
if isinstance(formatted, str):
|
|
formatted = formatted.splitlines()
|
|
formatted = iter(formatted)
|
|
|
|
if (not expanded and max_width and headers and cur):
|
|
first_line = next(formatted)
|
|
if len(strip_ansi(first_line)) > max_width:
|
|
formatted = self.formatter.format_output(
|
|
cur, headers, format_name='vertical', column_types=column_types, **output_kwargs)
|
|
if isinstance(formatted, str):
|
|
formatted = iter(formatted.splitlines())
|
|
else:
|
|
formatted = itertools.chain([first_line], formatted)
|
|
|
|
output = itertools.chain(output, formatted)
|
|
|
|
|
|
return output
|
|
|
|
def get_reserved_space(self):
|
|
"""Get the number of lines to reserve for the completion menu."""
|
|
reserved_space_ratio = .45
|
|
max_reserved_space = 8
|
|
_, height = click.get_terminal_size()
|
|
return min(int(round(height * reserved_space_ratio)), max_reserved_space)
|
|
|
|
def get_last_query(self):
|
|
"""Get the last query executed or None."""
|
|
return self.query_history[-1][0] if self.query_history else None
|
|
|
|
|
|
@click.command()
|
|
@click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.')
|
|
@click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors '
|
|
'$MYSQL_TCP_PORT.')
|
|
@click.option('-u', '--user', help='User name to connect to the database.')
|
|
@click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.')
|
|
@click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str,
|
|
help='Password to connect to the database.')
|
|
@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str,
|
|
help='Password to connect to the database.')
|
|
@click.option('--ssh-user', help='User name to connect to ssh server.')
|
|
@click.option('--ssh-host', help='Host name to connect to ssh server.')
|
|
@click.option('--ssh-port', default=22, help='Port to connect to ssh server.')
|
|
@click.option('--ssh-password', help='Password to connect to ssh server.')
|
|
@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.')
|
|
@click.option('--ssh-config-path', help='Path to ssh configuration.',
|
|
default=os.path.expanduser('~') + '/.ssh/config')
|
|
@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.')
|
|
@click.option('--ssl-ca', help='CA file in PEM format.',
|
|
type=click.Path(exists=True))
|
|
@click.option('--ssl-capath', help='CA directory.')
|
|
@click.option('--ssl-cert', help='X509 cert in PEM format.',
|
|
type=click.Path(exists=True))
|
|
@click.option('--ssl-key', help='X509 key in PEM format.',
|
|
type=click.Path(exists=True))
|
|
@click.option('--ssl-cipher', help='SSL cipher to use.')
|
|
@click.option('--ssl-verify-server-cert', is_flag=True,
|
|
help=('Verify server\'s "Common Name" in its cert against '
|
|
'hostname used when connecting. This option is disabled '
|
|
'by default.'))
|
|
# as of 2016-02-15 revocation list is not supported by underling PyMySQL
|
|
# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client)
|
|
@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.')
|
|
@click.option('-v', '--verbose', is_flag=True, help='Verbose output.')
|
|
@click.option('-D', '--database', 'dbname', help='Database to use.')
|
|
@click.option('-d', '--dsn', default='', envvar='DSN',
|
|
help='Use DSN configured into the [alias_dsn] section of myclirc file.')
|
|
@click.option('--list-dsn', 'list_dsn', is_flag=True,
|
|
help='list of DSN configured into the [alias_dsn] section of myclirc file.')
|
|
@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True,
|
|
help='list ssh configurations in the ssh config (requires paramiko).')
|
|
@click.option('-R', '--prompt', 'prompt',
|
|
help='Prompt format (Default: "{0}").'.format(
|
|
MyCli.default_prompt))
|
|
@click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'),
|
|
help='Log every query and its results to a file.')
|
|
@click.option('--defaults-group-suffix', type=str,
|
|
help='Read MySQL config groups with the specified suffix.')
|
|
@click.option('--defaults-file', type=click.Path(),
|
|
help='Only read MySQL options from the given file.')
|
|
@click.option('--myclirc', type=click.Path(), default="~/.myclirc",
|
|
help='Location of myclirc file.')
|
|
@click.option('--auto-vertical-output', is_flag=True,
|
|
help='Automatically switch to vertical output mode if the result is wider than the terminal width.')
|
|
@click.option('-t', '--table', is_flag=True,
|
|
help='Display batch output in table format.')
|
|
@click.option('--csv', is_flag=True,
|
|
help='Display batch output in CSV format.')
|
|
@click.option('--warn/--no-warn', default=None,
|
|
help='Warn before running a destructive query.')
|
|
@click.option('--local-infile', type=bool,
|
|
help='Enable/disable LOAD DATA LOCAL INFILE.')
|
|
@click.option('-g', '--login-path', type=str,
|
|
help='Read this path from the login file.')
|
|
@click.option('-e', '--execute', type=str,
|
|
help='Execute command and quit.')
|
|
@click.option('--init-command', type=str,
|
|
help='SQL statement to execute after connecting.')
|
|
@click.option('--charset', type=str,
|
|
help='Character set for MySQL session.')
|
|
@click.option('--password-file', type=click.Path(),
|
|
help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
|
|
@click.argument('database', default='', nargs=1)
|
|
def cli(database, user, host, port, socket, password, dbname,
|
|
version, verbose, prompt, logfile, defaults_group_suffix,
|
|
defaults_file, login_path, auto_vertical_output, local_infile,
|
|
ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher,
|
|
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
|
|
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
|
|
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
|
|
init_command, charset, password_file):
|
|
"""A MySQL terminal client with auto-completion and syntax highlighting.
|
|
|
|
\b
|
|
Examples:
|
|
- mycli my_database
|
|
- mycli -u my_user -h my_host.com my_database
|
|
- mycli mysql://my_user@my_host.com:3306/my_database
|
|
|
|
"""
|
|
|
|
if version:
|
|
print('Version:', __version__)
|
|
sys.exit(0)
|
|
|
|
mycli = MyCli(prompt=prompt, logfile=logfile,
|
|
defaults_suffix=defaults_group_suffix,
|
|
defaults_file=defaults_file, login_path=login_path,
|
|
auto_vertical_output=auto_vertical_output, warn=warn,
|
|
myclirc=myclirc)
|
|
if list_dsn:
|
|
try:
|
|
alias_dsn = mycli.config['alias_dsn']
|
|
except KeyError as err:
|
|
click.secho('Invalid DSNs found in the config file. '\
|
|
'Please check the "[alias_dsn]" section in myclirc.',
|
|
err=True, fg='red')
|
|
exit(1)
|
|
except Exception as e:
|
|
click.secho(str(e), err=True, fg='red')
|
|
exit(1)
|
|
for alias, value in alias_dsn.items():
|
|
if verbose:
|
|
click.secho("{} : {}".format(alias, value))
|
|
else:
|
|
click.secho(alias)
|
|
sys.exit(0)
|
|
if list_ssh_config:
|
|
ssh_config = read_ssh_config(ssh_config_path)
|
|
for host in ssh_config.get_hostnames():
|
|
if verbose:
|
|
host_config = ssh_config.lookup(host)
|
|
click.secho("{} : {}".format(
|
|
host, host_config.get('hostname')))
|
|
else:
|
|
click.secho(host)
|
|
sys.exit(0)
|
|
# Choose which ever one has a valid value.
|
|
database = dbname or database
|
|
|
|
ssl = {
|
|
'ca': ssl_ca and os.path.expanduser(ssl_ca),
|
|
'cert': ssl_cert and os.path.expanduser(ssl_cert),
|
|
'key': ssl_key and os.path.expanduser(ssl_key),
|
|
'capath': ssl_capath,
|
|
'cipher': ssl_cipher,
|
|
'check_hostname': ssl_verify_server_cert,
|
|
}
|
|
|
|
# remove empty ssl options
|
|
ssl = {k: v for k, v in ssl.items() if v is not None}
|
|
|
|
dsn_uri = None
|
|
|
|
# Treat the database argument as a DSN alias if we're missing
|
|
# other connection information.
|
|
if (mycli.config['alias_dsn'] and database and '://' not in database
|
|
and not any([user, password, host, port, login_path])):
|
|
dsn, database = database, ''
|
|
|
|
if database and '://' in database:
|
|
dsn_uri, database = database, ''
|
|
|
|
if dsn:
|
|
try:
|
|
dsn_uri = mycli.config['alias_dsn'][dsn]
|
|
except KeyError:
|
|
click.secho('Could not find the specified DSN in the config file. '
|
|
'Please check the "[alias_dsn]" section in your '
|
|
'myclirc.', err=True, fg='red')
|
|
exit(1)
|
|
else:
|
|
mycli.dsn_alias = dsn
|
|
|
|
if dsn_uri:
|
|
uri = urlparse(dsn_uri)
|
|
if not database:
|
|
database = uri.path[1:] # ignore the leading fwd slash
|
|
if not user:
|
|
user = unquote(uri.username)
|
|
if not password and uri.password is not None:
|
|
password = unquote(uri.password)
|
|
if not host:
|
|
host = uri.hostname
|
|
if not port:
|
|
port = uri.port
|
|
|
|
if ssh_config_host:
|
|
ssh_config = read_ssh_config(
|
|
ssh_config_path
|
|
).lookup(ssh_config_host)
|
|
ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
|
|
ssh_user = ssh_user if ssh_user else ssh_config.get('user')
|
|
if ssh_config.get('port') and ssh_port == 22:
|
|
# port has a default value, overwrite it if it's in the config
|
|
ssh_port = int(ssh_config.get('port'))
|
|
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
|
|
'identityfile', [None])[0]
|
|
|
|
ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
|
|
|
|
mycli.connect(
|
|
database=database,
|
|
user=user,
|
|
passwd=password,
|
|
host=host,
|
|
port=port,
|
|
socket=socket,
|
|
local_infile=local_infile,
|
|
ssl=ssl,
|
|
ssh_user=ssh_user,
|
|
ssh_host=ssh_host,
|
|
ssh_port=ssh_port,
|
|
ssh_password=ssh_password,
|
|
ssh_key_filename=ssh_key_filename,
|
|
init_command=init_command,
|
|
charset=charset,
|
|
password_file=password_file
|
|
)
|
|
|
|
mycli.logger.debug('Launch Params: \n'
|
|
'\tdatabase: %r'
|
|
'\tuser: %r'
|
|
'\thost: %r'
|
|
'\tport: %r', database, user, host, port)
|
|
|
|
# --execute argument
|
|
if execute:
|
|
try:
|
|
if csv:
|
|
mycli.formatter.format_name = 'csv'
|
|
elif not table:
|
|
mycli.formatter.format_name = 'tsv'
|
|
|
|
mycli.run_query(execute)
|
|
exit(0)
|
|
except Exception as e:
|
|
click.secho(str(e), err=True, fg='red')
|
|
exit(1)
|
|
|
|
if sys.stdin.isatty():
|
|
mycli.run_cli()
|
|
else:
|
|
stdin = click.get_text_stream('stdin')
|
|
try:
|
|
stdin_text = stdin.read()
|
|
except MemoryError:
|
|
click.secho('Failed! Ran out of memory.', err=True, fg='red')
|
|
click.secho('You might want to try the official mysql client.', err=True, fg='red')
|
|
click.secho('Sorry... :(', err=True, fg='red')
|
|
exit(1)
|
|
|
|
if mycli.destructive_warning and is_destructive(stdin_text):
|
|
try:
|
|
sys.stdin = open('/dev/tty')
|
|
warn_confirmed = confirm_destructive_query(stdin_text)
|
|
except (IOError, OSError):
|
|
mycli.logger.warning('Unable to open TTY as stdin.')
|
|
if not warn_confirmed:
|
|
exit(0)
|
|
|
|
try:
|
|
new_line = True
|
|
|
|
if csv:
|
|
mycli.formatter.format_name = 'csv'
|
|
elif not table:
|
|
mycli.formatter.format_name = 'tsv'
|
|
|
|
mycli.run_query(stdin_text, new_line=new_line)
|
|
exit(0)
|
|
except Exception as e:
|
|
click.secho(str(e), err=True, fg='red')
|
|
exit(1)
|
|
|
|
|
|
def need_completion_refresh(queries):
|
|
"""Determines if the completion needs a refresh by checking if the sql
|
|
statement is an alter, create, drop or change db."""
|
|
for query in sqlparse.split(queries):
|
|
try:
|
|
first_token = query.split()[0]
|
|
if first_token.lower() in ('alter', 'create', 'use', '\\r',
|
|
'\\u', 'connect', 'drop', 'rename'):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def need_completion_reset(queries):
|
|
"""Determines if the statement is a database switch such as 'use' or '\\u'.
|
|
When a database is changed the existing completions must be reset before we
|
|
start the completion refresh for the new database.
|
|
"""
|
|
for query in sqlparse.split(queries):
|
|
try:
|
|
first_token = query.split()[0]
|
|
if first_token.lower() in ('use', '\\u'):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def is_mutating(status):
|
|
"""Determines if the statement is mutating based on the status."""
|
|
if not status:
|
|
return False
|
|
|
|
mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop',
|
|
'replace', 'truncate', 'load', 'rename'])
|
|
return status.split(None, 1)[0].lower() in mutating
|
|
|
|
|
|
def is_select(status):
|
|
"""Returns true if the first word in status is 'select'."""
|
|
if not status:
|
|
return False
|
|
return status.split(None, 1)[0].lower() == 'select'
|
|
|
|
|
|
def thanks_picker():
|
|
import mycli
|
|
lines = (
|
|
resources.read_text(mycli, 'AUTHORS') +
|
|
resources.read_text(mycli, 'SPONSORS')
|
|
).split('\n')
|
|
|
|
contents = []
|
|
for line in lines:
|
|
m = re.match(r'^ *\* (.*)', line)
|
|
if m:
|
|
contents.append(m.group(1))
|
|
return choice(contents)
|
|
|
|
|
|
@prompt_register('edit-and-execute-command')
|
|
def edit_and_execute(event):
|
|
"""Different from the prompt-toolkit default, we want to have a choice not
|
|
to execute a query after editing, hence validate_and_handle=False."""
|
|
buff = event.current_buffer
|
|
buff.open_in_editor(validate_and_handle=False)
|
|
|
|
|
|
def read_ssh_config(ssh_config_path):
|
|
ssh_config = paramiko.config.SSHConfig()
|
|
try:
|
|
with open(ssh_config_path) as f:
|
|
ssh_config.parse(f)
|
|
except FileNotFoundError as e:
|
|
click.secho(str(e), err=True, fg='red')
|
|
sys.exit(1)
|
|
# Paramiko prior to version 2.7 raises Exception on parse errors.
|
|
# In 2.7 it has become paramiko.ssh_exception.SSHException,
|
|
# but let's catch everything for compatibility
|
|
except Exception as err:
|
|
click.secho(
|
|
f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ',
|
|
err=True, fg='red'
|
|
)
|
|
sys.exit(1)
|
|
else:
|
|
return ssh_config
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|