1
0
Fork 0

Merging upstream version 1.24.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 18:56:59 +01:00
parent 570aa52ec2
commit 06dd2aeb28
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
26 changed files with 565 additions and 169 deletions

View file

@ -7,13 +7,20 @@ on:
jobs: jobs:
linux: linux:
runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.6, 3.7, 3.8, 3.9]
include:
- python-version: 3.6
os: ubuntu-16.04 # MySQL 5.7.32
- python-version: 3.7
os: ubuntu-18.04 # MySQL 5.7.32
- python-version: 3.8
os: ubuntu-18.04 # MySQL 5.7.32
- python-version: 3.9
os: ubuntu-20.04 # MySQL 8.0.22
runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
@ -42,6 +49,7 @@ jobs:
- name: Pytest / behave - name: Pytest / behave
env: env:
PYTEST_PASSWORD: root PYTEST_PASSWORD: root
PYTEST_HOST: 127.0.0.1
run: | run: |
./setup.py test --pytest-args="--cov-report= --cov=mycli" ./setup.py test --pytest-args="--cov-report= --cov=mycli"

View file

@ -1,8 +1,8 @@
# mycli # mycli
[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli) [![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli)
[![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli) [![PyPI](https://img.shields.io/pypi/v/mycli.svg)](https://pypi.python.org/pypi/mycli)
[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![LGTM](https://img.shields.io/lgtm/grade/python/github/dbcli/mycli.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/dbcli/mycli/context:python)
A command line client for MySQL that can do auto-completion and syntax highlighting. A command line client for MySQL that can do auto-completion and syntax highlighting.
@ -53,6 +53,7 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
-h, --host TEXT Host address of the database. -h, --host TEXT Host address of the database.
-P, --port INTEGER Port number to use for connection. Honors -P, --port INTEGER Port number to use for connection. Honors
$MYSQL_TCP_PORT. $MYSQL_TCP_PORT.
-u, --user TEXT User name to connect to the database. -u, --user TEXT User name to connect to the database.
-S, --socket TEXT The socket file to use for connection. -S, --socket TEXT The socket file to use for connection.
-p, --password TEXT Password to connect to the database. -p, --password TEXT Password to connect to the database.
@ -63,8 +64,11 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
--ssh-password TEXT Password to connect to ssh server. --ssh-password TEXT Password to connect to ssh server.
--ssh-key-filename TEXT Private key filename (identify file) for the --ssh-key-filename TEXT Private key filename (identify file) for the
ssh connection. ssh connection.
--ssh-config-path TEXT Path to ssh configuration. --ssh-config-path TEXT Path to ssh configuration.
--ssh-config-host TEXT Host for ssh server in ssh configurations (requires paramiko). --ssh-config-host TEXT Host to connect to ssh server reading from ssh
configuration.
--ssl-ca PATH CA file in PEM format. --ssl-ca PATH CA file in PEM format.
--ssl-capath TEXT CA directory. --ssl-capath TEXT CA directory.
--ssl-cert PATH X509 cert in PEM format. --ssl-cert PATH X509 cert in PEM format.
@ -73,33 +77,43 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
--ssl-verify-server-cert Verify server's "Common Name" in its cert --ssl-verify-server-cert Verify server's "Common Name" in its cert
against hostname used when connecting. This against hostname used when connecting. This
option is disabled by default. option is disabled by default.
-V, --version Output mycli's version. -V, --version Output mycli's version.
-v, --verbose Verbose output. -v, --verbose Verbose output.
-D, --database TEXT Database to use. -D, --database TEXT Database to use.
-d, --dsn TEXT Use DSN configured into the [alias_dsn] -d, --dsn TEXT Use DSN configured into the [alias_dsn]
section of myclirc file. section of myclirc file.
--list-dsn list of DSN configured into the [alias_dsn] --list-dsn list of DSN configured into the [alias_dsn]
section of myclirc file. section of myclirc file.
--list-ssh-config list ssh configurations in the ssh config (requires paramiko).
--list-ssh-config list ssh configurations in the ssh config
(requires paramiko).
-R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> "). -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> ").
-l, --logfile FILENAME Log every query and its results to a file. -l, --logfile FILENAME Log every query and its results to a file.
--defaults-group-suffix TEXT Read MySQL config groups with the specified --defaults-group-suffix TEXT Read MySQL config groups with the specified
suffix. suffix.
--defaults-file PATH Only read MySQL options from the given file. --defaults-file PATH Only read MySQL options from the given file.
--myclirc PATH Location of myclirc file. --myclirc PATH Location of myclirc file.
--auto-vertical-output Automatically switch to vertical output mode --auto-vertical-output Automatically switch to vertical output mode
if the result is wider than the terminal if the result is wider than the terminal
width. width.
-t, --table Display batch output in table format. -t, --table Display batch output in table format.
--csv Display batch output in CSV format. --csv Display batch output in CSV format.
--warn / --no-warn Warn before running a destructive query. --warn / --no-warn Warn before running a destructive query.
--local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE. --local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE.
--login-path TEXT Read this path from the login file. -g, --login-path TEXT Read this path from the login file.
-e, --execute TEXT Execute command and quit. -e, --execute TEXT Execute command and quit.
--init-command TEXT SQL statement to execute after connecting. --init-command TEXT SQL statement to execute after connecting.
--charset TEXT Character set for MySQL session. --charset TEXT Character set for MySQL session.
--password-file PATH File or FIFO path containing the password
to connect to the db if not specified otherwise
--help Show this message and exit. --help Show this message and exit.
Features Features
-------- --------

View file

@ -1,19 +1,60 @@
TBD:
====
*
1.24.1:
=======
Bug Fixes:
---------
* Restore dependency on cryptography for the interactive password prompt
1.24.0
======
Bug Fixes:
----------
* Allow `FileNotFound` exception for SSH config files.
* Fix startup error on MySQL < 5.0.22
* Check error code rather than message for Access Denied error
* Fix login with ~/.my.cnf files
Features:
---------
* Add `-g` shortcut to option `--login-path`.
* Alt-Enter dispatches the command in multi-line mode.
* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html)
Internal:
---------
* Remove unused function is_open_quote()
* Use importlib, instead of file links, to locate resources
* Test various host-port combinations in command line arguments
* Switched from Cryptography to pyaes for decrypting mylogin.cnf
1.23.2 1.23.2
=== ======
Bug Fixes: Bug Fixes:
---------- ----------
* Ensure `--port` is always an int. * Ensure `--port` is always an int.
1.23.1 1.23.1
=== ======
Bug Fixes: Bug Fixes:
---------- ----------
* Allow `--host` without `--port` to make a TCP connection. * Allow `--host` without `--port` to make a TCP connection.
1.23.0 1.23.0
=== ======
Bug Fixes:
----------
* Fix config file include logic
Features: Features:
--------- ---------

View file

@ -75,6 +75,8 @@ Contributors:
* Zach DeCook * Zach DeCook
* kevinhwang91 * kevinhwang91
* KITAGAWA Yasutaka * KITAGAWA Yasutaka
* Nicolas Palumbo
* Andy Teijelo Pérez
* bitkeen * bitkeen
* Morgan Mitchell * Morgan Mitchell
* Massimiliano Torromeo * Massimiliano Torromeo
@ -82,6 +84,7 @@ Contributors:
* xeron * xeron
* 0xflotus * 0xflotus
* Seamile * Seamile
* Jerome Provensal
Creator: Creator:
-------- --------

View file

@ -1 +1 @@
__version__ = '1.23.2' __version__ = '1.24.1'

View file

@ -1,7 +1,6 @@
from prompt_toolkit.enums import DEFAULT_BUFFER from prompt_toolkit.enums import DEFAULT_BUFFER
from prompt_toolkit.filters import Condition from prompt_toolkit.filters import Condition
from prompt_toolkit.application import get_app from prompt_toolkit.application import get_app
from .packages.parseutils import is_open_quote
from .packages import special from .packages import special

View file

@ -1,5 +1,3 @@
import io
import shutil
from copy import copy from copy import copy
from io import BytesIO, TextIOWrapper from io import BytesIO, TextIOWrapper
import logging import logging
@ -7,11 +5,16 @@ import os
from os.path import exists from os.path import exists
import struct import struct
import sys import sys
from typing import Union from typing import Union, IO
from configobj import ConfigObj, ConfigObjError from configobj import ConfigObj, ConfigObjError
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes import pyaes
from cryptography.hazmat.backends import default_backend
try:
import importlib.resources as resources
except ImportError:
# Python < 3.7
import importlib_resources as resources
try: try:
basestring basestring
@ -49,9 +52,9 @@ def read_config_file(f, list_values=True):
config = ConfigObj(f, interpolation=False, encoding='utf8', config = ConfigObj(f, interpolation=False, encoding='utf8',
list_values=list_values) list_values=list_values)
except ConfigObjError as e: except ConfigObjError as e:
log(logger, logging.ERROR, "Unable to parse line {0} of config file " log(logger, logging.WARNING, "Unable to parse line {0} of config file "
"'{1}'.".format(e.line_number, f)) "'{1}'.".format(e.line_number, f))
log(logger, logging.ERROR, "Using successfully parsed config values.") log(logger, logging.WARNING, "Using successfully parsed config values.")
return e.config return e.config
except (IOError, OSError) as e: except (IOError, OSError) as e:
log(logger, logging.WARNING, "You don't have permission to read " log(logger, logging.WARNING, "You don't have permission to read "
@ -61,7 +64,7 @@ def read_config_file(f, list_values=True):
return config return config
def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list: def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
"""Get a list of configuration files that are included into config_path """Get a list of configuration files that are included into config_path
with !includedir directive. with !includedir directive.
@ -95,7 +98,7 @@ def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
def read_config_files(files, list_values=True): def read_config_files(files, list_values=True):
"""Read and merge a list of config files.""" """Read and merge a list of config files."""
config = ConfigObj(list_values=list_values) config = create_default_config(list_values=list_values)
_files = copy(files) _files = copy(files)
while _files: while _files:
_file = _files.pop(0) _file = _files.pop(0)
@ -112,12 +115,21 @@ def read_config_files(files, list_values=True):
return config return config
def write_default_config(source, destination, overwrite=False): def create_default_config(list_values=True):
import mycli
default_config_file = resources.open_text(mycli, 'myclirc')
return read_config_file(default_config_file, list_values=list_values)
def write_default_config(destination, overwrite=False):
import mycli
default_config = resources.read_text(mycli, 'myclirc')
destination = os.path.expanduser(destination) destination = os.path.expanduser(destination)
if not overwrite and exists(destination): if not overwrite and exists(destination):
return return
shutil.copyfile(source, destination) with open(destination, 'w') as f:
f.write(default_config)
def get_mylogin_cnf_path(): def get_mylogin_cnf_path():
@ -160,6 +172,58 @@ def open_mylogin_cnf(name):
return TextIOWrapper(plaintext) return TextIOWrapper(plaintext)
# TODO reuse code between encryption an decryption
def encrypt_mylogin_cnf(plaintext: IO[str]):
"""Encryption of .mylogin.cnf file, analogous to calling
mysql_config_editor.
Code is based on the python implementation by Kristian Koehntopp
https://github.com/isotopp/mysql-config-coder
"""
def realkey(key):
"""Create the AES key from the login key."""
rkey = bytearray(16)
for i in range(len(key)):
rkey[i % 16] ^= key[i]
return bytes(rkey)
def encode_line(plaintext, real_key, buf_len):
aes = pyaes.AESModeOfOperationECB(real_key)
text_len = len(plaintext)
pad_len = buf_len - text_len
pad_chr = bytes(chr(pad_len), "utf8")
plaintext = plaintext.encode() + pad_chr * pad_len
encrypted_text = b''.join(
[aes.encrypt(plaintext[i: i + 16])
for i in range(0, len(plaintext), 16)]
)
return encrypted_text
LOGIN_KEY_LENGTH = 20
key = os.urandom(LOGIN_KEY_LENGTH)
real_key = realkey(key)
outfile = BytesIO()
outfile.write(struct.pack("i", 0))
outfile.write(key)
while True:
line = plaintext.readline()
if not line:
break
real_len = len(line)
pad_len = (int(real_len / 16) + 1) * 16
outfile.write(struct.pack("i", pad_len))
x = encode_line(line, real_key, pad_len)
outfile.write(x)
outfile.seek(0)
return outfile
def read_and_decrypt_mylogin_cnf(f): def read_and_decrypt_mylogin_cnf(f):
"""Read and decrypt the contents of .mylogin.cnf. """Read and decrypt the contents of .mylogin.cnf.
@ -201,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f):
return None return None
rkey = struct.pack('16B', *rkey) rkey = struct.pack('16B', *rkey)
# Create a decryptor object using the key.
decryptor = _get_decryptor(rkey)
# Create a bytes buffer to hold the plaintext. # Create a bytes buffer to hold the plaintext.
plaintext = BytesIO() plaintext = BytesIO()
aes = pyaes.AESModeOfOperationECB(rkey)
while True: while True:
# Read the length of the ciphertext. # Read the length of the ciphertext.
@ -216,7 +278,10 @@ def read_and_decrypt_mylogin_cnf(f):
# Read cipher_len bytes from the file and decrypt. # Read cipher_len bytes from the file and decrypt.
cipher = f.read(cipher_len) cipher = f.read(cipher_len)
plain = _remove_pad(decryptor.update(cipher)) plain = _remove_pad(
b''.join([aes.decrypt(cipher[i: i + 16])
for i in range(0, cipher_len, 16)])
)
if plain is False: if plain is False:
continue continue
plaintext.write(plain) plaintext.write(plain)
@ -244,7 +309,7 @@ def str_to_bool(s):
elif s.lower() in false_values: elif s.lower() in false_values:
return False return False
else: else:
raise ValueError('not a recognized boolean value: %s'.format(s)) raise ValueError('not a recognized boolean value: {0}'.format(s))
def strip_matching_quotes(s): def strip_matching_quotes(s):
@ -260,15 +325,8 @@ def strip_matching_quotes(s):
return s return s
def _get_decryptor(key):
"""Get the AES decryptor."""
c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
return c.decryptor()
def _remove_pad(line): def _remove_pad(line):
"""Remove the pad from the *line*.""" """Remove the pad from the *line*."""
pad_length = ord(line[-1:])
try: try:
# Determine pad length. # Determine pad length.
pad_length = ord(line[-1:]) pad_length = ord(line[-1:])

View file

@ -78,8 +78,12 @@ def mycli_bindings(mycli):
@kb.add('escape', 'enter') @kb.add('escape', 'enter')
def _(event): def _(event):
"""Introduces a line break regardless of multi-line mode or not.""" """Introduces a line break in multi-line mode, or dispatches the
command in single-line mode."""
_logger.debug('Detected alt-enter key.') _logger.debug('Detected alt-enter key.')
event.app.current_buffer.insert_text('\n') if mycli.multi_line:
event.app.current_buffer.validate_and_handle()
else:
event.app.current_buffer.insert_text('\n')
return kb return kb

View file

@ -1,9 +1,12 @@
from collections import defaultdict
from io import open
import os import os
import sys import sys
import traceback import traceback
import logging import logging
import threading import threading
import re import re
import stat
import fileinput import fileinput
from collections import namedtuple from collections import namedtuple
try: try:
@ -13,7 +16,6 @@ except ImportError:
from time import time from time import time
from datetime import datetime from datetime import datetime
from random import choice from random import choice
from io import open
from pymysql import OperationalError from pymysql import OperationalError
from cli_helpers.tabular_output import TabularOutputFormatter from cli_helpers.tabular_output import TabularOutputFormatter
@ -43,7 +45,7 @@ from .packages.special.favoritequeries import FavoriteQueries
from .sqlcompleter import SQLCompleter from .sqlcompleter import SQLCompleter
from .clitoolbar import create_toolbar_tokens_func from .clitoolbar import create_toolbar_tokens_func
from .clistyle import style_factory, style_factory_output from .clistyle import style_factory, style_factory_output
from .sqlexecute import FIELD_TYPES, SQLExecute from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
from .clibuffer import cli_is_multiline from .clibuffer import cli_is_multiline
from .completion_refresher import CompletionRefresher from .completion_refresher import CompletionRefresher
from .config import (write_default_config, get_mylogin_cnf_path, from .config import (write_default_config, get_mylogin_cnf_path,
@ -51,7 +53,7 @@ from .config import (write_default_config, get_mylogin_cnf_path,
strip_matching_quotes) strip_matching_quotes)
from .key_bindings import mycli_bindings from .key_bindings import mycli_bindings
from .lexer import MyCliLexer from .lexer import MyCliLexer
from .__init__ import __version__ from . import __version__
from .compat import WIN from .compat import WIN
from .packages.filepaths import dir_path_exists, guess_socket_location from .packages.filepaths import dir_path_exists, guess_socket_location
@ -66,6 +68,11 @@ except ImportError:
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.parse import unquote from urllib.parse import unquote
try:
import importlib.resources as resources
except ImportError:
# Python < 3.7
import importlib_resources as resources
try: try:
import paramiko import paramiko
@ -75,7 +82,10 @@ except ImportError:
# Query tuples are used for maintaining history # Query tuples are used for maintaining history
Query = namedtuple('Query', ['query', 'successful', 'mutating']) Query = namedtuple('Query', ['query', 'successful', 'mutating'])
PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__)) SUPPORT_INFO = (
'Home: http://mycli.net\n'
'Bug tracker: https://github.com/dbcli/mycli/issues'
)
class MyCli(object): class MyCli(object):
@ -89,7 +99,7 @@ class MyCli(object):
'/etc/my.cnf', '/etc/my.cnf',
'/etc/mysql/my.cnf', '/etc/mysql/my.cnf',
'/usr/local/etc/my.cnf', '/usr/local/etc/my.cnf',
'~/.my.cnf' os.path.expanduser('~/.my.cnf'),
] ]
# check XDG_CONFIG_HOME exists and not an empty string # check XDG_CONFIG_HOME exists and not an empty string
@ -102,7 +112,6 @@ class MyCli(object):
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc") os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
] ]
default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
pwd_config_file = os.path.join(os.getcwd(), ".myclirc") pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
def __init__(self, sqlexecute=None, prompt=None, def __init__(self, sqlexecute=None, prompt=None,
@ -122,7 +131,7 @@ class MyCli(object):
self.cnf_files = [defaults_file] self.cnf_files = [defaults_file]
# Load config. # Load config.
config_files = ([self.default_config_file] + self.system_config_files + config_files = (self.system_config_files +
[myclirc] + [self.pwd_config_file]) [myclirc] + [self.pwd_config_file])
c = self.config = read_config_files(config_files) c = self.config = read_config_files(config_files)
self.multi_line = c['main'].as_bool('multi_line') self.multi_line = c['main'].as_bool('multi_line')
@ -154,7 +163,7 @@ class MyCli(object):
# Write user config if system config wasn't the last config loaded. # Write user config if system config wasn't the last config loaded.
if c.filename not in self.system_config_files and not os.path.exists(myclirc): if c.filename not in self.system_config_files and not os.path.exists(myclirc):
write_default_config(self.default_config_file, myclirc) write_default_config(myclirc)
# audit log # audit log
if self.logfile is None and 'audit_log' in c['main']: if self.logfile is None and 'audit_log' in c['main']:
@ -326,20 +335,33 @@ class MyCli(object):
cnf = read_config_files(files, list_values=False) cnf = read_config_files(files, list_values=False)
sections = ['client', 'mysqld'] sections = ['client', 'mysqld']
key_transformations = {
'mysqld': {
'socket': 'default_socket',
'port': 'default_port',
},
}
if self.login_path and self.login_path != 'client': if self.login_path and self.login_path != 'client':
sections.append(self.login_path) sections.append(self.login_path)
if self.defaults_suffix: if self.defaults_suffix:
sections.extend([sect + self.defaults_suffix for sect in sections]) sections.extend([sect + self.defaults_suffix for sect in sections])
def get(key): configuration = defaultdict(lambda: None)
result = None for key in keys:
for sect in cnf: for section in cnf:
if sect in sections and key in cnf[sect]: if (
result = strip_matching_quotes(cnf[sect][key]) section not in sections or
return result key not in cnf[section]
):
continue
new_key = key_transformations.get(section, {}).get(key) or key
configuration[new_key] = strip_matching_quotes(
cnf[section][key])
return configuration
return {x: get(x) for x in keys}
def merge_ssl_with_cnf(self, ssl, cnf): def merge_ssl_with_cnf(self, ssl, cnf):
"""Merge SSL configuration dict with cnf dict""" """Merge SSL configuration dict with cnf dict"""
@ -367,7 +389,7 @@ class MyCli(object):
def connect(self, database='', user='', passwd='', host='', port='', def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl='', socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='', ssh_user='', ssh_host='', ssh_port='',
ssh_password='', ssh_key_filename='', init_command=''): ssh_password='', ssh_key_filename='', init_command='', password_file=''):
cnf = {'database': None, cnf = {'database': None,
'user': None, 'user': None,
@ -375,6 +397,7 @@ class MyCli(object):
'host': None, 'host': None,
'port': None, 'port': None,
'socket': None, 'socket': None,
'default_socket': None,
'default-character-set': None, 'default-character-set': None,
'local-infile': None, 'local-infile': None,
'loose-local-infile': None, 'loose-local-infile': None,
@ -388,18 +411,23 @@ class MyCli(object):
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
# Fall back to config values only if user did not specify a value. # Fall back to config values only if user did not specify a value.
database = database or cnf['database'] database = database or cnf['database']
# Socket interface not supported for SSH connections
if port or (host and host != 'localhost') or (ssh_host and ssh_port):
socket = ''
else:
socket = socket or cnf['socket'] or guess_socket_location()
user = user or cnf['user'] or os.getenv('USER') user = user or cnf['user'] or os.getenv('USER')
host = host or cnf['host'] host = host or cnf['host']
port = int(port or cnf['port'] or 3306) port = port or cnf['port']
ssl = ssl or {} ssl = ssl or {}
port = port and int(port)
if not port:
port = 3306
if not host or host == 'localhost':
socket = (
cnf['socket'] or
cnf['default_socket'] or
guess_socket_location()
)
passwd = passwd if isinstance(passwd, str) else cnf['password'] passwd = passwd if isinstance(passwd, str) else cnf['password']
charset = charset or cnf['default-character-set'] or 'utf8' charset = charset or cnf['default-character-set'] or 'utf8'
@ -417,6 +445,10 @@ class MyCli(object):
if not any(v for v in ssl.values()): if not any(v for v in ssl.values()):
ssl = None ssl = None
# if the passwd is not specfied try to set it using the password_file option
password_from_file = self.get_password_from_file(password_file)
passwd = passwd or password_from_file
# Connect to the database. # Connect to the database.
def _connect(): def _connect():
@ -427,9 +459,12 @@ class MyCli(object):
ssh_password, ssh_key_filename, init_command ssh_password, ssh_key_filename, init_command
) )
except OperationalError as e: except OperationalError as e:
if ('Access denied for user' in e.args[1]): if e.args[0] == ERROR_CODE_ACCESS_DENIED:
new_passwd = click.prompt('Password', hide_input=True, if password_from_file:
show_default=False, type=str, err=True) new_passwd = password_from_file
else:
new_passwd = click.prompt('Password', hide_input=True,
show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute( self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket, database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host, charset, local_infile, ssl, ssh_user, ssh_host,
@ -484,6 +519,17 @@ class MyCli(object):
self.echo(str(e), err=True, fg='red') self.echo(str(e), err=True, fg='red')
exit(1) exit(1)
def get_password_from_file(self, password_file):
password_from_file = None
if password_file:
if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
and os.access(password_file, os.R_OK):
with open(password_file) as fp:
password_from_file = fp.readline()
password_from_file = password_from_file.rstrip().lstrip()
return password_from_file
def handle_editor_command(self, text): def handle_editor_command(self, text):
r"""Editor command is any query that is prefixed or suffixed by a '\e'. r"""Editor command is any query that is prefixed or suffixed by a '\e'.
The reason for a while loop is because a user might edit a query The reason for a while loop is because a user might edit a query
@ -542,9 +588,6 @@ class MyCli(object):
if self.smart_completion: if self.smart_completion:
self.refresh_completions() self.refresh_completions()
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
history_file = os.path.expanduser( history_file = os.path.expanduser(
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
if dir_path_exists(history_file): if dir_path_exists(history_file):
@ -559,12 +602,10 @@ class MyCli(object):
key_bindings = mycli_bindings(self) key_bindings = mycli_bindings(self)
if not self.less_chatty: if not self.less_chatty:
print(' '.join(sqlexecute.server_type())) print(sqlexecute.server_info)
print('mycli', __version__) print('mycli', __version__)
print('Chat: https://gitter.im/dbcli/mycli') print(SUPPORT_INFO)
print('Mail: https://groups.google.com/forum/#!forum/mycli-users') print('Thanks to the contributor -', thanks_picker())
print('Home: http://mycli.net')
print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
def get_message(): def get_message():
prompt = self.get_prompt(self.prompt_format) prompt = self.get_prompt(self.prompt_format)
@ -862,8 +903,8 @@ class MyCli(object):
if not output_via_pager: if not output_via_pager:
# doesn't fit, flush buffer # doesn't fit, flush buffer
for line in buf: for buf_line in buf:
click.secho(line) click.secho(buf_line)
buf = [] buf = []
else: else:
click.secho(line) click.secho(line)
@ -933,7 +974,7 @@ class MyCli(object):
string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\u', sqlexecute.user or '(none)')
string = string.replace('\\h', host or '(none)') string = string.replace('\\h', host or '(none)')
string = string.replace('\\d', sqlexecute.dbname or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)')
string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') string = string.replace('\\t', sqlexecute.server_info.species.name)
string = string.replace('\\n', "\n") string = string.replace('\\n', "\n")
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\m', now.strftime('%M'))
@ -1083,7 +1124,7 @@ class MyCli(object):
help='Warn before running a destructive query.') help='Warn before running a destructive query.')
@click.option('--local-infile', type=bool, @click.option('--local-infile', type=bool,
help='Enable/disable LOAD DATA LOCAL INFILE.') help='Enable/disable LOAD DATA LOCAL INFILE.')
@click.option('--login-path', type=str, @click.option('-g', '--login-path', type=str,
help='Read this path from the login file.') help='Read this path from the login file.')
@click.option('-e', '--execute', type=str, @click.option('-e', '--execute', type=str,
help='Execute command and quit.') help='Execute command and quit.')
@ -1091,6 +1132,8 @@ class MyCli(object):
help='SQL statement to execute after connecting.') help='SQL statement to execute after connecting.')
@click.option('--charset', type=str, @click.option('--charset', type=str,
help='Character set for MySQL session.') help='Character set for MySQL session.')
@click.option('--password-file', type=click.Path(),
help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
@click.argument('database', default='', nargs=1) @click.argument('database', default='', nargs=1)
def cli(database, user, host, port, socket, password, dbname, def cli(database, user, host, port, socket, password, dbname,
version, verbose, prompt, logfile, defaults_group_suffix, version, verbose, prompt, logfile, defaults_group_suffix,
@ -1099,7 +1142,7 @@ def cli(database, user, host, port, socket, password, dbname,
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn, ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
init_command, charset): init_command, charset, password_file):
"""A MySQL terminal client with auto-completion and syntax highlighting. """A MySQL terminal client with auto-completion and syntax highlighting.
\b \b
@ -1225,7 +1268,8 @@ def cli(database, user, host, port, socket, password, dbname,
ssh_password=ssh_password, ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename, ssh_key_filename=ssh_key_filename,
init_command=init_command, init_command=init_command,
charset=charset charset=charset,
password_file=password_file
) )
mycli.logger.debug('Launch Params: \n' mycli.logger.debug('Launch Params: \n'
@ -1328,9 +1372,15 @@ def is_select(status):
return status.split(None, 1)[0].lower() == 'select' return status.split(None, 1)[0].lower() == 'select'
def thanks_picker(files=()): def thanks_picker():
import mycli
lines = (
resources.read_text(mycli, 'AUTHORS') +
resources.read_text(mycli, 'SPONSORS')
).split('\n')
contents = [] contents = []
for line in fileinput.input(files=files): for line in lines:
m = re.match(r'^ *\* (.*)', line) m = re.match(r'^ *\* (.*)', line)
if m: if m:
contents.append(m.group(1)) contents.append(m.group(1))
@ -1350,6 +1400,9 @@ def read_ssh_config(ssh_config_path):
try: try:
with open(ssh_config_path) as f: with open(ssh_config_path) as f:
ssh_config.parse(f) ssh_config.parse(f)
except FileNotFoundError as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)
# Paramiko prior to version 2.7 raises Exception on parse errors. # Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException, # In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility # but let's catch everything for compatibility
@ -1359,9 +1412,6 @@ def read_ssh_config(ssh_config_path):
err=True, fg='red' err=True, fg='red'
) )
sys.exit(1) sys.exit(1)
except FileNotFoundError as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)
else: else:
return ssh_config return ssh_config

View file

@ -1,5 +1,3 @@
import os
import sys
import sqlparse import sqlparse
from sqlparse.sql import Comparison, Identifier, Where from sqlparse.sql import Comparison, Identifier, Where
from .parseutils import last_word, extract_tables, find_prev_keyword from .parseutils import last_word, extract_tables, find_prev_keyword

View file

@ -12,7 +12,8 @@ cleanup_regex = {
'most_punctuations': re.compile(r'([^\.():,\s]+)$'), 'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
# This matches everything except a space. # This matches everything except a space.
'all_punctuations': re.compile(r'([^\s]+)$'), 'all_punctuations': re.compile(r'([^\s]+)$'),
} }
def last_word(text, include='alphanum_underscore'): def last_word(text, include='alphanum_underscore'):
r""" r"""
@ -226,14 +227,6 @@ def is_destructive(queries):
return False return False
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote."""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
if __name__ == '__main__': if __name__ == '__main__':
sql = 'select * from (select t. from tabl t' sql = 'select * from (select t. from tabl t'
print (extract_tables(sql)) print (extract_tables(sql))
@ -263,5 +256,4 @@ def is_dropping_database(queries, dbname):
) )
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
result = keywords[0].normalized == "DROP" result = keywords[0].normalized == "DROP"
else: return result
return result

View file

@ -302,7 +302,7 @@ def execute_system_command(arg, **_):
usage = "Syntax: system [command].\n" usage = "Syntax: system [command].\n"
if not arg: if not arg:
return [(None, None, None, usage)] return [(None, None, None, usage)]
try: try:
command = arg.strip() command = arg.strip()

View file

@ -1,6 +1,5 @@
"""Format adapter for sql.""" """Format adapter for sql."""
from cli_helpers.utils import filter_dict_by_key
from mycli.packages.parseutils import extract_tables from mycli.packages.parseutils import extract_tables
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',

View file

@ -72,7 +72,7 @@ class SQLCompleter(Completer):
if name and ((not self.name_pattern.match(name)) if name and ((not self.name_pattern.match(name))
or (name.upper() in self.reserved_words) or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)): or (name.upper() in self.functions)):
name = '`%s`' % name name = '`%s`' % name
return name return name

View file

@ -1,6 +1,8 @@
import enum
import logging import logging
import re
import pymysql import pymysql
import sqlparse
from .packages import special from .packages import special
from pymysql.constants import FIELD_TYPE from pymysql.constants import FIELD_TYPE
from pymysql.converters import (convert_datetime, from pymysql.converters import (convert_datetime,
@ -18,17 +20,71 @@ FIELD_TYPES.update({
FIELD_TYPE.NULL: type(None) FIELD_TYPE.NULL: type(None)
}) })
ERROR_CODE_ACCESS_DENIED = 1045
class ServerSpecies(enum.Enum):
MySQL = 'MySQL'
MariaDB = 'MariaDB'
Percona = 'Percona'
Unknown = 'MySQL'
class ServerInfo:
def __init__(self, species, version_str):
self.species = species
self.version_str = version_str
self.version = self.calc_mysql_version_value(version_str)
@staticmethod
def calc_mysql_version_value(version_str) -> int:
if not version_str or not isinstance(version_str, str):
return 0
try:
major, minor, patch = version_str.split('.')
except ValueError:
return 0
else:
return int(major) * 10_000 + int(minor) * 100 + int(patch)
@classmethod
def from_version_string(cls, version_string):
if not version_string:
return cls(ServerSpecies.Unknown, '')
re_species = (
(r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
ServerSpecies.Percona),
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
ServerSpecies.MySQL),
)
for regexp, species in re_species:
match = re.search(regexp, version_string)
if match is not None:
parsed_version = match.group('version')
detected_species = species
break
else:
detected_species = ServerSpecies.Unknown
parsed_version = ''
return cls(detected_species, parsed_version)
def __str__(self):
if self.species:
return f'{self.species.value} {self.version_str}'
else:
return self.version_str
class SQLExecute(object): class SQLExecute(object):
databases_query = '''SHOW DATABASES''' databases_query = '''SHOW DATABASES'''
tables_query = '''SHOW TABLES''' tables_query = '''SHOW TABLES'''
version_query = '''SELECT @@VERSION'''
version_comment_query = '''SELECT @@VERSION_COMMENT'''
version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
@ -52,7 +108,7 @@ class SQLExecute(object):
self.charset = charset self.charset = charset
self.local_infile = local_infile self.local_infile = local_infile
self.ssl = ssl self.ssl = ssl
self._server_type = None self.server_info = None
self.connection_id = None self.connection_id = None
self.ssh_user = ssh_user self.ssh_user = ssh_user
self.ssh_host = ssh_host self.ssh_host = ssh_host
@ -157,6 +213,7 @@ class SQLExecute(object):
self.init_command = init_command self.init_command = init_command
# retrieve connection id # retrieve connection id
self.reset_connection_id() self.reset_connection_id()
self.server_info = ServerInfo.from_version_string(conn.server_version)
def run(self, statement): def run(self, statement):
"""Execute the sql in the database and return the results. The results """Execute the sql in the database and return the results. The results
@ -273,37 +330,6 @@ class SQLExecute(object):
for row in cur: for row in cur:
yield row yield row
def server_type(self):
if self._server_type:
return self._server_type
with self.conn.cursor() as cur:
_logger.debug('Version Query. sql: %r', self.version_query)
cur.execute(self.version_query)
version = cur.fetchone()[0]
if version[0] == '4':
_logger.debug('Version Comment. sql: %r',
self.version_comment_query_mysql4)
cur.execute(self.version_comment_query_mysql4)
version_comment = cur.fetchone()[1].lower()
if isinstance(version_comment, bytes):
# with python3 this query returns bytes
version_comment = version_comment.decode('utf-8')
else:
_logger.debug('Version Comment. sql: %r',
self.version_comment_query)
cur.execute(self.version_comment_query)
version_comment = cur.fetchone()[0].lower()
if 'mariadb' in version_comment:
product_type = 'mariadb'
elif 'percona' in version_comment:
product_type = 'percona'
else:
product_type = 'mysql'
self._server_type = (product_type, version)
return self._server_type
def get_connection_id(self): def get_connection_id(self):
if not self.connection_id: if not self.connection_id:
self.reset_connection_id() self.reset_connection_id()

2
pytest.ini Normal file
View file

@ -0,0 +1,2 @@
[pytest]
addopts = --ignore=mycli/packages/paramiko_stub/__init__.py

View file

@ -1,6 +1,5 @@
"""A script to publish a release of mycli to PyPI.""" """A script to publish a release of mycli to PyPI."""
import io
from optparse import OptionParser from optparse import OptionParser
import re import re
import subprocess import subprocess

View file

@ -18,16 +18,20 @@ description = 'CLI for MySQL Database. With auto-completion and syntax highlight
install_requirements = [ install_requirements = [
'click >= 7.0', 'click >= 7.0',
'cryptography >= 1.0.0',
'Pygments >= 1.6', 'Pygments >= 1.6',
'prompt_toolkit>=3.0.6,<4.0.0', 'prompt_toolkit>=3.0.6,<4.0.0',
'PyMySQL >= 0.9.2', 'PyMySQL >= 0.9.2',
'sqlparse>=0.3.0,<0.4.0', 'sqlparse>=0.3.0,<0.4.0',
'configobj >= 5.0.5', 'configobj >= 5.0.5',
'cryptography >= 1.0.0',
'cli_helpers[styles] >= 2.0.1', 'cli_helpers[styles] >= 2.0.1',
'pyperclip >= 1.8.1' 'pyperclip >= 1.8.1',
'pyaes >= 1.6.1'
] ]
if sys.version_info.minor < 9:
install_requirements.append('importlib_resources >= 5.0.0')
class lint(Command): class lint(Command):
description = 'check code against PEP 8 (and fix violations)' description = 'check code against PEP 8 (and fix violations)'

View file

@ -0,0 +1,35 @@
Feature: connect to a database:
@requires_local_db
Scenario: run mycli on localhost without port
When we run mycli with arguments "host=localhost" without arguments "port"
When we query "status"
Then status contains "via UNIX socket"
Scenario: run mycli on TCP host without port
When we run mycli without arguments "port"
When we query "status"
Then status contains "via TCP/IP"
Scenario: run mycli with port but without host
When we run mycli without arguments "host"
When we query "status"
Then status contains "via TCP/IP"
@requires_local_db
Scenario: run mycli without host and port
When we run mycli without arguments "host port"
When we query "status"
Then status contains "via UNIX socket"
Scenario: run mycli with my.cnf configuration
When we create my.cnf file
When we run mycli without arguments "host port user pass defaults_file"
Then we are logged in
Scenario: run mycli with mylogin.cnf configuration
When we create mylogin.cnf file
When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file"
Then we are logged in

View file

@ -1,4 +1,5 @@
import os import os
import shutil
import sys import sys
from tempfile import mkstemp from tempfile import mkstemp
@ -11,6 +12,24 @@ from steps.wrappers import run_cli, wait_prompt
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
SELF_CONNECTING_FEATURES = (
'test/features/connection.feature',
)
MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
def get_db_name_from_context(context):
return context.config.userdata.get(
'my_test_db', None
) or "mycli_behave_tests"
def before_all(context): def before_all(context):
"""Set env parameters.""" """Set env parameters."""
os.environ['LINES'] = "100" os.environ['LINES'] = "100"
@ -22,7 +41,7 @@ def before_all(context):
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
login_path_file = os.path.join(test_dir, 'mylogin.cnf') login_path_file = os.path.join(test_dir, 'mylogin.cnf')
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file # os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
context.package_root = os.path.abspath( context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
@ -33,8 +52,7 @@ def before_all(context):
context.exit_sent = False context.exit_sent = False
vi = '_'.join([str(x) for x in sys.version_info[:3]]) vi = '_'.join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get( db_name = get_db_name_from_context(context)
'my_test_db', None) or "mycli_behave_tests"
db_name_full = '{0}_{1}'.format(db_name, vi) db_name_full = '{0}_{1}'.format(db_name, vi)
# Store get params from config/environment variables # Store get params from config/environment variables
@ -104,11 +122,18 @@ def before_step(context, _):
context.atprompt = False context.atprompt = False
def before_scenario(context, _): def before_scenario(context, arg):
with open(test_log_file, 'w') as f: with open(test_log_file, 'w') as f:
f.write('') f.write('')
run_cli(context) if arg.location.filename not in SELF_CONNECTING_FEATURES:
wait_prompt(context) run_cli(context)
wait_prompt(context)
if os.path.exists(MY_CNF_PATH):
shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH)
if os.path.exists(MYLOGIN_CNF_PATH):
shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH)
def after_scenario(context, _): def after_scenario(context, _):
@ -134,6 +159,17 @@ def after_scenario(context, _):
context.cli.sendcontrol('d') context.cli.sendcontrol('d')
context.cli.expect_exact(pexpect.EOF, timeout=5) context.cli.expect_exact(pexpect.EOF, timeout=5)
if os.path.exists(MY_CNF_BACKUP_PATH):
shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH)
if os.path.exists(MYLOGIN_CNF_BACKUP_PATH):
shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH)
elif os.path.exists(MYLOGIN_CNF_PATH):
# This file was moved in `before_scenario`.
# If it exists now, it has been created during a test
os.remove(MYLOGIN_CNF_PATH)
# TODO: uncomment to debug a failure # TODO: uncomment to debug a failure
# def after_step(context, step): # def after_step(context, step):
# if step.status == "failed": # if step.status == "failed":

View file

@ -3,11 +3,12 @@ from textwrap import dedent
from behave import then, when from behave import then, when
import wrappers import wrappers
from utils import parse_cli_args_to_dict
@when('we run dbcli with {arg}') @when('we run dbcli with {arg}')
def step_run_cli_with_arg(context, arg): def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=arg.split('=')) wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
@when('we execute a small query') @when('we execute a small query')

View file

@ -0,0 +1,71 @@
import io
import os
import shlex
from behave import when, then
import pexpect
import wrappers
from test.features.steps.utils import parse_cli_args_to_dict
from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context
from test.utils import HOST, PORT, USER, PASSWORD
from mycli.config import encrypt_mylogin_cnf
TEST_LOGIN_PATH = 'test_login_path'
@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
@when('we run mycli without arguments "{excluded_args}"')
def step_run_cli_without_args(context, excluded_args, exact_args=''):
wrappers.run_cli(
context,
run_args=parse_cli_args_to_dict(exact_args),
exclude_args=parse_cli_args_to_dict(excluded_args).keys()
)
@then('status contains "{expression}"')
def status_contains(context, expression):
wrappers.expect_exact(context, f'{expression}', timeout=5)
# Normally, the shutdown after scenario waits for the prompt.
# But we may have changed the prompt, depending on parameters,
# so let's wait for its last character
context.cli.expect_exact('>')
context.atprompt = True
@when('we create my.cnf file')
def step_create_my_cnf_file(context):
my_cnf = (
'[client]\n'
f'host = {HOST}\n'
f'port = {PORT}\n'
f'user = {USER}\n'
f'password = {PASSWORD}\n'
)
with open(MY_CNF_PATH, 'w') as f:
f.write(my_cnf)
@when('we create mylogin.cnf file')
def step_create_mylogin_cnf_file(context):
os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
mylogin_cnf = (
f'[{TEST_LOGIN_PATH}]\n'
f'host = {HOST}\n'
f'port = {PORT}\n'
f'user = {USER}\n'
f'password = {PASSWORD}\n'
)
with open(MYLOGIN_CNF_PATH, 'wb') as f:
input_file = io.StringIO(mylogin_cnf)
f.write(encrypt_mylogin_cnf(input_file).read())
@then('we are logged in')
def we_are_logged_in(context):
db_name = get_db_name_from_context(context)
context.cli.expect_exact(f'{db_name}>', timeout=5)
context.atprompt = True

View file

@ -0,0 +1,12 @@
import shlex
def parse_cli_args_to_dict(cli_args: str):
args_dict = {}
for arg in shlex.split(cli_args):
if '=' in arg:
key, value = arg.split('=')
args_dict[key] = value
else:
args_dict[arg] = None
return args_dict

View file

@ -3,6 +3,7 @@ import pexpect
import sys import sys
import textwrap import textwrap
try: try:
from StringIO import StringIO from StringIO import StringIO
except ImportError: except ImportError:
@ -13,7 +14,7 @@ def expect_exact(context, expected, timeout):
timedout = False timedout = False
try: try:
context.cli.expect_exact(expected, timeout=timeout) context.cli.expect_exact(expected, timeout=timeout)
except pexpect.exceptions.TIMEOUT: except pexpect.TIMEOUT:
timedout = True timedout = True
if timedout: if timedout:
# Strip color codes out of the output. # Strip color codes out of the output.
@ -46,21 +47,43 @@ def expect_pager(context, expected, timeout):
context.conf['pager_boundary'], expected), timeout=timeout) context.conf['pager_boundary'], expected), timeout=timeout)
def run_cli(context, run_args=None): def run_cli(context, run_args=None, exclude_args=None):
"""Run the process using pexpect.""" """Run the process using pexpect."""
run_args = run_args or [] run_args = run_args or {}
if context.conf.get('host', None): rendered_args = []
run_args.extend(('-h', context.conf['host'])) exclude_args = set(exclude_args) if exclude_args else set()
if context.conf.get('user', None):
run_args.extend(('-u', context.conf['user'])) conf = dict(**context.conf)
if context.conf.get('pass', None): conf.update(run_args)
run_args.extend(('-p', context.conf['pass']))
if context.conf.get('dbname', None): def add_arg(name, key, value):
run_args.extend(('-D', context.conf['dbname'])) if name not in exclude_args:
if context.conf.get('defaults-file', None): if value is not None:
run_args.extend(('--defaults-file', context.conf['defaults-file'])) rendered_args.extend((key, value))
if context.conf.get('myclirc', None): else:
run_args.extend(('--myclirc', context.conf['myclirc'])) rendered_args.append(key)
if conf.get('host', None):
add_arg('host', '-h', conf['host'])
if conf.get('user', None):
add_arg('user', '-u', conf['user'])
if conf.get('pass', None):
add_arg('pass', '-p', conf['pass'])
if conf.get('port', None):
add_arg('port', '-P', str(conf['port']))
if conf.get('dbname', None):
add_arg('dbname', '-D', conf['dbname'])
if conf.get('defaults-file', None):
add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
if conf.get('myclirc', None):
add_arg('myclirc', '--myclirc', conf['myclirc'])
if conf.get('login_path'):
add_arg('login_path', '--login-path', conf['login_path'])
for arg_name, arg_value in conf.items():
if arg_name.startswith('-'):
add_arg(arg_name, arg_name, arg_value)
try: try:
cli_cmd = context.conf['cli_command'] cli_cmd = context.conf['cli_command']
except KeyError: except KeyError:
@ -73,7 +96,7 @@ def run_cli(context, run_args=None):
'"' '"'
).format(sys.executable) ).format(sys.executable)
cmd_parts = [cli_cmd] + run_args cmd_parts = [cli_cmd] + rendered_args
cmd = ' '.join(cmd_parts) cmd = ' '.join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO() context.logfile = StringIO()

View file

@ -3,8 +3,9 @@ import os
import click import click
from click.testing import CliRunner from click.testing import CliRunner
from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT from mycli.main import MyCli, cli, thanks_picker
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from mycli.sqlexecute import ServerInfo
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
from textwrap import dedent from textwrap import dedent
@ -140,10 +141,7 @@ def test_batch_mode_csv(executor):
def test_thanks_picker_utf8(): def test_thanks_picker_utf8():
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') name = thanks_picker()
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
name = thanks_picker((author_file, sponsor_file))
assert name and isinstance(name, str) assert name and isinstance(name, str)
@ -177,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
host = 'test' host = 'test'
user = 'test' user = 'test'
dbname = 'test' dbname = 'test'
server_info = ServerInfo.from_version_string('unknown')
port = 0 port = 0
def server_type(self): def server_type(self):

View file

@ -3,6 +3,7 @@ import os
import pytest import pytest
import pymysql import pymysql
from mycli.sqlexecute import ServerInfo, ServerSpecies
from .utils import run, dbtest, set_expanded_output, is_expanded_output from .utils import run, dbtest, set_expanded_output, is_expanded_output
@ -270,3 +271,24 @@ def test_multiple_results(executor):
'status': '1 row in set'} 'status': '1 row in set'}
] ]
assert results == expected assert results == expected
@pytest.mark.parametrize(
'version_string, species, parsed_version_string, version',
(
('5.7.32-35', 'Percona', '5.7.32', 50732),
('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
('unexpected version string', None, '', 0),
('', None, '', 0),
(None, None, '', 0),
)
)
def test_version_parsing(version_string, species, parsed_version_string, version):
server_info = ServerInfo.from_version_string(version_string)
assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
assert server_info.version_str == parsed_version_string
assert server_info.version == version