Merging upstream version 1.24.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
570aa52ec2
commit
06dd2aeb28
26 changed files with 565 additions and 169 deletions
14
.github/workflows/ci.yml
vendored
14
.github/workflows/ci.yml
vendored
|
@ -7,13 +7,20 @@ on:
|
|||
|
||||
jobs:
|
||||
linux:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
include:
|
||||
- python-version: 3.6
|
||||
os: ubuntu-16.04 # MySQL 5.7.32
|
||||
- python-version: 3.7
|
||||
os: ubuntu-18.04 # MySQL 5.7.32
|
||||
- python-version: 3.8
|
||||
os: ubuntu-18.04 # MySQL 5.7.32
|
||||
- python-version: 3.9
|
||||
os: ubuntu-20.04 # MySQL 8.0.22
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
|
@ -42,6 +49,7 @@ jobs:
|
|||
- name: Pytest / behave
|
||||
env:
|
||||
PYTEST_PASSWORD: root
|
||||
PYTEST_HOST: 127.0.0.1
|
||||
run: |
|
||||
./setup.py test --pytest-args="--cov-report= --cov=mycli"
|
||||
|
||||
|
|
26
README.md
26
README.md
|
@ -1,8 +1,8 @@
|
|||
# mycli
|
||||
|
||||
[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli)
|
||||
[![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli)
|
||||
[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
[![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli)
|
||||
[![PyPI](https://img.shields.io/pypi/v/mycli.svg)](https://pypi.python.org/pypi/mycli)
|
||||
[![LGTM](https://img.shields.io/lgtm/grade/python/github/dbcli/mycli.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/dbcli/mycli/context:python)
|
||||
|
||||
A command line client for MySQL that can do auto-completion and syntax highlighting.
|
||||
|
||||
|
@ -53,6 +53,7 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
|
|||
-h, --host TEXT Host address of the database.
|
||||
-P, --port INTEGER Port number to use for connection. Honors
|
||||
$MYSQL_TCP_PORT.
|
||||
|
||||
-u, --user TEXT User name to connect to the database.
|
||||
-S, --socket TEXT The socket file to use for connection.
|
||||
-p, --password TEXT Password to connect to the database.
|
||||
|
@ -63,8 +64,11 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
|
|||
--ssh-password TEXT Password to connect to ssh server.
|
||||
--ssh-key-filename TEXT Private key filename (identify file) for the
|
||||
ssh connection.
|
||||
|
||||
--ssh-config-path TEXT Path to ssh configuration.
|
||||
--ssh-config-host TEXT Host for ssh server in ssh configurations (requires paramiko).
|
||||
--ssh-config-host TEXT Host to connect to ssh server reading from ssh
|
||||
configuration.
|
||||
|
||||
--ssl-ca PATH CA file in PEM format.
|
||||
--ssl-capath TEXT CA directory.
|
||||
--ssl-cert PATH X509 cert in PEM format.
|
||||
|
@ -73,33 +77,43 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
|
|||
--ssl-verify-server-cert Verify server's "Common Name" in its cert
|
||||
against hostname used when connecting. This
|
||||
option is disabled by default.
|
||||
|
||||
-V, --version Output mycli's version.
|
||||
-v, --verbose Verbose output.
|
||||
-D, --database TEXT Database to use.
|
||||
-d, --dsn TEXT Use DSN configured into the [alias_dsn]
|
||||
section of myclirc file.
|
||||
|
||||
--list-dsn list of DSN configured into the [alias_dsn]
|
||||
section of myclirc file.
|
||||
--list-ssh-config list ssh configurations in the ssh config (requires paramiko).
|
||||
|
||||
--list-ssh-config list ssh configurations in the ssh config
|
||||
(requires paramiko).
|
||||
|
||||
-R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> ").
|
||||
-l, --logfile FILENAME Log every query and its results to a file.
|
||||
--defaults-group-suffix TEXT Read MySQL config groups with the specified
|
||||
suffix.
|
||||
|
||||
--defaults-file PATH Only read MySQL options from the given file.
|
||||
--myclirc PATH Location of myclirc file.
|
||||
--auto-vertical-output Automatically switch to vertical output mode
|
||||
if the result is wider than the terminal
|
||||
width.
|
||||
|
||||
-t, --table Display batch output in table format.
|
||||
--csv Display batch output in CSV format.
|
||||
--warn / --no-warn Warn before running a destructive query.
|
||||
--local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE.
|
||||
--login-path TEXT Read this path from the login file.
|
||||
-g, --login-path TEXT Read this path from the login file.
|
||||
-e, --execute TEXT Execute command and quit.
|
||||
--init-command TEXT SQL statement to execute after connecting.
|
||||
--charset TEXT Character set for MySQL session.
|
||||
--password-file PATH File or FIFO path containing the password
|
||||
to connect to the db if not specified otherwise
|
||||
--help Show this message and exit.
|
||||
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
|
|
47
changelog.md
47
changelog.md
|
@ -1,19 +1,60 @@
|
|||
TBD:
|
||||
====
|
||||
|
||||
*
|
||||
|
||||
1.24.1:
|
||||
=======
|
||||
|
||||
Bug Fixes:
|
||||
---------
|
||||
* Restore dependency on cryptography for the interactive password prompt
|
||||
|
||||
|
||||
1.24.0
|
||||
======
|
||||
|
||||
Bug Fixes:
|
||||
----------
|
||||
* Allow `FileNotFound` exception for SSH config files.
|
||||
* Fix startup error on MySQL < 5.0.22
|
||||
* Check error code rather than message for Access Denied error
|
||||
* Fix login with ~/.my.cnf files
|
||||
|
||||
Features:
|
||||
---------
|
||||
* Add `-g` shortcut to option `--login-path`.
|
||||
* Alt-Enter dispatches the command in multi-line mode.
|
||||
* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html)
|
||||
|
||||
Internal:
|
||||
---------
|
||||
* Remove unused function is_open_quote()
|
||||
* Use importlib, instead of file links, to locate resources
|
||||
* Test various host-port combinations in command line arguments
|
||||
* Switched from Cryptography to pyaes for decrypting mylogin.cnf
|
||||
|
||||
|
||||
1.23.2
|
||||
===
|
||||
======
|
||||
|
||||
Bug Fixes:
|
||||
----------
|
||||
* Ensure `--port` is always an int.
|
||||
|
||||
1.23.1
|
||||
===
|
||||
======
|
||||
|
||||
Bug Fixes:
|
||||
----------
|
||||
* Allow `--host` without `--port` to make a TCP connection.
|
||||
|
||||
1.23.0
|
||||
===
|
||||
======
|
||||
|
||||
Bug Fixes:
|
||||
----------
|
||||
* Fix config file include logic
|
||||
|
||||
Features:
|
||||
---------
|
||||
|
|
|
@ -75,6 +75,8 @@ Contributors:
|
|||
* Zach DeCook
|
||||
* kevinhwang91
|
||||
* KITAGAWA Yasutaka
|
||||
* Nicolas Palumbo
|
||||
* Andy Teijelo Pérez
|
||||
* bitkeen
|
||||
* Morgan Mitchell
|
||||
* Massimiliano Torromeo
|
||||
|
@ -82,6 +84,7 @@ Contributors:
|
|||
* xeron
|
||||
* 0xflotus
|
||||
* Seamile
|
||||
* Jerome Provensal
|
||||
|
||||
Creator:
|
||||
--------
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '1.23.2'
|
||||
__version__ = '1.24.1'
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from prompt_toolkit.enums import DEFAULT_BUFFER
|
||||
from prompt_toolkit.filters import Condition
|
||||
from prompt_toolkit.application import get_app
|
||||
from .packages.parseutils import is_open_quote
|
||||
from .packages import special
|
||||
|
||||
|
||||
|
|
104
mycli/config.py
104
mycli/config.py
|
@ -1,5 +1,3 @@
|
|||
import io
|
||||
import shutil
|
||||
from copy import copy
|
||||
from io import BytesIO, TextIOWrapper
|
||||
import logging
|
||||
|
@ -7,11 +5,16 @@ import os
|
|||
from os.path import exists
|
||||
import struct
|
||||
import sys
|
||||
from typing import Union
|
||||
from typing import Union, IO
|
||||
|
||||
from configobj import ConfigObj, ConfigObjError
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
import pyaes
|
||||
|
||||
try:
|
||||
import importlib.resources as resources
|
||||
except ImportError:
|
||||
# Python < 3.7
|
||||
import importlib_resources as resources
|
||||
|
||||
try:
|
||||
basestring
|
||||
|
@ -49,9 +52,9 @@ def read_config_file(f, list_values=True):
|
|||
config = ConfigObj(f, interpolation=False, encoding='utf8',
|
||||
list_values=list_values)
|
||||
except ConfigObjError as e:
|
||||
log(logger, logging.ERROR, "Unable to parse line {0} of config file "
|
||||
log(logger, logging.WARNING, "Unable to parse line {0} of config file "
|
||||
"'{1}'.".format(e.line_number, f))
|
||||
log(logger, logging.ERROR, "Using successfully parsed config values.")
|
||||
log(logger, logging.WARNING, "Using successfully parsed config values.")
|
||||
return e.config
|
||||
except (IOError, OSError) as e:
|
||||
log(logger, logging.WARNING, "You don't have permission to read "
|
||||
|
@ -61,7 +64,7 @@ def read_config_file(f, list_values=True):
|
|||
return config
|
||||
|
||||
|
||||
def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
|
||||
def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
|
||||
"""Get a list of configuration files that are included into config_path
|
||||
with !includedir directive.
|
||||
|
||||
|
@ -95,7 +98,7 @@ def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
|
|||
def read_config_files(files, list_values=True):
|
||||
"""Read and merge a list of config files."""
|
||||
|
||||
config = ConfigObj(list_values=list_values)
|
||||
config = create_default_config(list_values=list_values)
|
||||
_files = copy(files)
|
||||
while _files:
|
||||
_file = _files.pop(0)
|
||||
|
@ -112,12 +115,21 @@ def read_config_files(files, list_values=True):
|
|||
return config
|
||||
|
||||
|
||||
def write_default_config(source, destination, overwrite=False):
|
||||
def create_default_config(list_values=True):
|
||||
import mycli
|
||||
default_config_file = resources.open_text(mycli, 'myclirc')
|
||||
return read_config_file(default_config_file, list_values=list_values)
|
||||
|
||||
|
||||
def write_default_config(destination, overwrite=False):
|
||||
import mycli
|
||||
default_config = resources.read_text(mycli, 'myclirc')
|
||||
destination = os.path.expanduser(destination)
|
||||
if not overwrite and exists(destination):
|
||||
return
|
||||
|
||||
shutil.copyfile(source, destination)
|
||||
with open(destination, 'w') as f:
|
||||
f.write(default_config)
|
||||
|
||||
|
||||
def get_mylogin_cnf_path():
|
||||
|
@ -160,6 +172,58 @@ def open_mylogin_cnf(name):
|
|||
return TextIOWrapper(plaintext)
|
||||
|
||||
|
||||
# TODO reuse code between encryption an decryption
|
||||
def encrypt_mylogin_cnf(plaintext: IO[str]):
|
||||
"""Encryption of .mylogin.cnf file, analogous to calling
|
||||
mysql_config_editor.
|
||||
|
||||
Code is based on the python implementation by Kristian Koehntopp
|
||||
https://github.com/isotopp/mysql-config-coder
|
||||
|
||||
"""
|
||||
def realkey(key):
|
||||
"""Create the AES key from the login key."""
|
||||
rkey = bytearray(16)
|
||||
for i in range(len(key)):
|
||||
rkey[i % 16] ^= key[i]
|
||||
return bytes(rkey)
|
||||
|
||||
def encode_line(plaintext, real_key, buf_len):
|
||||
aes = pyaes.AESModeOfOperationECB(real_key)
|
||||
text_len = len(plaintext)
|
||||
pad_len = buf_len - text_len
|
||||
pad_chr = bytes(chr(pad_len), "utf8")
|
||||
plaintext = plaintext.encode() + pad_chr * pad_len
|
||||
encrypted_text = b''.join(
|
||||
[aes.encrypt(plaintext[i: i + 16])
|
||||
for i in range(0, len(plaintext), 16)]
|
||||
)
|
||||
return encrypted_text
|
||||
|
||||
LOGIN_KEY_LENGTH = 20
|
||||
key = os.urandom(LOGIN_KEY_LENGTH)
|
||||
real_key = realkey(key)
|
||||
|
||||
outfile = BytesIO()
|
||||
|
||||
outfile.write(struct.pack("i", 0))
|
||||
outfile.write(key)
|
||||
|
||||
while True:
|
||||
line = plaintext.readline()
|
||||
if not line:
|
||||
break
|
||||
real_len = len(line)
|
||||
pad_len = (int(real_len / 16) + 1) * 16
|
||||
|
||||
outfile.write(struct.pack("i", pad_len))
|
||||
x = encode_line(line, real_key, pad_len)
|
||||
outfile.write(x)
|
||||
|
||||
outfile.seek(0)
|
||||
return outfile
|
||||
|
||||
|
||||
def read_and_decrypt_mylogin_cnf(f):
|
||||
"""Read and decrypt the contents of .mylogin.cnf.
|
||||
|
||||
|
@ -201,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f):
|
|||
return None
|
||||
rkey = struct.pack('16B', *rkey)
|
||||
|
||||
# Create a decryptor object using the key.
|
||||
decryptor = _get_decryptor(rkey)
|
||||
|
||||
# Create a bytes buffer to hold the plaintext.
|
||||
plaintext = BytesIO()
|
||||
aes = pyaes.AESModeOfOperationECB(rkey)
|
||||
|
||||
while True:
|
||||
# Read the length of the ciphertext.
|
||||
|
@ -216,7 +278,10 @@ def read_and_decrypt_mylogin_cnf(f):
|
|||
|
||||
# Read cipher_len bytes from the file and decrypt.
|
||||
cipher = f.read(cipher_len)
|
||||
plain = _remove_pad(decryptor.update(cipher))
|
||||
plain = _remove_pad(
|
||||
b''.join([aes.decrypt(cipher[i: i + 16])
|
||||
for i in range(0, cipher_len, 16)])
|
||||
)
|
||||
if plain is False:
|
||||
continue
|
||||
plaintext.write(plain)
|
||||
|
@ -244,7 +309,7 @@ def str_to_bool(s):
|
|||
elif s.lower() in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError('not a recognized boolean value: %s'.format(s))
|
||||
raise ValueError('not a recognized boolean value: {0}'.format(s))
|
||||
|
||||
|
||||
def strip_matching_quotes(s):
|
||||
|
@ -260,15 +325,8 @@ def strip_matching_quotes(s):
|
|||
return s
|
||||
|
||||
|
||||
def _get_decryptor(key):
|
||||
"""Get the AES decryptor."""
|
||||
c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
|
||||
return c.decryptor()
|
||||
|
||||
|
||||
def _remove_pad(line):
|
||||
"""Remove the pad from the *line*."""
|
||||
pad_length = ord(line[-1:])
|
||||
try:
|
||||
# Determine pad length.
|
||||
pad_length = ord(line[-1:])
|
||||
|
|
|
@ -78,8 +78,12 @@ def mycli_bindings(mycli):
|
|||
|
||||
@kb.add('escape', 'enter')
|
||||
def _(event):
|
||||
"""Introduces a line break regardless of multi-line mode or not."""
|
||||
"""Introduces a line break in multi-line mode, or dispatches the
|
||||
command in single-line mode."""
|
||||
_logger.debug('Detected alt-enter key.')
|
||||
event.app.current_buffer.insert_text('\n')
|
||||
if mycli.multi_line:
|
||||
event.app.current_buffer.validate_and_handle()
|
||||
else:
|
||||
event.app.current_buffer.insert_text('\n')
|
||||
|
||||
return kb
|
||||
|
|
140
mycli/main.py
140
mycli/main.py
|
@ -1,9 +1,12 @@
|
|||
from collections import defaultdict
|
||||
from io import open
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import logging
|
||||
import threading
|
||||
import re
|
||||
import stat
|
||||
import fileinput
|
||||
from collections import namedtuple
|
||||
try:
|
||||
|
@ -13,7 +16,6 @@ except ImportError:
|
|||
from time import time
|
||||
from datetime import datetime
|
||||
from random import choice
|
||||
from io import open
|
||||
|
||||
from pymysql import OperationalError
|
||||
from cli_helpers.tabular_output import TabularOutputFormatter
|
||||
|
@ -43,7 +45,7 @@ from .packages.special.favoritequeries import FavoriteQueries
|
|||
from .sqlcompleter import SQLCompleter
|
||||
from .clitoolbar import create_toolbar_tokens_func
|
||||
from .clistyle import style_factory, style_factory_output
|
||||
from .sqlexecute import FIELD_TYPES, SQLExecute
|
||||
from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
|
||||
from .clibuffer import cli_is_multiline
|
||||
from .completion_refresher import CompletionRefresher
|
||||
from .config import (write_default_config, get_mylogin_cnf_path,
|
||||
|
@ -51,7 +53,7 @@ from .config import (write_default_config, get_mylogin_cnf_path,
|
|||
strip_matching_quotes)
|
||||
from .key_bindings import mycli_bindings
|
||||
from .lexer import MyCliLexer
|
||||
from .__init__ import __version__
|
||||
from . import __version__
|
||||
from .compat import WIN
|
||||
from .packages.filepaths import dir_path_exists, guess_socket_location
|
||||
|
||||
|
@ -66,6 +68,11 @@ except ImportError:
|
|||
from urllib.parse import urlparse
|
||||
from urllib.parse import unquote
|
||||
|
||||
try:
|
||||
import importlib.resources as resources
|
||||
except ImportError:
|
||||
# Python < 3.7
|
||||
import importlib_resources as resources
|
||||
|
||||
try:
|
||||
import paramiko
|
||||
|
@ -75,7 +82,10 @@ except ImportError:
|
|||
# Query tuples are used for maintaining history
|
||||
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
|
||||
|
||||
PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
SUPPORT_INFO = (
|
||||
'Home: http://mycli.net\n'
|
||||
'Bug tracker: https://github.com/dbcli/mycli/issues'
|
||||
)
|
||||
|
||||
|
||||
class MyCli(object):
|
||||
|
@ -89,7 +99,7 @@ class MyCli(object):
|
|||
'/etc/my.cnf',
|
||||
'/etc/mysql/my.cnf',
|
||||
'/usr/local/etc/my.cnf',
|
||||
'~/.my.cnf'
|
||||
os.path.expanduser('~/.my.cnf'),
|
||||
]
|
||||
|
||||
# check XDG_CONFIG_HOME exists and not an empty string
|
||||
|
@ -102,7 +112,6 @@ class MyCli(object):
|
|||
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
|
||||
]
|
||||
|
||||
default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
|
||||
pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
|
||||
|
||||
def __init__(self, sqlexecute=None, prompt=None,
|
||||
|
@ -122,7 +131,7 @@ class MyCli(object):
|
|||
self.cnf_files = [defaults_file]
|
||||
|
||||
# Load config.
|
||||
config_files = ([self.default_config_file] + self.system_config_files +
|
||||
config_files = (self.system_config_files +
|
||||
[myclirc] + [self.pwd_config_file])
|
||||
c = self.config = read_config_files(config_files)
|
||||
self.multi_line = c['main'].as_bool('multi_line')
|
||||
|
@ -154,7 +163,7 @@ class MyCli(object):
|
|||
|
||||
# Write user config if system config wasn't the last config loaded.
|
||||
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
|
||||
write_default_config(self.default_config_file, myclirc)
|
||||
write_default_config(myclirc)
|
||||
|
||||
# audit log
|
||||
if self.logfile is None and 'audit_log' in c['main']:
|
||||
|
@ -326,20 +335,33 @@ class MyCli(object):
|
|||
cnf = read_config_files(files, list_values=False)
|
||||
|
||||
sections = ['client', 'mysqld']
|
||||
key_transformations = {
|
||||
'mysqld': {
|
||||
'socket': 'default_socket',
|
||||
'port': 'default_port',
|
||||
},
|
||||
}
|
||||
|
||||
if self.login_path and self.login_path != 'client':
|
||||
sections.append(self.login_path)
|
||||
|
||||
if self.defaults_suffix:
|
||||
sections.extend([sect + self.defaults_suffix for sect in sections])
|
||||
|
||||
def get(key):
|
||||
result = None
|
||||
for sect in cnf:
|
||||
if sect in sections and key in cnf[sect]:
|
||||
result = strip_matching_quotes(cnf[sect][key])
|
||||
return result
|
||||
configuration = defaultdict(lambda: None)
|
||||
for key in keys:
|
||||
for section in cnf:
|
||||
if (
|
||||
section not in sections or
|
||||
key not in cnf[section]
|
||||
):
|
||||
continue
|
||||
new_key = key_transformations.get(section, {}).get(key) or key
|
||||
configuration[new_key] = strip_matching_quotes(
|
||||
cnf[section][key])
|
||||
|
||||
return configuration
|
||||
|
||||
return {x: get(x) for x in keys}
|
||||
|
||||
def merge_ssl_with_cnf(self, ssl, cnf):
|
||||
"""Merge SSL configuration dict with cnf dict"""
|
||||
|
@ -367,7 +389,7 @@ class MyCli(object):
|
|||
def connect(self, database='', user='', passwd='', host='', port='',
|
||||
socket='', charset='', local_infile='', ssl='',
|
||||
ssh_user='', ssh_host='', ssh_port='',
|
||||
ssh_password='', ssh_key_filename='', init_command=''):
|
||||
ssh_password='', ssh_key_filename='', init_command='', password_file=''):
|
||||
|
||||
cnf = {'database': None,
|
||||
'user': None,
|
||||
|
@ -375,6 +397,7 @@ class MyCli(object):
|
|||
'host': None,
|
||||
'port': None,
|
||||
'socket': None,
|
||||
'default_socket': None,
|
||||
'default-character-set': None,
|
||||
'local-infile': None,
|
||||
'loose-local-infile': None,
|
||||
|
@ -388,18 +411,23 @@ class MyCli(object):
|
|||
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
|
||||
|
||||
# Fall back to config values only if user did not specify a value.
|
||||
|
||||
database = database or cnf['database']
|
||||
# Socket interface not supported for SSH connections
|
||||
if port or (host and host != 'localhost') or (ssh_host and ssh_port):
|
||||
socket = ''
|
||||
else:
|
||||
socket = socket or cnf['socket'] or guess_socket_location()
|
||||
user = user or cnf['user'] or os.getenv('USER')
|
||||
host = host or cnf['host']
|
||||
port = int(port or cnf['port'] or 3306)
|
||||
port = port or cnf['port']
|
||||
ssl = ssl or {}
|
||||
|
||||
port = port and int(port)
|
||||
if not port:
|
||||
port = 3306
|
||||
if not host or host == 'localhost':
|
||||
socket = (
|
||||
cnf['socket'] or
|
||||
cnf['default_socket'] or
|
||||
guess_socket_location()
|
||||
)
|
||||
|
||||
|
||||
passwd = passwd if isinstance(passwd, str) else cnf['password']
|
||||
charset = charset or cnf['default-character-set'] or 'utf8'
|
||||
|
||||
|
@ -417,6 +445,10 @@ class MyCli(object):
|
|||
if not any(v for v in ssl.values()):
|
||||
ssl = None
|
||||
|
||||
# if the passwd is not specfied try to set it using the password_file option
|
||||
password_from_file = self.get_password_from_file(password_file)
|
||||
passwd = passwd or password_from_file
|
||||
|
||||
# Connect to the database.
|
||||
|
||||
def _connect():
|
||||
|
@ -427,9 +459,12 @@ class MyCli(object):
|
|||
ssh_password, ssh_key_filename, init_command
|
||||
)
|
||||
except OperationalError as e:
|
||||
if ('Access denied for user' in e.args[1]):
|
||||
new_passwd = click.prompt('Password', hide_input=True,
|
||||
show_default=False, type=str, err=True)
|
||||
if e.args[0] == ERROR_CODE_ACCESS_DENIED:
|
||||
if password_from_file:
|
||||
new_passwd = password_from_file
|
||||
else:
|
||||
new_passwd = click.prompt('Password', hide_input=True,
|
||||
show_default=False, type=str, err=True)
|
||||
self.sqlexecute = SQLExecute(
|
||||
database, user, new_passwd, host, port, socket,
|
||||
charset, local_infile, ssl, ssh_user, ssh_host,
|
||||
|
@ -484,6 +519,17 @@ class MyCli(object):
|
|||
self.echo(str(e), err=True, fg='red')
|
||||
exit(1)
|
||||
|
||||
def get_password_from_file(self, password_file):
|
||||
password_from_file = None
|
||||
if password_file:
|
||||
if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
|
||||
and os.access(password_file, os.R_OK):
|
||||
with open(password_file) as fp:
|
||||
password_from_file = fp.readline()
|
||||
password_from_file = password_from_file.rstrip().lstrip()
|
||||
|
||||
return password_from_file
|
||||
|
||||
def handle_editor_command(self, text):
|
||||
r"""Editor command is any query that is prefixed or suffixed by a '\e'.
|
||||
The reason for a while loop is because a user might edit a query
|
||||
|
@ -542,9 +588,6 @@ class MyCli(object):
|
|||
if self.smart_completion:
|
||||
self.refresh_completions()
|
||||
|
||||
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
|
||||
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
|
||||
|
||||
history_file = os.path.expanduser(
|
||||
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
|
||||
if dir_path_exists(history_file):
|
||||
|
@ -559,12 +602,10 @@ class MyCli(object):
|
|||
key_bindings = mycli_bindings(self)
|
||||
|
||||
if not self.less_chatty:
|
||||
print(' '.join(sqlexecute.server_type()))
|
||||
print(sqlexecute.server_info)
|
||||
print('mycli', __version__)
|
||||
print('Chat: https://gitter.im/dbcli/mycli')
|
||||
print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
|
||||
print('Home: http://mycli.net')
|
||||
print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
|
||||
print(SUPPORT_INFO)
|
||||
print('Thanks to the contributor -', thanks_picker())
|
||||
|
||||
def get_message():
|
||||
prompt = self.get_prompt(self.prompt_format)
|
||||
|
@ -862,8 +903,8 @@ class MyCli(object):
|
|||
|
||||
if not output_via_pager:
|
||||
# doesn't fit, flush buffer
|
||||
for line in buf:
|
||||
click.secho(line)
|
||||
for buf_line in buf:
|
||||
click.secho(buf_line)
|
||||
buf = []
|
||||
else:
|
||||
click.secho(line)
|
||||
|
@ -933,7 +974,7 @@ class MyCli(object):
|
|||
string = string.replace('\\u', sqlexecute.user or '(none)')
|
||||
string = string.replace('\\h', host or '(none)')
|
||||
string = string.replace('\\d', sqlexecute.dbname or '(none)')
|
||||
string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
|
||||
string = string.replace('\\t', sqlexecute.server_info.species.name)
|
||||
string = string.replace('\\n', "\n")
|
||||
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
|
||||
string = string.replace('\\m', now.strftime('%M'))
|
||||
|
@ -1083,7 +1124,7 @@ class MyCli(object):
|
|||
help='Warn before running a destructive query.')
|
||||
@click.option('--local-infile', type=bool,
|
||||
help='Enable/disable LOAD DATA LOCAL INFILE.')
|
||||
@click.option('--login-path', type=str,
|
||||
@click.option('-g', '--login-path', type=str,
|
||||
help='Read this path from the login file.')
|
||||
@click.option('-e', '--execute', type=str,
|
||||
help='Execute command and quit.')
|
||||
|
@ -1091,6 +1132,8 @@ class MyCli(object):
|
|||
help='SQL statement to execute after connecting.')
|
||||
@click.option('--charset', type=str,
|
||||
help='Character set for MySQL session.')
|
||||
@click.option('--password-file', type=click.Path(),
|
||||
help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
|
||||
@click.argument('database', default='', nargs=1)
|
||||
def cli(database, user, host, port, socket, password, dbname,
|
||||
version, verbose, prompt, logfile, defaults_group_suffix,
|
||||
|
@ -1099,7 +1142,7 @@ def cli(database, user, host, port, socket, password, dbname,
|
|||
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
|
||||
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
|
||||
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
|
||||
init_command, charset):
|
||||
init_command, charset, password_file):
|
||||
"""A MySQL terminal client with auto-completion and syntax highlighting.
|
||||
|
||||
\b
|
||||
|
@ -1225,7 +1268,8 @@ def cli(database, user, host, port, socket, password, dbname,
|
|||
ssh_password=ssh_password,
|
||||
ssh_key_filename=ssh_key_filename,
|
||||
init_command=init_command,
|
||||
charset=charset
|
||||
charset=charset,
|
||||
password_file=password_file
|
||||
)
|
||||
|
||||
mycli.logger.debug('Launch Params: \n'
|
||||
|
@ -1328,9 +1372,15 @@ def is_select(status):
|
|||
return status.split(None, 1)[0].lower() == 'select'
|
||||
|
||||
|
||||
def thanks_picker(files=()):
|
||||
def thanks_picker():
|
||||
import mycli
|
||||
lines = (
|
||||
resources.read_text(mycli, 'AUTHORS') +
|
||||
resources.read_text(mycli, 'SPONSORS')
|
||||
).split('\n')
|
||||
|
||||
contents = []
|
||||
for line in fileinput.input(files=files):
|
||||
for line in lines:
|
||||
m = re.match(r'^ *\* (.*)', line)
|
||||
if m:
|
||||
contents.append(m.group(1))
|
||||
|
@ -1350,6 +1400,9 @@ def read_ssh_config(ssh_config_path):
|
|||
try:
|
||||
with open(ssh_config_path) as f:
|
||||
ssh_config.parse(f)
|
||||
except FileNotFoundError as e:
|
||||
click.secho(str(e), err=True, fg='red')
|
||||
sys.exit(1)
|
||||
# Paramiko prior to version 2.7 raises Exception on parse errors.
|
||||
# In 2.7 it has become paramiko.ssh_exception.SSHException,
|
||||
# but let's catch everything for compatibility
|
||||
|
@ -1359,9 +1412,6 @@ def read_ssh_config(ssh_config_path):
|
|||
err=True, fg='red'
|
||||
)
|
||||
sys.exit(1)
|
||||
except FileNotFoundError as e:
|
||||
click.secho(str(e), err=True, fg='red')
|
||||
sys.exit(1)
|
||||
else:
|
||||
return ssh_config
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
import sys
|
||||
import sqlparse
|
||||
from sqlparse.sql import Comparison, Identifier, Where
|
||||
from .parseutils import last_word, extract_tables, find_prev_keyword
|
||||
|
|
|
@ -12,7 +12,8 @@ cleanup_regex = {
|
|||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
r"""
|
||||
|
@ -226,14 +227,6 @@ def is_destructive(queries):
|
|||
return False
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
"""Returns true if the query contains an unclosed quote."""
|
||||
|
||||
# parsed can contain one or more semi-colon separated commands
|
||||
parsed = sqlparse.parse(sql)
|
||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sql = 'select * from (select t. from tabl t'
|
||||
print (extract_tables(sql))
|
||||
|
@ -263,5 +256,4 @@ def is_dropping_database(queries, dbname):
|
|||
)
|
||||
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
|
||||
result = keywords[0].normalized == "DROP"
|
||||
else:
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -302,7 +302,7 @@ def execute_system_command(arg, **_):
|
|||
usage = "Syntax: system [command].\n"
|
||||
|
||||
if not arg:
|
||||
return [(None, None, None, usage)]
|
||||
return [(None, None, None, usage)]
|
||||
|
||||
try:
|
||||
command = arg.strip()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
"""Format adapter for sql."""
|
||||
|
||||
from cli_helpers.utils import filter_dict_by_key
|
||||
from mycli.packages.parseutils import extract_tables
|
||||
|
||||
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
|
||||
|
|
|
@ -72,7 +72,7 @@ class SQLCompleter(Completer):
|
|||
if name and ((not self.name_pattern.match(name))
|
||||
or (name.upper() in self.reserved_words)
|
||||
or (name.upper() in self.functions)):
|
||||
name = '`%s`' % name
|
||||
name = '`%s`' % name
|
||||
|
||||
return name
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import enum
|
||||
import logging
|
||||
import re
|
||||
|
||||
import pymysql
|
||||
import sqlparse
|
||||
from .packages import special
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
from pymysql.converters import (convert_datetime,
|
||||
|
@ -18,17 +20,71 @@ FIELD_TYPES.update({
|
|||
FIELD_TYPE.NULL: type(None)
|
||||
})
|
||||
|
||||
|
||||
ERROR_CODE_ACCESS_DENIED = 1045
|
||||
|
||||
|
||||
class ServerSpecies(enum.Enum):
|
||||
MySQL = 'MySQL'
|
||||
MariaDB = 'MariaDB'
|
||||
Percona = 'Percona'
|
||||
Unknown = 'MySQL'
|
||||
|
||||
|
||||
class ServerInfo:
|
||||
def __init__(self, species, version_str):
|
||||
self.species = species
|
||||
self.version_str = version_str
|
||||
self.version = self.calc_mysql_version_value(version_str)
|
||||
|
||||
@staticmethod
|
||||
def calc_mysql_version_value(version_str) -> int:
|
||||
if not version_str or not isinstance(version_str, str):
|
||||
return 0
|
||||
try:
|
||||
major, minor, patch = version_str.split('.')
|
||||
except ValueError:
|
||||
return 0
|
||||
else:
|
||||
return int(major) * 10_000 + int(minor) * 100 + int(patch)
|
||||
|
||||
@classmethod
|
||||
def from_version_string(cls, version_string):
|
||||
if not version_string:
|
||||
return cls(ServerSpecies.Unknown, '')
|
||||
|
||||
re_species = (
|
||||
(r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
|
||||
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
|
||||
ServerSpecies.Percona),
|
||||
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
|
||||
ServerSpecies.MySQL),
|
||||
)
|
||||
for regexp, species in re_species:
|
||||
match = re.search(regexp, version_string)
|
||||
if match is not None:
|
||||
parsed_version = match.group('version')
|
||||
detected_species = species
|
||||
break
|
||||
else:
|
||||
detected_species = ServerSpecies.Unknown
|
||||
parsed_version = ''
|
||||
|
||||
return cls(detected_species, parsed_version)
|
||||
|
||||
def __str__(self):
|
||||
if self.species:
|
||||
return f'{self.species.value} {self.version_str}'
|
||||
else:
|
||||
return self.version_str
|
||||
|
||||
|
||||
class SQLExecute(object):
|
||||
|
||||
databases_query = '''SHOW DATABASES'''
|
||||
|
||||
tables_query = '''SHOW TABLES'''
|
||||
|
||||
version_query = '''SELECT @@VERSION'''
|
||||
|
||||
version_comment_query = '''SELECT @@VERSION_COMMENT'''
|
||||
version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
|
||||
|
||||
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
|
||||
|
||||
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
|
||||
|
@ -52,7 +108,7 @@ class SQLExecute(object):
|
|||
self.charset = charset
|
||||
self.local_infile = local_infile
|
||||
self.ssl = ssl
|
||||
self._server_type = None
|
||||
self.server_info = None
|
||||
self.connection_id = None
|
||||
self.ssh_user = ssh_user
|
||||
self.ssh_host = ssh_host
|
||||
|
@ -157,6 +213,7 @@ class SQLExecute(object):
|
|||
self.init_command = init_command
|
||||
# retrieve connection id
|
||||
self.reset_connection_id()
|
||||
self.server_info = ServerInfo.from_version_string(conn.server_version)
|
||||
|
||||
def run(self, statement):
|
||||
"""Execute the sql in the database and return the results. The results
|
||||
|
@ -273,37 +330,6 @@ class SQLExecute(object):
|
|||
for row in cur:
|
||||
yield row
|
||||
|
||||
def server_type(self):
|
||||
if self._server_type:
|
||||
return self._server_type
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Version Query. sql: %r', self.version_query)
|
||||
cur.execute(self.version_query)
|
||||
version = cur.fetchone()[0]
|
||||
if version[0] == '4':
|
||||
_logger.debug('Version Comment. sql: %r',
|
||||
self.version_comment_query_mysql4)
|
||||
cur.execute(self.version_comment_query_mysql4)
|
||||
version_comment = cur.fetchone()[1].lower()
|
||||
if isinstance(version_comment, bytes):
|
||||
# with python3 this query returns bytes
|
||||
version_comment = version_comment.decode('utf-8')
|
||||
else:
|
||||
_logger.debug('Version Comment. sql: %r',
|
||||
self.version_comment_query)
|
||||
cur.execute(self.version_comment_query)
|
||||
version_comment = cur.fetchone()[0].lower()
|
||||
|
||||
if 'mariadb' in version_comment:
|
||||
product_type = 'mariadb'
|
||||
elif 'percona' in version_comment:
|
||||
product_type = 'percona'
|
||||
else:
|
||||
product_type = 'mysql'
|
||||
|
||||
self._server_type = (product_type, version)
|
||||
return self._server_type
|
||||
|
||||
def get_connection_id(self):
|
||||
if not self.connection_id:
|
||||
self.reset_connection_id()
|
||||
|
|
2
pytest.ini
Normal file
2
pytest.ini
Normal file
|
@ -0,0 +1,2 @@
|
|||
[pytest]
|
||||
addopts = --ignore=mycli/packages/paramiko_stub/__init__.py
|
|
@ -1,6 +1,5 @@
|
|||
"""A script to publish a release of mycli to PyPI."""
|
||||
|
||||
import io
|
||||
from optparse import OptionParser
|
||||
import re
|
||||
import subprocess
|
||||
|
|
8
setup.py
8
setup.py
|
@ -18,16 +18,20 @@ description = 'CLI for MySQL Database. With auto-completion and syntax highlight
|
|||
|
||||
install_requirements = [
|
||||
'click >= 7.0',
|
||||
'cryptography >= 1.0.0',
|
||||
'Pygments >= 1.6',
|
||||
'prompt_toolkit>=3.0.6,<4.0.0',
|
||||
'PyMySQL >= 0.9.2',
|
||||
'sqlparse>=0.3.0,<0.4.0',
|
||||
'configobj >= 5.0.5',
|
||||
'cryptography >= 1.0.0',
|
||||
'cli_helpers[styles] >= 2.0.1',
|
||||
'pyperclip >= 1.8.1'
|
||||
'pyperclip >= 1.8.1',
|
||||
'pyaes >= 1.6.1'
|
||||
]
|
||||
|
||||
if sys.version_info.minor < 9:
|
||||
install_requirements.append('importlib_resources >= 5.0.0')
|
||||
|
||||
|
||||
class lint(Command):
|
||||
description = 'check code against PEP 8 (and fix violations)'
|
||||
|
|
35
test/features/connection.feature
Normal file
35
test/features/connection.feature
Normal file
|
@ -0,0 +1,35 @@
|
|||
Feature: connect to a database:
|
||||
|
||||
@requires_local_db
|
||||
Scenario: run mycli on localhost without port
|
||||
When we run mycli with arguments "host=localhost" without arguments "port"
|
||||
When we query "status"
|
||||
Then status contains "via UNIX socket"
|
||||
|
||||
Scenario: run mycli on TCP host without port
|
||||
When we run mycli without arguments "port"
|
||||
When we query "status"
|
||||
Then status contains "via TCP/IP"
|
||||
|
||||
Scenario: run mycli with port but without host
|
||||
When we run mycli without arguments "host"
|
||||
When we query "status"
|
||||
Then status contains "via TCP/IP"
|
||||
|
||||
@requires_local_db
|
||||
Scenario: run mycli without host and port
|
||||
When we run mycli without arguments "host port"
|
||||
When we query "status"
|
||||
Then status contains "via UNIX socket"
|
||||
|
||||
Scenario: run mycli with my.cnf configuration
|
||||
When we create my.cnf file
|
||||
When we run mycli without arguments "host port user pass defaults_file"
|
||||
Then we are logged in
|
||||
|
||||
Scenario: run mycli with mylogin.cnf configuration
|
||||
When we create mylogin.cnf file
|
||||
When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file"
|
||||
Then we are logged in
|
||||
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from tempfile import mkstemp
|
||||
|
||||
|
@ -11,6 +12,24 @@ from steps.wrappers import run_cli, wait_prompt
|
|||
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
|
||||
|
||||
|
||||
SELF_CONNECTING_FEATURES = (
|
||||
'test/features/connection.feature',
|
||||
)
|
||||
|
||||
|
||||
MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
|
||||
MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
|
||||
MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
|
||||
MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
|
||||
|
||||
|
||||
def get_db_name_from_context(context):
|
||||
return context.config.userdata.get(
|
||||
'my_test_db', None
|
||||
) or "mycli_behave_tests"
|
||||
|
||||
|
||||
|
||||
def before_all(context):
|
||||
"""Set env parameters."""
|
||||
os.environ['LINES'] = "100"
|
||||
|
@ -22,7 +41,7 @@ def before_all(context):
|
|||
|
||||
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
|
||||
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
|
||||
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
|
||||
|
||||
context.package_root = os.path.abspath(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
@ -33,8 +52,7 @@ def before_all(context):
|
|||
context.exit_sent = False
|
||||
|
||||
vi = '_'.join([str(x) for x in sys.version_info[:3]])
|
||||
db_name = context.config.userdata.get(
|
||||
'my_test_db', None) or "mycli_behave_tests"
|
||||
db_name = get_db_name_from_context(context)
|
||||
db_name_full = '{0}_{1}'.format(db_name, vi)
|
||||
|
||||
# Store get params from config/environment variables
|
||||
|
@ -104,11 +122,18 @@ def before_step(context, _):
|
|||
context.atprompt = False
|
||||
|
||||
|
||||
def before_scenario(context, _):
|
||||
def before_scenario(context, arg):
|
||||
with open(test_log_file, 'w') as f:
|
||||
f.write('')
|
||||
run_cli(context)
|
||||
wait_prompt(context)
|
||||
if arg.location.filename not in SELF_CONNECTING_FEATURES:
|
||||
run_cli(context)
|
||||
wait_prompt(context)
|
||||
|
||||
if os.path.exists(MY_CNF_PATH):
|
||||
shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH)
|
||||
|
||||
if os.path.exists(MYLOGIN_CNF_PATH):
|
||||
shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH)
|
||||
|
||||
|
||||
def after_scenario(context, _):
|
||||
|
@ -134,6 +159,17 @@ def after_scenario(context, _):
|
|||
context.cli.sendcontrol('d')
|
||||
context.cli.expect_exact(pexpect.EOF, timeout=5)
|
||||
|
||||
if os.path.exists(MY_CNF_BACKUP_PATH):
|
||||
shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH)
|
||||
|
||||
if os.path.exists(MYLOGIN_CNF_BACKUP_PATH):
|
||||
shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH)
|
||||
elif os.path.exists(MYLOGIN_CNF_PATH):
|
||||
# This file was moved in `before_scenario`.
|
||||
# If it exists now, it has been created during a test
|
||||
os.remove(MYLOGIN_CNF_PATH)
|
||||
|
||||
|
||||
# TODO: uncomment to debug a failure
|
||||
# def after_step(context, step):
|
||||
# if step.status == "failed":
|
||||
|
|
|
@ -3,11 +3,12 @@ from textwrap import dedent
|
|||
from behave import then, when
|
||||
|
||||
import wrappers
|
||||
from utils import parse_cli_args_to_dict
|
||||
|
||||
|
||||
@when('we run dbcli with {arg}')
|
||||
def step_run_cli_with_arg(context, arg):
|
||||
wrappers.run_cli(context, run_args=arg.split('='))
|
||||
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
|
||||
|
||||
|
||||
@when('we execute a small query')
|
||||
|
|
71
test/features/steps/connection.py
Normal file
71
test/features/steps/connection.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import io
|
||||
import os
|
||||
import shlex
|
||||
|
||||
from behave import when, then
|
||||
import pexpect
|
||||
|
||||
import wrappers
|
||||
from test.features.steps.utils import parse_cli_args_to_dict
|
||||
from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context
|
||||
from test.utils import HOST, PORT, USER, PASSWORD
|
||||
from mycli.config import encrypt_mylogin_cnf
|
||||
|
||||
|
||||
TEST_LOGIN_PATH = 'test_login_path'
|
||||
|
||||
|
||||
@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
|
||||
@when('we run mycli without arguments "{excluded_args}"')
|
||||
def step_run_cli_without_args(context, excluded_args, exact_args=''):
|
||||
wrappers.run_cli(
|
||||
context,
|
||||
run_args=parse_cli_args_to_dict(exact_args),
|
||||
exclude_args=parse_cli_args_to_dict(excluded_args).keys()
|
||||
)
|
||||
|
||||
|
||||
@then('status contains "{expression}"')
|
||||
def status_contains(context, expression):
|
||||
wrappers.expect_exact(context, f'{expression}', timeout=5)
|
||||
|
||||
# Normally, the shutdown after scenario waits for the prompt.
|
||||
# But we may have changed the prompt, depending on parameters,
|
||||
# so let's wait for its last character
|
||||
context.cli.expect_exact('>')
|
||||
context.atprompt = True
|
||||
|
||||
|
||||
@when('we create my.cnf file')
|
||||
def step_create_my_cnf_file(context):
|
||||
my_cnf = (
|
||||
'[client]\n'
|
||||
f'host = {HOST}\n'
|
||||
f'port = {PORT}\n'
|
||||
f'user = {USER}\n'
|
||||
f'password = {PASSWORD}\n'
|
||||
)
|
||||
with open(MY_CNF_PATH, 'w') as f:
|
||||
f.write(my_cnf)
|
||||
|
||||
|
||||
@when('we create mylogin.cnf file')
|
||||
def step_create_mylogin_cnf_file(context):
|
||||
os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
|
||||
mylogin_cnf = (
|
||||
f'[{TEST_LOGIN_PATH}]\n'
|
||||
f'host = {HOST}\n'
|
||||
f'port = {PORT}\n'
|
||||
f'user = {USER}\n'
|
||||
f'password = {PASSWORD}\n'
|
||||
)
|
||||
with open(MYLOGIN_CNF_PATH, 'wb') as f:
|
||||
input_file = io.StringIO(mylogin_cnf)
|
||||
f.write(encrypt_mylogin_cnf(input_file).read())
|
||||
|
||||
|
||||
@then('we are logged in')
|
||||
def we_are_logged_in(context):
|
||||
db_name = get_db_name_from_context(context)
|
||||
context.cli.expect_exact(f'{db_name}>', timeout=5)
|
||||
context.atprompt = True
|
12
test/features/steps/utils.py
Normal file
12
test/features/steps/utils.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import shlex
|
||||
|
||||
|
||||
def parse_cli_args_to_dict(cli_args: str):
|
||||
args_dict = {}
|
||||
for arg in shlex.split(cli_args):
|
||||
if '=' in arg:
|
||||
key, value = arg.split('=')
|
||||
args_dict[key] = value
|
||||
else:
|
||||
args_dict[arg] = None
|
||||
return args_dict
|
|
@ -3,6 +3,7 @@ import pexpect
|
|||
import sys
|
||||
import textwrap
|
||||
|
||||
|
||||
try:
|
||||
from StringIO import StringIO
|
||||
except ImportError:
|
||||
|
@ -13,7 +14,7 @@ def expect_exact(context, expected, timeout):
|
|||
timedout = False
|
||||
try:
|
||||
context.cli.expect_exact(expected, timeout=timeout)
|
||||
except pexpect.exceptions.TIMEOUT:
|
||||
except pexpect.TIMEOUT:
|
||||
timedout = True
|
||||
if timedout:
|
||||
# Strip color codes out of the output.
|
||||
|
@ -46,21 +47,43 @@ def expect_pager(context, expected, timeout):
|
|||
context.conf['pager_boundary'], expected), timeout=timeout)
|
||||
|
||||
|
||||
def run_cli(context, run_args=None):
|
||||
def run_cli(context, run_args=None, exclude_args=None):
|
||||
"""Run the process using pexpect."""
|
||||
run_args = run_args or []
|
||||
if context.conf.get('host', None):
|
||||
run_args.extend(('-h', context.conf['host']))
|
||||
if context.conf.get('user', None):
|
||||
run_args.extend(('-u', context.conf['user']))
|
||||
if context.conf.get('pass', None):
|
||||
run_args.extend(('-p', context.conf['pass']))
|
||||
if context.conf.get('dbname', None):
|
||||
run_args.extend(('-D', context.conf['dbname']))
|
||||
if context.conf.get('defaults-file', None):
|
||||
run_args.extend(('--defaults-file', context.conf['defaults-file']))
|
||||
if context.conf.get('myclirc', None):
|
||||
run_args.extend(('--myclirc', context.conf['myclirc']))
|
||||
run_args = run_args or {}
|
||||
rendered_args = []
|
||||
exclude_args = set(exclude_args) if exclude_args else set()
|
||||
|
||||
conf = dict(**context.conf)
|
||||
conf.update(run_args)
|
||||
|
||||
def add_arg(name, key, value):
|
||||
if name not in exclude_args:
|
||||
if value is not None:
|
||||
rendered_args.extend((key, value))
|
||||
else:
|
||||
rendered_args.append(key)
|
||||
|
||||
if conf.get('host', None):
|
||||
add_arg('host', '-h', conf['host'])
|
||||
if conf.get('user', None):
|
||||
add_arg('user', '-u', conf['user'])
|
||||
if conf.get('pass', None):
|
||||
add_arg('pass', '-p', conf['pass'])
|
||||
if conf.get('port', None):
|
||||
add_arg('port', '-P', str(conf['port']))
|
||||
if conf.get('dbname', None):
|
||||
add_arg('dbname', '-D', conf['dbname'])
|
||||
if conf.get('defaults-file', None):
|
||||
add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
|
||||
if conf.get('myclirc', None):
|
||||
add_arg('myclirc', '--myclirc', conf['myclirc'])
|
||||
if conf.get('login_path'):
|
||||
add_arg('login_path', '--login-path', conf['login_path'])
|
||||
|
||||
for arg_name, arg_value in conf.items():
|
||||
if arg_name.startswith('-'):
|
||||
add_arg(arg_name, arg_name, arg_value)
|
||||
|
||||
try:
|
||||
cli_cmd = context.conf['cli_command']
|
||||
except KeyError:
|
||||
|
@ -73,7 +96,7 @@ def run_cli(context, run_args=None):
|
|||
'"'
|
||||
).format(sys.executable)
|
||||
|
||||
cmd_parts = [cli_cmd] + run_args
|
||||
cmd_parts = [cli_cmd] + rendered_args
|
||||
cmd = ' '.join(cmd_parts)
|
||||
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
|
||||
context.logfile = StringIO()
|
||||
|
|
|
@ -3,8 +3,9 @@ import os
|
|||
import click
|
||||
from click.testing import CliRunner
|
||||
|
||||
from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT
|
||||
from mycli.main import MyCli, cli, thanks_picker
|
||||
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
|
||||
from mycli.sqlexecute import ServerInfo
|
||||
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
|
||||
|
||||
from textwrap import dedent
|
||||
|
@ -140,10 +141,7 @@ def test_batch_mode_csv(executor):
|
|||
|
||||
|
||||
def test_thanks_picker_utf8():
|
||||
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
|
||||
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
|
||||
|
||||
name = thanks_picker((author_file, sponsor_file))
|
||||
name = thanks_picker()
|
||||
assert name and isinstance(name, str)
|
||||
|
||||
|
||||
|
@ -177,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
|
|||
host = 'test'
|
||||
user = 'test'
|
||||
dbname = 'test'
|
||||
server_info = ServerInfo.from_version_string('unknown')
|
||||
port = 0
|
||||
|
||||
def server_type(self):
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import pytest
|
||||
import pymysql
|
||||
|
||||
from mycli.sqlexecute import ServerInfo, ServerSpecies
|
||||
from .utils import run, dbtest, set_expanded_output, is_expanded_output
|
||||
|
||||
|
||||
|
@ -270,3 +271,24 @@ def test_multiple_results(executor):
|
|||
'status': '1 row in set'}
|
||||
]
|
||||
assert results == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'version_string, species, parsed_version_string, version',
|
||||
(
|
||||
('5.7.32-35', 'Percona', '5.7.32', 50732),
|
||||
('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
|
||||
('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
|
||||
('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
|
||||
('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
|
||||
('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
|
||||
('unexpected version string', None, '', 0),
|
||||
('', None, '', 0),
|
||||
(None, None, '', 0),
|
||||
)
|
||||
)
|
||||
def test_version_parsing(version_string, species, parsed_version_string, version):
|
||||
server_info = ServerInfo.from_version_string(version_string)
|
||||
assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
|
||||
assert server_info.version_str == parsed_version_string
|
||||
assert server_info.version == version
|
||||
|
|
Loading…
Add table
Reference in a new issue