Merging upstream version 1.24.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
570aa52ec2
commit
06dd2aeb28
26 changed files with 565 additions and 169 deletions
14
.github/workflows/ci.yml
vendored
14
.github/workflows/ci.yml
vendored
|
@ -7,13 +7,20 @@ on:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
linux:
|
linux:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||||
|
include:
|
||||||
|
- python-version: 3.6
|
||||||
|
os: ubuntu-16.04 # MySQL 5.7.32
|
||||||
|
- python-version: 3.7
|
||||||
|
os: ubuntu-18.04 # MySQL 5.7.32
|
||||||
|
- python-version: 3.8
|
||||||
|
os: ubuntu-18.04 # MySQL 5.7.32
|
||||||
|
- python-version: 3.9
|
||||||
|
os: ubuntu-20.04 # MySQL 8.0.22
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
@ -42,6 +49,7 @@ jobs:
|
||||||
- name: Pytest / behave
|
- name: Pytest / behave
|
||||||
env:
|
env:
|
||||||
PYTEST_PASSWORD: root
|
PYTEST_PASSWORD: root
|
||||||
|
PYTEST_HOST: 127.0.0.1
|
||||||
run: |
|
run: |
|
||||||
./setup.py test --pytest-args="--cov-report= --cov=mycli"
|
./setup.py test --pytest-args="--cov-report= --cov=mycli"
|
||||||
|
|
||||||
|
|
26
README.md
26
README.md
|
@ -1,8 +1,8 @@
|
||||||
# mycli
|
# mycli
|
||||||
|
|
||||||
[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli)
|
[![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli)
|
||||||
[![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli)
|
[![PyPI](https://img.shields.io/pypi/v/mycli.svg)](https://pypi.python.org/pypi/mycli)
|
||||||
[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
[![LGTM](https://img.shields.io/lgtm/grade/python/github/dbcli/mycli.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/dbcli/mycli/context:python)
|
||||||
|
|
||||||
A command line client for MySQL that can do auto-completion and syntax highlighting.
|
A command line client for MySQL that can do auto-completion and syntax highlighting.
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
|
||||||
-h, --host TEXT Host address of the database.
|
-h, --host TEXT Host address of the database.
|
||||||
-P, --port INTEGER Port number to use for connection. Honors
|
-P, --port INTEGER Port number to use for connection. Honors
|
||||||
$MYSQL_TCP_PORT.
|
$MYSQL_TCP_PORT.
|
||||||
|
|
||||||
-u, --user TEXT User name to connect to the database.
|
-u, --user TEXT User name to connect to the database.
|
||||||
-S, --socket TEXT The socket file to use for connection.
|
-S, --socket TEXT The socket file to use for connection.
|
||||||
-p, --password TEXT Password to connect to the database.
|
-p, --password TEXT Password to connect to the database.
|
||||||
|
@ -63,8 +64,11 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
|
||||||
--ssh-password TEXT Password to connect to ssh server.
|
--ssh-password TEXT Password to connect to ssh server.
|
||||||
--ssh-key-filename TEXT Private key filename (identify file) for the
|
--ssh-key-filename TEXT Private key filename (identify file) for the
|
||||||
ssh connection.
|
ssh connection.
|
||||||
|
|
||||||
--ssh-config-path TEXT Path to ssh configuration.
|
--ssh-config-path TEXT Path to ssh configuration.
|
||||||
--ssh-config-host TEXT Host for ssh server in ssh configurations (requires paramiko).
|
--ssh-config-host TEXT Host to connect to ssh server reading from ssh
|
||||||
|
configuration.
|
||||||
|
|
||||||
--ssl-ca PATH CA file in PEM format.
|
--ssl-ca PATH CA file in PEM format.
|
||||||
--ssl-capath TEXT CA directory.
|
--ssl-capath TEXT CA directory.
|
||||||
--ssl-cert PATH X509 cert in PEM format.
|
--ssl-cert PATH X509 cert in PEM format.
|
||||||
|
@ -73,33 +77,43 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
|
||||||
--ssl-verify-server-cert Verify server's "Common Name" in its cert
|
--ssl-verify-server-cert Verify server's "Common Name" in its cert
|
||||||
against hostname used when connecting. This
|
against hostname used when connecting. This
|
||||||
option is disabled by default.
|
option is disabled by default.
|
||||||
|
|
||||||
-V, --version Output mycli's version.
|
-V, --version Output mycli's version.
|
||||||
-v, --verbose Verbose output.
|
-v, --verbose Verbose output.
|
||||||
-D, --database TEXT Database to use.
|
-D, --database TEXT Database to use.
|
||||||
-d, --dsn TEXT Use DSN configured into the [alias_dsn]
|
-d, --dsn TEXT Use DSN configured into the [alias_dsn]
|
||||||
section of myclirc file.
|
section of myclirc file.
|
||||||
|
|
||||||
--list-dsn list of DSN configured into the [alias_dsn]
|
--list-dsn list of DSN configured into the [alias_dsn]
|
||||||
section of myclirc file.
|
section of myclirc file.
|
||||||
--list-ssh-config list ssh configurations in the ssh config (requires paramiko).
|
|
||||||
|
--list-ssh-config list ssh configurations in the ssh config
|
||||||
|
(requires paramiko).
|
||||||
|
|
||||||
-R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> ").
|
-R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> ").
|
||||||
-l, --logfile FILENAME Log every query and its results to a file.
|
-l, --logfile FILENAME Log every query and its results to a file.
|
||||||
--defaults-group-suffix TEXT Read MySQL config groups with the specified
|
--defaults-group-suffix TEXT Read MySQL config groups with the specified
|
||||||
suffix.
|
suffix.
|
||||||
|
|
||||||
--defaults-file PATH Only read MySQL options from the given file.
|
--defaults-file PATH Only read MySQL options from the given file.
|
||||||
--myclirc PATH Location of myclirc file.
|
--myclirc PATH Location of myclirc file.
|
||||||
--auto-vertical-output Automatically switch to vertical output mode
|
--auto-vertical-output Automatically switch to vertical output mode
|
||||||
if the result is wider than the terminal
|
if the result is wider than the terminal
|
||||||
width.
|
width.
|
||||||
|
|
||||||
-t, --table Display batch output in table format.
|
-t, --table Display batch output in table format.
|
||||||
--csv Display batch output in CSV format.
|
--csv Display batch output in CSV format.
|
||||||
--warn / --no-warn Warn before running a destructive query.
|
--warn / --no-warn Warn before running a destructive query.
|
||||||
--local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE.
|
--local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE.
|
||||||
--login-path TEXT Read this path from the login file.
|
-g, --login-path TEXT Read this path from the login file.
|
||||||
-e, --execute TEXT Execute command and quit.
|
-e, --execute TEXT Execute command and quit.
|
||||||
--init-command TEXT SQL statement to execute after connecting.
|
--init-command TEXT SQL statement to execute after connecting.
|
||||||
--charset TEXT Character set for MySQL session.
|
--charset TEXT Character set for MySQL session.
|
||||||
|
--password-file PATH File or FIFO path containing the password
|
||||||
|
to connect to the db if not specified otherwise
|
||||||
--help Show this message and exit.
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
|
||||||
Features
|
Features
|
||||||
--------
|
--------
|
||||||
|
|
||||||
|
|
47
changelog.md
47
changelog.md
|
@ -1,19 +1,60 @@
|
||||||
|
TBD:
|
||||||
|
====
|
||||||
|
|
||||||
|
*
|
||||||
|
|
||||||
|
1.24.1:
|
||||||
|
=======
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
---------
|
||||||
|
* Restore dependency on cryptography for the interactive password prompt
|
||||||
|
|
||||||
|
|
||||||
|
1.24.0
|
||||||
|
======
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
----------
|
||||||
|
* Allow `FileNotFound` exception for SSH config files.
|
||||||
|
* Fix startup error on MySQL < 5.0.22
|
||||||
|
* Check error code rather than message for Access Denied error
|
||||||
|
* Fix login with ~/.my.cnf files
|
||||||
|
|
||||||
|
Features:
|
||||||
|
---------
|
||||||
|
* Add `-g` shortcut to option `--login-path`.
|
||||||
|
* Alt-Enter dispatches the command in multi-line mode.
|
||||||
|
* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html)
|
||||||
|
|
||||||
|
Internal:
|
||||||
|
---------
|
||||||
|
* Remove unused function is_open_quote()
|
||||||
|
* Use importlib, instead of file links, to locate resources
|
||||||
|
* Test various host-port combinations in command line arguments
|
||||||
|
* Switched from Cryptography to pyaes for decrypting mylogin.cnf
|
||||||
|
|
||||||
|
|
||||||
1.23.2
|
1.23.2
|
||||||
===
|
======
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
----------
|
----------
|
||||||
* Ensure `--port` is always an int.
|
* Ensure `--port` is always an int.
|
||||||
|
|
||||||
1.23.1
|
1.23.1
|
||||||
===
|
======
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
----------
|
----------
|
||||||
* Allow `--host` without `--port` to make a TCP connection.
|
* Allow `--host` without `--port` to make a TCP connection.
|
||||||
|
|
||||||
1.23.0
|
1.23.0
|
||||||
===
|
======
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
----------
|
||||||
|
* Fix config file include logic
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
---------
|
---------
|
||||||
|
|
|
@ -75,6 +75,8 @@ Contributors:
|
||||||
* Zach DeCook
|
* Zach DeCook
|
||||||
* kevinhwang91
|
* kevinhwang91
|
||||||
* KITAGAWA Yasutaka
|
* KITAGAWA Yasutaka
|
||||||
|
* Nicolas Palumbo
|
||||||
|
* Andy Teijelo Pérez
|
||||||
* bitkeen
|
* bitkeen
|
||||||
* Morgan Mitchell
|
* Morgan Mitchell
|
||||||
* Massimiliano Torromeo
|
* Massimiliano Torromeo
|
||||||
|
@ -82,6 +84,7 @@ Contributors:
|
||||||
* xeron
|
* xeron
|
||||||
* 0xflotus
|
* 0xflotus
|
||||||
* Seamile
|
* Seamile
|
||||||
|
* Jerome Provensal
|
||||||
|
|
||||||
Creator:
|
Creator:
|
||||||
--------
|
--------
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = '1.23.2'
|
__version__ = '1.24.1'
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from prompt_toolkit.enums import DEFAULT_BUFFER
|
from prompt_toolkit.enums import DEFAULT_BUFFER
|
||||||
from prompt_toolkit.filters import Condition
|
from prompt_toolkit.filters import Condition
|
||||||
from prompt_toolkit.application import get_app
|
from prompt_toolkit.application import get_app
|
||||||
from .packages.parseutils import is_open_quote
|
|
||||||
from .packages import special
|
from .packages import special
|
||||||
|
|
||||||
|
|
||||||
|
|
104
mycli/config.py
104
mycli/config.py
|
@ -1,5 +1,3 @@
|
||||||
import io
|
|
||||||
import shutil
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from io import BytesIO, TextIOWrapper
|
from io import BytesIO, TextIOWrapper
|
||||||
import logging
|
import logging
|
||||||
|
@ -7,11 +5,16 @@ import os
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from typing import Union
|
from typing import Union, IO
|
||||||
|
|
||||||
from configobj import ConfigObj, ConfigObjError
|
from configobj import ConfigObj, ConfigObjError
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
import pyaes
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
|
try:
|
||||||
|
import importlib.resources as resources
|
||||||
|
except ImportError:
|
||||||
|
# Python < 3.7
|
||||||
|
import importlib_resources as resources
|
||||||
|
|
||||||
try:
|
try:
|
||||||
basestring
|
basestring
|
||||||
|
@ -49,9 +52,9 @@ def read_config_file(f, list_values=True):
|
||||||
config = ConfigObj(f, interpolation=False, encoding='utf8',
|
config = ConfigObj(f, interpolation=False, encoding='utf8',
|
||||||
list_values=list_values)
|
list_values=list_values)
|
||||||
except ConfigObjError as e:
|
except ConfigObjError as e:
|
||||||
log(logger, logging.ERROR, "Unable to parse line {0} of config file "
|
log(logger, logging.WARNING, "Unable to parse line {0} of config file "
|
||||||
"'{1}'.".format(e.line_number, f))
|
"'{1}'.".format(e.line_number, f))
|
||||||
log(logger, logging.ERROR, "Using successfully parsed config values.")
|
log(logger, logging.WARNING, "Using successfully parsed config values.")
|
||||||
return e.config
|
return e.config
|
||||||
except (IOError, OSError) as e:
|
except (IOError, OSError) as e:
|
||||||
log(logger, logging.WARNING, "You don't have permission to read "
|
log(logger, logging.WARNING, "You don't have permission to read "
|
||||||
|
@ -61,7 +64,7 @@ def read_config_file(f, list_values=True):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
|
def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
|
||||||
"""Get a list of configuration files that are included into config_path
|
"""Get a list of configuration files that are included into config_path
|
||||||
with !includedir directive.
|
with !includedir directive.
|
||||||
|
|
||||||
|
@ -95,7 +98,7 @@ def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
|
||||||
def read_config_files(files, list_values=True):
|
def read_config_files(files, list_values=True):
|
||||||
"""Read and merge a list of config files."""
|
"""Read and merge a list of config files."""
|
||||||
|
|
||||||
config = ConfigObj(list_values=list_values)
|
config = create_default_config(list_values=list_values)
|
||||||
_files = copy(files)
|
_files = copy(files)
|
||||||
while _files:
|
while _files:
|
||||||
_file = _files.pop(0)
|
_file = _files.pop(0)
|
||||||
|
@ -112,12 +115,21 @@ def read_config_files(files, list_values=True):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def write_default_config(source, destination, overwrite=False):
|
def create_default_config(list_values=True):
|
||||||
|
import mycli
|
||||||
|
default_config_file = resources.open_text(mycli, 'myclirc')
|
||||||
|
return read_config_file(default_config_file, list_values=list_values)
|
||||||
|
|
||||||
|
|
||||||
|
def write_default_config(destination, overwrite=False):
|
||||||
|
import mycli
|
||||||
|
default_config = resources.read_text(mycli, 'myclirc')
|
||||||
destination = os.path.expanduser(destination)
|
destination = os.path.expanduser(destination)
|
||||||
if not overwrite and exists(destination):
|
if not overwrite and exists(destination):
|
||||||
return
|
return
|
||||||
|
|
||||||
shutil.copyfile(source, destination)
|
with open(destination, 'w') as f:
|
||||||
|
f.write(default_config)
|
||||||
|
|
||||||
|
|
||||||
def get_mylogin_cnf_path():
|
def get_mylogin_cnf_path():
|
||||||
|
@ -160,6 +172,58 @@ def open_mylogin_cnf(name):
|
||||||
return TextIOWrapper(plaintext)
|
return TextIOWrapper(plaintext)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO reuse code between encryption an decryption
|
||||||
|
def encrypt_mylogin_cnf(plaintext: IO[str]):
|
||||||
|
"""Encryption of .mylogin.cnf file, analogous to calling
|
||||||
|
mysql_config_editor.
|
||||||
|
|
||||||
|
Code is based on the python implementation by Kristian Koehntopp
|
||||||
|
https://github.com/isotopp/mysql-config-coder
|
||||||
|
|
||||||
|
"""
|
||||||
|
def realkey(key):
|
||||||
|
"""Create the AES key from the login key."""
|
||||||
|
rkey = bytearray(16)
|
||||||
|
for i in range(len(key)):
|
||||||
|
rkey[i % 16] ^= key[i]
|
||||||
|
return bytes(rkey)
|
||||||
|
|
||||||
|
def encode_line(plaintext, real_key, buf_len):
|
||||||
|
aes = pyaes.AESModeOfOperationECB(real_key)
|
||||||
|
text_len = len(plaintext)
|
||||||
|
pad_len = buf_len - text_len
|
||||||
|
pad_chr = bytes(chr(pad_len), "utf8")
|
||||||
|
plaintext = plaintext.encode() + pad_chr * pad_len
|
||||||
|
encrypted_text = b''.join(
|
||||||
|
[aes.encrypt(plaintext[i: i + 16])
|
||||||
|
for i in range(0, len(plaintext), 16)]
|
||||||
|
)
|
||||||
|
return encrypted_text
|
||||||
|
|
||||||
|
LOGIN_KEY_LENGTH = 20
|
||||||
|
key = os.urandom(LOGIN_KEY_LENGTH)
|
||||||
|
real_key = realkey(key)
|
||||||
|
|
||||||
|
outfile = BytesIO()
|
||||||
|
|
||||||
|
outfile.write(struct.pack("i", 0))
|
||||||
|
outfile.write(key)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
line = plaintext.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
real_len = len(line)
|
||||||
|
pad_len = (int(real_len / 16) + 1) * 16
|
||||||
|
|
||||||
|
outfile.write(struct.pack("i", pad_len))
|
||||||
|
x = encode_line(line, real_key, pad_len)
|
||||||
|
outfile.write(x)
|
||||||
|
|
||||||
|
outfile.seek(0)
|
||||||
|
return outfile
|
||||||
|
|
||||||
|
|
||||||
def read_and_decrypt_mylogin_cnf(f):
|
def read_and_decrypt_mylogin_cnf(f):
|
||||||
"""Read and decrypt the contents of .mylogin.cnf.
|
"""Read and decrypt the contents of .mylogin.cnf.
|
||||||
|
|
||||||
|
@ -201,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f):
|
||||||
return None
|
return None
|
||||||
rkey = struct.pack('16B', *rkey)
|
rkey = struct.pack('16B', *rkey)
|
||||||
|
|
||||||
# Create a decryptor object using the key.
|
|
||||||
decryptor = _get_decryptor(rkey)
|
|
||||||
|
|
||||||
# Create a bytes buffer to hold the plaintext.
|
# Create a bytes buffer to hold the plaintext.
|
||||||
plaintext = BytesIO()
|
plaintext = BytesIO()
|
||||||
|
aes = pyaes.AESModeOfOperationECB(rkey)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Read the length of the ciphertext.
|
# Read the length of the ciphertext.
|
||||||
|
@ -216,7 +278,10 @@ def read_and_decrypt_mylogin_cnf(f):
|
||||||
|
|
||||||
# Read cipher_len bytes from the file and decrypt.
|
# Read cipher_len bytes from the file and decrypt.
|
||||||
cipher = f.read(cipher_len)
|
cipher = f.read(cipher_len)
|
||||||
plain = _remove_pad(decryptor.update(cipher))
|
plain = _remove_pad(
|
||||||
|
b''.join([aes.decrypt(cipher[i: i + 16])
|
||||||
|
for i in range(0, cipher_len, 16)])
|
||||||
|
)
|
||||||
if plain is False:
|
if plain is False:
|
||||||
continue
|
continue
|
||||||
plaintext.write(plain)
|
plaintext.write(plain)
|
||||||
|
@ -244,7 +309,7 @@ def str_to_bool(s):
|
||||||
elif s.lower() in false_values:
|
elif s.lower() in false_values:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
raise ValueError('not a recognized boolean value: %s'.format(s))
|
raise ValueError('not a recognized boolean value: {0}'.format(s))
|
||||||
|
|
||||||
|
|
||||||
def strip_matching_quotes(s):
|
def strip_matching_quotes(s):
|
||||||
|
@ -260,15 +325,8 @@ def strip_matching_quotes(s):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def _get_decryptor(key):
|
|
||||||
"""Get the AES decryptor."""
|
|
||||||
c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
|
|
||||||
return c.decryptor()
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_pad(line):
|
def _remove_pad(line):
|
||||||
"""Remove the pad from the *line*."""
|
"""Remove the pad from the *line*."""
|
||||||
pad_length = ord(line[-1:])
|
|
||||||
try:
|
try:
|
||||||
# Determine pad length.
|
# Determine pad length.
|
||||||
pad_length = ord(line[-1:])
|
pad_length = ord(line[-1:])
|
||||||
|
|
|
@ -78,8 +78,12 @@ def mycli_bindings(mycli):
|
||||||
|
|
||||||
@kb.add('escape', 'enter')
|
@kb.add('escape', 'enter')
|
||||||
def _(event):
|
def _(event):
|
||||||
"""Introduces a line break regardless of multi-line mode or not."""
|
"""Introduces a line break in multi-line mode, or dispatches the
|
||||||
|
command in single-line mode."""
|
||||||
_logger.debug('Detected alt-enter key.')
|
_logger.debug('Detected alt-enter key.')
|
||||||
event.app.current_buffer.insert_text('\n')
|
if mycli.multi_line:
|
||||||
|
event.app.current_buffer.validate_and_handle()
|
||||||
|
else:
|
||||||
|
event.app.current_buffer.insert_text('\n')
|
||||||
|
|
||||||
return kb
|
return kb
|
||||||
|
|
140
mycli/main.py
140
mycli/main.py
|
@ -1,9 +1,12 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from io import open
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import re
|
import re
|
||||||
|
import stat
|
||||||
import fileinput
|
import fileinput
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
try:
|
try:
|
||||||
|
@ -13,7 +16,6 @@ except ImportError:
|
||||||
from time import time
|
from time import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from random import choice
|
from random import choice
|
||||||
from io import open
|
|
||||||
|
|
||||||
from pymysql import OperationalError
|
from pymysql import OperationalError
|
||||||
from cli_helpers.tabular_output import TabularOutputFormatter
|
from cli_helpers.tabular_output import TabularOutputFormatter
|
||||||
|
@ -43,7 +45,7 @@ from .packages.special.favoritequeries import FavoriteQueries
|
||||||
from .sqlcompleter import SQLCompleter
|
from .sqlcompleter import SQLCompleter
|
||||||
from .clitoolbar import create_toolbar_tokens_func
|
from .clitoolbar import create_toolbar_tokens_func
|
||||||
from .clistyle import style_factory, style_factory_output
|
from .clistyle import style_factory, style_factory_output
|
||||||
from .sqlexecute import FIELD_TYPES, SQLExecute
|
from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
|
||||||
from .clibuffer import cli_is_multiline
|
from .clibuffer import cli_is_multiline
|
||||||
from .completion_refresher import CompletionRefresher
|
from .completion_refresher import CompletionRefresher
|
||||||
from .config import (write_default_config, get_mylogin_cnf_path,
|
from .config import (write_default_config, get_mylogin_cnf_path,
|
||||||
|
@ -51,7 +53,7 @@ from .config import (write_default_config, get_mylogin_cnf_path,
|
||||||
strip_matching_quotes)
|
strip_matching_quotes)
|
||||||
from .key_bindings import mycli_bindings
|
from .key_bindings import mycli_bindings
|
||||||
from .lexer import MyCliLexer
|
from .lexer import MyCliLexer
|
||||||
from .__init__ import __version__
|
from . import __version__
|
||||||
from .compat import WIN
|
from .compat import WIN
|
||||||
from .packages.filepaths import dir_path_exists, guess_socket_location
|
from .packages.filepaths import dir_path_exists, guess_socket_location
|
||||||
|
|
||||||
|
@ -66,6 +68,11 @@ except ImportError:
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
|
try:
|
||||||
|
import importlib.resources as resources
|
||||||
|
except ImportError:
|
||||||
|
# Python < 3.7
|
||||||
|
import importlib_resources as resources
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import paramiko
|
import paramiko
|
||||||
|
@ -75,7 +82,10 @@ except ImportError:
|
||||||
# Query tuples are used for maintaining history
|
# Query tuples are used for maintaining history
|
||||||
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
|
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
|
||||||
|
|
||||||
PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
|
SUPPORT_INFO = (
|
||||||
|
'Home: http://mycli.net\n'
|
||||||
|
'Bug tracker: https://github.com/dbcli/mycli/issues'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MyCli(object):
|
class MyCli(object):
|
||||||
|
@ -89,7 +99,7 @@ class MyCli(object):
|
||||||
'/etc/my.cnf',
|
'/etc/my.cnf',
|
||||||
'/etc/mysql/my.cnf',
|
'/etc/mysql/my.cnf',
|
||||||
'/usr/local/etc/my.cnf',
|
'/usr/local/etc/my.cnf',
|
||||||
'~/.my.cnf'
|
os.path.expanduser('~/.my.cnf'),
|
||||||
]
|
]
|
||||||
|
|
||||||
# check XDG_CONFIG_HOME exists and not an empty string
|
# check XDG_CONFIG_HOME exists and not an empty string
|
||||||
|
@ -102,7 +112,6 @@ class MyCli(object):
|
||||||
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
|
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
|
||||||
]
|
]
|
||||||
|
|
||||||
default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
|
|
||||||
pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
|
pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
|
||||||
|
|
||||||
def __init__(self, sqlexecute=None, prompt=None,
|
def __init__(self, sqlexecute=None, prompt=None,
|
||||||
|
@ -122,7 +131,7 @@ class MyCli(object):
|
||||||
self.cnf_files = [defaults_file]
|
self.cnf_files = [defaults_file]
|
||||||
|
|
||||||
# Load config.
|
# Load config.
|
||||||
config_files = ([self.default_config_file] + self.system_config_files +
|
config_files = (self.system_config_files +
|
||||||
[myclirc] + [self.pwd_config_file])
|
[myclirc] + [self.pwd_config_file])
|
||||||
c = self.config = read_config_files(config_files)
|
c = self.config = read_config_files(config_files)
|
||||||
self.multi_line = c['main'].as_bool('multi_line')
|
self.multi_line = c['main'].as_bool('multi_line')
|
||||||
|
@ -154,7 +163,7 @@ class MyCli(object):
|
||||||
|
|
||||||
# Write user config if system config wasn't the last config loaded.
|
# Write user config if system config wasn't the last config loaded.
|
||||||
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
|
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
|
||||||
write_default_config(self.default_config_file, myclirc)
|
write_default_config(myclirc)
|
||||||
|
|
||||||
# audit log
|
# audit log
|
||||||
if self.logfile is None and 'audit_log' in c['main']:
|
if self.logfile is None and 'audit_log' in c['main']:
|
||||||
|
@ -326,20 +335,33 @@ class MyCli(object):
|
||||||
cnf = read_config_files(files, list_values=False)
|
cnf = read_config_files(files, list_values=False)
|
||||||
|
|
||||||
sections = ['client', 'mysqld']
|
sections = ['client', 'mysqld']
|
||||||
|
key_transformations = {
|
||||||
|
'mysqld': {
|
||||||
|
'socket': 'default_socket',
|
||||||
|
'port': 'default_port',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if self.login_path and self.login_path != 'client':
|
if self.login_path and self.login_path != 'client':
|
||||||
sections.append(self.login_path)
|
sections.append(self.login_path)
|
||||||
|
|
||||||
if self.defaults_suffix:
|
if self.defaults_suffix:
|
||||||
sections.extend([sect + self.defaults_suffix for sect in sections])
|
sections.extend([sect + self.defaults_suffix for sect in sections])
|
||||||
|
|
||||||
def get(key):
|
configuration = defaultdict(lambda: None)
|
||||||
result = None
|
for key in keys:
|
||||||
for sect in cnf:
|
for section in cnf:
|
||||||
if sect in sections and key in cnf[sect]:
|
if (
|
||||||
result = strip_matching_quotes(cnf[sect][key])
|
section not in sections or
|
||||||
return result
|
key not in cnf[section]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
new_key = key_transformations.get(section, {}).get(key) or key
|
||||||
|
configuration[new_key] = strip_matching_quotes(
|
||||||
|
cnf[section][key])
|
||||||
|
|
||||||
|
return configuration
|
||||||
|
|
||||||
return {x: get(x) for x in keys}
|
|
||||||
|
|
||||||
def merge_ssl_with_cnf(self, ssl, cnf):
|
def merge_ssl_with_cnf(self, ssl, cnf):
|
||||||
"""Merge SSL configuration dict with cnf dict"""
|
"""Merge SSL configuration dict with cnf dict"""
|
||||||
|
@ -367,7 +389,7 @@ class MyCli(object):
|
||||||
def connect(self, database='', user='', passwd='', host='', port='',
|
def connect(self, database='', user='', passwd='', host='', port='',
|
||||||
socket='', charset='', local_infile='', ssl='',
|
socket='', charset='', local_infile='', ssl='',
|
||||||
ssh_user='', ssh_host='', ssh_port='',
|
ssh_user='', ssh_host='', ssh_port='',
|
||||||
ssh_password='', ssh_key_filename='', init_command=''):
|
ssh_password='', ssh_key_filename='', init_command='', password_file=''):
|
||||||
|
|
||||||
cnf = {'database': None,
|
cnf = {'database': None,
|
||||||
'user': None,
|
'user': None,
|
||||||
|
@ -375,6 +397,7 @@ class MyCli(object):
|
||||||
'host': None,
|
'host': None,
|
||||||
'port': None,
|
'port': None,
|
||||||
'socket': None,
|
'socket': None,
|
||||||
|
'default_socket': None,
|
||||||
'default-character-set': None,
|
'default-character-set': None,
|
||||||
'local-infile': None,
|
'local-infile': None,
|
||||||
'loose-local-infile': None,
|
'loose-local-infile': None,
|
||||||
|
@ -388,18 +411,23 @@ class MyCli(object):
|
||||||
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
|
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
|
||||||
|
|
||||||
# Fall back to config values only if user did not specify a value.
|
# Fall back to config values only if user did not specify a value.
|
||||||
|
|
||||||
database = database or cnf['database']
|
database = database or cnf['database']
|
||||||
# Socket interface not supported for SSH connections
|
|
||||||
if port or (host and host != 'localhost') or (ssh_host and ssh_port):
|
|
||||||
socket = ''
|
|
||||||
else:
|
|
||||||
socket = socket or cnf['socket'] or guess_socket_location()
|
|
||||||
user = user or cnf['user'] or os.getenv('USER')
|
user = user or cnf['user'] or os.getenv('USER')
|
||||||
host = host or cnf['host']
|
host = host or cnf['host']
|
||||||
port = int(port or cnf['port'] or 3306)
|
port = port or cnf['port']
|
||||||
ssl = ssl or {}
|
ssl = ssl or {}
|
||||||
|
|
||||||
|
port = port and int(port)
|
||||||
|
if not port:
|
||||||
|
port = 3306
|
||||||
|
if not host or host == 'localhost':
|
||||||
|
socket = (
|
||||||
|
cnf['socket'] or
|
||||||
|
cnf['default_socket'] or
|
||||||
|
guess_socket_location()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
passwd = passwd if isinstance(passwd, str) else cnf['password']
|
passwd = passwd if isinstance(passwd, str) else cnf['password']
|
||||||
charset = charset or cnf['default-character-set'] or 'utf8'
|
charset = charset or cnf['default-character-set'] or 'utf8'
|
||||||
|
|
||||||
|
@ -417,6 +445,10 @@ class MyCli(object):
|
||||||
if not any(v for v in ssl.values()):
|
if not any(v for v in ssl.values()):
|
||||||
ssl = None
|
ssl = None
|
||||||
|
|
||||||
|
# if the passwd is not specfied try to set it using the password_file option
|
||||||
|
password_from_file = self.get_password_from_file(password_file)
|
||||||
|
passwd = passwd or password_from_file
|
||||||
|
|
||||||
# Connect to the database.
|
# Connect to the database.
|
||||||
|
|
||||||
def _connect():
|
def _connect():
|
||||||
|
@ -427,9 +459,12 @@ class MyCli(object):
|
||||||
ssh_password, ssh_key_filename, init_command
|
ssh_password, ssh_key_filename, init_command
|
||||||
)
|
)
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
if ('Access denied for user' in e.args[1]):
|
if e.args[0] == ERROR_CODE_ACCESS_DENIED:
|
||||||
new_passwd = click.prompt('Password', hide_input=True,
|
if password_from_file:
|
||||||
show_default=False, type=str, err=True)
|
new_passwd = password_from_file
|
||||||
|
else:
|
||||||
|
new_passwd = click.prompt('Password', hide_input=True,
|
||||||
|
show_default=False, type=str, err=True)
|
||||||
self.sqlexecute = SQLExecute(
|
self.sqlexecute = SQLExecute(
|
||||||
database, user, new_passwd, host, port, socket,
|
database, user, new_passwd, host, port, socket,
|
||||||
charset, local_infile, ssl, ssh_user, ssh_host,
|
charset, local_infile, ssl, ssh_user, ssh_host,
|
||||||
|
@ -484,6 +519,17 @@ class MyCli(object):
|
||||||
self.echo(str(e), err=True, fg='red')
|
self.echo(str(e), err=True, fg='red')
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
def get_password_from_file(self, password_file):
|
||||||
|
password_from_file = None
|
||||||
|
if password_file:
|
||||||
|
if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
|
||||||
|
and os.access(password_file, os.R_OK):
|
||||||
|
with open(password_file) as fp:
|
||||||
|
password_from_file = fp.readline()
|
||||||
|
password_from_file = password_from_file.rstrip().lstrip()
|
||||||
|
|
||||||
|
return password_from_file
|
||||||
|
|
||||||
def handle_editor_command(self, text):
|
def handle_editor_command(self, text):
|
||||||
r"""Editor command is any query that is prefixed or suffixed by a '\e'.
|
r"""Editor command is any query that is prefixed or suffixed by a '\e'.
|
||||||
The reason for a while loop is because a user might edit a query
|
The reason for a while loop is because a user might edit a query
|
||||||
|
@ -542,9 +588,6 @@ class MyCli(object):
|
||||||
if self.smart_completion:
|
if self.smart_completion:
|
||||||
self.refresh_completions()
|
self.refresh_completions()
|
||||||
|
|
||||||
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
|
|
||||||
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
|
|
||||||
|
|
||||||
history_file = os.path.expanduser(
|
history_file = os.path.expanduser(
|
||||||
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
|
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
|
||||||
if dir_path_exists(history_file):
|
if dir_path_exists(history_file):
|
||||||
|
@ -559,12 +602,10 @@ class MyCli(object):
|
||||||
key_bindings = mycli_bindings(self)
|
key_bindings = mycli_bindings(self)
|
||||||
|
|
||||||
if not self.less_chatty:
|
if not self.less_chatty:
|
||||||
print(' '.join(sqlexecute.server_type()))
|
print(sqlexecute.server_info)
|
||||||
print('mycli', __version__)
|
print('mycli', __version__)
|
||||||
print('Chat: https://gitter.im/dbcli/mycli')
|
print(SUPPORT_INFO)
|
||||||
print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
|
print('Thanks to the contributor -', thanks_picker())
|
||||||
print('Home: http://mycli.net')
|
|
||||||
print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
|
|
||||||
|
|
||||||
def get_message():
|
def get_message():
|
||||||
prompt = self.get_prompt(self.prompt_format)
|
prompt = self.get_prompt(self.prompt_format)
|
||||||
|
@ -862,8 +903,8 @@ class MyCli(object):
|
||||||
|
|
||||||
if not output_via_pager:
|
if not output_via_pager:
|
||||||
# doesn't fit, flush buffer
|
# doesn't fit, flush buffer
|
||||||
for line in buf:
|
for buf_line in buf:
|
||||||
click.secho(line)
|
click.secho(buf_line)
|
||||||
buf = []
|
buf = []
|
||||||
else:
|
else:
|
||||||
click.secho(line)
|
click.secho(line)
|
||||||
|
@ -933,7 +974,7 @@ class MyCli(object):
|
||||||
string = string.replace('\\u', sqlexecute.user or '(none)')
|
string = string.replace('\\u', sqlexecute.user or '(none)')
|
||||||
string = string.replace('\\h', host or '(none)')
|
string = string.replace('\\h', host or '(none)')
|
||||||
string = string.replace('\\d', sqlexecute.dbname or '(none)')
|
string = string.replace('\\d', sqlexecute.dbname or '(none)')
|
||||||
string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
|
string = string.replace('\\t', sqlexecute.server_info.species.name)
|
||||||
string = string.replace('\\n', "\n")
|
string = string.replace('\\n', "\n")
|
||||||
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
|
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
|
||||||
string = string.replace('\\m', now.strftime('%M'))
|
string = string.replace('\\m', now.strftime('%M'))
|
||||||
|
@ -1083,7 +1124,7 @@ class MyCli(object):
|
||||||
help='Warn before running a destructive query.')
|
help='Warn before running a destructive query.')
|
||||||
@click.option('--local-infile', type=bool,
|
@click.option('--local-infile', type=bool,
|
||||||
help='Enable/disable LOAD DATA LOCAL INFILE.')
|
help='Enable/disable LOAD DATA LOCAL INFILE.')
|
||||||
@click.option('--login-path', type=str,
|
@click.option('-g', '--login-path', type=str,
|
||||||
help='Read this path from the login file.')
|
help='Read this path from the login file.')
|
||||||
@click.option('-e', '--execute', type=str,
|
@click.option('-e', '--execute', type=str,
|
||||||
help='Execute command and quit.')
|
help='Execute command and quit.')
|
||||||
|
@ -1091,6 +1132,8 @@ class MyCli(object):
|
||||||
help='SQL statement to execute after connecting.')
|
help='SQL statement to execute after connecting.')
|
||||||
@click.option('--charset', type=str,
|
@click.option('--charset', type=str,
|
||||||
help='Character set for MySQL session.')
|
help='Character set for MySQL session.')
|
||||||
|
@click.option('--password-file', type=click.Path(),
|
||||||
|
help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
|
||||||
@click.argument('database', default='', nargs=1)
|
@click.argument('database', default='', nargs=1)
|
||||||
def cli(database, user, host, port, socket, password, dbname,
|
def cli(database, user, host, port, socket, password, dbname,
|
||||||
version, verbose, prompt, logfile, defaults_group_suffix,
|
version, verbose, prompt, logfile, defaults_group_suffix,
|
||||||
|
@ -1099,7 +1142,7 @@ def cli(database, user, host, port, socket, password, dbname,
|
||||||
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
|
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
|
||||||
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
|
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
|
||||||
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
|
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
|
||||||
init_command, charset):
|
init_command, charset, password_file):
|
||||||
"""A MySQL terminal client with auto-completion and syntax highlighting.
|
"""A MySQL terminal client with auto-completion and syntax highlighting.
|
||||||
|
|
||||||
\b
|
\b
|
||||||
|
@ -1225,7 +1268,8 @@ def cli(database, user, host, port, socket, password, dbname,
|
||||||
ssh_password=ssh_password,
|
ssh_password=ssh_password,
|
||||||
ssh_key_filename=ssh_key_filename,
|
ssh_key_filename=ssh_key_filename,
|
||||||
init_command=init_command,
|
init_command=init_command,
|
||||||
charset=charset
|
charset=charset,
|
||||||
|
password_file=password_file
|
||||||
)
|
)
|
||||||
|
|
||||||
mycli.logger.debug('Launch Params: \n'
|
mycli.logger.debug('Launch Params: \n'
|
||||||
|
@ -1328,9 +1372,15 @@ def is_select(status):
|
||||||
return status.split(None, 1)[0].lower() == 'select'
|
return status.split(None, 1)[0].lower() == 'select'
|
||||||
|
|
||||||
|
|
||||||
def thanks_picker(files=()):
|
def thanks_picker():
|
||||||
|
import mycli
|
||||||
|
lines = (
|
||||||
|
resources.read_text(mycli, 'AUTHORS') +
|
||||||
|
resources.read_text(mycli, 'SPONSORS')
|
||||||
|
).split('\n')
|
||||||
|
|
||||||
contents = []
|
contents = []
|
||||||
for line in fileinput.input(files=files):
|
for line in lines:
|
||||||
m = re.match(r'^ *\* (.*)', line)
|
m = re.match(r'^ *\* (.*)', line)
|
||||||
if m:
|
if m:
|
||||||
contents.append(m.group(1))
|
contents.append(m.group(1))
|
||||||
|
@ -1350,6 +1400,9 @@ def read_ssh_config(ssh_config_path):
|
||||||
try:
|
try:
|
||||||
with open(ssh_config_path) as f:
|
with open(ssh_config_path) as f:
|
||||||
ssh_config.parse(f)
|
ssh_config.parse(f)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
click.secho(str(e), err=True, fg='red')
|
||||||
|
sys.exit(1)
|
||||||
# Paramiko prior to version 2.7 raises Exception on parse errors.
|
# Paramiko prior to version 2.7 raises Exception on parse errors.
|
||||||
# In 2.7 it has become paramiko.ssh_exception.SSHException,
|
# In 2.7 it has become paramiko.ssh_exception.SSHException,
|
||||||
# but let's catch everything for compatibility
|
# but let's catch everything for compatibility
|
||||||
|
@ -1359,9 +1412,6 @@ def read_ssh_config(ssh_config_path):
|
||||||
err=True, fg='red'
|
err=True, fg='red'
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except FileNotFoundError as e:
|
|
||||||
click.secho(str(e), err=True, fg='red')
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
else:
|
||||||
return ssh_config
|
return ssh_config
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from sqlparse.sql import Comparison, Identifier, Where
|
from sqlparse.sql import Comparison, Identifier, Where
|
||||||
from .parseutils import last_word, extract_tables, find_prev_keyword
|
from .parseutils import last_word, extract_tables, find_prev_keyword
|
||||||
|
|
|
@ -12,7 +12,8 @@ cleanup_regex = {
|
||||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||||
# This matches everything except a space.
|
# This matches everything except a space.
|
||||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def last_word(text, include='alphanum_underscore'):
|
def last_word(text, include='alphanum_underscore'):
|
||||||
r"""
|
r"""
|
||||||
|
@ -226,14 +227,6 @@ def is_destructive(queries):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_open_quote(sql):
|
|
||||||
"""Returns true if the query contains an unclosed quote."""
|
|
||||||
|
|
||||||
# parsed can contain one or more semi-colon separated commands
|
|
||||||
parsed = sqlparse.parse(sql)
|
|
||||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
sql = 'select * from (select t. from tabl t'
|
sql = 'select * from (select t. from tabl t'
|
||||||
print (extract_tables(sql))
|
print (extract_tables(sql))
|
||||||
|
@ -263,5 +256,4 @@ def is_dropping_database(queries, dbname):
|
||||||
)
|
)
|
||||||
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
|
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
|
||||||
result = keywords[0].normalized == "DROP"
|
result = keywords[0].normalized == "DROP"
|
||||||
else:
|
return result
|
||||||
return result
|
|
||||||
|
|
|
@ -302,7 +302,7 @@ def execute_system_command(arg, **_):
|
||||||
usage = "Syntax: system [command].\n"
|
usage = "Syntax: system [command].\n"
|
||||||
|
|
||||||
if not arg:
|
if not arg:
|
||||||
return [(None, None, None, usage)]
|
return [(None, None, None, usage)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
command = arg.strip()
|
command = arg.strip()
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""Format adapter for sql."""
|
"""Format adapter for sql."""
|
||||||
|
|
||||||
from cli_helpers.utils import filter_dict_by_key
|
|
||||||
from mycli.packages.parseutils import extract_tables
|
from mycli.packages.parseutils import extract_tables
|
||||||
|
|
||||||
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
|
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
|
||||||
|
|
|
@ -72,7 +72,7 @@ class SQLCompleter(Completer):
|
||||||
if name and ((not self.name_pattern.match(name))
|
if name and ((not self.name_pattern.match(name))
|
||||||
or (name.upper() in self.reserved_words)
|
or (name.upper() in self.reserved_words)
|
||||||
or (name.upper() in self.functions)):
|
or (name.upper() in self.functions)):
|
||||||
name = '`%s`' % name
|
name = '`%s`' % name
|
||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
import pymysql
|
import pymysql
|
||||||
import sqlparse
|
|
||||||
from .packages import special
|
from .packages import special
|
||||||
from pymysql.constants import FIELD_TYPE
|
from pymysql.constants import FIELD_TYPE
|
||||||
from pymysql.converters import (convert_datetime,
|
from pymysql.converters import (convert_datetime,
|
||||||
|
@ -18,17 +20,71 @@ FIELD_TYPES.update({
|
||||||
FIELD_TYPE.NULL: type(None)
|
FIELD_TYPE.NULL: type(None)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
ERROR_CODE_ACCESS_DENIED = 1045
|
||||||
|
|
||||||
|
|
||||||
|
class ServerSpecies(enum.Enum):
|
||||||
|
MySQL = 'MySQL'
|
||||||
|
MariaDB = 'MariaDB'
|
||||||
|
Percona = 'Percona'
|
||||||
|
Unknown = 'MySQL'
|
||||||
|
|
||||||
|
|
||||||
|
class ServerInfo:
|
||||||
|
def __init__(self, species, version_str):
|
||||||
|
self.species = species
|
||||||
|
self.version_str = version_str
|
||||||
|
self.version = self.calc_mysql_version_value(version_str)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calc_mysql_version_value(version_str) -> int:
|
||||||
|
if not version_str or not isinstance(version_str, str):
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
major, minor, patch = version_str.split('.')
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return int(major) * 10_000 + int(minor) * 100 + int(patch)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_version_string(cls, version_string):
|
||||||
|
if not version_string:
|
||||||
|
return cls(ServerSpecies.Unknown, '')
|
||||||
|
|
||||||
|
re_species = (
|
||||||
|
(r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
|
||||||
|
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
|
||||||
|
ServerSpecies.Percona),
|
||||||
|
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
|
||||||
|
ServerSpecies.MySQL),
|
||||||
|
)
|
||||||
|
for regexp, species in re_species:
|
||||||
|
match = re.search(regexp, version_string)
|
||||||
|
if match is not None:
|
||||||
|
parsed_version = match.group('version')
|
||||||
|
detected_species = species
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
detected_species = ServerSpecies.Unknown
|
||||||
|
parsed_version = ''
|
||||||
|
|
||||||
|
return cls(detected_species, parsed_version)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.species:
|
||||||
|
return f'{self.species.value} {self.version_str}'
|
||||||
|
else:
|
||||||
|
return self.version_str
|
||||||
|
|
||||||
|
|
||||||
class SQLExecute(object):
|
class SQLExecute(object):
|
||||||
|
|
||||||
databases_query = '''SHOW DATABASES'''
|
databases_query = '''SHOW DATABASES'''
|
||||||
|
|
||||||
tables_query = '''SHOW TABLES'''
|
tables_query = '''SHOW TABLES'''
|
||||||
|
|
||||||
version_query = '''SELECT @@VERSION'''
|
|
||||||
|
|
||||||
version_comment_query = '''SELECT @@VERSION_COMMENT'''
|
|
||||||
version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
|
|
||||||
|
|
||||||
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
|
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
|
||||||
|
|
||||||
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
|
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
|
||||||
|
@ -52,7 +108,7 @@ class SQLExecute(object):
|
||||||
self.charset = charset
|
self.charset = charset
|
||||||
self.local_infile = local_infile
|
self.local_infile = local_infile
|
||||||
self.ssl = ssl
|
self.ssl = ssl
|
||||||
self._server_type = None
|
self.server_info = None
|
||||||
self.connection_id = None
|
self.connection_id = None
|
||||||
self.ssh_user = ssh_user
|
self.ssh_user = ssh_user
|
||||||
self.ssh_host = ssh_host
|
self.ssh_host = ssh_host
|
||||||
|
@ -157,6 +213,7 @@ class SQLExecute(object):
|
||||||
self.init_command = init_command
|
self.init_command = init_command
|
||||||
# retrieve connection id
|
# retrieve connection id
|
||||||
self.reset_connection_id()
|
self.reset_connection_id()
|
||||||
|
self.server_info = ServerInfo.from_version_string(conn.server_version)
|
||||||
|
|
||||||
def run(self, statement):
|
def run(self, statement):
|
||||||
"""Execute the sql in the database and return the results. The results
|
"""Execute the sql in the database and return the results. The results
|
||||||
|
@ -273,37 +330,6 @@ class SQLExecute(object):
|
||||||
for row in cur:
|
for row in cur:
|
||||||
yield row
|
yield row
|
||||||
|
|
||||||
def server_type(self):
|
|
||||||
if self._server_type:
|
|
||||||
return self._server_type
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
_logger.debug('Version Query. sql: %r', self.version_query)
|
|
||||||
cur.execute(self.version_query)
|
|
||||||
version = cur.fetchone()[0]
|
|
||||||
if version[0] == '4':
|
|
||||||
_logger.debug('Version Comment. sql: %r',
|
|
||||||
self.version_comment_query_mysql4)
|
|
||||||
cur.execute(self.version_comment_query_mysql4)
|
|
||||||
version_comment = cur.fetchone()[1].lower()
|
|
||||||
if isinstance(version_comment, bytes):
|
|
||||||
# with python3 this query returns bytes
|
|
||||||
version_comment = version_comment.decode('utf-8')
|
|
||||||
else:
|
|
||||||
_logger.debug('Version Comment. sql: %r',
|
|
||||||
self.version_comment_query)
|
|
||||||
cur.execute(self.version_comment_query)
|
|
||||||
version_comment = cur.fetchone()[0].lower()
|
|
||||||
|
|
||||||
if 'mariadb' in version_comment:
|
|
||||||
product_type = 'mariadb'
|
|
||||||
elif 'percona' in version_comment:
|
|
||||||
product_type = 'percona'
|
|
||||||
else:
|
|
||||||
product_type = 'mysql'
|
|
||||||
|
|
||||||
self._server_type = (product_type, version)
|
|
||||||
return self._server_type
|
|
||||||
|
|
||||||
def get_connection_id(self):
|
def get_connection_id(self):
|
||||||
if not self.connection_id:
|
if not self.connection_id:
|
||||||
self.reset_connection_id()
|
self.reset_connection_id()
|
||||||
|
|
2
pytest.ini
Normal file
2
pytest.ini
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
[pytest]
|
||||||
|
addopts = --ignore=mycli/packages/paramiko_stub/__init__.py
|
|
@ -1,6 +1,5 @@
|
||||||
"""A script to publish a release of mycli to PyPI."""
|
"""A script to publish a release of mycli to PyPI."""
|
||||||
|
|
||||||
import io
|
|
||||||
from optparse import OptionParser
|
from optparse import OptionParser
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
8
setup.py
8
setup.py
|
@ -18,16 +18,20 @@ description = 'CLI for MySQL Database. With auto-completion and syntax highlight
|
||||||
|
|
||||||
install_requirements = [
|
install_requirements = [
|
||||||
'click >= 7.0',
|
'click >= 7.0',
|
||||||
|
'cryptography >= 1.0.0',
|
||||||
'Pygments >= 1.6',
|
'Pygments >= 1.6',
|
||||||
'prompt_toolkit>=3.0.6,<4.0.0',
|
'prompt_toolkit>=3.0.6,<4.0.0',
|
||||||
'PyMySQL >= 0.9.2',
|
'PyMySQL >= 0.9.2',
|
||||||
'sqlparse>=0.3.0,<0.4.0',
|
'sqlparse>=0.3.0,<0.4.0',
|
||||||
'configobj >= 5.0.5',
|
'configobj >= 5.0.5',
|
||||||
'cryptography >= 1.0.0',
|
|
||||||
'cli_helpers[styles] >= 2.0.1',
|
'cli_helpers[styles] >= 2.0.1',
|
||||||
'pyperclip >= 1.8.1'
|
'pyperclip >= 1.8.1',
|
||||||
|
'pyaes >= 1.6.1'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if sys.version_info.minor < 9:
|
||||||
|
install_requirements.append('importlib_resources >= 5.0.0')
|
||||||
|
|
||||||
|
|
||||||
class lint(Command):
|
class lint(Command):
|
||||||
description = 'check code against PEP 8 (and fix violations)'
|
description = 'check code against PEP 8 (and fix violations)'
|
||||||
|
|
35
test/features/connection.feature
Normal file
35
test/features/connection.feature
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
Feature: connect to a database:
|
||||||
|
|
||||||
|
@requires_local_db
|
||||||
|
Scenario: run mycli on localhost without port
|
||||||
|
When we run mycli with arguments "host=localhost" without arguments "port"
|
||||||
|
When we query "status"
|
||||||
|
Then status contains "via UNIX socket"
|
||||||
|
|
||||||
|
Scenario: run mycli on TCP host without port
|
||||||
|
When we run mycli without arguments "port"
|
||||||
|
When we query "status"
|
||||||
|
Then status contains "via TCP/IP"
|
||||||
|
|
||||||
|
Scenario: run mycli with port but without host
|
||||||
|
When we run mycli without arguments "host"
|
||||||
|
When we query "status"
|
||||||
|
Then status contains "via TCP/IP"
|
||||||
|
|
||||||
|
@requires_local_db
|
||||||
|
Scenario: run mycli without host and port
|
||||||
|
When we run mycli without arguments "host port"
|
||||||
|
When we query "status"
|
||||||
|
Then status contains "via UNIX socket"
|
||||||
|
|
||||||
|
Scenario: run mycli with my.cnf configuration
|
||||||
|
When we create my.cnf file
|
||||||
|
When we run mycli without arguments "host port user pass defaults_file"
|
||||||
|
Then we are logged in
|
||||||
|
|
||||||
|
Scenario: run mycli with mylogin.cnf configuration
|
||||||
|
When we create mylogin.cnf file
|
||||||
|
When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file"
|
||||||
|
Then we are logged in
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from tempfile import mkstemp
|
from tempfile import mkstemp
|
||||||
|
|
||||||
|
@ -11,6 +12,24 @@ from steps.wrappers import run_cli, wait_prompt
|
||||||
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
|
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
|
||||||
|
|
||||||
|
|
||||||
|
SELF_CONNECTING_FEATURES = (
|
||||||
|
'test/features/connection.feature',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
|
||||||
|
MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
|
||||||
|
MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
|
||||||
|
MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_name_from_context(context):
|
||||||
|
return context.config.userdata.get(
|
||||||
|
'my_test_db', None
|
||||||
|
) or "mycli_behave_tests"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def before_all(context):
|
def before_all(context):
|
||||||
"""Set env parameters."""
|
"""Set env parameters."""
|
||||||
os.environ['LINES'] = "100"
|
os.environ['LINES'] = "100"
|
||||||
|
@ -22,7 +41,7 @@ def before_all(context):
|
||||||
|
|
||||||
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||||
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
|
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
|
||||||
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
|
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
|
||||||
|
|
||||||
context.package_root = os.path.abspath(
|
context.package_root = os.path.abspath(
|
||||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
|
@ -33,8 +52,7 @@ def before_all(context):
|
||||||
context.exit_sent = False
|
context.exit_sent = False
|
||||||
|
|
||||||
vi = '_'.join([str(x) for x in sys.version_info[:3]])
|
vi = '_'.join([str(x) for x in sys.version_info[:3]])
|
||||||
db_name = context.config.userdata.get(
|
db_name = get_db_name_from_context(context)
|
||||||
'my_test_db', None) or "mycli_behave_tests"
|
|
||||||
db_name_full = '{0}_{1}'.format(db_name, vi)
|
db_name_full = '{0}_{1}'.format(db_name, vi)
|
||||||
|
|
||||||
# Store get params from config/environment variables
|
# Store get params from config/environment variables
|
||||||
|
@ -104,11 +122,18 @@ def before_step(context, _):
|
||||||
context.atprompt = False
|
context.atprompt = False
|
||||||
|
|
||||||
|
|
||||||
def before_scenario(context, _):
|
def before_scenario(context, arg):
|
||||||
with open(test_log_file, 'w') as f:
|
with open(test_log_file, 'w') as f:
|
||||||
f.write('')
|
f.write('')
|
||||||
run_cli(context)
|
if arg.location.filename not in SELF_CONNECTING_FEATURES:
|
||||||
wait_prompt(context)
|
run_cli(context)
|
||||||
|
wait_prompt(context)
|
||||||
|
|
||||||
|
if os.path.exists(MY_CNF_PATH):
|
||||||
|
shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH)
|
||||||
|
|
||||||
|
if os.path.exists(MYLOGIN_CNF_PATH):
|
||||||
|
shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH)
|
||||||
|
|
||||||
|
|
||||||
def after_scenario(context, _):
|
def after_scenario(context, _):
|
||||||
|
@ -134,6 +159,17 @@ def after_scenario(context, _):
|
||||||
context.cli.sendcontrol('d')
|
context.cli.sendcontrol('d')
|
||||||
context.cli.expect_exact(pexpect.EOF, timeout=5)
|
context.cli.expect_exact(pexpect.EOF, timeout=5)
|
||||||
|
|
||||||
|
if os.path.exists(MY_CNF_BACKUP_PATH):
|
||||||
|
shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH)
|
||||||
|
|
||||||
|
if os.path.exists(MYLOGIN_CNF_BACKUP_PATH):
|
||||||
|
shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH)
|
||||||
|
elif os.path.exists(MYLOGIN_CNF_PATH):
|
||||||
|
# This file was moved in `before_scenario`.
|
||||||
|
# If it exists now, it has been created during a test
|
||||||
|
os.remove(MYLOGIN_CNF_PATH)
|
||||||
|
|
||||||
|
|
||||||
# TODO: uncomment to debug a failure
|
# TODO: uncomment to debug a failure
|
||||||
# def after_step(context, step):
|
# def after_step(context, step):
|
||||||
# if step.status == "failed":
|
# if step.status == "failed":
|
||||||
|
|
|
@ -3,11 +3,12 @@ from textwrap import dedent
|
||||||
from behave import then, when
|
from behave import then, when
|
||||||
|
|
||||||
import wrappers
|
import wrappers
|
||||||
|
from utils import parse_cli_args_to_dict
|
||||||
|
|
||||||
|
|
||||||
@when('we run dbcli with {arg}')
|
@when('we run dbcli with {arg}')
|
||||||
def step_run_cli_with_arg(context, arg):
|
def step_run_cli_with_arg(context, arg):
|
||||||
wrappers.run_cli(context, run_args=arg.split('='))
|
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
|
||||||
|
|
||||||
|
|
||||||
@when('we execute a small query')
|
@when('we execute a small query')
|
||||||
|
|
71
test/features/steps/connection.py
Normal file
71
test/features/steps/connection.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
from behave import when, then
|
||||||
|
import pexpect
|
||||||
|
|
||||||
|
import wrappers
|
||||||
|
from test.features.steps.utils import parse_cli_args_to_dict
|
||||||
|
from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context
|
||||||
|
from test.utils import HOST, PORT, USER, PASSWORD
|
||||||
|
from mycli.config import encrypt_mylogin_cnf
|
||||||
|
|
||||||
|
|
||||||
|
TEST_LOGIN_PATH = 'test_login_path'
|
||||||
|
|
||||||
|
|
||||||
|
@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
|
||||||
|
@when('we run mycli without arguments "{excluded_args}"')
|
||||||
|
def step_run_cli_without_args(context, excluded_args, exact_args=''):
|
||||||
|
wrappers.run_cli(
|
||||||
|
context,
|
||||||
|
run_args=parse_cli_args_to_dict(exact_args),
|
||||||
|
exclude_args=parse_cli_args_to_dict(excluded_args).keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@then('status contains "{expression}"')
|
||||||
|
def status_contains(context, expression):
|
||||||
|
wrappers.expect_exact(context, f'{expression}', timeout=5)
|
||||||
|
|
||||||
|
# Normally, the shutdown after scenario waits for the prompt.
|
||||||
|
# But we may have changed the prompt, depending on parameters,
|
||||||
|
# so let's wait for its last character
|
||||||
|
context.cli.expect_exact('>')
|
||||||
|
context.atprompt = True
|
||||||
|
|
||||||
|
|
||||||
|
@when('we create my.cnf file')
|
||||||
|
def step_create_my_cnf_file(context):
|
||||||
|
my_cnf = (
|
||||||
|
'[client]\n'
|
||||||
|
f'host = {HOST}\n'
|
||||||
|
f'port = {PORT}\n'
|
||||||
|
f'user = {USER}\n'
|
||||||
|
f'password = {PASSWORD}\n'
|
||||||
|
)
|
||||||
|
with open(MY_CNF_PATH, 'w') as f:
|
||||||
|
f.write(my_cnf)
|
||||||
|
|
||||||
|
|
||||||
|
@when('we create mylogin.cnf file')
|
||||||
|
def step_create_mylogin_cnf_file(context):
|
||||||
|
os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
|
||||||
|
mylogin_cnf = (
|
||||||
|
f'[{TEST_LOGIN_PATH}]\n'
|
||||||
|
f'host = {HOST}\n'
|
||||||
|
f'port = {PORT}\n'
|
||||||
|
f'user = {USER}\n'
|
||||||
|
f'password = {PASSWORD}\n'
|
||||||
|
)
|
||||||
|
with open(MYLOGIN_CNF_PATH, 'wb') as f:
|
||||||
|
input_file = io.StringIO(mylogin_cnf)
|
||||||
|
f.write(encrypt_mylogin_cnf(input_file).read())
|
||||||
|
|
||||||
|
|
||||||
|
@then('we are logged in')
|
||||||
|
def we_are_logged_in(context):
|
||||||
|
db_name = get_db_name_from_context(context)
|
||||||
|
context.cli.expect_exact(f'{db_name}>', timeout=5)
|
||||||
|
context.atprompt = True
|
12
test/features/steps/utils.py
Normal file
12
test/features/steps/utils.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli_args_to_dict(cli_args: str):
|
||||||
|
args_dict = {}
|
||||||
|
for arg in shlex.split(cli_args):
|
||||||
|
if '=' in arg:
|
||||||
|
key, value = arg.split('=')
|
||||||
|
args_dict[key] = value
|
||||||
|
else:
|
||||||
|
args_dict[arg] = None
|
||||||
|
return args_dict
|
|
@ -3,6 +3,7 @@ import pexpect
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from StringIO import StringIO
|
from StringIO import StringIO
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -13,7 +14,7 @@ def expect_exact(context, expected, timeout):
|
||||||
timedout = False
|
timedout = False
|
||||||
try:
|
try:
|
||||||
context.cli.expect_exact(expected, timeout=timeout)
|
context.cli.expect_exact(expected, timeout=timeout)
|
||||||
except pexpect.exceptions.TIMEOUT:
|
except pexpect.TIMEOUT:
|
||||||
timedout = True
|
timedout = True
|
||||||
if timedout:
|
if timedout:
|
||||||
# Strip color codes out of the output.
|
# Strip color codes out of the output.
|
||||||
|
@ -46,21 +47,43 @@ def expect_pager(context, expected, timeout):
|
||||||
context.conf['pager_boundary'], expected), timeout=timeout)
|
context.conf['pager_boundary'], expected), timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
def run_cli(context, run_args=None):
|
def run_cli(context, run_args=None, exclude_args=None):
|
||||||
"""Run the process using pexpect."""
|
"""Run the process using pexpect."""
|
||||||
run_args = run_args or []
|
run_args = run_args or {}
|
||||||
if context.conf.get('host', None):
|
rendered_args = []
|
||||||
run_args.extend(('-h', context.conf['host']))
|
exclude_args = set(exclude_args) if exclude_args else set()
|
||||||
if context.conf.get('user', None):
|
|
||||||
run_args.extend(('-u', context.conf['user']))
|
conf = dict(**context.conf)
|
||||||
if context.conf.get('pass', None):
|
conf.update(run_args)
|
||||||
run_args.extend(('-p', context.conf['pass']))
|
|
||||||
if context.conf.get('dbname', None):
|
def add_arg(name, key, value):
|
||||||
run_args.extend(('-D', context.conf['dbname']))
|
if name not in exclude_args:
|
||||||
if context.conf.get('defaults-file', None):
|
if value is not None:
|
||||||
run_args.extend(('--defaults-file', context.conf['defaults-file']))
|
rendered_args.extend((key, value))
|
||||||
if context.conf.get('myclirc', None):
|
else:
|
||||||
run_args.extend(('--myclirc', context.conf['myclirc']))
|
rendered_args.append(key)
|
||||||
|
|
||||||
|
if conf.get('host', None):
|
||||||
|
add_arg('host', '-h', conf['host'])
|
||||||
|
if conf.get('user', None):
|
||||||
|
add_arg('user', '-u', conf['user'])
|
||||||
|
if conf.get('pass', None):
|
||||||
|
add_arg('pass', '-p', conf['pass'])
|
||||||
|
if conf.get('port', None):
|
||||||
|
add_arg('port', '-P', str(conf['port']))
|
||||||
|
if conf.get('dbname', None):
|
||||||
|
add_arg('dbname', '-D', conf['dbname'])
|
||||||
|
if conf.get('defaults-file', None):
|
||||||
|
add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
|
||||||
|
if conf.get('myclirc', None):
|
||||||
|
add_arg('myclirc', '--myclirc', conf['myclirc'])
|
||||||
|
if conf.get('login_path'):
|
||||||
|
add_arg('login_path', '--login-path', conf['login_path'])
|
||||||
|
|
||||||
|
for arg_name, arg_value in conf.items():
|
||||||
|
if arg_name.startswith('-'):
|
||||||
|
add_arg(arg_name, arg_name, arg_value)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cli_cmd = context.conf['cli_command']
|
cli_cmd = context.conf['cli_command']
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -73,7 +96,7 @@ def run_cli(context, run_args=None):
|
||||||
'"'
|
'"'
|
||||||
).format(sys.executable)
|
).format(sys.executable)
|
||||||
|
|
||||||
cmd_parts = [cli_cmd] + run_args
|
cmd_parts = [cli_cmd] + rendered_args
|
||||||
cmd = ' '.join(cmd_parts)
|
cmd = ' '.join(cmd_parts)
|
||||||
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
|
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
|
||||||
context.logfile = StringIO()
|
context.logfile = StringIO()
|
||||||
|
|
|
@ -3,8 +3,9 @@ import os
|
||||||
import click
|
import click
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT
|
from mycli.main import MyCli, cli, thanks_picker
|
||||||
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
|
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
|
||||||
|
from mycli.sqlexecute import ServerInfo
|
||||||
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
|
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
|
||||||
|
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
@ -140,10 +141,7 @@ def test_batch_mode_csv(executor):
|
||||||
|
|
||||||
|
|
||||||
def test_thanks_picker_utf8():
|
def test_thanks_picker_utf8():
|
||||||
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
|
name = thanks_picker()
|
||||||
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
|
|
||||||
|
|
||||||
name = thanks_picker((author_file, sponsor_file))
|
|
||||||
assert name and isinstance(name, str)
|
assert name and isinstance(name, str)
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
|
||||||
host = 'test'
|
host = 'test'
|
||||||
user = 'test'
|
user = 'test'
|
||||||
dbname = 'test'
|
dbname = 'test'
|
||||||
|
server_info = ServerInfo.from_version_string('unknown')
|
||||||
port = 0
|
port = 0
|
||||||
|
|
||||||
def server_type(self):
|
def server_type(self):
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pymysql
|
import pymysql
|
||||||
|
|
||||||
|
from mycli.sqlexecute import ServerInfo, ServerSpecies
|
||||||
from .utils import run, dbtest, set_expanded_output, is_expanded_output
|
from .utils import run, dbtest, set_expanded_output, is_expanded_output
|
||||||
|
|
||||||
|
|
||||||
|
@ -270,3 +271,24 @@ def test_multiple_results(executor):
|
||||||
'status': '1 row in set'}
|
'status': '1 row in set'}
|
||||||
]
|
]
|
||||||
assert results == expected
|
assert results == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'version_string, species, parsed_version_string, version',
|
||||||
|
(
|
||||||
|
('5.7.32-35', 'Percona', '5.7.32', 50732),
|
||||||
|
('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
|
||||||
|
('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
|
||||||
|
('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
|
||||||
|
('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
|
||||||
|
('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
|
||||||
|
('unexpected version string', None, '', 0),
|
||||||
|
('', None, '', 0),
|
||||||
|
(None, None, '', 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def test_version_parsing(version_string, species, parsed_version_string, version):
|
||||||
|
server_info = ServerInfo.from_version_string(version_string)
|
||||||
|
assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
|
||||||
|
assert server_info.version_str == parsed_version_string
|
||||||
|
assert server_info.version == version
|
||||||
|
|
Loading…
Add table
Reference in a new issue