Merging upstream version 1.24.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
570aa52ec2
commit
06dd2aeb28
26 changed files with 565 additions and 169 deletions
|
@ -75,6 +75,8 @@ Contributors:
|
|||
* Zach DeCook
|
||||
* kevinhwang91
|
||||
* KITAGAWA Yasutaka
|
||||
* Nicolas Palumbo
|
||||
* Andy Teijelo Pérez
|
||||
* bitkeen
|
||||
* Morgan Mitchell
|
||||
* Massimiliano Torromeo
|
||||
|
@ -82,6 +84,7 @@ Contributors:
|
|||
* xeron
|
||||
* 0xflotus
|
||||
* Seamile
|
||||
* Jerome Provensal
|
||||
|
||||
Creator:
|
||||
--------
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '1.23.2'
|
||||
__version__ = '1.24.1'
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from prompt_toolkit.enums import DEFAULT_BUFFER
|
||||
from prompt_toolkit.filters import Condition
|
||||
from prompt_toolkit.application import get_app
|
||||
from .packages.parseutils import is_open_quote
|
||||
from .packages import special
|
||||
|
||||
|
||||
|
|
104
mycli/config.py
104
mycli/config.py
|
@ -1,5 +1,3 @@
|
|||
import io
|
||||
import shutil
|
||||
from copy import copy
|
||||
from io import BytesIO, TextIOWrapper
|
||||
import logging
|
||||
|
@ -7,11 +5,16 @@ import os
|
|||
from os.path import exists
|
||||
import struct
|
||||
import sys
|
||||
from typing import Union
|
||||
from typing import Union, IO
|
||||
|
||||
from configobj import ConfigObj, ConfigObjError
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
import pyaes
|
||||
|
||||
try:
|
||||
import importlib.resources as resources
|
||||
except ImportError:
|
||||
# Python < 3.7
|
||||
import importlib_resources as resources
|
||||
|
||||
try:
|
||||
basestring
|
||||
|
@ -49,9 +52,9 @@ def read_config_file(f, list_values=True):
|
|||
config = ConfigObj(f, interpolation=False, encoding='utf8',
|
||||
list_values=list_values)
|
||||
except ConfigObjError as e:
|
||||
log(logger, logging.ERROR, "Unable to parse line {0} of config file "
|
||||
log(logger, logging.WARNING, "Unable to parse line {0} of config file "
|
||||
"'{1}'.".format(e.line_number, f))
|
||||
log(logger, logging.ERROR, "Using successfully parsed config values.")
|
||||
log(logger, logging.WARNING, "Using successfully parsed config values.")
|
||||
return e.config
|
||||
except (IOError, OSError) as e:
|
||||
log(logger, logging.WARNING, "You don't have permission to read "
|
||||
|
@ -61,7 +64,7 @@ def read_config_file(f, list_values=True):
|
|||
return config
|
||||
|
||||
|
||||
def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
|
||||
def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
|
||||
"""Get a list of configuration files that are included into config_path
|
||||
with !includedir directive.
|
||||
|
||||
|
@ -95,7 +98,7 @@ def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
|
|||
def read_config_files(files, list_values=True):
|
||||
"""Read and merge a list of config files."""
|
||||
|
||||
config = ConfigObj(list_values=list_values)
|
||||
config = create_default_config(list_values=list_values)
|
||||
_files = copy(files)
|
||||
while _files:
|
||||
_file = _files.pop(0)
|
||||
|
@ -112,12 +115,21 @@ def read_config_files(files, list_values=True):
|
|||
return config
|
||||
|
||||
|
||||
def write_default_config(source, destination, overwrite=False):
|
||||
def create_default_config(list_values=True):
|
||||
import mycli
|
||||
default_config_file = resources.open_text(mycli, 'myclirc')
|
||||
return read_config_file(default_config_file, list_values=list_values)
|
||||
|
||||
|
||||
def write_default_config(destination, overwrite=False):
|
||||
import mycli
|
||||
default_config = resources.read_text(mycli, 'myclirc')
|
||||
destination = os.path.expanduser(destination)
|
||||
if not overwrite and exists(destination):
|
||||
return
|
||||
|
||||
shutil.copyfile(source, destination)
|
||||
with open(destination, 'w') as f:
|
||||
f.write(default_config)
|
||||
|
||||
|
||||
def get_mylogin_cnf_path():
|
||||
|
@ -160,6 +172,58 @@ def open_mylogin_cnf(name):
|
|||
return TextIOWrapper(plaintext)
|
||||
|
||||
|
||||
# TODO reuse code between encryption an decryption
|
||||
def encrypt_mylogin_cnf(plaintext: IO[str]):
|
||||
"""Encryption of .mylogin.cnf file, analogous to calling
|
||||
mysql_config_editor.
|
||||
|
||||
Code is based on the python implementation by Kristian Koehntopp
|
||||
https://github.com/isotopp/mysql-config-coder
|
||||
|
||||
"""
|
||||
def realkey(key):
|
||||
"""Create the AES key from the login key."""
|
||||
rkey = bytearray(16)
|
||||
for i in range(len(key)):
|
||||
rkey[i % 16] ^= key[i]
|
||||
return bytes(rkey)
|
||||
|
||||
def encode_line(plaintext, real_key, buf_len):
|
||||
aes = pyaes.AESModeOfOperationECB(real_key)
|
||||
text_len = len(plaintext)
|
||||
pad_len = buf_len - text_len
|
||||
pad_chr = bytes(chr(pad_len), "utf8")
|
||||
plaintext = plaintext.encode() + pad_chr * pad_len
|
||||
encrypted_text = b''.join(
|
||||
[aes.encrypt(plaintext[i: i + 16])
|
||||
for i in range(0, len(plaintext), 16)]
|
||||
)
|
||||
return encrypted_text
|
||||
|
||||
LOGIN_KEY_LENGTH = 20
|
||||
key = os.urandom(LOGIN_KEY_LENGTH)
|
||||
real_key = realkey(key)
|
||||
|
||||
outfile = BytesIO()
|
||||
|
||||
outfile.write(struct.pack("i", 0))
|
||||
outfile.write(key)
|
||||
|
||||
while True:
|
||||
line = plaintext.readline()
|
||||
if not line:
|
||||
break
|
||||
real_len = len(line)
|
||||
pad_len = (int(real_len / 16) + 1) * 16
|
||||
|
||||
outfile.write(struct.pack("i", pad_len))
|
||||
x = encode_line(line, real_key, pad_len)
|
||||
outfile.write(x)
|
||||
|
||||
outfile.seek(0)
|
||||
return outfile
|
||||
|
||||
|
||||
def read_and_decrypt_mylogin_cnf(f):
|
||||
"""Read and decrypt the contents of .mylogin.cnf.
|
||||
|
||||
|
@ -201,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f):
|
|||
return None
|
||||
rkey = struct.pack('16B', *rkey)
|
||||
|
||||
# Create a decryptor object using the key.
|
||||
decryptor = _get_decryptor(rkey)
|
||||
|
||||
# Create a bytes buffer to hold the plaintext.
|
||||
plaintext = BytesIO()
|
||||
aes = pyaes.AESModeOfOperationECB(rkey)
|
||||
|
||||
while True:
|
||||
# Read the length of the ciphertext.
|
||||
|
@ -216,7 +278,10 @@ def read_and_decrypt_mylogin_cnf(f):
|
|||
|
||||
# Read cipher_len bytes from the file and decrypt.
|
||||
cipher = f.read(cipher_len)
|
||||
plain = _remove_pad(decryptor.update(cipher))
|
||||
plain = _remove_pad(
|
||||
b''.join([aes.decrypt(cipher[i: i + 16])
|
||||
for i in range(0, cipher_len, 16)])
|
||||
)
|
||||
if plain is False:
|
||||
continue
|
||||
plaintext.write(plain)
|
||||
|
@ -244,7 +309,7 @@ def str_to_bool(s):
|
|||
elif s.lower() in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError('not a recognized boolean value: %s'.format(s))
|
||||
raise ValueError('not a recognized boolean value: {0}'.format(s))
|
||||
|
||||
|
||||
def strip_matching_quotes(s):
|
||||
|
@ -260,15 +325,8 @@ def strip_matching_quotes(s):
|
|||
return s
|
||||
|
||||
|
||||
def _get_decryptor(key):
|
||||
"""Get the AES decryptor."""
|
||||
c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
|
||||
return c.decryptor()
|
||||
|
||||
|
||||
def _remove_pad(line):
|
||||
"""Remove the pad from the *line*."""
|
||||
pad_length = ord(line[-1:])
|
||||
try:
|
||||
# Determine pad length.
|
||||
pad_length = ord(line[-1:])
|
||||
|
|
|
@ -78,8 +78,12 @@ def mycli_bindings(mycli):
|
|||
|
||||
@kb.add('escape', 'enter')
|
||||
def _(event):
|
||||
"""Introduces a line break regardless of multi-line mode or not."""
|
||||
"""Introduces a line break in multi-line mode, or dispatches the
|
||||
command in single-line mode."""
|
||||
_logger.debug('Detected alt-enter key.')
|
||||
event.app.current_buffer.insert_text('\n')
|
||||
if mycli.multi_line:
|
||||
event.app.current_buffer.validate_and_handle()
|
||||
else:
|
||||
event.app.current_buffer.insert_text('\n')
|
||||
|
||||
return kb
|
||||
|
|
140
mycli/main.py
140
mycli/main.py
|
@ -1,9 +1,12 @@
|
|||
from collections import defaultdict
|
||||
from io import open
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import logging
|
||||
import threading
|
||||
import re
|
||||
import stat
|
||||
import fileinput
|
||||
from collections import namedtuple
|
||||
try:
|
||||
|
@ -13,7 +16,6 @@ except ImportError:
|
|||
from time import time
|
||||
from datetime import datetime
|
||||
from random import choice
|
||||
from io import open
|
||||
|
||||
from pymysql import OperationalError
|
||||
from cli_helpers.tabular_output import TabularOutputFormatter
|
||||
|
@ -43,7 +45,7 @@ from .packages.special.favoritequeries import FavoriteQueries
|
|||
from .sqlcompleter import SQLCompleter
|
||||
from .clitoolbar import create_toolbar_tokens_func
|
||||
from .clistyle import style_factory, style_factory_output
|
||||
from .sqlexecute import FIELD_TYPES, SQLExecute
|
||||
from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
|
||||
from .clibuffer import cli_is_multiline
|
||||
from .completion_refresher import CompletionRefresher
|
||||
from .config import (write_default_config, get_mylogin_cnf_path,
|
||||
|
@ -51,7 +53,7 @@ from .config import (write_default_config, get_mylogin_cnf_path,
|
|||
strip_matching_quotes)
|
||||
from .key_bindings import mycli_bindings
|
||||
from .lexer import MyCliLexer
|
||||
from .__init__ import __version__
|
||||
from . import __version__
|
||||
from .compat import WIN
|
||||
from .packages.filepaths import dir_path_exists, guess_socket_location
|
||||
|
||||
|
@ -66,6 +68,11 @@ except ImportError:
|
|||
from urllib.parse import urlparse
|
||||
from urllib.parse import unquote
|
||||
|
||||
try:
|
||||
import importlib.resources as resources
|
||||
except ImportError:
|
||||
# Python < 3.7
|
||||
import importlib_resources as resources
|
||||
|
||||
try:
|
||||
import paramiko
|
||||
|
@ -75,7 +82,10 @@ except ImportError:
|
|||
# Query tuples are used for maintaining history
|
||||
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
|
||||
|
||||
PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
SUPPORT_INFO = (
|
||||
'Home: http://mycli.net\n'
|
||||
'Bug tracker: https://github.com/dbcli/mycli/issues'
|
||||
)
|
||||
|
||||
|
||||
class MyCli(object):
|
||||
|
@ -89,7 +99,7 @@ class MyCli(object):
|
|||
'/etc/my.cnf',
|
||||
'/etc/mysql/my.cnf',
|
||||
'/usr/local/etc/my.cnf',
|
||||
'~/.my.cnf'
|
||||
os.path.expanduser('~/.my.cnf'),
|
||||
]
|
||||
|
||||
# check XDG_CONFIG_HOME exists and not an empty string
|
||||
|
@ -102,7 +112,6 @@ class MyCli(object):
|
|||
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
|
||||
]
|
||||
|
||||
default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
|
||||
pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
|
||||
|
||||
def __init__(self, sqlexecute=None, prompt=None,
|
||||
|
@ -122,7 +131,7 @@ class MyCli(object):
|
|||
self.cnf_files = [defaults_file]
|
||||
|
||||
# Load config.
|
||||
config_files = ([self.default_config_file] + self.system_config_files +
|
||||
config_files = (self.system_config_files +
|
||||
[myclirc] + [self.pwd_config_file])
|
||||
c = self.config = read_config_files(config_files)
|
||||
self.multi_line = c['main'].as_bool('multi_line')
|
||||
|
@ -154,7 +163,7 @@ class MyCli(object):
|
|||
|
||||
# Write user config if system config wasn't the last config loaded.
|
||||
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
|
||||
write_default_config(self.default_config_file, myclirc)
|
||||
write_default_config(myclirc)
|
||||
|
||||
# audit log
|
||||
if self.logfile is None and 'audit_log' in c['main']:
|
||||
|
@ -326,20 +335,33 @@ class MyCli(object):
|
|||
cnf = read_config_files(files, list_values=False)
|
||||
|
||||
sections = ['client', 'mysqld']
|
||||
key_transformations = {
|
||||
'mysqld': {
|
||||
'socket': 'default_socket',
|
||||
'port': 'default_port',
|
||||
},
|
||||
}
|
||||
|
||||
if self.login_path and self.login_path != 'client':
|
||||
sections.append(self.login_path)
|
||||
|
||||
if self.defaults_suffix:
|
||||
sections.extend([sect + self.defaults_suffix for sect in sections])
|
||||
|
||||
def get(key):
|
||||
result = None
|
||||
for sect in cnf:
|
||||
if sect in sections and key in cnf[sect]:
|
||||
result = strip_matching_quotes(cnf[sect][key])
|
||||
return result
|
||||
configuration = defaultdict(lambda: None)
|
||||
for key in keys:
|
||||
for section in cnf:
|
||||
if (
|
||||
section not in sections or
|
||||
key not in cnf[section]
|
||||
):
|
||||
continue
|
||||
new_key = key_transformations.get(section, {}).get(key) or key
|
||||
configuration[new_key] = strip_matching_quotes(
|
||||
cnf[section][key])
|
||||
|
||||
return configuration
|
||||
|
||||
return {x: get(x) for x in keys}
|
||||
|
||||
def merge_ssl_with_cnf(self, ssl, cnf):
|
||||
"""Merge SSL configuration dict with cnf dict"""
|
||||
|
@ -367,7 +389,7 @@ class MyCli(object):
|
|||
def connect(self, database='', user='', passwd='', host='', port='',
|
||||
socket='', charset='', local_infile='', ssl='',
|
||||
ssh_user='', ssh_host='', ssh_port='',
|
||||
ssh_password='', ssh_key_filename='', init_command=''):
|
||||
ssh_password='', ssh_key_filename='', init_command='', password_file=''):
|
||||
|
||||
cnf = {'database': None,
|
||||
'user': None,
|
||||
|
@ -375,6 +397,7 @@ class MyCli(object):
|
|||
'host': None,
|
||||
'port': None,
|
||||
'socket': None,
|
||||
'default_socket': None,
|
||||
'default-character-set': None,
|
||||
'local-infile': None,
|
||||
'loose-local-infile': None,
|
||||
|
@ -388,18 +411,23 @@ class MyCli(object):
|
|||
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
|
||||
|
||||
# Fall back to config values only if user did not specify a value.
|
||||
|
||||
database = database or cnf['database']
|
||||
# Socket interface not supported for SSH connections
|
||||
if port or (host and host != 'localhost') or (ssh_host and ssh_port):
|
||||
socket = ''
|
||||
else:
|
||||
socket = socket or cnf['socket'] or guess_socket_location()
|
||||
user = user or cnf['user'] or os.getenv('USER')
|
||||
host = host or cnf['host']
|
||||
port = int(port or cnf['port'] or 3306)
|
||||
port = port or cnf['port']
|
||||
ssl = ssl or {}
|
||||
|
||||
port = port and int(port)
|
||||
if not port:
|
||||
port = 3306
|
||||
if not host or host == 'localhost':
|
||||
socket = (
|
||||
cnf['socket'] or
|
||||
cnf['default_socket'] or
|
||||
guess_socket_location()
|
||||
)
|
||||
|
||||
|
||||
passwd = passwd if isinstance(passwd, str) else cnf['password']
|
||||
charset = charset or cnf['default-character-set'] or 'utf8'
|
||||
|
||||
|
@ -417,6 +445,10 @@ class MyCli(object):
|
|||
if not any(v for v in ssl.values()):
|
||||
ssl = None
|
||||
|
||||
# if the passwd is not specfied try to set it using the password_file option
|
||||
password_from_file = self.get_password_from_file(password_file)
|
||||
passwd = passwd or password_from_file
|
||||
|
||||
# Connect to the database.
|
||||
|
||||
def _connect():
|
||||
|
@ -427,9 +459,12 @@ class MyCli(object):
|
|||
ssh_password, ssh_key_filename, init_command
|
||||
)
|
||||
except OperationalError as e:
|
||||
if ('Access denied for user' in e.args[1]):
|
||||
new_passwd = click.prompt('Password', hide_input=True,
|
||||
show_default=False, type=str, err=True)
|
||||
if e.args[0] == ERROR_CODE_ACCESS_DENIED:
|
||||
if password_from_file:
|
||||
new_passwd = password_from_file
|
||||
else:
|
||||
new_passwd = click.prompt('Password', hide_input=True,
|
||||
show_default=False, type=str, err=True)
|
||||
self.sqlexecute = SQLExecute(
|
||||
database, user, new_passwd, host, port, socket,
|
||||
charset, local_infile, ssl, ssh_user, ssh_host,
|
||||
|
@ -484,6 +519,17 @@ class MyCli(object):
|
|||
self.echo(str(e), err=True, fg='red')
|
||||
exit(1)
|
||||
|
||||
def get_password_from_file(self, password_file):
|
||||
password_from_file = None
|
||||
if password_file:
|
||||
if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
|
||||
and os.access(password_file, os.R_OK):
|
||||
with open(password_file) as fp:
|
||||
password_from_file = fp.readline()
|
||||
password_from_file = password_from_file.rstrip().lstrip()
|
||||
|
||||
return password_from_file
|
||||
|
||||
def handle_editor_command(self, text):
|
||||
r"""Editor command is any query that is prefixed or suffixed by a '\e'.
|
||||
The reason for a while loop is because a user might edit a query
|
||||
|
@ -542,9 +588,6 @@ class MyCli(object):
|
|||
if self.smart_completion:
|
||||
self.refresh_completions()
|
||||
|
||||
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
|
||||
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
|
||||
|
||||
history_file = os.path.expanduser(
|
||||
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
|
||||
if dir_path_exists(history_file):
|
||||
|
@ -559,12 +602,10 @@ class MyCli(object):
|
|||
key_bindings = mycli_bindings(self)
|
||||
|
||||
if not self.less_chatty:
|
||||
print(' '.join(sqlexecute.server_type()))
|
||||
print(sqlexecute.server_info)
|
||||
print('mycli', __version__)
|
||||
print('Chat: https://gitter.im/dbcli/mycli')
|
||||
print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
|
||||
print('Home: http://mycli.net')
|
||||
print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
|
||||
print(SUPPORT_INFO)
|
||||
print('Thanks to the contributor -', thanks_picker())
|
||||
|
||||
def get_message():
|
||||
prompt = self.get_prompt(self.prompt_format)
|
||||
|
@ -862,8 +903,8 @@ class MyCli(object):
|
|||
|
||||
if not output_via_pager:
|
||||
# doesn't fit, flush buffer
|
||||
for line in buf:
|
||||
click.secho(line)
|
||||
for buf_line in buf:
|
||||
click.secho(buf_line)
|
||||
buf = []
|
||||
else:
|
||||
click.secho(line)
|
||||
|
@ -933,7 +974,7 @@ class MyCli(object):
|
|||
string = string.replace('\\u', sqlexecute.user or '(none)')
|
||||
string = string.replace('\\h', host or '(none)')
|
||||
string = string.replace('\\d', sqlexecute.dbname or '(none)')
|
||||
string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
|
||||
string = string.replace('\\t', sqlexecute.server_info.species.name)
|
||||
string = string.replace('\\n', "\n")
|
||||
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
|
||||
string = string.replace('\\m', now.strftime('%M'))
|
||||
|
@ -1083,7 +1124,7 @@ class MyCli(object):
|
|||
help='Warn before running a destructive query.')
|
||||
@click.option('--local-infile', type=bool,
|
||||
help='Enable/disable LOAD DATA LOCAL INFILE.')
|
||||
@click.option('--login-path', type=str,
|
||||
@click.option('-g', '--login-path', type=str,
|
||||
help='Read this path from the login file.')
|
||||
@click.option('-e', '--execute', type=str,
|
||||
help='Execute command and quit.')
|
||||
|
@ -1091,6 +1132,8 @@ class MyCli(object):
|
|||
help='SQL statement to execute after connecting.')
|
||||
@click.option('--charset', type=str,
|
||||
help='Character set for MySQL session.')
|
||||
@click.option('--password-file', type=click.Path(),
|
||||
help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
|
||||
@click.argument('database', default='', nargs=1)
|
||||
def cli(database, user, host, port, socket, password, dbname,
|
||||
version, verbose, prompt, logfile, defaults_group_suffix,
|
||||
|
@ -1099,7 +1142,7 @@ def cli(database, user, host, port, socket, password, dbname,
|
|||
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
|
||||
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
|
||||
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
|
||||
init_command, charset):
|
||||
init_command, charset, password_file):
|
||||
"""A MySQL terminal client with auto-completion and syntax highlighting.
|
||||
|
||||
\b
|
||||
|
@ -1225,7 +1268,8 @@ def cli(database, user, host, port, socket, password, dbname,
|
|||
ssh_password=ssh_password,
|
||||
ssh_key_filename=ssh_key_filename,
|
||||
init_command=init_command,
|
||||
charset=charset
|
||||
charset=charset,
|
||||
password_file=password_file
|
||||
)
|
||||
|
||||
mycli.logger.debug('Launch Params: \n'
|
||||
|
@ -1328,9 +1372,15 @@ def is_select(status):
|
|||
return status.split(None, 1)[0].lower() == 'select'
|
||||
|
||||
|
||||
def thanks_picker(files=()):
|
||||
def thanks_picker():
|
||||
import mycli
|
||||
lines = (
|
||||
resources.read_text(mycli, 'AUTHORS') +
|
||||
resources.read_text(mycli, 'SPONSORS')
|
||||
).split('\n')
|
||||
|
||||
contents = []
|
||||
for line in fileinput.input(files=files):
|
||||
for line in lines:
|
||||
m = re.match(r'^ *\* (.*)', line)
|
||||
if m:
|
||||
contents.append(m.group(1))
|
||||
|
@ -1350,6 +1400,9 @@ def read_ssh_config(ssh_config_path):
|
|||
try:
|
||||
with open(ssh_config_path) as f:
|
||||
ssh_config.parse(f)
|
||||
except FileNotFoundError as e:
|
||||
click.secho(str(e), err=True, fg='red')
|
||||
sys.exit(1)
|
||||
# Paramiko prior to version 2.7 raises Exception on parse errors.
|
||||
# In 2.7 it has become paramiko.ssh_exception.SSHException,
|
||||
# but let's catch everything for compatibility
|
||||
|
@ -1359,9 +1412,6 @@ def read_ssh_config(ssh_config_path):
|
|||
err=True, fg='red'
|
||||
)
|
||||
sys.exit(1)
|
||||
except FileNotFoundError as e:
|
||||
click.secho(str(e), err=True, fg='red')
|
||||
sys.exit(1)
|
||||
else:
|
||||
return ssh_config
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
import sys
|
||||
import sqlparse
|
||||
from sqlparse.sql import Comparison, Identifier, Where
|
||||
from .parseutils import last_word, extract_tables, find_prev_keyword
|
||||
|
|
|
@ -12,7 +12,8 @@ cleanup_regex = {
|
|||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
r"""
|
||||
|
@ -226,14 +227,6 @@ def is_destructive(queries):
|
|||
return False
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
"""Returns true if the query contains an unclosed quote."""
|
||||
|
||||
# parsed can contain one or more semi-colon separated commands
|
||||
parsed = sqlparse.parse(sql)
|
||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sql = 'select * from (select t. from tabl t'
|
||||
print (extract_tables(sql))
|
||||
|
@ -263,5 +256,4 @@ def is_dropping_database(queries, dbname):
|
|||
)
|
||||
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
|
||||
result = keywords[0].normalized == "DROP"
|
||||
else:
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -302,7 +302,7 @@ def execute_system_command(arg, **_):
|
|||
usage = "Syntax: system [command].\n"
|
||||
|
||||
if not arg:
|
||||
return [(None, None, None, usage)]
|
||||
return [(None, None, None, usage)]
|
||||
|
||||
try:
|
||||
command = arg.strip()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
"""Format adapter for sql."""
|
||||
|
||||
from cli_helpers.utils import filter_dict_by_key
|
||||
from mycli.packages.parseutils import extract_tables
|
||||
|
||||
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
|
||||
|
|
|
@ -72,7 +72,7 @@ class SQLCompleter(Completer):
|
|||
if name and ((not self.name_pattern.match(name))
|
||||
or (name.upper() in self.reserved_words)
|
||||
or (name.upper() in self.functions)):
|
||||
name = '`%s`' % name
|
||||
name = '`%s`' % name
|
||||
|
||||
return name
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import enum
|
||||
import logging
|
||||
import re
|
||||
|
||||
import pymysql
|
||||
import sqlparse
|
||||
from .packages import special
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
from pymysql.converters import (convert_datetime,
|
||||
|
@ -18,17 +20,71 @@ FIELD_TYPES.update({
|
|||
FIELD_TYPE.NULL: type(None)
|
||||
})
|
||||
|
||||
|
||||
ERROR_CODE_ACCESS_DENIED = 1045
|
||||
|
||||
|
||||
class ServerSpecies(enum.Enum):
|
||||
MySQL = 'MySQL'
|
||||
MariaDB = 'MariaDB'
|
||||
Percona = 'Percona'
|
||||
Unknown = 'MySQL'
|
||||
|
||||
|
||||
class ServerInfo:
|
||||
def __init__(self, species, version_str):
|
||||
self.species = species
|
||||
self.version_str = version_str
|
||||
self.version = self.calc_mysql_version_value(version_str)
|
||||
|
||||
@staticmethod
|
||||
def calc_mysql_version_value(version_str) -> int:
|
||||
if not version_str or not isinstance(version_str, str):
|
||||
return 0
|
||||
try:
|
||||
major, minor, patch = version_str.split('.')
|
||||
except ValueError:
|
||||
return 0
|
||||
else:
|
||||
return int(major) * 10_000 + int(minor) * 100 + int(patch)
|
||||
|
||||
@classmethod
|
||||
def from_version_string(cls, version_string):
|
||||
if not version_string:
|
||||
return cls(ServerSpecies.Unknown, '')
|
||||
|
||||
re_species = (
|
||||
(r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
|
||||
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
|
||||
ServerSpecies.Percona),
|
||||
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
|
||||
ServerSpecies.MySQL),
|
||||
)
|
||||
for regexp, species in re_species:
|
||||
match = re.search(regexp, version_string)
|
||||
if match is not None:
|
||||
parsed_version = match.group('version')
|
||||
detected_species = species
|
||||
break
|
||||
else:
|
||||
detected_species = ServerSpecies.Unknown
|
||||
parsed_version = ''
|
||||
|
||||
return cls(detected_species, parsed_version)
|
||||
|
||||
def __str__(self):
|
||||
if self.species:
|
||||
return f'{self.species.value} {self.version_str}'
|
||||
else:
|
||||
return self.version_str
|
||||
|
||||
|
||||
class SQLExecute(object):
|
||||
|
||||
databases_query = '''SHOW DATABASES'''
|
||||
|
||||
tables_query = '''SHOW TABLES'''
|
||||
|
||||
version_query = '''SELECT @@VERSION'''
|
||||
|
||||
version_comment_query = '''SELECT @@VERSION_COMMENT'''
|
||||
version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
|
||||
|
||||
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
|
||||
|
||||
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
|
||||
|
@ -52,7 +108,7 @@ class SQLExecute(object):
|
|||
self.charset = charset
|
||||
self.local_infile = local_infile
|
||||
self.ssl = ssl
|
||||
self._server_type = None
|
||||
self.server_info = None
|
||||
self.connection_id = None
|
||||
self.ssh_user = ssh_user
|
||||
self.ssh_host = ssh_host
|
||||
|
@ -157,6 +213,7 @@ class SQLExecute(object):
|
|||
self.init_command = init_command
|
||||
# retrieve connection id
|
||||
self.reset_connection_id()
|
||||
self.server_info = ServerInfo.from_version_string(conn.server_version)
|
||||
|
||||
def run(self, statement):
|
||||
"""Execute the sql in the database and return the results. The results
|
||||
|
@ -273,37 +330,6 @@ class SQLExecute(object):
|
|||
for row in cur:
|
||||
yield row
|
||||
|
||||
def server_type(self):
|
||||
if self._server_type:
|
||||
return self._server_type
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Version Query. sql: %r', self.version_query)
|
||||
cur.execute(self.version_query)
|
||||
version = cur.fetchone()[0]
|
||||
if version[0] == '4':
|
||||
_logger.debug('Version Comment. sql: %r',
|
||||
self.version_comment_query_mysql4)
|
||||
cur.execute(self.version_comment_query_mysql4)
|
||||
version_comment = cur.fetchone()[1].lower()
|
||||
if isinstance(version_comment, bytes):
|
||||
# with python3 this query returns bytes
|
||||
version_comment = version_comment.decode('utf-8')
|
||||
else:
|
||||
_logger.debug('Version Comment. sql: %r',
|
||||
self.version_comment_query)
|
||||
cur.execute(self.version_comment_query)
|
||||
version_comment = cur.fetchone()[0].lower()
|
||||
|
||||
if 'mariadb' in version_comment:
|
||||
product_type = 'mariadb'
|
||||
elif 'percona' in version_comment:
|
||||
product_type = 'percona'
|
||||
else:
|
||||
product_type = 'mysql'
|
||||
|
||||
self._server_type = (product_type, version)
|
||||
return self._server_type
|
||||
|
||||
def get_connection_id(self):
|
||||
if not self.connection_id:
|
||||
self.reset_connection_id()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue