1
0
Fork 0

Adding upstream version 1.29.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 19:15:57 +01:00
parent 5bd6a68e8c
commit f9065f1bef
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
68 changed files with 3723 additions and 3336 deletions

View file

@ -1,3 +1,2 @@
[run]
parallel = True
source = mycli

View file

@ -4,34 +4,21 @@ on:
pull_request:
paths-ignore:
- '**.md'
- 'AUTHORS'
jobs:
linux:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [
'3.8',
'3.9',
'3.10',
'3.11',
'3.12',
]
include:
- python-version: '3.8'
os: ubuntu-20.04 # MySQL 8.0.36
- python-version: '3.9'
os: ubuntu-20.04 # MySQL 8.0.36
- python-version: '3.10'
os: ubuntu-22.04 # MySQL 8.0.36
- python-version: '3.11'
os: ubuntu-22.04 # MySQL 8.0.36
- python-version: '3.12'
os: ubuntu-22.04 # MySQL 8.0.36
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v1
with:
version: "latest"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
@ -43,10 +30,7 @@ jobs:
sudo /etc/init.d/mysql start
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
pip install --no-cache-dir -e .
run: uv sync --all-extras -p ${{ matrix.python-version }}
- name: Wait for MySQL connection
run: |
@ -59,13 +43,7 @@ jobs:
PYTEST_PASSWORD: root
PYTEST_HOST: 127.0.0.1
run: |
./setup.py test --pytest-args="--cov-report= --cov=mycli"
uv run tox -e py${{ matrix.python-version }}
- name: Lint
run: |
./setup.py lint --branch=HEAD
- name: Coverage
run: |
coverage combine
coverage report
- name: Run Style Checks
run: uv run tox -e style

94
.github/workflows/publish.yml vendored Normal file
View file

@ -0,0 +1,94 @@
name: Publish Python Package
on:
release:
types: [created]
permissions:
contents: read
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v1
with:
version: "latest"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Start MySQL
run: |
sudo /etc/init.d/mysql start
- name: Install dependencies
run: uv sync --all-extras -p ${{ matrix.python-version }}
- name: Wait for MySQL connection
run: |
while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do
sleep 5
done
- name: Pytest / behave
env:
PYTEST_PASSWORD: root
PYTEST_HOST: 127.0.0.1
run: |
uv run tox -e py${{ matrix.python-version }}
- name: Run Style Checks
run: uv run tox -e style
build:
runs-on: ubuntu-latest
needs: [test]
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v1
with:
version: "latest"
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.13'
- name: Install dependencies
run: uv sync --all-extras -p 3.13
- name: Build
run: uv build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-packages
path: dist/
publish:
name: Publish to PyPI
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
needs: [build]
environment: release
permissions:
id-token: write
steps:
- name: Download distribution packages
uses: actions/download-artifact@v4
with:
name: python-packages
path: dist/
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View file

@ -1,3 +1,34 @@
1.29.2 (2024/12/11)
===================
Internal
--------
* Exclude tests from the python package.
1.29.1 (2024/12/11)
===================
Internal
--------
* Fix the GH actions to publish a new version.
1.29.0 (NEVER RELEASED)
=======================
Bug Fixes
----------
* fix SSL through SSH jump host by using a true python socket for a tunnel
* Fix mycli crash when connecting to Vitess
Internal
---------
* Modernize to use PEP-621. Use `uv` instead of `pip` in GH actions.
* Remove Python 3.8 and add Python 3.13 in test matrix.
1.28.0 (2024/11/10)
======================

View file

@ -98,6 +98,7 @@ Contributors:
* Houston Wong
* Mohamed Rezk
* Ryosuke Kazami
* Cornel Cruceru
Created by:

View file

@ -1 +1,3 @@
__version__ = "1.28.0"
import importlib.metadata
__version__ = importlib.metadata.version("mycli")

View file

@ -13,6 +13,7 @@ def cli_is_multiline(mycli):
return False
else:
return not _multiline_exception(doc.text)
return cond
@ -23,33 +24,32 @@ def _multiline_exception(text):
# Multi-statement favorite query is a special case. Because there will
# be a semicolon separating statements, we can't consider semicolon an
# EOL. Let's consider an empty line an EOL instead.
if text.startswith('\\fs'):
return orig.endswith('\n')
if text.startswith("\\fs"):
return orig.endswith("\n")
return (
# Special Command
text.startswith('\\') or
text.startswith("\\")
or
# Delimiter declaration
text.lower().startswith('delimiter') or
text.lower().startswith("delimiter")
or
# Ended with the current delimiter (usually a semi-column)
text.endswith(special.get_current_delimiter()) or
text.endswith('\\g') or
text.endswith('\\G') or
text.endswith(r'\e') or
text.endswith(r'\clip') or
text.endswith(special.get_current_delimiter())
or text.endswith("\\g")
or text.endswith("\\G")
or text.endswith(r"\e")
or text.endswith(r"\clip")
or
# Exit doesn't need semi-column`
(text == 'exit') or
(text == "exit")
or
# Quit doesn't need semi-column
(text == 'quit') or
(text == "quit")
or
# To all teh vim fans out there
(text == ':q') or
(text == ":q")
or
# just a plain enter without any text
(text == '')
(text == "")
)

View file

@ -11,70 +11,69 @@ logger = logging.getLogger(__name__)
# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
TOKEN_TO_PROMPT_STYLE = {
Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current',
Token.Menu.Completions.Completion: 'completion-menu.completion',
Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current',
Token.Menu.Completions.Meta: 'completion-menu.meta.completion',
Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta',
Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess
Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess
Token.SelectedText: 'selected',
Token.SearchMatch: 'search',
Token.SearchMatch.Current: 'search.current',
Token.Toolbar: 'bottom-toolbar',
Token.Toolbar.Off: 'bottom-toolbar.off',
Token.Toolbar.On: 'bottom-toolbar.on',
Token.Toolbar.Search: 'search-toolbar',
Token.Toolbar.Search.Text: 'search-toolbar.text',
Token.Toolbar.System: 'system-toolbar',
Token.Toolbar.Arg: 'arg-toolbar',
Token.Toolbar.Arg.Text: 'arg-toolbar.text',
Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid',
Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed',
Token.Output.Header: 'output.header',
Token.Output.OddRow: 'output.odd-row',
Token.Output.EvenRow: 'output.even-row',
Token.Output.Null: 'output.null',
Token.Prompt: 'prompt',
Token.Continuation: 'continuation',
Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
Token.Menu.Completions.Completion: "completion-menu.completion",
Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
Token.Menu.Completions.Meta: "completion-menu.meta.completion",
Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
Token.SelectedText: "selected",
Token.SearchMatch: "search",
Token.SearchMatch.Current: "search.current",
Token.Toolbar: "bottom-toolbar",
Token.Toolbar.Off: "bottom-toolbar.off",
Token.Toolbar.On: "bottom-toolbar.on",
Token.Toolbar.Search: "search-toolbar",
Token.Toolbar.Search.Text: "search-toolbar.text",
Token.Toolbar.System: "system-toolbar",
Token.Toolbar.Arg: "arg-toolbar",
Token.Toolbar.Arg.Text: "arg-toolbar.text",
Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
Token.Output.Header: "output.header",
Token.Output.OddRow: "output.odd-row",
Token.Output.EvenRow: "output.even-row",
Token.Output.Null: "output.null",
Token.Prompt: "prompt",
Token.Continuation: "continuation",
}
# reverse dict for cli_helpers, because they still expect Pygments tokens.
PROMPT_STYLE_TO_TOKEN = {
v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()
}
PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
# all tokens that the Pygments MySQL lexer can produce
OVERRIDE_STYLE_TO_TOKEN = {
'sql.comment': Token.Comment,
'sql.comment.multi-line': Token.Comment.Multiline,
'sql.comment.single-line': Token.Comment.Single,
'sql.comment.optimizer-hint': Token.Comment.Special,
'sql.escape': Token.Error,
'sql.keyword': Token.Keyword,
'sql.datatype': Token.Keyword.Type,
'sql.literal': Token.Literal,
'sql.literal.date': Token.Literal.Date,
'sql.symbol': Token.Name,
'sql.quoted-schema-object': Token.Name.Quoted,
'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape,
'sql.constant': Token.Name.Constant,
'sql.function': Token.Name.Function,
'sql.variable': Token.Name.Variable,
'sql.number': Token.Number,
'sql.number.binary': Token.Number.Bin,
'sql.number.float': Token.Number.Float,
'sql.number.hex': Token.Number.Hex,
'sql.number.integer': Token.Number.Integer,
'sql.operator': Token.Operator,
'sql.punctuation': Token.Punctuation,
'sql.string': Token.String,
'sql.string.double-quouted': Token.String.Double,
'sql.string.escape': Token.String.Escape,
'sql.string.single-quoted': Token.String.Single,
'sql.whitespace': Token.Text,
"sql.comment": Token.Comment,
"sql.comment.multi-line": Token.Comment.Multiline,
"sql.comment.single-line": Token.Comment.Single,
"sql.comment.optimizer-hint": Token.Comment.Special,
"sql.escape": Token.Error,
"sql.keyword": Token.Keyword,
"sql.datatype": Token.Keyword.Type,
"sql.literal": Token.Literal,
"sql.literal.date": Token.Literal.Date,
"sql.symbol": Token.Name,
"sql.quoted-schema-object": Token.Name.Quoted,
"sql.quoted-schema-object.escape": Token.Name.Quoted.Escape,
"sql.constant": Token.Name.Constant,
"sql.function": Token.Name.Function,
"sql.variable": Token.Name.Variable,
"sql.number": Token.Number,
"sql.number.binary": Token.Number.Bin,
"sql.number.float": Token.Number.Float,
"sql.number.hex": Token.Number.Hex,
"sql.number.integer": Token.Number.Integer,
"sql.operator": Token.Operator,
"sql.punctuation": Token.Punctuation,
"sql.string": Token.String,
"sql.string.double-quouted": Token.String.Double,
"sql.string.escape": Token.String.Escape,
"sql.string.single-quoted": Token.String.Single,
"sql.whitespace": Token.Text,
}
def parse_pygments_style(token_name, style_object, style_dict):
"""Parse token type and style string.
@ -87,7 +86,7 @@ def parse_pygments_style(token_name, style_object, style_dict):
try:
other_token_type = string_to_tokentype(style_dict[token_name])
return token_type, style_object.styles[other_token_type]
except AttributeError as err:
except AttributeError:
return token_type, style_dict[token_name]
@ -95,45 +94,39 @@ def style_factory(name, cli_style):
try:
style = pygments.styles.get_style_by_name(name)
except ClassNotFound:
style = pygments.styles.get_style_by_name('native')
style = pygments.styles.get_style_by_name("native")
prompt_styles = []
# prompt-toolkit used pygments tokens for styling before, switched to style
# names in 2.0. Convert old token types to new style names, for backwards compatibility.
for token in cli_style:
if token.startswith('Token.'):
if token.startswith("Token."):
# treat as pygments token (1.0)
token_type, style_value = parse_pygments_style(
token, style, cli_style)
token_type, style_value = parse_pygments_style(token, style, cli_style)
if token_type in TOKEN_TO_PROMPT_STYLE:
prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
prompt_styles.append((prompt_style, style_value))
else:
# we don't want to support tokens anymore
logger.error('Unhandled style / class name: %s', token)
logger.error("Unhandled style / class name: %s", token)
else:
# treat as prompt style name (2.0). See default style names here:
# https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
prompt_styles.append((token, cli_style[token]))
override_style = Style([('bottom-toolbar', 'noreverse')])
return merge_styles([
style_from_pygments_cls(style),
override_style,
Style(prompt_styles)
])
override_style = Style([("bottom-toolbar", "noreverse")])
return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)])
def style_factory_output(name, cli_style):
try:
style = pygments.styles.get_style_by_name(name).styles
except ClassNotFound:
style = pygments.styles.get_style_by_name('native').styles
style = pygments.styles.get_style_by_name("native").styles
for token in cli_style:
if token.startswith('Token.'):
token_type, style_value = parse_pygments_style(
token, style, cli_style)
if token.startswith("Token."):
token_type, style_value = parse_pygments_style(token, style, cli_style)
style.update({token_type: style_value})
elif token in PROMPT_STYLE_TO_TOKEN:
token_type = PROMPT_STYLE_TO_TOKEN[token]
@ -143,7 +136,7 @@ def style_factory_output(name, cli_style):
style.update({token_type: cli_style[token]})
else:
# TODO: cli helpers will have to switch to ptk.Style
logger.error('Unhandled style / class name: %s', token)
logger.error("Unhandled style / class name: %s", token)
class OutputStyle(PygmentsStyle):
default_style = ""

View file

@ -6,52 +6,47 @@ from .packages import special
def create_toolbar_tokens_func(mycli, show_fish_help):
"""Return a function that generates the toolbar tokens."""
def get_toolbar_tokens():
result = [('class:bottom-toolbar', ' ')]
result = [("class:bottom-toolbar", " ")]
if mycli.multi_line:
delimiter = special.get_current_delimiter()
result.append(
(
'class:bottom-toolbar',
' ({} [{}] will end the line) '.format(
'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter)
))
"class:bottom-toolbar",
" ({} [{}] will end the line) ".format("Semi-colon" if delimiter == ";" else "Delimiter", delimiter),
)
)
if mycli.multi_line:
result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON '))
result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
else:
result.append(('class:bottom-toolbar.off',
'[F3] Multiline: OFF '))
result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
if mycli.prompt_app.editing_mode == EditingMode.VI:
result.append((
'class:bottom-toolbar.on',
'Vi-mode ({})'.format(_get_vi_mode())
))
result.append(("class:bottom-toolbar.on", "Vi-mode ({})".format(_get_vi_mode())))
if mycli.toolbar_error_message:
result.append(
('class:bottom-toolbar', ' ' + mycli.toolbar_error_message))
result.append(("class:bottom-toolbar", " " + mycli.toolbar_error_message))
mycli.toolbar_error_message = None
if show_fish_help():
result.append(
('class:bottom-toolbar', ' Right-arrow to complete suggestion'))
result.append(("class:bottom-toolbar", " Right-arrow to complete suggestion"))
if mycli.completion_refresher.is_refreshing():
result.append(
('class:bottom-toolbar', ' Refreshing completions...'))
result.append(("class:bottom-toolbar", " Refreshing completions..."))
return result
return get_toolbar_tokens
def _get_vi_mode():
"""Get the current vi mode for display."""
return {
InputMode.INSERT: 'I',
InputMode.NAVIGATION: 'N',
InputMode.REPLACE: 'R',
InputMode.REPLACE_SINGLE: 'R',
InputMode.INSERT_MULTIPLE: 'M',
InputMode.INSERT: "I",
InputMode.NAVIGATION: "N",
InputMode.REPLACE: "R",
InputMode.REPLACE_SINGLE: "R",
InputMode.INSERT_MULTIPLE: "M",
}[get_app().vi_state.input_mode]

View file

@ -3,4 +3,4 @@
import sys
WIN = sys.platform in ('win32', 'cygwin')
WIN = sys.platform in ("win32", "cygwin")

View file

@ -5,8 +5,8 @@ from collections import OrderedDict
from .sqlcompleter import SQLCompleter
from .sqlexecute import SQLExecute, ServerSpecies
class CompletionRefresher(object):
class CompletionRefresher(object):
refreshers = OrderedDict()
def __init__(self):
@ -30,16 +30,14 @@ class CompletionRefresher(object):
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, 'Auto-completion refresh restarted.')]
return [(None, None, None, "Auto-completion refresh restarted.")]
else:
self._completer_thread = threading.Thread(
target=self._bg_refresh,
args=(executor, callbacks, completer_options),
name='completion_refresh')
target=self._bg_refresh, args=(executor, callbacks, completer_options), name="completion_refresh"
)
self._completer_thread.daemon = True
self._completer_thread.start()
return [(None, None, None,
'Auto-completion refresh started in the background.')]
return [(None, None, None, "Auto-completion refresh started in the background.")]
def is_refreshing(self):
return self._completer_thread and self._completer_thread.is_alive()
@ -49,10 +47,22 @@ class CompletionRefresher(object):
# Create a new pgexecute method to populate the completions.
e = sqlexecute
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
e.socket, e.charset, e.local_infile, e.ssl,
e.ssh_user, e.ssh_host, e.ssh_port,
e.ssh_password, e.ssh_key_filename)
executor = SQLExecute(
e.dbname,
e.user,
e.password,
e.host,
e.port,
e.socket,
e.charset,
e.local_infile,
e.ssl,
e.ssh_user,
e.ssh_host,
e.ssh_port,
e.ssh_password,
e.ssh_key_filename,
)
# If callbacks is a single function then push it into a list.
if callable(callbacks):
@ -76,55 +86,68 @@ class CompletionRefresher(object):
for callback in callbacks:
callback(completer)
def refresher(name, refreshers=CompletionRefresher.refreshers):
"""Decorator to add the decorated function to the dictionary of
refreshers. Any function decorated with a @refresher will be executed as
part of the completion refresh routine."""
def wrapper(wrapped):
refreshers[name] = wrapped
return wrapped
return wrapper
@refresher('databases')
@refresher("databases")
def refresh_databases(completer, executor):
completer.extend_database_names(executor.databases())
@refresher('schemata')
@refresher("schemata")
def refresh_schemata(completer, executor):
# schemata - In MySQL Schema is the same as database. But for mycli
# schemata will be the name of the current database.
completer.extend_schemata(executor.dbname)
completer.set_dbname(executor.dbname)
@refresher('tables')
def refresh_tables(completer, executor):
completer.extend_relations(executor.tables(), kind='tables')
completer.extend_columns(executor.table_columns(), kind='tables')
@refresher('users')
@refresher("tables")
def refresh_tables(completer, executor):
table_columns_dbresult = list(executor.table_columns())
completer.extend_relations(table_columns_dbresult, kind="tables")
completer.extend_columns(table_columns_dbresult, kind="tables")
@refresher("users")
def refresh_users(completer, executor):
completer.extend_users(executor.users())
# @refresher('views')
# def refresh_views(completer, executor):
# completer.extend_relations(executor.views(), kind='views')
# completer.extend_columns(executor.view_columns(), kind='views')
@refresher('functions')
@refresher("functions")
def refresh_functions(completer, executor):
completer.extend_functions(executor.functions())
if executor.server_info.species == ServerSpecies.TiDB:
completer.extend_functions(completer.tidb_functions, builtin=True)
@refresher('special_commands')
@refresher("special_commands")
def refresh_special(completer, executor):
completer.extend_special_commands(COMMANDS.keys())
@refresher('show_commands')
@refresher("show_commands")
def refresh_show_commands(completer, executor):
completer.extend_show_items(executor.show_candidates())
@refresher('keywords')
@refresher("keywords")
def refresh_keywords(completer, executor):
if executor.server_info.species == ServerSpecies.TiDB:
completer.extend_keywords(completer.tidb_keywords, replace=True)

View file

@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
def log(logger, level, message):
"""Logs message to stderr if logging isn't initialized."""
if logger.parent.name != 'root':
if logger.parent.name != "root":
logger.log(level, message)
else:
print(message, file=sys.stderr)
@ -49,16 +49,13 @@ def read_config_file(f, list_values=True):
f = os.path.expanduser(f)
try:
config = ConfigObj(f, interpolation=False, encoding='utf8',
list_values=list_values)
config = ConfigObj(f, interpolation=False, encoding="utf8", list_values=list_values)
except ConfigObjError as e:
log(logger, logging.WARNING, "Unable to parse line {0} of config file "
"'{1}'.".format(e.line_number, f))
log(logger, logging.WARNING, "Unable to parse line {0} of config file " "'{1}'.".format(e.line_number, f))
log(logger, logging.WARNING, "Using successfully parsed config values.")
return e.config
except (IOError, OSError) as e:
log(logger, logging.WARNING, "You don't have permission to read "
"config file '{0}'.".format(e.filename))
log(logger, logging.WARNING, "You don't have permission to read " "config file '{0}'.".format(e.filename))
return None
return config
@ -80,15 +77,12 @@ def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
try:
with open(config_file) as f:
include_directives = filter(
lambda s: s.startswith('!includedir'),
f
)
include_directives = filter(lambda s: s.startswith("!includedir"), f)
dirs = map(lambda s: s.strip().split()[-1], include_directives)
dirs = filter(os.path.isdir, dirs)
for dir in dirs:
for filename in os.listdir(dir):
if filename.endswith('.cnf'):
if filename.endswith(".cnf"):
included_configs.append(os.path.join(dir, filename))
except (PermissionError, UnicodeDecodeError):
pass
@ -117,29 +111,31 @@ def read_config_files(files, list_values=True):
def create_default_config(list_values=True):
import mycli
default_config_file = resources.open_text(mycli, 'myclirc')
default_config_file = resources.open_text(mycli, "myclirc")
return read_config_file(default_config_file, list_values=list_values)
def write_default_config(destination, overwrite=False):
import mycli
default_config = resources.read_text(mycli, 'myclirc')
default_config = resources.read_text(mycli, "myclirc")
destination = os.path.expanduser(destination)
if not overwrite and exists(destination):
return
with open(destination, 'w') as f:
with open(destination, "w") as f:
f.write(default_config)
def get_mylogin_cnf_path():
"""Return the path to the login path file or None if it doesn't exist."""
mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE')
mylogin_cnf_path = os.getenv("MYSQL_TEST_LOGIN_FILE")
if mylogin_cnf_path is None:
app_data = os.getenv('APPDATA')
default_dir = os.path.join(app_data, 'MySQL') if app_data else '~'
mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf')
app_data = os.getenv("APPDATA")
default_dir = os.path.join(app_data, "MySQL") if app_data else "~"
mylogin_cnf_path = os.path.join(default_dir, ".mylogin.cnf")
mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path)
@ -159,14 +155,14 @@ def open_mylogin_cnf(name):
"""
try:
with open(name, 'rb') as f:
with open(name, "rb") as f:
plaintext = read_and_decrypt_mylogin_cnf(f)
except (OSError, IOError, ValueError):
logger.error('Unable to open login path file.')
logger.error("Unable to open login path file.")
return None
if not isinstance(plaintext, BytesIO):
logger.error('Unable to read login path file.')
logger.error("Unable to read login path file.")
return None
return TextIOWrapper(plaintext)
@ -181,6 +177,7 @@ def encrypt_mylogin_cnf(plaintext: IO[str]):
https://github.com/isotopp/mysql-config-coder
"""
def realkey(key):
"""Create the AES key from the login key."""
rkey = bytearray(16)
@ -194,10 +191,7 @@ def encrypt_mylogin_cnf(plaintext: IO[str]):
pad_len = buf_len - text_len
pad_chr = bytes(chr(pad_len), "utf8")
plaintext = plaintext.encode() + pad_chr * pad_len
encrypted_text = b''.join(
[aes.encrypt(plaintext[i: i + 16])
for i in range(0, len(plaintext), 16)]
)
encrypted_text = b"".join([aes.encrypt(plaintext[i : i + 16]) for i in range(0, len(plaintext), 16)])
return encrypted_text
LOGIN_KEY_LENGTH = 20
@ -248,7 +242,7 @@ def read_and_decrypt_mylogin_cnf(f):
buf = f.read(4)
if not buf or len(buf) != 4:
logger.error('Login path file is blank or incomplete.')
logger.error("Login path file is blank or incomplete.")
return None
# Read the login key.
@ -258,12 +252,12 @@ def read_and_decrypt_mylogin_cnf(f):
rkey = [0] * 16
for i in range(LOGIN_KEY_LEN):
try:
rkey[i % 16] ^= ord(key[i:i+1])
rkey[i % 16] ^= ord(key[i : i + 1])
except TypeError:
# ord() was unable to get the value of the byte.
logger.error('Unable to generate login path AES key.')
logger.error("Unable to generate login path AES key.")
return None
rkey = struct.pack('16B', *rkey)
rkey = struct.pack("16B", *rkey)
# Create a bytes buffer to hold the plaintext.
plaintext = BytesIO()
@ -274,20 +268,17 @@ def read_and_decrypt_mylogin_cnf(f):
len_buf = f.read(MAX_CIPHER_STORE_LEN)
if len(len_buf) < MAX_CIPHER_STORE_LEN:
break
cipher_len, = struct.unpack("<i", len_buf)
(cipher_len,) = struct.unpack("<i", len_buf)
# Read cipher_len bytes from the file and decrypt.
cipher = f.read(cipher_len)
plain = _remove_pad(
b''.join([aes.decrypt(cipher[i: i + 16])
for i in range(0, cipher_len, 16)])
)
plain = _remove_pad(b"".join([aes.decrypt(cipher[i : i + 16]) for i in range(0, cipher_len, 16)]))
if plain is False:
continue
plaintext.write(plain)
if plaintext.tell() == 0:
logger.error('No data successfully decrypted from login path file.')
logger.error("No data successfully decrypted from login path file.")
return None
plaintext.seek(0)
@ -299,17 +290,17 @@ def str_to_bool(s):
if isinstance(s, bool):
return s
elif not isinstance(s, basestring):
raise TypeError('argument must be a string')
raise TypeError("argument must be a string")
true_values = ('true', 'on', '1')
false_values = ('false', 'off', '0')
true_values = ("true", "on", "1")
false_values = ("false", "off", "0")
if s.lower() in true_values:
return True
elif s.lower() in false_values:
return False
else:
raise ValueError('not a recognized boolean value: {0}'.format(s))
raise ValueError("not a recognized boolean value: {0}".format(s))
def strip_matching_quotes(s):
@ -319,8 +310,7 @@ def strip_matching_quotes(s):
values.
"""
if (isinstance(s, basestring) and len(s) >= 2 and
s[0] == s[-1] and s[0] in ('"', "'")):
if isinstance(s, basestring) and len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"):
s = s[1:-1]
return s
@ -332,13 +322,13 @@ def _remove_pad(line):
pad_length = ord(line[-1:])
except TypeError:
# ord() was unable to get the value of the byte.
logger.warning('Unable to remove pad.')
logger.warning("Unable to remove pad.")
return False
if pad_length > len(line) or len(set(line[-pad_length:])) != 1:
# Pad length should be less than or equal to the length of the
# plaintext. The pad should have a single unique byte.
logger.warning('Invalid pad found in login path file.')
logger.warning("Invalid pad found in login path file.")
return False
return line[:-pad_length]

View file

@ -12,22 +12,22 @@ def mycli_bindings(mycli):
"""Custom key bindings for mycli."""
kb = KeyBindings()
@kb.add('f2')
@kb.add("f2")
def _(event):
"""Enable/Disable SmartCompletion Mode."""
_logger.debug('Detected F2 key.')
_logger.debug("Detected F2 key.")
mycli.completer.smart_completion = not mycli.completer.smart_completion
@kb.add('f3')
@kb.add("f3")
def _(event):
"""Enable/Disable Multiline Mode."""
_logger.debug('Detected F3 key.')
_logger.debug("Detected F3 key.")
mycli.multi_line = not mycli.multi_line
@kb.add('f4')
@kb.add("f4")
def _(event):
"""Toggle between Vi and Emacs mode."""
_logger.debug('Detected F4 key.')
_logger.debug("Detected F4 key.")
if mycli.key_bindings == "vi":
event.app.editing_mode = EditingMode.EMACS
mycli.key_bindings = "emacs"
@ -35,17 +35,17 @@ def mycli_bindings(mycli):
event.app.editing_mode = EditingMode.VI
mycli.key_bindings = "vi"
@kb.add('tab')
@kb.add("tab")
def _(event):
"""Force autocompletion at cursor."""
_logger.debug('Detected <Tab> key.')
_logger.debug("Detected <Tab> key.")
b = event.app.current_buffer
if b.complete_state:
b.complete_next()
else:
b.start_completion(select_first=True)
@kb.add('c-space')
@kb.add("c-space")
def _(event):
"""
Initialize autocompletion at cursor.
@ -55,7 +55,7 @@ def mycli_bindings(mycli):
If the menu is showing, select the next completion.
"""
_logger.debug('Detected <C-Space> key.')
_logger.debug("Detected <C-Space> key.")
b = event.app.current_buffer
if b.complete_state:
@ -63,14 +63,14 @@ def mycli_bindings(mycli):
else:
b.start_completion(select_first=False)
@kb.add('c-x', 'p', filter=emacs_mode)
@kb.add("c-x", "p", filter=emacs_mode)
def _(event):
"""
Prettify and indent current statement, usually into multiple lines.
Only accepts buffers containing single SQL statements.
"""
_logger.debug('Detected <C-x p>/> key.')
_logger.debug("Detected <C-x p>/> key.")
b = event.app.current_buffer
cursorpos_relative = b.cursor_position / max(1, len(b.text))
@ -78,19 +78,18 @@ def mycli_bindings(mycli):
if len(pretty_text) > 0:
b.text = pretty_text
cursorpos_abs = int(round(cursorpos_relative * len(b.text)))
while 0 < cursorpos_abs < len(b.text) \
and b.text[cursorpos_abs] in (' ', '\n'):
while 0 < cursorpos_abs < len(b.text) and b.text[cursorpos_abs] in (" ", "\n"):
cursorpos_abs -= 1
b.cursor_position = min(cursorpos_abs, len(b.text))
@kb.add('c-x', 'u', filter=emacs_mode)
@kb.add("c-x", "u", filter=emacs_mode)
def _(event):
"""
Unprettify and dedent current statement, usually into one line.
Only accepts buffers containing single SQL statements.
"""
_logger.debug('Detected <C-x u>/< key.')
_logger.debug("Detected <C-x u>/< key.")
b = event.app.current_buffer
cursorpos_relative = b.cursor_position / max(1, len(b.text))
@ -98,18 +97,17 @@ def mycli_bindings(mycli):
if len(unpretty_text) > 0:
b.text = unpretty_text
cursorpos_abs = int(round(cursorpos_relative * len(b.text)))
while 0 < cursorpos_abs < len(b.text) \
and b.text[cursorpos_abs] in (' ', '\n'):
while 0 < cursorpos_abs < len(b.text) and b.text[cursorpos_abs] in (" ", "\n"):
cursorpos_abs -= 1
b.cursor_position = min(cursorpos_abs, len(b.text))
@kb.add('c-r', filter=emacs_mode)
@kb.add("c-r", filter=emacs_mode)
def _(event):
"""Search history using fzf or default reverse incremental search."""
_logger.debug('Detected <C-r> key.')
_logger.debug("Detected <C-r> key.")
search_history(event)
@kb.add('enter', filter=completion_is_selected)
@kb.add("enter", filter=completion_is_selected)
def _(event):
"""Makes the enter key work as the tab key only when showing the menu.
@ -118,20 +116,20 @@ def mycli_bindings(mycli):
(accept current selection).
"""
_logger.debug('Detected enter key.')
_logger.debug("Detected enter key.")
event.current_buffer.complete_state = None
b = event.app.current_buffer
b.complete_state = None
@kb.add('escape', 'enter')
@kb.add("escape", "enter")
def _(event):
"""Introduces a line break in multi-line mode, or dispatches the
command in single-line mode."""
_logger.debug('Detected alt-enter key.')
_logger.debug("Detected alt-enter key.")
if mycli.multi_line:
event.app.current_buffer.validate_and_handle()
else:
event.app.current_buffer.insert_text('\n')
event.app.current_buffer.insert_text("\n")
return kb

View file

@ -7,6 +7,5 @@ class MyCliLexer(MySqlLexer):
"""Extends MySQL lexer to add keywords."""
tokens = {
'root': [(r'\brepair\b', Keyword),
(r'\boffset\b', Keyword), inherit],
"root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit],
}

View file

@ -5,19 +5,20 @@ import logging
_logger = logging.getLogger(__name__)
def load_ipython_extension(ipython):
def load_ipython_extension(ipython):
# This is called via the ipython command '%load_ext mycli.magic'.
# First, load the sql magic if it isn't already loaded.
if not ipython.find_line_magic('sql'):
ipython.run_line_magic('load_ext', 'sql')
if not ipython.find_line_magic("sql"):
ipython.run_line_magic("load_ext", "sql")
# Register our own magic.
ipython.register_magic_function(mycli_line_magic, 'line', 'mycli')
ipython.register_magic_function(mycli_line_magic, "line", "mycli")
def mycli_line_magic(line):
_logger.debug('mycli magic called: %r', line)
_logger.debug("mycli magic called: %r", line)
parsed = sql.parse.parse(line, {})
# "get" was renamed to "set" in ipython-sql:
# https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43
@ -32,17 +33,17 @@ def mycli_line_magic(line):
try:
# A corresponding mycli object already exists
mycli = conn._mycli
_logger.debug('Reusing existing mycli')
_logger.debug("Reusing existing mycli")
except AttributeError:
mycli = MyCli()
u = conn.session.engine.url
_logger.debug('New mycli: %r', str(u))
_logger.debug("New mycli: %r", str(u))
mycli.connect(host=u.host, port=u.port, passwd=u.password, database=u.database, user=u.username, init_command=None)
conn._mycli = mycli
# For convenience, print the connection alias
print('Connected: {}'.format(conn.name))
print("Connected: {}".format(conn.name))
try:
mycli.run_cli()
@ -54,9 +55,9 @@ def mycli_line_magic(line):
q = mycli.query_history[-1]
if q.mutating:
_logger.debug('Mutating query detected -- ignoring')
_logger.debug("Mutating query detected -- ignoring")
return
if q.successful:
ipython = get_ipython()
return ipython.run_cell_magic('sql', line, q.query)
ipython = get_ipython() # noqa: F821
return ipython.run_cell_magic("sql", line, q.query)

File diff suppressed because it is too large Load diff

View file

@ -12,8 +12,7 @@ def suggest_type(full_text, text_before_cursor):
A scope for a column category will be a list of tables.
"""
word_before_cursor = last_word(text_before_cursor,
include='many_punctuations')
word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
identifier = None
@ -25,12 +24,10 @@ def suggest_type(full_text, text_before_cursor):
# partially typed string which renders the smart completion useless because
# it will always return the list of keywords as completion.
if word_before_cursor:
if word_before_cursor.endswith(
'(') or word_before_cursor.startswith('\\'):
if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
parsed = sqlparse.parse(text_before_cursor)
else:
parsed = sqlparse.parse(
text_before_cursor[:-len(word_before_cursor)])
parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])
# word_before_cursor may include a schema qualification, like
# "schema_name.partial_name" or "schema_name.", so parse it
@ -42,7 +39,7 @@ def suggest_type(full_text, text_before_cursor):
else:
parsed = sqlparse.parse(text_before_cursor)
except (TypeError, AttributeError):
return [{'type': 'keyword'}]
return [{"type": "keyword"}]
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
@ -72,13 +69,12 @@ def suggest_type(full_text, text_before_cursor):
# Be careful here because trivial whitespace is parsed as a statement,
# but the statement won't have a first token
tok1 = statement.token_first()
if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')):
if tok1 and (tok1.value == "source" or tok1.value.startswith("\\")):
return suggest_special(text_before_cursor)
last_token = statement and statement.token_prev(len(statement.tokens))[1] or ''
last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""
return suggest_based_on_last_token(last_token, text_before_cursor,
full_text, identifier)
return suggest_based_on_last_token(last_token, text_before_cursor, full_text, identifier)
def suggest_special(text):
@ -87,27 +83,27 @@ def suggest_special(text):
if cmd == text:
# Trying to complete the special command itself
return [{'type': 'special'}]
return [{"type": "special"}]
if cmd in ('\\u', '\\r'):
return [{'type': 'database'}]
if cmd in ("\\u", "\\r"):
return [{"type": "database"}]
if cmd in ('\\T'):
return [{'type': 'table_format'}]
if cmd in ("\\T"):
return [{"type": "table_format"}]
if cmd in ['\\f', '\\fs', '\\fd']:
return [{'type': 'favoritequery'}]
if cmd in ["\\f", "\\fs", "\\fd"]:
return [{"type": "favoritequery"}]
if cmd in ['\\dt', '\\dt+']:
if cmd in ["\\dt", "\\dt+"]:
return [
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'},
{"type": "table", "schema": []},
{"type": "view", "schema": []},
{"type": "schema"},
]
elif cmd in ['\\.', 'source']:
return[{'type': 'file_name'}]
elif cmd in ["\\.", "source"]:
return [{"type": "file_name"}]
return [{'type': 'keyword'}, {'type': 'special'}]
return [{"type": "keyword"}, {"type": "special"}]
def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
@ -127,20 +123,19 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
return suggest_based_on_last_token(prev_keyword, text_before_cursor,
full_text, identifier)
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
elif token is None:
return [{'type': 'keyword'}]
return [{"type": "keyword"}]
else:
token_v = token.value.lower()
is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']])
is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]]) # noqa: E731
if not token:
return [{'type': 'keyword'}, {'type': 'special'}]
return [{"type": "keyword"}, {"type": "special"}]
elif token_v == "*":
return [{'type': 'keyword'}]
elif token_v.endswith('('):
return [{"type": "keyword"}]
elif token_v.endswith("("):
p = sqlparse.parse(text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
@ -155,8 +150,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token('where',
text_before_cursor, full_text, identifier)
column_suggestions = suggest_based_on_last_token("where", text_before_cursor, full_text, identifier)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
@ -167,130 +161,133 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == 'exists':
return [{'type': 'keyword'}]
if prev_tok == "exists":
return [{"type": "keyword"}]
else:
return column_suggestions
# Get the token before the parens
idx, prev_tok = p.token_prev(len(p.tokens) - 1)
if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = extract_tables(full_text)
# suggest columns that are present in more than one table
return [{'type': 'column', 'tables': tables, 'drop_unique': True}]
elif p.token_first().value.lower() == 'select':
return [{"type": "column", "tables": tables, "drop_unique": True}]
elif p.token_first().value.lower() == "select":
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(text_before_cursor,
'all_punctuations').startswith('('):
return [{'type': 'keyword'}]
elif p.token_first().value.lower() == 'show':
return [{'type': 'show'}]
if last_word(text_before_cursor, "all_punctuations").startswith("("):
return [{"type": "keyword"}]
elif p.token_first().value.lower() == "show":
return [{"type": "show"}]
# We're probably in a function argument list
return [{'type': 'column', 'tables': extract_tables(full_text)}]
elif token_v in ('set', 'order by', 'distinct'):
return [{'type': 'column', 'tables': extract_tables(full_text)}]
elif token_v == 'as':
return [{"type": "column", "tables": extract_tables(full_text)}]
elif token_v in ("set", "order by", "distinct"):
return [{"type": "column", "tables": extract_tables(full_text)}]
elif token_v == "as":
# Don't suggest anything for an alias
return []
elif token_v in ('show'):
return [{'type': 'show'}]
elif token_v in ('to',):
elif token_v in ("show"):
return [{"type": "show"}]
elif token_v in ("to",):
p = sqlparse.parse(text_before_cursor)[0]
if p.token_first().value.lower() == 'change':
return [{'type': 'change'}]
if p.token_first().value.lower() == "change":
return [{"type": "change"}]
else:
return [{'type': 'user'}]
elif token_v in ('user', 'for'):
return [{'type': 'user'}]
elif token_v in ('select', 'where', 'having'):
return [{"type": "user"}]
elif token_v in ("user", "for"):
return [{"type": "user"}]
elif token_v in ("select", "where", "having"):
# Check for a table alias or schema qualification
parent = (identifier and identifier.get_parent_name()) or []
tables = extract_tables(full_text)
if parent:
tables = [t for t in tables if identifies(parent, *t)]
return [{'type': 'column', 'tables': tables},
{'type': 'table', 'schema': parent},
{'type': 'view', 'schema': parent},
{'type': 'function', 'schema': parent}]
return [
{"type": "column", "tables": tables},
{"type": "table", "schema": parent},
{"type": "view", "schema": parent},
{"type": "function", "schema": parent},
]
else:
aliases = [alias or table for (schema, table, alias) in tables]
return [{'type': 'column', 'tables': tables},
{'type': 'function', 'schema': []},
{'type': 'alias', 'aliases': aliases},
{'type': 'keyword'}]
elif (token_v.endswith('join') and token.is_keyword) or (token_v in
('copy', 'from', 'update', 'into', 'describe', 'truncate',
'desc', 'explain')):
return [
{"type": "column", "tables": tables},
{"type": "function", "schema": []},
{"type": "alias", "aliases": aliases},
{"type": "keyword"},
]
elif (token_v.endswith("join") and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
):
schema = (identifier and identifier.get_parent_name()) or []
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = [{'type': 'table', 'schema': schema}]
suggest = [{"type": "table", "schema": schema}]
if not schema:
# Suggest schemas
suggest.insert(0, {'type': 'schema'})
suggest.insert(0, {"type": "schema"})
# Only tables can be TRUNCATED, otherwise suggest views
if token_v != 'truncate':
suggest.append({'type': 'view', 'schema': schema})
if token_v != "truncate":
suggest.append({"type": "view", "schema": schema})
return suggest
elif token_v in ('table', 'view', 'function'):
elif token_v in ("table", "view", "function"):
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
rel_type = token_v
schema = (identifier and identifier.get_parent_name()) or []
if schema:
return [{'type': rel_type, 'schema': schema}]
return [{"type": rel_type, "schema": schema}]
else:
return [{'type': 'schema'}, {'type': rel_type, 'schema': []}]
elif token_v == 'on':
return [{"type": "schema"}, {"type": rel_type, "schema": []}]
elif token_v == "on":
tables = extract_tables(full_text) # [(schema, table, alias), ...]
parent = (identifier and identifier.get_parent_name()) or []
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
tables = [t for t in tables if identifies(parent, *t)]
return [{'type': 'column', 'tables': tables},
{'type': 'table', 'schema': parent},
{'type': 'view', 'schema': parent},
{'type': 'function', 'schema': parent}]
return [
{"type": "column", "tables": tables},
{"type": "table", "schema": parent},
{"type": "view", "schema": parent},
{"type": "function", "schema": parent},
]
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = [alias or table for (schema, table, alias) in tables]
suggest = [{'type': 'alias', 'aliases': aliases}]
suggest = [{"type": "alias", "aliases": aliases}]
# The lists of 'aliases' could be empty if we're trying to complete
# a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
# In that case we just suggest all tables.
if not aliases:
suggest.append({'type': 'table', 'schema': parent})
suggest.append({"type": "table", "schema": parent})
return suggest
elif token_v in ('use', 'database', 'template', 'connect'):
elif token_v in ("use", "database", "template", "connect"):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return [{'type': 'database'}]
elif token_v == 'tableformat':
return [{'type': 'table_format'}]
elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']:
return [{"type": "database"}]
elif token_v == "tableformat":
return [{"type": "table_format"}]
elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
if prev_keyword:
return suggest_based_on_last_token(
prev_keyword, text_before_cursor, full_text, identifier)
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
else:
return []
else:
return [{'type': 'keyword'}]
return [{"type": "keyword"}]
def identifies(id, schema, table, alias):
return id == alias or id == table or (
schema and (id == schema + '.' + table))
return id == alias or id == table or (schema and (id == schema + "." + table))

View file

@ -38,7 +38,7 @@ def complete_path(curr_dir, last_dir):
"""
if not last_dir or curr_dir.startswith(last_dir):
return curr_dir
elif last_dir == '~':
elif last_dir == "~":
return os.path.join(last_dir, curr_dir)
@ -51,7 +51,7 @@ def parse_path(root_dir):
:return: tuple of (string, string, int)
"""
base_dir, last_dir, position = '', '', 0
base_dir, last_dir, position = "", "", 0
if root_dir:
base_dir, last_dir = os.path.split(root_dir)
position = -len(last_dir) if last_dir else 0
@ -69,9 +69,9 @@ def suggest_path(root_dir):
"""
if not root_dir:
return [os.path.abspath(os.sep), '~', os.curdir, os.pardir]
return [os.path.abspath(os.sep), "~", os.curdir, os.pardir]
if '~' in root_dir:
if "~" in root_dir:
root_dir = os.path.expanduser(root_dir)
if not os.path.exists(root_dir):
@ -100,7 +100,7 @@ def guess_socket_location():
for r, dirs, files in os.walk(directory, topdown=True):
for filename in files:
name, ext = os.path.splitext(filename)
if name.startswith("mysql") and name != "mysqlx" and ext in ('.socket', '.sock'):
if name.startswith("mysql") and name != "mysqlx" and ext in (".socket", ".sock"):
return os.path.join(r, filename)
dirs[:] = [d for d in dirs if d.startswith("mysql")]
return None

View file

@ -12,16 +12,19 @@ class Paramiko:
def __getattr__(self, name):
import sys
from textwrap import dedent
print(dedent("""
To enable certain SSH features you need to install paramiko:
print(
dedent("""
To enable certain SSH features you need to install paramiko and sshtunnel:
pip install paramiko
pip install paramiko sshtunnel
It is required for the following configuration options:
--list-ssh-config
--ssh-config-host
--ssh-host
"""))
""")
)
sys.exit(1)

View file

@ -4,18 +4,18 @@ from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation
cleanup_regex = {
# This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'),
# This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'),
# This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
# This matches everything except a space.
'all_punctuations': re.compile(r'([^\s]+)$'),
# This matches only alphanumerics and underscores.
"alphanum_underscore": re.compile(r"(\w+)$"),
# This matches everything except spaces, parens, colon, and comma
"many_punctuations": re.compile(r"([^():,\s]+)$"),
# This matches everything except spaces, parens, colon, comma, and period
"most_punctuations": re.compile(r"([^\.():,\s]+)$"),
# This matches everything except a space.
"all_punctuations": re.compile(r"([^\s]+)$"),
}
def last_word(text, include='alphanum_underscore'):
def last_word(text, include="alphanum_underscore"):
r"""
Find the last word in a sentence.
@ -47,18 +47,18 @@ def last_word(text, include='alphanum_underscore'):
'def'
"""
if not text: # Empty string
return ''
if not text: # Empty string
return ""
if text[-1].isspace():
return ''
return ""
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ''
return ""
# This code is borrowed from sqlparse example script.
@ -67,11 +67,11 @@ def is_subselect(parsed):
if not parsed.is_group:
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
'UPDATE', 'CREATE', 'DELETE'):
if item.ttype is DML and item.value.upper() in ("SELECT", "INSERT", "UPDATE", "CREATE", "DELETE"):
return True
return False
def extract_from_part(parsed, stop_at_punctuation=True):
tbl_prefix_seen = False
for item in parsed.tokens:
@ -85,7 +85,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
# "ON" is a keyword and will trigger the next elif condition.
# So instead of stooping the loop when finding an "ON" skip it
# eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi'
elif item.ttype is Keyword and item.value.upper() == 'ON':
elif item.ttype is Keyword and item.value.upper() == "ON":
tbl_prefix_seen = False
continue
# An incomplete nested select won't be recognized correctly as a
@ -96,24 +96,28 @@ def extract_from_part(parsed, stop_at_punctuation=True):
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")):
return
else:
yield item
elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
elif (item.ttype is Keyword or item.ttype is Keyword.DML) and item.value.upper() in (
"COPY",
"FROM",
"INTO",
"UPDATE",
"TABLE",
"JOIN",
):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
tbl_prefix_seen = True
break
def extract_table_identifiers(token_stream):
"""yields tuples of (schema_name, table_name, table_alias)"""
@ -141,6 +145,7 @@ def extract_table_identifiers(token_stream):
elif isinstance(item, Function):
yield (None, item.get_name(), item.get_name())
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql):
"""Extract the table names from an SQL statement.
@ -156,27 +161,27 @@ def extract_tables(sql):
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
insert_stmt = parsed[0].token_first().value.lower() == "insert"
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
return list(extract_table_identifiers(stream))
def find_prev_keyword(sql):
""" Find the last sql keyword in an SQL statement
"""Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped
"""
if not sql.strip():
return None, ''
return None, ""
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
logical_operators = ("AND", "OR", "NOT", "BETWEEN")
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
if t.value == "(" or (t.is_keyword and (t.value.upper() not in logical_operators)):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index thows an error
@ -189,10 +194,10 @@ def find_prev_keyword(sql):
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx+1])
text = "".join(tok.value for tok in flattened[: idx + 1])
return t, text
return None, ''
return None, ""
def query_starts_with(query, prefixes):
@ -212,31 +217,25 @@ def queries_start_with(queries, prefixes):
def query_has_where_clause(query):
"""Check if the query contains a where-clause."""
return any(
isinstance(token, sqlparse.sql.Where)
for token_list in sqlparse.parse(query)
for token in token_list
)
return any(isinstance(token, sqlparse.sql.Where) for token_list in sqlparse.parse(query) for token in token_list)
def is_destructive(queries):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
for query in sqlparse.split(queries):
if query:
if query_starts_with(query, keywords) is True:
return True
elif query_starts_with(
query, ['update']
) is True and not query_has_where_clause(query):
elif query_starts_with(query, ["update"]) is True and not query_has_where_clause(query):
return True
return False
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print (extract_tables(sql))
if __name__ == "__main__":
sql = "select * from (select t. from tabl t"
print(extract_tables(sql))
def is_dropping_database(queries, dbname):
@ -258,9 +257,7 @@ def is_dropping_database(queries, dbname):
"database",
"schema",
):
database_token = next(
(t for t in query.tokens if isinstance(t, Identifier)), None
)
database_token = next((t for t in query.tokens if isinstance(t, Identifier)), None)
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
result = keywords[0].normalized == "DROP"
return result

View file

@ -4,20 +4,20 @@ from .parseutils import is_destructive
class ConfirmBoolParamType(click.ParamType):
name = 'confirmation'
name = "confirmation"
def convert(self, value, param, ctx):
if isinstance(value, bool):
return bool(value)
value = value.lower()
if value in ('yes', 'y'):
if value in ("yes", "y"):
return True
elif value in ('no', 'n'):
elif value in ("no", "n"):
return False
self.fail('%s is not a valid boolean' % value, param, ctx)
self.fail("%s is not a valid boolean" % value, param, ctx)
def __repr__(self):
return 'BOOL'
return "BOOL"
BOOLEAN_TYPE = ConfirmBoolParamType()
@ -32,8 +32,7 @@ def confirm_destructive_query(queries):
* False if the query is destructive and the user doesn't want to proceed.
"""
prompt_text = ("You're about to run a destructive command.\n"
"Do you want to proceed? (y/n)")
prompt_text = "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
if is_destructive(queries) and sys.stdin.isatty():
return prompt(prompt_text, type=BOOLEAN_TYPE)

View file

@ -1,10 +1,12 @@
__all__ = []
def export(defn):
"""Decorator to explicitly mark functions that are exposed in a lib."""
globals()[defn.__name__] = defn
__all__.append(defn.__name__)
return defn
from . import dbcommands
from . import iocommands
from . import dbcommands # noqa: E402 F401
from . import iocommands # noqa: E402 F401

View file

@ -10,24 +10,23 @@ from pymysql import ProgrammingError
log = logging.getLogger(__name__)
@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.',
arg_type=PARSED_QUERY, case_sensitive=True)
@special_command("\\dt", "\\dt[+] [table]", "List or describe tables.", arg_type=PARSED_QUERY, case_sensitive=True)
def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
if arg:
query = 'SHOW FIELDS FROM {0}'.format(arg)
query = "SHOW FIELDS FROM {0}".format(arg)
else:
query = 'SHOW TABLES'
query = "SHOW TABLES"
log.debug(query)
cur.execute(query)
tables = cur.fetchall()
status = ''
status = ""
if cur.description:
headers = [x[0] for x in cur.description]
else:
return [(None, None, None, '')]
return [(None, None, None, "")]
if verbose and arg:
query = 'SHOW CREATE TABLE {0}'.format(arg)
query = "SHOW CREATE TABLE {0}".format(arg)
log.debug(query)
cur.execute(query)
status = cur.fetchone()[1]
@ -35,128 +34,121 @@ def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
return [(None, tables, headers, status)]
@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True)
@special_command("\\l", "\\l", "List databases.", arg_type=RAW_QUERY, case_sensitive=True)
def list_databases(cur, **_):
query = 'SHOW DATABASES'
query = "SHOW DATABASES"
log.debug(query)
cur.execute(query)
if cur.description:
headers = [x[0] for x in cur.description]
return [(None, cur, headers, '')]
return [(None, cur, headers, "")]
else:
return [(None, None, None, '')]
return [(None, None, None, "")]
@special_command('status', '\\s', 'Get status information from the server.',
arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True)
@special_command("status", "\\s", "Get status information from the server.", arg_type=RAW_QUERY, aliases=("\\s",), case_sensitive=True)
def status(cur, **_):
query = 'SHOW GLOBAL STATUS;'
query = "SHOW GLOBAL STATUS;"
log.debug(query)
try:
cur.execute(query)
except ProgrammingError:
# Fallback in case query fail, as it does with Mysql 4
query = 'SHOW STATUS;'
query = "SHOW STATUS;"
log.debug(query)
cur.execute(query)
status = dict(cur.fetchall())
query = 'SHOW GLOBAL VARIABLES;'
query = "SHOW GLOBAL VARIABLES;"
log.debug(query)
cur.execute(query)
variables = dict(cur.fetchall())
# prepare in case keys are bytes, as with Python 3 and Mysql 4
if (isinstance(list(variables)[0], bytes) and
isinstance(list(status)[0], bytes)):
variables = {k.decode('utf-8'): v.decode('utf-8') for k, v
in variables.items()}
status = {k.decode('utf-8'): v.decode('utf-8') for k, v
in status.items()}
if isinstance(list(variables)[0], bytes) and isinstance(list(status)[0], bytes):
variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in variables.items()}
status = {k.decode("utf-8"): v.decode("utf-8") for k, v in status.items()}
# Create output buffers.
title = []
output = []
footer = []
title.append('--------------')
title.append("--------------")
# Output the mycli client information.
implementation = platform.python_implementation()
version = platform.python_version()
client_info = []
client_info.append('mycli {0},'.format(__version__))
client_info.append('running on {0} {1}'.format(implementation, version))
title.append(' '.join(client_info) + '\n')
client_info.append("mycli {0},".format(__version__))
client_info.append("running on {0} {1}".format(implementation, version))
title.append(" ".join(client_info) + "\n")
# Build the output that will be displayed as a table.
output.append(('Connection id:', cur.connection.thread_id()))
output.append(("Connection id:", cur.connection.thread_id()))
query = 'SELECT DATABASE(), USER();'
query = "SELECT DATABASE(), USER();"
log.debug(query)
cur.execute(query)
db, user = cur.fetchone()
if db is None:
db = ''
db = ""
output.append(('Current database:', db))
output.append(('Current user:', user))
output.append(("Current database:", db))
output.append(("Current user:", user))
if iocommands.is_pager_enabled():
if 'PAGER' in os.environ:
pager = os.environ['PAGER']
if "PAGER" in os.environ:
pager = os.environ["PAGER"]
else:
pager = 'System default'
pager = "System default"
else:
pager = 'stdout'
output.append(('Current pager:', pager))
pager = "stdout"
output.append(("Current pager:", pager))
output.append(('Server version:', '{0} {1}'.format(
variables['version'], variables['version_comment'])))
output.append(('Protocol version:', variables['protocol_version']))
output.append(("Server version:", "{0} {1}".format(variables["version"], variables["version_comment"])))
output.append(("Protocol version:", variables["protocol_version"]))
if 'unix' in cur.connection.host_info.lower():
if "unix" in cur.connection.host_info.lower():
host_info = cur.connection.host_info
else:
host_info = '{0} via TCP/IP'.format(cur.connection.host)
host_info = "{0} via TCP/IP".format(cur.connection.host)
output.append(('Connection:', host_info))
output.append(("Connection:", host_info))
query = ('SELECT @@character_set_server, @@character_set_database, '
'@@character_set_client, @@character_set_connection LIMIT 1;')
query = "SELECT @@character_set_server, @@character_set_database, " "@@character_set_client, @@character_set_connection LIMIT 1;"
log.debug(query)
cur.execute(query)
charset = cur.fetchone()
output.append(('Server characterset:', charset[0]))
output.append(('Db characterset:', charset[1]))
output.append(('Client characterset:', charset[2]))
output.append(('Conn. characterset:', charset[3]))
output.append(("Server characterset:", charset[0]))
output.append(("Db characterset:", charset[1]))
output.append(("Client characterset:", charset[2]))
output.append(("Conn. characterset:", charset[3]))
if 'TCP/IP' in host_info:
output.append(('TCP port:', cur.connection.port))
if "TCP/IP" in host_info:
output.append(("TCP port:", cur.connection.port))
else:
output.append(('UNIX socket:', variables['socket']))
output.append(("UNIX socket:", variables["socket"]))
if 'Uptime' in status:
output.append(('Uptime:', format_uptime(status['Uptime'])))
if "Uptime" in status:
output.append(("Uptime:", format_uptime(status["Uptime"])))
if 'Threads_connected' in status:
if "Threads_connected" in status:
# Print the current server statistics.
stats = []
stats.append('Connections: {0}'.format(status['Threads_connected']))
if 'Queries' in status:
stats.append('Queries: {0}'.format(status['Queries']))
stats.append('Slow queries: {0}'.format(status['Slow_queries']))
stats.append('Opens: {0}'.format(status['Opened_tables']))
if 'Flush_commands' in status:
stats.append('Flush tables: {0}'.format(status['Flush_commands']))
stats.append('Open tables: {0}'.format(status['Open_tables']))
if 'Queries' in status:
queries_per_second = int(status['Queries']) / int(status['Uptime'])
stats.append('Queries per second avg: {:.3f}'.format(
queries_per_second))
stats = ' '.join(stats)
footer.append('\n' + stats)
stats.append("Connections: {0}".format(status["Threads_connected"]))
if "Queries" in status:
stats.append("Queries: {0}".format(status["Queries"]))
stats.append("Slow queries: {0}".format(status["Slow_queries"]))
stats.append("Opens: {0}".format(status["Opened_tables"]))
if "Flush_commands" in status:
stats.append("Flush tables: {0}".format(status["Flush_commands"]))
stats.append("Open tables: {0}".format(status["Open_tables"]))
if "Queries" in status:
queries_per_second = int(status["Queries"]) / int(status["Uptime"])
stats.append("Queries per second avg: {:.3f}".format(queries_per_second))
stats = " ".join(stats)
footer.append("\n" + stats)
footer.append('--------------')
return [('\n'.join(title), output, '', '\n'.join(footer))]
footer.append("--------------")
return [("\n".join(title), output, "", "\n".join(footer))]

View file

@ -4,7 +4,7 @@ import sqlparse
class DelimiterCommand(object):
def __init__(self):
self._delimiter = ';'
self._delimiter = ";"
def _split(self, sql):
"""Temporary workaround until sqlparse.split() learns about custom
@ -12,22 +12,19 @@ class DelimiterCommand(object):
placeholder = "\ufffc" # unicode object replacement character
if self._delimiter == ';':
if self._delimiter == ";":
return sqlparse.split(sql)
# We must find a string that original sql does not contain.
# Most likely, our placeholder is enough, but if not, keep looking
while placeholder in sql:
placeholder += placeholder[0]
sql = sql.replace(';', placeholder)
sql = sql.replace(self._delimiter, ';')
sql = sql.replace(";", placeholder)
sql = sql.replace(self._delimiter, ";")
split = sqlparse.split(sql)
return [
stmt.replace(';', self._delimiter).replace(placeholder, ';')
for stmt in split
]
return [stmt.replace(";", self._delimiter).replace(placeholder, ";") for stmt in split]
def queries_iter(self, input):
"""Iterate over queries in the input string."""
@ -49,7 +46,7 @@ class DelimiterCommand(object):
# re-split everything, and if we previously stripped
# the delimiter, append it to the end
if self._delimiter != delimiter:
combined_statement = ' '.join([sql] + queries)
combined_statement = " ".join([sql] + queries)
if trailing_delimiter:
combined_statement += delimiter
queries = self._split(combined_statement)[1:]
@ -63,13 +60,13 @@ class DelimiterCommand(object):
word of it.
"""
match = arg and re.search(r'[^\s]+', arg)
match = arg and re.search(r"[^\s]+", arg)
if not match:
message = 'Missing required argument, delimiter'
message = "Missing required argument, delimiter"
return [(None, None, None, message)]
delimiter = match.group()
if delimiter.lower() == 'delimiter':
if delimiter.lower() == "delimiter":
return [(None, None, None, 'Invalid delimiter "delimiter"')]
self._delimiter = delimiter

View file

@ -1,8 +1,7 @@
class FavoriteQueries(object):
section_name = "favorite_queries"
section_name = 'favorite_queries'
usage = '''
usage = """
Favorite Queries are a way to save frequently used queries
with a short name.
Examples:
@ -29,7 +28,7 @@ Examples:
# Delete a favorite query.
> \\fd simple
simple: Deleted
'''
"""
# Class-level variable, for convenience to use as a singleton.
instance = None
@ -48,7 +47,7 @@ Examples:
return self.config.get(self.section_name, {}).get(name, None)
def save(self, name, query):
self.config.encoding = 'utf-8'
self.config.encoding = "utf-8"
if self.section_name not in self.config:
self.config[self.section_name] = {}
self.config[self.section_name][name] = query
@ -58,6 +57,6 @@ Examples:
try:
del self.config[self.section_name][name]
except KeyError:
return '%s: Not Found.' % name
return "%s: Not Found." % name
self.config.write()
return '%s: Deleted' % name
return "%s: Deleted" % name

View file

@ -34,6 +34,7 @@ def set_timing_enabled(val):
global TIMING_ENABLED
TIMING_ENABLED = val
@export
def set_pager_enabled(val):
global PAGER_ENABLED
@ -44,33 +45,35 @@ def set_pager_enabled(val):
def is_pager_enabled():
return PAGER_ENABLED
@export
@special_command('pager', '\\P [command]',
'Set PAGER. Print the query results via PAGER.',
arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
@special_command(
"pager", "\\P [command]", "Set PAGER. Print the query results via PAGER.", arg_type=PARSED_QUERY, aliases=("\\P",), case_sensitive=True
)
def set_pager(arg, **_):
if arg:
os.environ['PAGER'] = arg
msg = 'PAGER set to %s.' % arg
os.environ["PAGER"] = arg
msg = "PAGER set to %s." % arg
set_pager_enabled(True)
else:
if 'PAGER' in os.environ:
msg = 'PAGER set to %s.' % os.environ['PAGER']
if "PAGER" in os.environ:
msg = "PAGER set to %s." % os.environ["PAGER"]
else:
# This uses click's default per echo_via_pager.
msg = 'Pager enabled.'
msg = "Pager enabled."
set_pager_enabled(True)
return [(None, None, None, msg)]
@export
@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
@special_command("nopager", "\\n", "Disable pager, print to stdout.", arg_type=NO_QUERY, aliases=("\\n",), case_sensitive=True)
def disable_pager():
set_pager_enabled(False)
return [(None, None, None, 'Pager disabled.')]
return [(None, None, None, "Pager disabled.")]
@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
@special_command("\\timing", "\\t", "Toggle timing of commands.", arg_type=NO_QUERY, aliases=("\\t",), case_sensitive=True)
def toggle_timing():
global TIMING_ENABLED
TIMING_ENABLED = not TIMING_ENABLED
@ -78,21 +81,26 @@ def toggle_timing():
message += "on." if TIMING_ENABLED else "off."
return [(None, None, None, message)]
@export
def is_timing_enabled():
return TIMING_ENABLED
@export
def set_expanded_output(val):
global use_expanded_output
use_expanded_output = val
@export
def is_expanded_output():
return use_expanded_output
_logger = logging.getLogger(__name__)
@export
def editor_command(command):
"""
@ -101,12 +109,13 @@ def editor_command(command):
"""
# It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
# for both conditions.
return command.strip().endswith('\\e') or command.strip().startswith('\\e')
return command.strip().endswith("\\e") or command.strip().startswith("\\e")
@export
def get_filename(sql):
if sql.strip().startswith('\\e'):
command, _, filename = sql.partition(' ')
if sql.strip().startswith("\\e"):
command, _, filename = sql.partition(" ")
return filename.strip() or None
@ -118,9 +127,9 @@ def get_editor_query(sql):
# The reason we can't simply do .strip('\e') is that it strips characters,
# not a substring. So it'll strip "e" in the end of the sql also!
# Ex: "select * from style\e" -> "select * from styl".
pattern = re.compile(r'(^\\e|\\e$)')
pattern = re.compile(r"(^\\e|\\e$)")
while pattern.search(sql):
sql = pattern.sub('', sql)
sql = pattern.sub("", sql)
return sql
@ -135,25 +144,24 @@ def open_external_editor(filename=None, sql=None):
"""
message = None
filename = filename.strip().split(' ', 1)[0] if filename else None
filename = filename.strip().split(" ", 1)[0] if filename else None
sql = sql or ''
MARKER = '# Type your query above this line.\n'
sql = sql or ""
MARKER = "# Type your query above this line.\n"
# Populate the editor buffer with the partial sql (if available) and a
# placeholder comment.
query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
filename=filename, extension='.sql')
query = click.edit("{sql}\n\n{marker}".format(sql=sql, marker=MARKER), filename=filename, extension=".sql")
if filename:
try:
with open(filename) as f:
query = f.read()
except IOError:
message = 'Error reading file: %s.' % filename
message = "Error reading file: %s." % filename
if query is not None:
query = query.split(MARKER, 1)[0].rstrip('\n')
query = query.split(MARKER, 1)[0].rstrip("\n")
else:
# Don't return None for the caller to deal with.
# Empty string is ok.
@ -171,7 +179,7 @@ def clip_command(command):
"""
# It is possible to have `\clip` or `SELECT * FROM \clip`. So we check
# for both conditions.
return command.strip().endswith('\\clip') or command.strip().startswith('\\clip')
return command.strip().endswith("\\clip") or command.strip().startswith("\\clip")
@export
@ -181,9 +189,9 @@ def get_clip_query(sql):
# The reason we can't simply do .strip('\clip') is that it strips characters,
# not a substring. So it'll strip "c" in the end of the sql also!
pattern = re.compile(r'(^\\clip|\\clip$)')
pattern = re.compile(r"(^\\clip|\\clip$)")
while pattern.search(sql):
sql = pattern.sub('', sql)
sql = pattern.sub("", sql)
return sql
@ -192,26 +200,26 @@ def get_clip_query(sql):
def copy_query_to_clipboard(sql=None):
"""Send query to the clipboard."""
sql = sql or ''
sql = sql or ""
message = None
try:
pyperclip.copy(u'{sql}'.format(sql=sql))
pyperclip.copy("{sql}".format(sql=sql))
except RuntimeError as e:
message = 'Error clipping query: %s.' % e.strerror
message = "Error clipping query: %s." % e.strerror
return message
@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
@special_command("\\f", "\\f [name [args..]]", "List or execute favorite queries.", arg_type=PARSED_QUERY, case_sensitive=True)
def execute_favorite_query(cur, arg, **_):
"""Returns (title, rows, headers, status)"""
if arg == '':
if arg == "":
for result in list_favorite_queries():
yield result
"""Parse out favorite name and optional substitution parameters"""
name, _, arg_str = arg.partition(' ')
name, _, arg_str = arg.partition(" ")
args = shlex.split(arg_str)
query = FavoriteQueries.instance.get(name)
@ -224,8 +232,8 @@ def execute_favorite_query(cur, arg, **_):
yield (None, None, None, arg_error)
else:
for sql in sqlparse.split(query):
sql = sql.rstrip(';')
title = '> %s' % (sql)
sql = sql.rstrip(";")
title = "> %s" % (sql)
cur.execute(sql)
if cur.description:
headers = [x[0] for x in cur.description]
@ -233,60 +241,60 @@ def execute_favorite_query(cur, arg, **_):
else:
yield (title, None, None, None)
def list_favorite_queries():
"""List of all favorite queries.
Returns (title, rows, headers, status)"""
headers = ["Name", "Query"]
rows = [(r, FavoriteQueries.instance.get(r))
for r in FavoriteQueries.instance.list()]
rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()]
if not rows:
status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
status = "\nNo favorite queries found." + FavoriteQueries.instance.usage
else:
status = ''
return [('', rows, headers, status)]
status = ""
return [("", rows, headers, status)]
def subst_favorite_query_args(query, args):
"""replace positional parameters ($1...$N) in query."""
for idx, val in enumerate(args):
subst_var = '$' + str(idx + 1)
subst_var = "$" + str(idx + 1)
if subst_var not in query:
return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query]
return [None, "query does not have substitution parameter " + subst_var + ":\n " + query]
query = query.replace(subst_var, val)
match = re.search(r'\$\d+', query)
match = re.search(r"\$\d+", query)
if match:
return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query]
return [None, "missing substitution for " + match.group(0) + " in query:\n " + query]
return [query, None]
@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
@special_command("\\fs", "\\fs name query", "Save a favorite query.")
def save_favorite_query(arg, **_):
"""Save a new favorite query.
Returns (title, rows, headers, status)"""
usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage
if not arg:
return [(None, None, None, usage)]
name, _, query = arg.partition(' ')
name, _, query = arg.partition(" ")
# If either name or query is missing then print the usage and complain.
if (not name) or (not query):
return [(None, None, None,
usage + 'Err: Both name and query are required.')]
return [(None, None, None, usage + "Err: Both name and query are required.")]
FavoriteQueries.instance.save(name, query)
return [(None, None, None, "Saved.")]
@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
def delete_favorite_query(arg, **_):
"""Delete an existing favorite query."""
usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage
if not arg:
return [(None, None, None, usage)]
@ -295,8 +303,7 @@ def delete_favorite_query(arg, **_):
return [(None, None, None, status)]
@special_command('system', 'system [command]',
'Execute a system shell commmand.')
@special_command("system", "system [command]", "Execute a system shell commmand.")
def execute_system_command(arg, **_):
"""Execute a system shell command."""
usage = "Syntax: system [command].\n"
@ -306,13 +313,13 @@ def execute_system_command(arg, **_):
try:
command = arg.strip()
if command.startswith('cd'):
if command.startswith("cd"):
ok, error_message = handle_cd_command(arg)
if not ok:
return [(None, None, None, error_message)]
return [(None, None, None, '')]
return [(None, None, None, "")]
args = arg.split(' ')
args = arg.split(" ")
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
response = output if not error else error
@ -324,25 +331,24 @@ def execute_system_command(arg, **_):
return [(None, None, None, response)]
except OSError as e:
return [(None, None, None, 'OSError: %s' % e.strerror)]
return [(None, None, None, "OSError: %s" % e.strerror)]
def parseargfile(arg):
if arg.startswith('-o '):
if arg.startswith("-o "):
mode = "w"
filename = arg[3:]
else:
mode = 'a'
mode = "a"
filename = arg
if not filename:
raise TypeError('You must provide a filename.')
raise TypeError("You must provide a filename.")
return {'file': os.path.expanduser(filename), 'mode': mode}
return {"file": os.path.expanduser(filename), "mode": mode}
@special_command('tee', 'tee [-o] filename',
'Append all results to an output file (overwrite using -o).')
@special_command("tee", "tee [-o] filename", "Append all results to an output file (overwrite using -o).")
def set_tee(arg, **_):
global tee_file
@ -353,6 +359,7 @@ def set_tee(arg, **_):
return [(None, None, None, "")]
@export
def close_tee():
global tee_file
@ -361,31 +368,29 @@ def close_tee():
tee_file = None
@special_command('notee', 'notee', 'Stop writing results to an output file.')
@special_command("notee", "notee", "Stop writing results to an output file.")
def no_tee(arg, **_):
close_tee()
return [(None, None, None, "")]
@export
def write_tee(output):
global tee_file
if tee_file:
click.echo(output, file=tee_file, nl=False)
click.echo(u'\n', file=tee_file, nl=False)
click.echo("\n", file=tee_file, nl=False)
tee_file.flush()
@special_command('\\once', '\\o [-o] filename',
'Append next result to an output file (overwrite using -o).',
aliases=('\\o', ))
@special_command("\\once", "\\o [-o] filename", "Append next result to an output file (overwrite using -o).", aliases=("\\o",))
def set_once(arg, **_):
global once_file, written_to_once_file
try:
once_file = open(**parseargfile(arg))
except (IOError, OSError) as e:
raise OSError("Cannot write to file '{}': {}".format(
e.filename, e.strerror))
raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
written_to_once_file = False
return [(None, None, None, "")]
@ -396,7 +401,7 @@ def write_once(output):
global once_file, written_to_once_file
if output and once_file:
click.echo(output, file=once_file, nl=False)
click.echo(u"\n", file=once_file, nl=False)
click.echo("\n", file=once_file, nl=False)
once_file.flush()
written_to_once_file = True
@ -410,22 +415,22 @@ def unset_once_if_written():
once_file = None
@special_command('\\pipe_once', '\\| command',
'Send next result to a subprocess.',
aliases=('\\|', ))
@special_command("\\pipe_once", "\\| command", "Send next result to a subprocess.", aliases=("\\|",))
def set_pipe_once(arg, **_):
global pipe_once_process, written_to_pipe_once_process
pipe_once_cmd = shlex.split(arg)
if len(pipe_once_cmd) == 0:
raise OSError("pipe_once requires a command")
written_to_pipe_once_process = False
pipe_once_process = subprocess.Popen(pipe_once_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=1,
encoding='UTF-8',
universal_newlines=True)
pipe_once_process = subprocess.Popen(
pipe_once_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=1,
encoding="UTF-8",
universal_newlines=True,
)
return [(None, None, None, "")]
@ -435,11 +440,10 @@ def write_pipe_once(output):
if output and pipe_once_process:
try:
click.echo(output, file=pipe_once_process.stdin, nl=False)
click.echo(u"\n", file=pipe_once_process.stdin, nl=False)
click.echo("\n", file=pipe_once_process.stdin, nl=False)
except (IOError, OSError) as e:
pipe_once_process.terminate()
raise OSError(
"Failed writing to pipe_once subprocess: {}".format(e.strerror))
raise OSError("Failed writing to pipe_once subprocess: {}".format(e.strerror))
written_to_pipe_once_process = True
@ -450,18 +454,14 @@ def unset_pipe_once_if_written():
if written_to_pipe_once_process:
(stdout_data, stderr_data) = pipe_once_process.communicate()
if len(stdout_data) > 0:
print(stdout_data.rstrip(u"\n"))
print(stdout_data.rstrip("\n"))
if len(stderr_data) > 0:
print(stderr_data.rstrip(u"\n"))
print(stderr_data.rstrip("\n"))
pipe_once_process = None
written_to_pipe_once_process = False
@special_command(
'watch',
'watch [seconds] [-c] query',
'Executes the query every [seconds] seconds (by default 5).'
)
@special_command("watch", "watch [seconds] [-c] query", "Executes the query every [seconds] seconds (by default 5).")
def watch_query(arg, **kwargs):
usage = """Syntax: watch [seconds] [-c] query.
* seconds: The interval at the query will be repeated, in seconds.
@ -480,27 +480,24 @@ def watch_query(arg, **kwargs):
# Oops, we parsed all the arguments without finding a statement
yield (None, None, None, usage)
return
(current_arg, _, arg) = arg.partition(' ')
(current_arg, _, arg) = arg.partition(" ")
try:
seconds = float(current_arg)
continue
except ValueError:
pass
if current_arg == '-c':
if current_arg == "-c":
clear_screen = True
continue
statement = '{0!s} {1!s}'.format(current_arg, arg)
statement = "{0!s} {1!s}".format(current_arg, arg)
destructive_prompt = confirm_destructive_query(statement)
if destructive_prompt is False:
click.secho("Wise choice!")
return
elif destructive_prompt is True:
click.secho("Your call!")
cur = kwargs['cur']
sql_list = [
(sql.rstrip(';'), "> {0!s}".format(sql))
for sql in sqlparse.split(statement)
]
cur = kwargs["cur"]
sql_list = [(sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)]
old_pager_enabled = is_pager_enabled()
while True:
if clear_screen:
@ -509,7 +506,7 @@ def watch_query(arg, **kwargs):
# Somewhere in the code the pager its activated after every yield,
# so we disable it in every iteration
set_pager_enabled(False)
for (sql, title) in sql_list:
for sql, title in sql_list:
cur.execute(sql)
if cur.description:
headers = [x[0] for x in cur.description]
@ -527,7 +524,7 @@ def watch_query(arg, **kwargs):
@export
@special_command('delimiter', None, 'Change SQL delimiter.')
@special_command("delimiter", None, "Change SQL delimiter.")
def set_delimiter(arg, **_):
return delimiter_command.set(arg)

View file

@ -9,43 +9,43 @@ NO_QUERY = 0
PARSED_QUERY = 1
RAW_QUERY = 2
SpecialCommand = namedtuple('SpecialCommand',
['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden',
'case_sensitive'])
SpecialCommand = namedtuple("SpecialCommand", ["handler", "command", "shortcut", "description", "arg_type", "hidden", "case_sensitive"])
COMMANDS = {}
@export
class CommandNotFound(Exception):
pass
@export
def parse_special_command(sql):
command, _, arg = sql.partition(' ')
verbose = '+' in command
command = command.strip().replace('+', '')
command, _, arg = sql.partition(" ")
verbose = "+" in command
command = command.strip().replace("+", "")
return (command, verbose, arg.strip())
@export
def special_command(command, shortcut, description, arg_type=PARSED_QUERY,
hidden=False, case_sensitive=False, aliases=()):
def wrapper(wrapped):
register_special_command(wrapped, command, shortcut, description,
arg_type, hidden, case_sensitive, aliases)
return wrapped
return wrapper
@export
def register_special_command(handler, command, shortcut, description,
arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
def special_command(command, shortcut, description, arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
def wrapper(wrapped):
register_special_command(wrapped, command, shortcut, description, arg_type, hidden, case_sensitive, aliases)
return wrapped
return wrapper
@export
def register_special_command(
handler, command, shortcut, description, arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()
):
cmd = command.lower() if not case_sensitive else command
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
arg_type, hidden, case_sensitive)
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, arg_type, hidden, case_sensitive)
for alias in aliases:
cmd = alias.lower() if not case_sensitive else alias
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
arg_type, case_sensitive=case_sensitive,
hidden=True)
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, arg_type, case_sensitive=case_sensitive, hidden=True)
@export
def execute(cur, sql):
@ -62,11 +62,11 @@ def execute(cur, sql):
except KeyError:
special_cmd = COMMANDS[command.lower()]
if special_cmd.case_sensitive:
raise CommandNotFound('Command not found: %s' % command)
raise CommandNotFound("Command not found: %s" % command)
# "help <SQL KEYWORD> is a special case. We want built-in help, not
# mycli help here.
if command == 'help' and arg:
if command == "help" and arg:
return show_keyword_help(cur=cur, arg=arg)
if special_cmd.arg_type == NO_QUERY:
@ -76,9 +76,10 @@ def execute(cur, sql):
elif special_cmd.arg_type == RAW_QUERY:
return special_cmd.handler(cur=cur, query=sql)
@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?'))
@special_command("help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?"))
def show_help(): # All the parameters are ignored.
headers = ['Command', 'Shortcut', 'Description']
headers = ["Command", "Shortcut", "Description"]
result = []
for _, value in sorted(COMMANDS.items()):
@ -86,6 +87,7 @@ def show_help(): # All the parameters are ignored.
result.append((value.command, value.shortcut, value.description))
return [(None, result, headers, None)]
def show_keyword_help(cur, arg):
"""
Call the built-in "show <command>", to display help for an SQL keyword.
@ -99,22 +101,19 @@ def show_keyword_help(cur, arg):
cur.execute(query)
if cur.description and cur.rowcount > 0:
headers = [x[0] for x in cur.description]
return [(None, cur, headers, '')]
return [(None, cur, headers, "")]
else:
return [(None, None, None, 'No help found for {0}.'.format(keyword))]
return [(None, None, None, "No help found for {0}.".format(keyword))]
@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', ))
@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY)
@special_command("exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q",))
@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY)
def quit(*_args):
raise EOFError
@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).',
arg_type=NO_QUERY, case_sensitive=True)
@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.',
arg_type=NO_QUERY, case_sensitive=True)
@special_command('\\G', '\\G', 'Display current query results vertically.',
arg_type=NO_QUERY, case_sensitive=True)
@special_command("\\e", "\\e", "Edit command with editor (uses $EDITOR).", arg_type=NO_QUERY, case_sensitive=True)
@special_command("\\clip", "\\clip", "Copy query to the system clipboard.", arg_type=NO_QUERY, case_sensitive=True)
@special_command("\\G", "\\G", "Display current query results vertically.", arg_type=NO_QUERY, case_sensitive=True)
def stub():
raise NotImplementedError

View file

@ -1,20 +1,22 @@
import os
import subprocess
def handle_cd_command(arg):
"""Handles a `cd` shell command by calling python's os.chdir."""
CD_CMD = 'cd'
tokens = arg.split(CD_CMD + ' ')
CD_CMD = "cd"
tokens = arg.split(CD_CMD + " ")
directory = tokens[-1] if len(tokens) > 1 else None
if not directory:
return False, "No folder name was provided."
try:
os.chdir(directory)
subprocess.call(['pwd'])
subprocess.call(["pwd"])
return True, None
except OSError as e:
return False, e.strerror
def format_uptime(uptime_in_seconds):
"""Format number of seconds into human-readable string.
@ -32,15 +34,15 @@ def format_uptime(uptime_in_seconds):
uptime_values = []
for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')):
for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")):
if value == 0 and not uptime_values:
# Don't include a value/unit if the unit isn't applicable to
# the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
continue
elif value == 1 and unit.endswith('s'):
elif value == 1 and unit.endswith("s"):
# Remove the "s" if the unit is singular.
unit = unit[:-1]
uptime_values.append('{0} {1}'.format(value, unit))
uptime_values.append("{0} {1}".format(value, unit))
uptime = ' '.join(uptime_values)
uptime = " ".join(uptime_values)
return uptime

View file

@ -2,8 +2,12 @@
from mycli.packages.parseutils import extract_tables
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
'sql-update-2', )
supported_formats = (
"sql-insert",
"sql-update",
"sql-update-1",
"sql-update-2",
)
preprocessors = ()
@ -25,19 +29,18 @@ def adapter(data, headers, table_format=None, **kwargs):
table_name = table[1]
else:
table_name = "`DUAL`"
if table_format == 'sql-insert':
if table_format == "sql-insert":
h = "`, `".join(headers)
yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h)
prefix = " "
for d in data:
values = ", ".join(escape_for_sql_statement(v)
for i, v in enumerate(d))
values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d))
yield "{}({})".format(prefix, values)
if prefix == " ":
prefix = ", "
yield ";"
if table_format.startswith('sql-update'):
s = table_format.split('-')
if table_format.startswith("sql-update"):
s = table_format.split("-")
keys = 1
if len(s) > 2:
keys = int(s[-1])
@ -49,8 +52,7 @@ def adapter(data, headers, table_format=None, **kwargs):
if prefix == " ":
prefix = ", "
f = "`{}` = {}"
where = (f.format(headers[i], escape_for_sql_statement(
d[i])) for i in range(keys))
where = (f.format(headers[i], escape_for_sql_statement(d[i])) for i in range(keys))
yield "WHERE {};".format(" AND ".join(where))
@ -58,5 +60,4 @@ def register_new_formatter(TabularOutputFormatter):
global formatter
formatter = TabularOutputFormatter
for sql_format in supported_formats:
TabularOutputFormatter.register_new_formatter(
sql_format, adapter, preprocessors, {'table_format': sql_format})
TabularOutputFormatter.register_new_formatter(sql_format, adapter, preprocessors, {"table_format": sql_format})

View file

@ -29,7 +29,7 @@ def search_history(event: KeyPressEvent):
formatted_history_items = []
original_history_items = []
for item, timestamp in history_items_with_timestamp:
formatted_item = item.replace('\n', ' ')
formatted_item = item.replace("\n", " ")
timestamp = timestamp.split(".")[0] if "." in timestamp else timestamp
formatted_history_items.append(f"{timestamp} {formatted_item}")
original_history_items.append(item)

View file

@ -1,5 +1,5 @@
import os
from typing import Iterable, Union, List, Tuple
from typing import Union, List, Tuple
from prompt_toolkit.history import FileHistory

File diff suppressed because it is too large Load diff

View file

@ -5,31 +5,29 @@ import re
import pymysql
from .packages import special
from pymysql.constants import FIELD_TYPE
from pymysql.converters import (convert_datetime,
convert_timedelta, convert_date, conversions,
decoders)
from pymysql.converters import convert_datetime, convert_timedelta, convert_date, conversions, decoders
try:
import paramiko
import paramiko # noqa: F401
import sshtunnel
except ImportError:
from mycli.packages.paramiko_stub import paramiko
pass
_logger = logging.getLogger(__name__)
FIELD_TYPES = decoders.copy()
FIELD_TYPES.update({
FIELD_TYPE.NULL: type(None)
})
FIELD_TYPES.update({FIELD_TYPE.NULL: type(None)})
ERROR_CODE_ACCESS_DENIED = 1045
class ServerSpecies(enum.Enum):
MySQL = 'MySQL'
MariaDB = 'MariaDB'
Percona = 'Percona'
TiDB = 'TiDB'
Unknown = 'MySQL'
MySQL = "MySQL"
MariaDB = "MariaDB"
Percona = "Percona"
TiDB = "TiDB"
Unknown = "MySQL"
class ServerInfo:
@ -43,7 +41,7 @@ class ServerInfo:
if not version_str or not isinstance(version_str, str):
return 0
try:
major, minor, patch = version_str.split('.')
major, minor, patch = version_str.split(".")
except ValueError:
return 0
else:
@ -52,55 +50,67 @@ class ServerInfo:
@classmethod
def from_version_string(cls, version_string):
if not version_string:
return cls(ServerSpecies.Unknown, '')
return cls(ServerSpecies.Unknown, "")
re_species = (
(r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
(r'[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)', ServerSpecies.TiDB),
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
ServerSpecies.Percona),
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
ServerSpecies.MySQL),
(r"(?P<version>[0-9\.]+)-MariaDB", ServerSpecies.MariaDB),
(r"[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)", ServerSpecies.TiDB),
(r"(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)", ServerSpecies.Percona),
(r"(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)", ServerSpecies.MySQL),
)
for regexp, species in re_species:
match = re.search(regexp, version_string)
if match is not None:
parsed_version = match.group('version')
parsed_version = match.group("version")
detected_species = species
break
else:
detected_species = ServerSpecies.Unknown
parsed_version = ''
parsed_version = ""
return cls(detected_species, parsed_version)
def __str__(self):
if self.species:
return f'{self.species.value} {self.version_str}'
return f"{self.species.value} {self.version_str}"
else:
return self.version_str
class SQLExecute(object):
databases_query = """SHOW DATABASES"""
databases_query = '''SHOW DATABASES'''
tables_query = '''SHOW TABLES'''
tables_query = """SHOW TABLES"""
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
users_query = """SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user"""
functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns
table_columns_query = """select TABLE_NAME, COLUMN_NAME from information_schema.columns
where table_schema = '%s'
order by table_name,ordinal_position'''
order by table_name,ordinal_position"""
def __init__(self, database, user, password, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename, init_command=None):
def __init__(
self,
database,
user,
password,
host,
port,
socket,
charset,
local_infile,
ssl,
ssh_user,
ssh_host,
ssh_port,
ssh_password,
ssh_key_filename,
init_command=None,
):
self.dbname = database
self.user = user
self.password = password
@ -120,52 +130,79 @@ class SQLExecute(object):
self.init_command = init_command
self.connect()
def connect(self, database=None, user=None, password=None, host=None,
port=None, socket=None, charset=None, local_infile=None,
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
ssh_password=None, ssh_key_filename=None, init_command=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
host = (host or self.host)
port = (port or self.port)
socket = (socket or self.socket)
charset = (charset or self.charset)
local_infile = (local_infile or self.local_infile)
ssl = (ssl or self.ssl)
ssh_user = (ssh_user or self.ssh_user)
ssh_host = (ssh_host or self.ssh_host)
ssh_port = (ssh_port or self.ssh_port)
ssh_password = (ssh_password or self.ssh_password)
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
init_command = (init_command or self.init_command)
def connect(
self,
database=None,
user=None,
password=None,
host=None,
port=None,
socket=None,
charset=None,
local_infile=None,
ssl=None,
ssh_host=None,
ssh_port=None,
ssh_user=None,
ssh_password=None,
ssh_key_filename=None,
init_command=None,
):
db = database or self.dbname
user = user or self.user
password = password or self.password
host = host or self.host
port = port or self.port
socket = socket or self.socket
charset = charset or self.charset
local_infile = local_infile or self.local_infile
ssl = ssl or self.ssl
ssh_user = ssh_user or self.ssh_user
ssh_host = ssh_host or self.ssh_host
ssh_port = ssh_port or self.ssh_port
ssh_password = ssh_password or self.ssh_password
ssh_key_filename = ssh_key_filename or self.ssh_key_filename
init_command = init_command or self.init_command
_logger.debug(
'Connection DB Params: \n'
'\tdatabase: %r'
'\tuser: %r'
'\thost: %r'
'\tport: %r'
'\tsocket: %r'
'\tcharset: %r'
'\tlocal_infile: %r'
'\tssl: %r'
'\tssh_user: %r'
'\tssh_host: %r'
'\tssh_port: %r'
'\tssh_password: %r'
'\tssh_key_filename: %r'
'\tinit_command: %r',
db, user, host, port, socket, charset, local_infile, ssl,
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
init_command
"Connection DB Params: \n"
"\tdatabase: %r"
"\tuser: %r"
"\thost: %r"
"\tport: %r"
"\tsocket: %r"
"\tcharset: %r"
"\tlocal_infile: %r"
"\tssl: %r"
"\tssh_user: %r"
"\tssh_host: %r"
"\tssh_port: %r"
"\tssh_password: %r"
"\tssh_key_filename: %r"
"\tinit_command: %r",
db,
user,
host,
port,
socket,
charset,
local_infile,
ssl,
ssh_user,
ssh_host,
ssh_port,
ssh_password,
ssh_key_filename,
init_command,
)
conv = conversions.copy()
conv.update({
FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
})
conv.update(
{
FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
}
)
defer_connect = False
@ -181,29 +218,45 @@ class SQLExecute(object):
ssl_context = self._create_ssl_ctx(ssl)
conn = pymysql.connect(
database=db, user=user, password=password, host=host, port=port,
unix_socket=socket, use_unicode=True, charset=charset,
autocommit=True, client_flag=client_flag,
local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli",
defer_connect=defer_connect, init_command=init_command
database=db,
user=user,
password=password,
host=host,
port=port,
unix_socket=socket,
use_unicode=True,
charset=charset,
autocommit=True,
client_flag=client_flag,
local_infile=local_infile,
conv=conv,
ssl=ssl_context,
program_name="mycli",
defer_connect=defer_connect,
init_command=init_command,
)
if ssh_host:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
client.connect(
ssh_host, ssh_port, ssh_user, ssh_password,
key_filename=ssh_key_filename
)
chan = client.get_transport().open_channel(
'direct-tcpip',
(host, port),
('0.0.0.0', 0),
)
conn.connect(chan)
##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel
#####
# instead let's open a tunnel and rewrite host:port to local bind
try:
chan = sshtunnel.SSHTunnelForwarder(
(ssh_host, ssh_port),
ssh_username=ssh_user,
ssh_pkey=ssh_key_filename,
ssh_password=ssh_password,
remote_bind_address=(host, port),
)
chan.start()
if hasattr(self, 'conn'):
conn.host = chan.local_bind_host
conn.port = chan.local_bind_port
conn.connect()
except Exception as e:
raise e
if hasattr(self, "conn"):
self.conn.close()
self.conn = conn
# Update them after the connection is made to ensure that it was a
@ -235,24 +288,24 @@ class SQLExecute(object):
# Split the sql into separate queries and run each one.
# Unless it's saving a favorite query, in which case we
# want to save them all together.
if statement.startswith('\\fs'):
if statement.startswith("\\fs"):
components = [statement]
else:
components = special.split_queries(statement)
for sql in components:
# \G is treated specially since we have to set the expanded output.
if sql.endswith('\\G'):
if sql.endswith("\\G"):
special.set_expanded_output(True)
sql = sql[:-2].strip()
cur = self.conn.cursor()
try: # Special command
_logger.debug('Trying a dbspecial command. sql: %r', sql)
try: # Special command
_logger.debug("Trying a dbspecial command. sql: %r", sql)
for result in special.execute(cur, sql):
yield result
except special.CommandNotFound: # Regular SQL
_logger.debug('Regular sql statement. sql: %r', sql)
_logger.debug("Regular sql statement. sql: %r", sql)
cur.execute(sql)
while True:
yield self.get_result(cur)
@ -271,12 +324,11 @@ class SQLExecute(object):
# e.g. SELECT or SHOW.
if cursor.description is not None:
headers = [x[0] for x in cursor.description]
status = '{0} row{1} in set'
status = "{0} row{1} in set"
else:
_logger.debug('No rows in result.')
status = 'Query OK, {0} row{1} affected'
status = status.format(cursor.rowcount,
'' if cursor.rowcount == 1 else 's')
_logger.debug("No rows in result.")
status = "Query OK, {0} row{1} affected"
status = status.format(cursor.rowcount, "" if cursor.rowcount == 1 else "s")
return (title, cursor if cursor.description else None, headers, status)
@ -284,7 +336,7 @@ class SQLExecute(object):
"""Yields table names"""
with self.conn.cursor() as cur:
_logger.debug('Tables Query. sql: %r', self.tables_query)
_logger.debug("Tables Query. sql: %r", self.tables_query)
cur.execute(self.tables_query)
for row in cur:
yield row
@ -292,14 +344,14 @@ class SQLExecute(object):
def table_columns(self):
"""Yields (table name, column name) pairs"""
with self.conn.cursor() as cur:
_logger.debug('Columns Query. sql: %r', self.table_columns_query)
_logger.debug("Columns Query. sql: %r", self.table_columns_query)
cur.execute(self.table_columns_query % self.dbname)
for row in cur:
yield row
def databases(self):
with self.conn.cursor() as cur:
_logger.debug('Databases Query. sql: %r', self.databases_query)
_logger.debug("Databases Query. sql: %r", self.databases_query)
cur.execute(self.databases_query)
return [x[0] for x in cur.fetchall()]
@ -307,31 +359,31 @@ class SQLExecute(object):
"""Yields tuples of (schema_name, function_name)"""
with self.conn.cursor() as cur:
_logger.debug('Functions Query. sql: %r', self.functions_query)
_logger.debug("Functions Query. sql: %r", self.functions_query)
cur.execute(self.functions_query % self.dbname)
for row in cur:
yield row
def show_candidates(self):
with self.conn.cursor() as cur:
_logger.debug('Show Query. sql: %r', self.show_candidates_query)
_logger.debug("Show Query. sql: %r", self.show_candidates_query)
try:
cur.execute(self.show_candidates_query)
except pymysql.DatabaseError as e:
_logger.error('No show completions due to %r', e)
yield ''
_logger.error("No show completions due to %r", e)
yield ""
else:
for row in cur:
yield (row[0].split(None, 1)[-1], )
yield (row[0].split(None, 1)[-1],)
def users(self):
with self.conn.cursor() as cur:
_logger.debug('Users Query. sql: %r', self.users_query)
_logger.debug("Users Query. sql: %r", self.users_query)
try:
cur.execute(self.users_query)
except pymysql.DatabaseError as e:
_logger.error('No user completions due to %r', e)
yield ''
_logger.error("No user completions due to %r", e)
yield ""
else:
for row in cur:
yield row
@ -343,17 +395,17 @@ class SQLExecute(object):
def reset_connection_id(self):
# Remember current connection id
_logger.debug('Get current connection id')
_logger.debug("Get current connection id")
try:
res = self.run('select connection_id()')
res = self.run("select connection_id()")
for title, cur, headers, status in res:
self.connection_id = cur.fetchone()[0]
except Exception as e:
# See #1054
self.connection_id = -1
_logger.error('Failed to get connection id: %s', e)
_logger.error("Failed to get connection id: %s", e)
else:
_logger.debug('Current connection id: %s', self.connection_id)
_logger.debug("Current connection id: %s", self.connection_id)
def change_db(self, db):
self.conn.select_db(db)
@ -392,6 +444,6 @@ class SQLExecute(object):
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
ctx.maximum_version = ssl.TLSVersion.TLSv1_3
else:
_logger.error('Invalid tls version: %s', tls_version)
_logger.error("Invalid tls version: %s", tls_version)
return ctx

59
pyproject.toml Normal file
View file

@ -0,0 +1,59 @@
[project]
name = "mycli"
dynamic = ["version"]
description = "CLI for MySQL Database. With auto-completion and syntax highlighting."
readme = "README.md"
requires-python = ">=3.7"
license = { text = "BSD" }
authors = [{ name = "Mycli Core Team", email = "mycli-dev@googlegroups.com" }]
urls = { homepage = "http://mycli.net" }
dependencies = [
"click >= 7.0",
"cryptography >= 1.0.0",
"Pygments>=1.6",
"prompt_toolkit>=3.0.6,<4.0.0",
"PyMySQL >= 0.9.2",
"sqlparse>=0.3.0,<0.5.0",
"sqlglot>=5.1.3",
"configobj >= 5.0.5",
"cli_helpers[styles] >= 2.2.1",
"pyperclip >= 1.8.1",
"pyaes >= 1.6.1",
"pyfzf >= 0.3.1",
"importlib_resources >= 5.0.0; python_version<'3.9'",
]
[build-system]
requires = [
"setuptools>=64.0",
"setuptools-scm>=8;python_version>='3.8'",
"setuptools-scm<8;python_version<'3.8'",
]
build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
[project.optional-dependencies]
ssh = ["paramiko", "sshtunnel"]
dev = [
"behave>=1.2.6",
"coverage>=7.2.7",
"pexpect>=4.9.0",
"pytest>=7.4.4",
"pytest-cov>=4.1.0",
"tox>=4.8.0",
"pdbpp>=0.10.3",
]
[project.scripts]
mycli = "mycli.main:cli"
[tool.setuptools.package-data]
mycli = ["myclirc", "AUTHORS", "SPONSORS"]
[tool.setuptools.packages.find]
include = ["mycli*"]
[tool.ruff]
line-length = 140

View file

@ -1,119 +0,0 @@
"""A script to publish a release of mycli to PyPI."""
from optparse import OptionParser
import re
import subprocess
import sys
import click
DEBUG = False
CONFIRM_STEPS = False
DRY_RUN = False
def skip_step():
"""
Asks for user's response whether to run a step. Default is yes.
:return: boolean
"""
global CONFIRM_STEPS
if CONFIRM_STEPS:
return not click.confirm('--- Run this step?', default=True)
return False
def run_step(*args):
"""
Prints out the command and asks if it should be run.
If yes (default), runs it.
:param args: list of strings (command and args)
"""
global DRY_RUN
cmd = args
print(' '.join(cmd))
if skip_step():
print('--- Skipping...')
elif DRY_RUN:
print('--- Pretending to run...')
else:
subprocess.check_output(cmd)
def version(version_file):
_version_re = re.compile(
r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)')
with open(version_file) as f:
ver = _version_re.search(f.read()).group('version')
return ver
def commit_for_release(version_file, ver):
run_step('git', 'reset')
run_step('git', 'add', version_file)
run_step('git', 'commit', '--message',
'Releasing version {}'.format(ver))
def create_git_tag(tag_name):
run_step('git', 'tag', tag_name)
def create_distribution_files():
run_step('python', 'setup.py', 'sdist', 'bdist_wheel')
def upload_distribution_files():
run_step('twine', 'upload', 'dist/*')
def push_to_github():
run_step('git', 'push', 'origin', 'main')
def push_tags_to_github():
run_step('git', 'push', '--tags', 'origin')
def checklist(questions):
for question in questions:
if not click.confirm('--- {}'.format(question), default=False):
sys.exit(1)
if __name__ == '__main__':
if DEBUG:
subprocess.check_output = lambda x: x
ver = version('mycli/__init__.py')
parser = OptionParser()
parser.add_option(
"-c", "--confirm-steps", action="store_true", dest="confirm_steps",
default=False, help=("Confirm every step. If the step is not "
"confirmed, it will be skipped.")
)
parser.add_option(
"-d", "--dry-run", action="store_true", dest="dry_run",
default=False, help="Print out, but not actually run any steps."
)
popts, pargs = parser.parse_args()
CONFIRM_STEPS = popts.confirm_steps
DRY_RUN = popts.dry_run
print('Releasing Version:', ver)
if not click.confirm('Are you sure?', default=False):
sys.exit(1)
commit_for_release('mycli/__init__.py', ver)
create_git_tag('v{}'.format(ver))
create_distribution_files()
push_to_github()
push_tags_to_github()
upload_distribution_files()

View file

@ -10,6 +10,7 @@ colorama>=0.4.1
git+https://github.com/hayd/pep8radius.git # --error-status option not released
click>=7.0
paramiko==2.11.0
sshtunnel==0.4.0
pyperclip>=1.8.1
importlib_resources>=5.0.0
pyaes>=1.6.1

View file

@ -1,18 +0,0 @@
[bdist_wheel]
universal = 1
[tool:pytest]
addopts = --capture=sys
--showlocals
--doctest-modules
--doctest-ignore-import-errors
--ignore=setup.py
--ignore=mycli/magic.py
--ignore=mycli/packages/parseutils.py
--ignore=test/features
[pep8]
rev = master
docformatter = True
diff = True
error-status = True

127
setup.py
View file

@ -1,127 +0,0 @@
#!/usr/bin/env python
import ast
import re
import subprocess
import sys
from setuptools import Command, find_packages, setup
from setuptools.command.test import test as TestCommand
_version_re = re.compile(r'__version__\s+=\s+(.*)')
with open('mycli/__init__.py') as f:
version = ast.literal_eval(_version_re.search(
f.read()).group(1))
description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.'
install_requirements = [
'click >= 7.0',
# Pinning cryptography is not needed after paramiko 2.11.0. Correct it
'cryptography >= 1.0.0',
# 'Pygments>=1.6,<=2.11.1',
'Pygments>=1.6',
'prompt_toolkit>=3.0.6,<4.0.0',
'PyMySQL >= 0.9.2',
'sqlparse>=0.3.0,<0.5.0',
'sqlglot>=5.1.3',
'configobj >= 5.0.5',
'cli_helpers[styles] >= 2.2.1',
'pyperclip >= 1.8.1',
'pyaes >= 1.6.1',
'pyfzf >= 0.3.1',
]
if sys.version_info.minor < 9:
install_requirements.append('importlib_resources >= 5.0.0')
class lint(Command):
description = 'check code against PEP 8 (and fix violations)'
user_options = [
('branch=', 'b', 'branch/revision to compare against (e.g. main)'),
('fix', 'f', 'fix the violations in place'),
('error-status', 'e', 'return an error code on failed PEP check'),
]
def initialize_options(self):
"""Set the default options."""
self.branch = 'main'
self.fix = False
self.error_status = True
def finalize_options(self):
pass
def run(self):
cmd = 'pep8radius {}'.format(self.branch)
if self.fix:
cmd += ' --in-place'
if self.error_status:
cmd += ' --error-status'
sys.exit(subprocess.call(cmd, shell=True))
class test(TestCommand):
user_options = [
('pytest-args=', 'a', 'Arguments to pass to pytest'),
('behave-args=', 'b', 'Arguments to pass to pytest')
]
def initialize_options(self):
TestCommand.initialize_options(self)
self.pytest_args = ''
self.behave_args = '--no-capture'
def run_tests(self):
unit_test_errno = subprocess.call(
'pytest test/ ' + self.pytest_args,
shell=True
)
cli_errno = subprocess.call(
'behave test/features ' + self.behave_args,
shell=True
)
subprocess.run(['git', 'checkout', '--', 'test/myclirc'], check=False)
sys.exit(unit_test_errno or cli_errno)
setup(
name='mycli',
author='Mycli Core Team',
author_email='mycli-dev@googlegroups.com',
version=version,
url='http://mycli.net',
packages=find_packages(exclude=['test*']),
package_data={'mycli': ['myclirc', 'AUTHORS', 'SPONSORS']},
description=description,
long_description=description,
install_requires=install_requirements,
entry_points={
'console_scripts': ['mycli = mycli.main:cli'],
},
cmdclass={'lint': lint, 'test': test},
python_requires=">=3.7",
classifiers=[
'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License',
'Operating System :: Unix',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: SQL',
'Topic :: Database',
'Topic :: Database :: Front-Ends',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries :: Python Modules',
],
extras_require={
'ssh': ['paramiko'],
},
)

View file

@ -1,13 +1,12 @@
import pytest
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
from .utils import HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT
import mycli.sqlexecute
@pytest.fixture(scope="function")
def connection():
create_db('mycli_test_db')
connection = db_connection('mycli_test_db')
create_db("mycli_test_db")
connection = db_connection("mycli_test_db")
yield connection
connection.close()
@ -22,8 +21,18 @@ def cursor(connection):
@pytest.fixture
def executor(connection):
return mycli.sqlexecute.SQLExecute(
database='mycli_test_db', user=USER,
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
database="mycli_test_db",
user=USER,
host=HOST,
password=PASSWORD,
port=PORT,
socket=None,
charset=CHARSET,
local_infile=False,
ssl=None,
ssh_user=SSH_USER,
ssh_host=SSH_HOST,
ssh_port=SSH_PORT,
ssh_password=None,
ssh_key_filename=None,
)

View file

@ -1,8 +1,7 @@
import pymysql
def create_db(hostname='localhost', port=3306, username=None,
password=None, dbname=None):
def create_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
"""Create test database.
:param hostname: string
@ -14,17 +13,12 @@ def create_db(hostname='localhost', port=3306, username=None,
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
cr.execute('drop database if exists ' + dbname)
cr.execute('create database ' + dbname)
cr.execute("drop database if exists " + dbname)
cr.execute("create database " + dbname)
cn.close()
@ -44,20 +38,13 @@ def create_cn(hostname, port, password, username, dbname):
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
return cn
def drop_db(hostname='localhost', port=3306, username=None,
password=None, dbname=None):
def drop_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
"""Drop database.
:param hostname: string
@ -68,17 +55,11 @@ def drop_db(hostname='localhost', port=3306, username=None,
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
cr.execute('drop database if exists ' + dbname)
cr.execute("drop database if exists " + dbname)
close_cn(cn)

View file

@ -9,96 +9,72 @@ import pexpect
from steps.wrappers import run_cli, wait_prompt
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log")
SELF_CONNECTING_FEATURES = (
'test/features/connection.feature',
)
SELF_CONNECTING_FEATURES = ("test/features/connection.feature",)
MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
MY_CNF_PATH = os.path.expanduser("~/.my.cnf")
MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup"
MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf")
MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup"
def get_db_name_from_context(context):
return context.config.userdata.get(
'my_test_db', None
) or "mycli_behave_tests"
return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests"
def before_all(context):
"""Set env parameters."""
os.environ['LINES'] = "100"
os.environ['COLUMNS'] = "100"
os.environ['EDITOR'] = 'ex'
os.environ['LC_ALL'] = 'en_US.UTF-8'
os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
os.environ['MYCLI_HISTFILE'] = os.devnull
os.environ["LINES"] = "100"
os.environ["COLUMNS"] = "100"
os.environ["EDITOR"] = "ex"
os.environ["LC_ALL"] = "en_US.UTF-8"
os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
os.environ["MYCLI_HISTFILE"] = os.devnull
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
# test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
# login_path_file = os.path.join(test_dir, "mylogin.cnf")
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
'.coveragerc')
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc")
context.exit_sent = False
vi = '_'.join([str(x) for x in sys.version_info[:3]])
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = get_db_name_from_context(context)
db_name_full = '{0}_{1}'.format(db_name, vi)
db_name_full = "{0}_{1}".format(db_name, vi)
# Store get params from config/environment variables
context.conf = {
'host': context.config.userdata.get(
'my_test_host',
os.getenv('PYTEST_HOST', 'localhost')
),
'port': context.config.userdata.get(
'my_test_port',
int(os.getenv('PYTEST_PORT', '3306'))
),
'user': context.config.userdata.get(
'my_test_user',
os.getenv('PYTEST_USER', 'root')
),
'pass': context.config.userdata.get(
'my_test_pass',
os.getenv('PYTEST_PASSWORD', None)
),
'cli_command': context.config.userdata.get(
'my_cli_command', None) or
sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
'dbname': db_name,
'dbname_tmp': db_name_full + '_tmp',
'vi': vi,
'pager_boundary': '---boundary---',
"host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", "localhost")),
"port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", "3306"))),
"user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", "root")),
"pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)),
"cli_command": context.config.userdata.get("my_cli_command", None)
or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
"dbname": db_name,
"dbname_tmp": db_name_full + "_tmp",
"vi": vi,
"pager_boundary": "---boundary---",
}
_, my_cnf = mkstemp()
with open(my_cnf, 'w') as f:
with open(my_cnf, "w") as f:
f.write(
'[client]\n'
'pager={0} {1} {2}\n'.format(
sys.executable, os.path.join(context.package_root,
'test/features/wrappager.py'),
context.conf['pager_boundary'])
"[client]\n" "pager={0} {1} {2}\n".format(
sys.executable, os.path.join(context.package_root, "test/features/wrappager.py"), context.conf["pager_boundary"]
)
)
context.conf['defaults-file'] = my_cnf
context.conf['myclirc'] = os.path.join(context.package_root, 'test',
'myclirc')
context.conf["defaults-file"] = my_cnf
context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc")
context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
context.conf['user'],
context.conf['pass'],
context.conf['dbname'])
context.cn = dbutils.create_db(
context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]
)
context.fixture_data = fixutils.read_fixture_files()
@ -106,12 +82,10 @@ def before_all(context):
def after_all(context):
"""Unset env parameters."""
dbutils.close_cn(context.cn)
dbutils.drop_db(context.conf['host'], context.conf['port'],
context.conf['user'], context.conf['pass'],
context.conf['dbname'])
dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"])
# Restore env vars.
#for k, v in context.pgenv.items():
# for k, v in context.pgenv.items():
# if k in os.environ and v is None:
# del os.environ[k]
# elif v:
@ -123,8 +97,8 @@ def before_step(context, _):
def before_scenario(context, arg):
with open(test_log_file, 'w') as f:
f.write('')
with open(test_log_file, "w") as f:
f.write("")
if arg.location.filename not in SELF_CONNECTING_FEATURES:
run_cli(context)
wait_prompt(context)
@ -140,23 +114,18 @@ def after_scenario(context, _):
"""Cleans up after each test complete."""
with open(test_log_file) as f:
for line in f:
if 'error' in line.lower():
raise RuntimeError(f'Error in log file: {line}')
if "error" in line.lower():
raise RuntimeError(f"Error in log file: {line}")
if hasattr(context, 'cli') and not context.exit_sent:
if hasattr(context, "cli") and not context.exit_sent:
# Quit nicely.
if not context.atprompt:
user = context.conf['user']
host = context.conf['host']
user = context.conf["user"]
host = context.conf["host"]
dbname = context.currentdb
context.cli.expect_exact(
'{0}@{1}:{2}>'.format(
user, host, dbname
),
timeout=5
)
context.cli.sendcontrol('c')
context.cli.sendcontrol('d')
context.cli.expect_exact("{0}@{1}:{2}>".format(user, host, dbname), timeout=5)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
context.cli.expect_exact(pexpect.EOF, timeout=5)
if os.path.exists(MY_CNF_BACKUP_PATH):

View file

@ -1,5 +1,4 @@
import os
import io
def read_fixture_lines(filename):
@ -20,9 +19,9 @@ def read_fixture_files():
fixture_dict = {}
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, 'fixture_data/')
fixture_dir = os.path.join(current_dir, "fixture_data/")
for filename in os.listdir(fixture_dir):
if filename not in ['.', '..']:
if filename not in [".", ".."]:
fullname = os.path.join(fixture_dir, filename)
fixture_dict[filename] = read_fixture_lines(fullname)

View file

@ -6,41 +6,42 @@ import wrappers
from utils import parse_cli_args_to_dict
@when('we run dbcli with {arg}')
@when("we run dbcli with {arg}")
def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
@when('we execute a small query')
@when("we execute a small query")
def step_execute_small_query(context):
context.cli.sendline('select 1')
context.cli.sendline("select 1")
@when('we execute a large query')
@when("we execute a large query")
def step_execute_large_query(context):
context.cli.sendline(
'select {}'.format(','.join([str(n) for n in range(1, 50)])))
context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
@then('we see small results in horizontal format')
@then("we see small results in horizontal format")
def step_see_small_results(context):
wrappers.expect_pager(context, dedent("""\
wrappers.expect_pager(
context,
dedent("""\
+---+\r
| 1 |\r
+---+\r
| 1 |\r
+---+\r
\r
"""), timeout=5)
wrappers.expect_exact(context, '1 row in set', timeout=2)
"""),
timeout=5,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)
@then('we see large results in vertical format')
@then("we see large results in vertical format")
def step_see_large_results(context):
rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)]
expected = ('***************************[ 1. row ]'
'***************************\r\n' +
'{}\r\n'.format('\r\n'.join(rows) + '\r\n'))
rows = ["{n:3}| {n}".format(n=str(n)) for n in range(1, 50)]
expected = "***************************[ 1. row ]" "***************************\r\n" + "{}\r\n".format("\r\n".join(rows) + "\r\n")
wrappers.expect_pager(context, expected, timeout=10)
wrappers.expect_exact(context, '1 row in set', timeout=2)
wrappers.expect_exact(context, "1 row in set", timeout=2)

View file

@ -5,18 +5,18 @@ to call the step in "*.feature" file.
"""
from behave import when
from behave import when, then
from textwrap import dedent
import tempfile
import wrappers
@when('we run dbcli')
@when("we run dbcli")
def step_run_cli(context):
wrappers.run_cli(context)
@when('we wait for prompt')
@when("we wait for prompt")
def step_wait_prompt(context):
wrappers.wait_prompt(context)
@ -24,77 +24,75 @@ def step_wait_prompt(context):
@when('we send "ctrl + d"')
def step_ctrl_d(context):
"""Send Ctrl + D to hopefully exit."""
context.cli.sendcontrol('d')
context.cli.sendcontrol("d")
context.exit_sent = True
@when('we send "\?" command')
@when(r'we send "\?" command')
def step_send_help(context):
"""Send \?
r"""Send \?
to see help.
"""
context.cli.sendline('\\?')
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
context.cli.sendline("\\?")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@when(u'we send source command')
@when("we send source command")
def step_send_source_command(context):
with tempfile.NamedTemporaryFile() as f:
f.write(b'\?')
f.write(b"\\?")
f.flush()
context.cli.sendline('\. {0}'.format(f.name))
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
context.cli.sendline("\\. {0}".format(f.name))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@when(u'we run query to check application_name')
@when("we run query to check application_name")
def step_check_application_name(context):
context.cli.sendline(
"SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'"
)
@then(u'we see found')
@then("we see found")
def step_see_found(context):
wrappers.expect_exact(
context,
context.conf['pager_boundary'] + '\r' + dedent('''
context.conf["pager_boundary"]
+ "\r"
+ dedent("""
+-------+\r
| found |\r
+-------+\r
| found |\r
+-------+\r
\r
''') + context.conf['pager_boundary'],
timeout=5
""")
+ context.conf["pager_boundary"],
timeout=5,
)
@then(u'we confirm the destructive warning')
def step_confirm_destructive_command(context):
@then("we confirm the destructive warning")
def step_confirm_destructive_command(context): # noqa
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
context.cli.sendline('y')
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline("y")
@when(u'we answer the destructive warning with "{confirmation}"')
def step_confirm_destructive_command(context, confirmation):
@when('we answer the destructive warning with "{confirmation}"')
def step_confirm_destructive_command(context, confirmation): # noqa
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline(confirmation)
@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
def step_confirm_destructive_command(context, confirmation, text):
@then('we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
def step_confirm_destructive_command(context, confirmation, text): # noqa
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline(confirmation)
wrappers.expect_exact(context, text, timeout=2)
# we must exit the Click loop, or the feature will hang
context.cli.sendline('n')
context.cli.sendline("n")

View file

@ -1,9 +1,7 @@
import io
import os
import shlex
from behave import when, then
import pexpect
import wrappers
from test.features.steps.utils import parse_cli_args_to_dict
@ -12,60 +10,44 @@ from test.utils import HOST, PORT, USER, PASSWORD
from mycli.config import encrypt_mylogin_cnf
TEST_LOGIN_PATH = 'test_login_path'
TEST_LOGIN_PATH = "test_login_path"
@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
@when('we run mycli without arguments "{excluded_args}"')
def step_run_cli_without_args(context, excluded_args, exact_args=''):
wrappers.run_cli(
context,
run_args=parse_cli_args_to_dict(exact_args),
exclude_args=parse_cli_args_to_dict(excluded_args).keys()
)
def step_run_cli_without_args(context, excluded_args, exact_args=""):
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(exact_args), exclude_args=parse_cli_args_to_dict(excluded_args).keys())
@then('status contains "{expression}"')
def status_contains(context, expression):
wrappers.expect_exact(context, f'{expression}', timeout=5)
wrappers.expect_exact(context, f"{expression}", timeout=5)
# Normally, the shutdown after scenario waits for the prompt.
# But we may have changed the prompt, depending on parameters,
# so let's wait for its last character
context.cli.expect_exact('>')
context.cli.expect_exact(">")
context.atprompt = True
@when('we create my.cnf file')
@when("we create my.cnf file")
def step_create_my_cnf_file(context):
my_cnf = (
'[client]\n'
f'host = {HOST}\n'
f'port = {PORT}\n'
f'user = {USER}\n'
f'password = {PASSWORD}\n'
)
with open(MY_CNF_PATH, 'w') as f:
my_cnf = "[client]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
with open(MY_CNF_PATH, "w") as f:
f.write(my_cnf)
@when('we create mylogin.cnf file')
@when("we create mylogin.cnf file")
def step_create_mylogin_cnf_file(context):
os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
mylogin_cnf = (
f'[{TEST_LOGIN_PATH}]\n'
f'host = {HOST}\n'
f'port = {PORT}\n'
f'user = {USER}\n'
f'password = {PASSWORD}\n'
)
with open(MYLOGIN_CNF_PATH, 'wb') as f:
os.environ.pop("MYSQL_TEST_LOGIN_FILE", None)
mylogin_cnf = f"[{TEST_LOGIN_PATH}]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
with open(MYLOGIN_CNF_PATH, "wb") as f:
input_file = io.StringIO(mylogin_cnf)
f.write(encrypt_mylogin_cnf(input_file).read())
@then('we are logged in')
@then("we are logged in")
def we_are_logged_in(context):
db_name = get_db_name_from_context(context)
context.cli.expect_exact(f'{db_name}>', timeout=5)
context.cli.expect_exact(f"{db_name}>", timeout=5)
context.atprompt = True

View file

@ -11,105 +11,99 @@ import wrappers
from behave import when, then
@when('we create database')
@when("we create database")
def step_db_create(context):
"""Send create database."""
context.cli.sendline('create database {0};'.format(
context.conf['dbname_tmp']))
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
context.response = {
'database_name': context.conf['dbname_tmp']
}
context.response = {"database_name": context.conf["dbname_tmp"]}
@when('we drop database')
@when("we drop database")
def step_db_drop(context):
"""Send drop database."""
context.cli.sendline('drop database {0};'.format(
context.conf['dbname_tmp']))
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
@when('we connect to test database')
@when("we connect to test database")
def step_db_connect_test(context):
"""Send connect to database."""
db_name = context.conf['dbname']
db_name = context.conf["dbname"]
context.currentdb = db_name
context.cli.sendline('use {0};'.format(db_name))
context.cli.sendline("use {0};".format(db_name))
@when('we connect to quoted test database')
@when("we connect to quoted test database")
def step_db_connect_quoted_tmp(context):
"""Send connect to database."""
db_name = context.conf['dbname']
db_name = context.conf["dbname"]
context.currentdb = db_name
context.cli.sendline('use `{0}`;'.format(db_name))
context.cli.sendline("use `{0}`;".format(db_name))
@when('we connect to tmp database')
@when("we connect to tmp database")
def step_db_connect_tmp(context):
"""Send connect to database."""
db_name = context.conf['dbname_tmp']
db_name = context.conf["dbname_tmp"]
context.currentdb = db_name
context.cli.sendline('use {0}'.format(db_name))
context.cli.sendline("use {0}".format(db_name))
@when('we connect to dbserver')
@when("we connect to dbserver")
def step_db_connect_dbserver(context):
"""Send connect to database."""
context.currentdb = 'mysql'
context.cli.sendline('use mysql')
context.currentdb = "mysql"
context.cli.sendline("use mysql")
@then('dbcli exits')
@then("dbcli exits")
def step_wait_exit(context):
"""Make sure the cli exits."""
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
@then('we see dbcli prompt')
@then("we see dbcli prompt")
def step_see_prompt(context):
"""Wait to see the prompt."""
user = context.conf['user']
host = context.conf['host']
user = context.conf["user"]
host = context.conf["host"]
dbname = context.currentdb
wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname))
wrappers.wait_prompt(context, "{0}@{1}:{2}> ".format(user, host, dbname))
@then('we see help output')
@then("we see help output")
def step_see_help(context):
for expected_line in context.fixture_data['help_commands.txt']:
for expected_line in context.fixture_data["help_commands.txt"]:
wrappers.expect_exact(context, expected_line, timeout=1)
@then('we see database created')
@then("we see database created")
def step_see_db_created(context):
"""Wait to see create database output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see database dropped')
@then("we see database dropped")
def step_see_db_dropped(context):
"""Wait to see drop database output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
@then('we see database dropped and no default database')
@then("we see database dropped and no default database")
def step_see_db_dropped_no_default(context):
"""Wait to see drop database output."""
user = context.conf['user']
host = context.conf['host']
database = '(none)'
user = context.conf["user"]
host = context.conf["host"]
database = "(none)"
context.currentdb = None
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database))
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
wrappers.wait_prompt(context, "{0}@{1}:{2}>".format(user, host, database))
@then('we see database connected')
@then("we see database connected")
def step_see_db_connected(context):
"""Wait to see drop database output."""
wrappers.expect_exact(
context, 'You are now connected to database "', timeout=2)
wrappers.expect_exact(context, 'You are now connected to database "', timeout=2)
wrappers.expect_exact(context, '"', timeout=2)
wrappers.expect_exact(context, ' as user "{0}"'.format(
context.conf['user']), timeout=2)
wrappers.expect_exact(context, ' as user "{0}"'.format(context.conf["user"]), timeout=2)

View file

@ -10,103 +10,109 @@ from behave import when, then
from textwrap import dedent
@when('we create table')
@when("we create table")
def step_create_table(context):
"""Send create table."""
context.cli.sendline('create table a(x text);')
context.cli.sendline("create table a(x text);")
@when('we insert into table')
@when("we insert into table")
def step_insert_into_table(context):
"""Send insert into table."""
context.cli.sendline('''insert into a(x) values('xxx');''')
context.cli.sendline("""insert into a(x) values('xxx');""")
@when('we update table')
@when("we update table")
def step_update_table(context):
"""Send insert into table."""
context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""")
@when('we select from table')
@when("we select from table")
def step_select_from_table(context):
"""Send select from table."""
context.cli.sendline('select * from a;')
context.cli.sendline("select * from a;")
@when('we delete from table')
@when("we delete from table")
def step_delete_from_table(context):
"""Send deete from table."""
context.cli.sendline('''delete from a where x = 'yyy';''')
context.cli.sendline("""delete from a where x = 'yyy';""")
@when('we drop table')
@when("we drop table")
def step_drop_table(context):
"""Send drop table."""
context.cli.sendline('drop table a;')
context.cli.sendline("drop table a;")
@then('we see table created')
@then("we see table created")
def step_see_table_created(context):
"""Wait to see create table output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
@then('we see record inserted')
@then("we see record inserted")
def step_see_record_inserted(context):
"""Wait to see insert output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see record updated')
@then("we see record updated")
def step_see_record_updated(context):
"""Wait to see update output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see data selected')
@then("we see data selected")
def step_see_data_selected(context):
"""Wait to see select output."""
wrappers.expect_pager(
context, dedent("""\
context,
dedent("""\
+-----+\r
| x |\r
+-----+\r
| yyy |\r
+-----+\r
\r
"""), timeout=2)
wrappers.expect_exact(context, '1 row in set', timeout=2)
"""),
timeout=2,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)
@then('we see record deleted')
@then("we see record deleted")
def step_see_data_deleted(context):
"""Wait to see delete output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see table dropped')
@then("we see table dropped")
def step_see_table_dropped(context):
"""Wait to see drop output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
@when('we select null')
@when("we select null")
def step_select_null(context):
"""Send select null."""
context.cli.sendline('select null;')
context.cli.sendline("select null;")
@then('we see null selected')
@then("we see null selected")
def step_see_null_selected(context):
"""Wait to see null output."""
wrappers.expect_pager(
context, dedent("""\
context,
dedent("""\
+--------+\r
| NULL |\r
+--------+\r
| <null> |\r
+--------+\r
\r
"""), timeout=2)
wrappers.expect_exact(context, '1 row in set', timeout=2)
"""),
timeout=2,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)

View file

@ -5,101 +5,93 @@ from behave import when, then
from textwrap import dedent
@when('we start external editor providing a file name')
@when("we start external editor providing a file name")
def step_edit_file(context):
"""Edit file with external editor."""
context.editor_file_name = os.path.join(
context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
context.editor_file_name = os.path.join(context.package_root, "test_file_{0}.sql".format(context.conf["vi"]))
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline('\e {0}'.format(
os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
wrappers.expect_exact(context, '\r\n:', timeout=2)
context.cli.sendline("\\e {0}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
wrappers.expect_exact(context, "\r\n:", timeout=2)
@when('we type "{query}" in the editor')
def step_edit_type_sql(context, query):
context.cli.sendline('i')
context.cli.sendline("i")
context.cli.sendline(query)
context.cli.sendline('.')
wrappers.expect_exact(context, '\r\n:', timeout=2)
context.cli.sendline(".")
wrappers.expect_exact(context, "\r\n:", timeout=2)
@when('we exit the editor')
@when("we exit the editor")
def step_edit_quit(context):
context.cli.sendline('x')
context.cli.sendline("x")
wrappers.expect_exact(context, "written", timeout=2)
@then('we see "{query}" in prompt')
def step_edit_done_sql(context, query):
for match in query.split(' '):
for match in query.split(" "):
wrappers.expect_exact(context, match, timeout=5)
# Cleanup the command line.
context.cli.sendcontrol('c')
context.cli.sendcontrol("c")
# Cleanup the edited file.
if context.editor_file_name and os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
@when(u'we tee output')
@when("we tee output")
def step_tee_ouptut(context):
context.tee_file_name = os.path.join(
context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
context.tee_file_name = os.path.join(context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]))
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline('tee {0}'.format(
os.path.basename(context.tee_file_name)))
context.cli.sendline("tee {0}".format(os.path.basename(context.tee_file_name)))
@when(u'we select "select {param}"')
@when('we select "select {param}"')
def step_query_select_number(context, param):
context.cli.sendline(u'select {}'.format(param))
wrappers.expect_pager(context, dedent(u"""\
context.cli.sendline("select {}".format(param))
wrappers.expect_pager(
context,
dedent(
"""\
+{dashes}+\r
| {param} |\r
+{dashes}+\r
| {param} |\r
+{dashes}+\r
\r
""".format(param=param, dashes='-' * (len(param) + 2))
), timeout=5)
wrappers.expect_exact(context, '1 row in set', timeout=2)
@then(u'we see result "{result}"')
def step_see_result(context, result):
wrappers.expect_exact(
context,
u"| {} |".format(result),
timeout=2
""".format(param=param, dashes="-" * (len(param) + 2))
),
timeout=5,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)
@when(u'we query "{query}"')
@then('we see result "{result}"')
def step_see_result(context, result):
wrappers.expect_exact(context, "| {} |".format(result), timeout=2)
@when('we query "{query}"')
def step_query(context, query):
context.cli.sendline(query)
@when(u'we notee output')
@when("we notee output")
def step_notee_output(context):
context.cli.sendline('notee')
context.cli.sendline("notee")
@then(u'we see 123456 in tee output')
@then("we see 123456 in tee output")
def step_see_123456_in_ouput(context):
with open(context.tee_file_name) as f:
assert '123456' in f.read()
assert "123456" in f.read()
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
@then(u'delimiter is set to "{delimiter}"')
@then('delimiter is set to "{delimiter}"')
def delimiter_is_set(context, delimiter):
wrappers.expect_exact(
context,
u'Changed delimiter to {}'.format(delimiter),
timeout=2
)
wrappers.expect_exact(context, "Changed delimiter to {}".format(delimiter), timeout=2)

View file

@ -9,82 +9,79 @@ import wrappers
from behave import when, then
@when('we save a named query')
@when("we save a named query")
def step_save_named_query(context):
"""Send \fs command."""
context.cli.sendline('\\fs foo SELECT 12345')
context.cli.sendline("\\fs foo SELECT 12345")
@when('we use a named query')
@when("we use a named query")
def step_use_named_query(context):
"""Send \f command."""
context.cli.sendline('\\f foo')
context.cli.sendline("\\f foo")
@when('we delete a named query')
@when("we delete a named query")
def step_delete_named_query(context):
"""Send \fd command."""
context.cli.sendline('\\fd foo')
context.cli.sendline("\\fd foo")
@then('we see the named query saved')
@then("we see the named query saved")
def step_see_named_query_saved(context):
"""Wait to see query saved."""
wrappers.expect_exact(context, 'Saved.', timeout=2)
wrappers.expect_exact(context, "Saved.", timeout=2)
@then('we see the named query executed')
@then("we see the named query executed")
def step_see_named_query_executed(context):
"""Wait to see select output."""
wrappers.expect_exact(context, 'SELECT 12345', timeout=2)
wrappers.expect_exact(context, "SELECT 12345", timeout=2)
@then('we see the named query deleted')
@then("we see the named query deleted")
def step_see_named_query_deleted(context):
"""Wait to see query deleted."""
wrappers.expect_exact(context, 'foo: Deleted', timeout=2)
wrappers.expect_exact(context, "foo: Deleted", timeout=2)
@when('we save a named query with parameters')
@when("we save a named query with parameters")
def step_save_named_query_with_parameters(context):
"""Send \fs command for query with parameters."""
context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"')
@when('we use named query with parameters')
@when("we use named query with parameters")
def step_use_named_query_with_parameters(context):
"""Send \f command with parameters."""
context.cli.sendline('\\f foo_args 101 second "third value"')
@then('we see the named query with parameters executed')
@then("we see the named query with parameters executed")
def step_see_named_query_with_parameters_executed(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'SELECT 101, "second", "third value"', timeout=2)
wrappers.expect_exact(context, 'SELECT 101, "second", "third value"', timeout=2)
@when('we use named query with too few parameters')
@when("we use named query with too few parameters")
def step_use_named_query_with_too_few_parameters(context):
"""Send \f command with missing parameters."""
context.cli.sendline('\\f foo_args 101')
context.cli.sendline("\\f foo_args 101")
@then('we see the named query with parameters fail with missing parameters')
@then("we see the named query with parameters fail with missing parameters")
def step_see_named_query_with_parameters_fail_with_missing_parameters(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'missing substitution for $2 in query:', timeout=2)
wrappers.expect_exact(context, "missing substitution for $2 in query:", timeout=2)
@when('we use named query with too many parameters')
@when("we use named query with too many parameters")
def step_use_named_query_with_too_many_parameters(context):
"""Send \f command with extra parameters."""
context.cli.sendline('\\f foo_args 101 102 103 104')
context.cli.sendline("\\f foo_args 101 102 103 104")
@then('we see the named query with parameters fail with extra parameters')
@then("we see the named query with parameters fail with extra parameters")
def step_see_named_query_with_parameters_fail_with_extra_parameters(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'query does not have substitution parameter $4:', timeout=2)
wrappers.expect_exact(context, "query does not have substitution parameter $4:", timeout=2)

View file

@ -9,10 +9,10 @@ import wrappers
from behave import when, then
@when('we refresh completions')
@when("we refresh completions")
def step_refresh_completions(context):
"""Send refresh command."""
context.cli.sendline('rehash')
context.cli.sendline("rehash")
@then('we see text "{text}"')
@ -20,8 +20,8 @@ def step_see_text(context, text):
"""Wait to see given text message."""
wrappers.expect_exact(context, text, timeout=2)
@then('we see completions refresh started')
@then("we see completions refresh started")
def step_see_refresh_started(context):
"""Wait to see refresh output."""
wrappers.expect_exact(
context, 'Auto-completion refresh started in the background.', timeout=2)
wrappers.expect_exact(context, "Auto-completion refresh started in the background.", timeout=2)

View file

@ -4,8 +4,8 @@ import shlex
def parse_cli_args_to_dict(cli_args: str):
args_dict = {}
for arg in shlex.split(cli_args):
if '=' in arg:
key, value = arg.split('=')
if "=" in arg:
key, value = arg.split("=")
args_dict[key] = value
else:
args_dict[arg] = None

View file

@ -18,10 +18,9 @@ def expect_exact(context, expected, timeout):
timedout = True
if timedout:
# Strip color codes out of the output.
actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
'', context.cli.before)
actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
raise Exception(
textwrap.dedent('''\
textwrap.dedent("""\
Expected:
---
{0!r}
@ -34,17 +33,12 @@ def expect_exact(context, expected, timeout):
---
{2!r}
---
''').format(
expected,
actual,
context.logfile.getvalue()
)
""").format(expected, actual, context.logfile.getvalue())
)
def expect_pager(context, expected, timeout):
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
context.conf['pager_boundary'], expected), timeout=timeout)
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), timeout=timeout)
def run_cli(context, run_args=None, exclude_args=None):
@ -63,55 +57,49 @@ def run_cli(context, run_args=None, exclude_args=None):
else:
rendered_args.append(key)
if conf.get('host', None):
add_arg('host', '-h', conf['host'])
if conf.get('user', None):
add_arg('user', '-u', conf['user'])
if conf.get('pass', None):
add_arg('pass', '-p', conf['pass'])
if conf.get('port', None):
add_arg('port', '-P', str(conf['port']))
if conf.get('dbname', None):
add_arg('dbname', '-D', conf['dbname'])
if conf.get('defaults-file', None):
add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
if conf.get('myclirc', None):
add_arg('myclirc', '--myclirc', conf['myclirc'])
if conf.get('login_path'):
add_arg('login_path', '--login-path', conf['login_path'])
if conf.get("host", None):
add_arg("host", "-h", conf["host"])
if conf.get("user", None):
add_arg("user", "-u", conf["user"])
if conf.get("pass", None):
add_arg("pass", "-p", conf["pass"])
if conf.get("port", None):
add_arg("port", "-P", str(conf["port"]))
if conf.get("dbname", None):
add_arg("dbname", "-D", conf["dbname"])
if conf.get("defaults-file", None):
add_arg("defaults_file", "--defaults-file", conf["defaults-file"])
if conf.get("myclirc", None):
add_arg("myclirc", "--myclirc", conf["myclirc"])
if conf.get("login_path"):
add_arg("login_path", "--login-path", conf["login_path"])
for arg_name, arg_value in conf.items():
if arg_name.startswith('-'):
if arg_name.startswith("-"):
add_arg(arg_name, arg_name, arg_value)
try:
cli_cmd = context.conf['cli_command']
cli_cmd = context.conf["cli_command"]
except KeyError:
cli_cmd = (
'{0!s} -c "'
'import coverage ; '
'coverage.process_startup(); '
'import mycli.main; '
'mycli.main.cli()'
'"'
).format(sys.executable)
cli_cmd = ('{0!s} -c "' "import coverage ; " "coverage.process_startup(); " "import mycli.main; " "mycli.main.cli()" '"').format(
sys.executable
)
cmd_parts = [cli_cmd] + rendered_args
cmd = ' '.join(cmd_parts)
cmd = " ".join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = context.conf['dbname']
context.currentdb = context.conf["dbname"]
def wait_prompt(context, prompt=None):
"""Make sure prompt is displayed."""
if prompt is None:
user = context.conf['user']
host = context.conf['host']
user = context.conf["user"]
host = context.conf["host"]
dbname = context.currentdb
prompt = '{0}@{1}:{2}>'.format(
user, host, dbname),
prompt = ("{0}@{1}:{2}>".format(user, host, dbname),)
expect_exact(context, prompt, timeout=5)
context.atprompt = True

View file

@ -153,6 +153,7 @@ output.null = "#808080"
# Favorite queries.
[favorite_queries]
check = 'select "✔"'
foo_args = 'SELECT $1, "$2", "$3"'
# Use the -d option to reference a DSN.
# Special characters in passwords and other strings can be escaped with URL encoding.

View file

@ -1,4 +1,5 @@
"""Test the mycli.clistyle module."""
import pytest
from pygments.style import Style
@ -10,9 +11,9 @@ from mycli.clistyle import style_factory
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory():
"""Test that a Pygments Style class is created."""
header = 'bold underline #ansired'
cli_style = {'Token.Output.Header': header}
style = style_factory('default', cli_style)
header = "bold underline #ansired"
cli_style = {"Token.Output.Header": header}
style = style_factory("default", cli_style)
assert isinstance(style(), Style)
assert Token.Output.Header in style.styles
@ -22,6 +23,6 @@ def test_style_factory():
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory_unknown_name():
"""Test that an unrecognized name will not throw an error."""
style = style_factory('foobar', {})
style = style_factory("foobar", {})
assert isinstance(style(), Style)

View file

@ -8,494 +8,528 @@ def sorted_dicts(dicts):
def test_select_suggests_cols_with_visible_table_scope():
suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_select_suggests_cols_with_qualified_table_scope():
suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [('sch', 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [("sch", "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE (',
'SELECT * FROM tabl WHERE foo = ',
'SELECT * FROM tabl WHERE bar OR ',
'SELECT * FROM tabl WHERE foo = 1 AND ',
'SELECT * FROM tabl WHERE (bar > 10 AND ',
'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (',
'SELECT * FROM tabl WHERE 10 < ',
'SELECT * FROM tabl WHERE foo BETWEEN ',
'SELECT * FROM tabl WHERE foo BETWEEN foo AND ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM tabl WHERE ",
"SELECT * FROM tabl WHERE (",
"SELECT * FROM tabl WHERE foo = ",
"SELECT * FROM tabl WHERE bar OR ",
"SELECT * FROM tabl WHERE foo = 1 AND ",
"SELECT * FROM tabl WHERE (bar > 10 AND ",
"SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
"SELECT * FROM tabl WHERE 10 < ",
"SELECT * FROM tabl WHERE foo BETWEEN ",
"SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
],
)
def test_where_suggests_columns_functions(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE foo IN (',
'SELECT * FROM tabl WHERE foo IN (bar, ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM tabl WHERE foo IN (",
"SELECT * FROM tabl WHERE foo IN (bar, ",
],
)
def test_where_in_suggests_columns(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_where_equals_any_suggests_columns_or_keywords():
text = 'SELECT * FROM tabl WHERE foo = ANY('
text = "SELECT * FROM tabl WHERE foo = ANY("
suggestions = suggest_type(text, text)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_lparen_suggests_cols():
suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_operand_inside_function_suggests_cols1():
suggestion = suggest_type(
'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ")
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_operand_inside_function_suggests_cols2():
suggestion = suggest_type(
'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
suggestion = suggest_type("SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + ")
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_select_suggests_cols_and_funcs():
suggestions = suggest_type('SELECT ', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': []},
{'type': 'column', 'tables': []},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT ", "SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": []},
{"type": "column", "tables": []},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM ',
'INSERT INTO ',
'COPY ',
'UPDATE ',
'DESCRIBE ',
'DESC ',
'EXPLAIN ',
'SELECT * FROM foo JOIN ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM ",
"INSERT INTO ",
"COPY ",
"UPDATE ",
"DESCRIBE ",
"DESC ",
"EXPLAIN ",
"SELECT * FROM foo JOIN ",
],
)
def test_expression_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
@pytest.mark.parametrize('expression', [
'SELECT * FROM sch.',
'INSERT INTO sch.',
'COPY sch.',
'UPDATE sch.',
'DESCRIBE sch.',
'DESC sch.',
'EXPLAIN sch.',
'SELECT * FROM foo JOIN sch.',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM sch.",
"INSERT INTO sch.",
"COPY sch.",
"UPDATE sch.",
"DESCRIBE sch.",
"DESC sch.",
"EXPLAIN sch.",
"SELECT * FROM foo JOIN sch.",
],
)
def test_expression_suggests_qualified_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 'sch'},
{'type': 'view', 'schema': 'sch'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}])
def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "schema"}])
def test_truncate_suggests_qualified_tables():
suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 'sch'}])
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}])
def test_distinct_suggests_cols():
suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
assert suggestions == [{'type': 'column', 'tables': []}]
suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ")
assert suggestions == [{"type": "column", "tables": []}]
def test_col_comma_suggests_cols():
suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tbl']},
{'type': 'column', 'tables': [(None, 'tbl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tbl"]},
{"type": "column", "tables": [(None, "tbl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type('SELECT a, b FROM tbl1, ',
'SELECT a, b FROM tbl1, ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_insert_into_lparen_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (")
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_insert_into_lparen_partial_text_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_insert_into_lparen_comma_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_partially_typed_col_name_suggests_col_names():
suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
'SELECT * FROM tabl WHERE col_n')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'table', 'schema': 'tabl'},
{'type': 'view', 'schema': 'tabl'},
{'type': 'function', 'schema': 'tabl'}])
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "table", "schema": "tabl"},
{"type": "view", "schema": "tabl"},
{"type": "function", "schema": "tabl"},
]
)
def test_dot_suggests_cols_of_an_alias():
suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
'SELECT t1.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 't1'},
{'type': 'view', 'schema': 't1'},
{'type': 'column', 'tables': [(None, 'tabl1', 't1')]},
{'type': 'function', 'schema': 't1'}])
suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "table", "schema": "t1"},
{"type": "view", "schema": "t1"},
{"type": "column", "tables": [(None, "tabl1", "t1")]},
{"type": "function", "schema": "t1"},
]
)
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
'SELECT t1.a, t2.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
{'type': 'table', 'schema': 't2'},
{'type': 'view', 'schema': 't2'},
{'type': 'function', 'schema': 't2'}])
suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "tabl2", "t2")]},
{"type": "table", "schema": "t2"},
{"type": "view", "schema": "t2"},
{"type": "function", "schema": "t2"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM (',
'SELECT * FROM foo WHERE EXISTS (',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (',
'SELECT 1 AS',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (",
"SELECT * FROM foo WHERE EXISTS (",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (",
"SELECT 1 AS",
],
)
def test_sub_select_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == [{'type': 'keyword'}]
assert suggestion == [{"type": "keyword"}]
@pytest.mark.parametrize('expression', [
'SELECT * FROM (S',
'SELECT * FROM foo WHERE EXISTS (S',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (S',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (S",
"SELECT * FROM foo WHERE EXISTS (S",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
],
)
def test_sub_select_partial_text_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == [{'type': 'keyword'}]
assert suggestion == [{"type": "keyword"}]
def test_outer_table_reference_in_exists_subquery_suggests_columns():
q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.'
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
suggestions = suggest_type(q, q)
assert suggestions == [
{'type': 'column', 'tables': [(None, 'foo', 'f')]},
{'type': 'table', 'schema': 'f'},
{'type': 'view', 'schema': 'f'},
{'type': 'function', 'schema': 'f'}]
{"type": "column", "tables": [(None, "foo", "f")]},
{"type": "table", "schema": "f"},
{"type": "view", "schema": "f"},
{"type": "function", "schema": "f"},
]
@pytest.mark.parametrize('expression', [
'SELECT * FROM (SELECT * FROM ',
'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (SELECT * FROM ",
"SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
],
)
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_sub_select_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
'SELECT * FROM (SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['abc']},
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["abc"]},
{"type": "column", "tables": [(None, "abc", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.xfail
def test_sub_select_multiple_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
'SELECT * FROM (SELECT a, ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function', 'schema': []}])
suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ")
assert sorted_dicts(suggestions) == sorted_dicts(
[{"type": "column", "tables": [(None, "abc", None)]}, {"type": "function", "schema": []}]
)
def test_sub_select_dot_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
'SELECT * FROM (SELECT t.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', 't')]},
{'type': 'table', 'schema': 't'},
{'type': 'view', 'schema': 't'},
{'type': 'function', 'schema': 't'}])
suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "tabl", "t")]},
{"type": "table", "schema": "t"},
{"type": "view", "schema": "t"},
{"type": "function", "schema": "t"},
]
)
@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
@pytest.mark.parametrize('tbl_alias', ['', 'foo'])
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
@pytest.mark.parametrize("tbl_alias", ["", "foo"])
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
suggestion = suggest_type(text, text)
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.',
])
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
],
)
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', 'a')]},
{'type': 'table', 'schema': 'a'},
{'type': 'view', 'schema': 'a'},
{'type': 'function', 'schema': 'a'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "abc", "a")]},
{"type": "table", "schema": "a"},
{"type": "view", "schema": "a"},
{"type": "function", "schema": "a"},
]
)
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.id = d.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.',
])
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.id = d.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
],
)
def test_join_alias_dot_suggests_cols2(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'def', 'd')]},
{'type': 'table', 'schema': 'd'},
{'type': 'view', 'schema': 'd'},
{'type': 'function', 'schema': 'd'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "def", "d")]},
{"type": "table", "schema": "d"},
{"type": "view", "schema": "d"},
{"type": "function", "schema": "d"},
]
)
@pytest.mark.parametrize('sql', [
'select a.x, b.y from abc a join bcd b on ',
'select a.x, b.y from abc a join bcd b on a.id = b.id OR ',
])
@pytest.mark.parametrize(
"sql",
[
"select a.x, b.y from abc a join bcd b on ",
"select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
],
)
def test_on_suggests_aliases(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
@pytest.mark.parametrize('sql', [
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ',
])
@pytest.mark.parametrize(
"sql",
[
"select abc.x, bcd.y from abc join bcd on ",
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
],
)
def test_on_suggests_tables(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
@pytest.mark.parametrize('sql', [
'select a.x, b.y from abc a join bcd b on a.id = ',
'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ',
])
@pytest.mark.parametrize(
"sql",
[
"select a.x, b.y from abc a join bcd b on a.id = ",
"select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
],
)
def test_on_suggests_aliases_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
@pytest.mark.parametrize('sql', [
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ',
])
@pytest.mark.parametrize(
"sql",
[
"select abc.x, bcd.y from abc join bcd on ",
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
],
)
def test_on_suggests_tables_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
@pytest.mark.parametrize('col_list', ['', 'col1, '])
@pytest.mark.parametrize("col_list", ["", "col1, "])
def test_join_using_suggests_common_columns(col_list):
text = 'select * from abc inner join def using (' + col_list
assert suggest_type(text, text) == [
{'type': 'column',
'tables': [(None, 'abc', None), (None, 'def', None)],
'drop_unique': True}]
text = "select * from abc inner join def using (" + col_list
assert suggest_type(text, text) == [{"type": "column", "tables": [(None, "abc", None), (None, "def", None)], "drop_unique": True}]
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.',
])
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.",
],
)
def test_two_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'ghi', 'g')]},
{'type': 'table', 'schema': 'g'},
{'type': 'view', 'schema': 'g'},
{'type': 'function', 'schema': 'g'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "ghi", "g")]},
{"type": "table", "schema": "g"},
{"type": "view", "schema": "g"},
{"type": "function", "schema": "g"},
]
)
def test_2_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ',
'select * from a; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
suggestions = suggest_type('select * from a; select from b',
'select * from a; select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['b']},
{'type': 'column', 'tables': [(None, 'b', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("select * from a; select from b", "select * from a; select ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["b"]},
{"type": "column", "tables": [(None, "b", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
# Should work even if first statement is invalid
suggestions = suggest_type('select * from; select * from ',
'select * from; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from; select * from ", "select * from; select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_2_statements_1st_current():
suggestions = suggest_type('select * from ; select * from b',
'select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from ; select * from b", "select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
suggestions = suggest_type('select from a; select * from b',
'select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['a']},
{'type': 'column', 'tables': [(None, 'a', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("select from a; select * from b", "select ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["a"]},
{"type": "column", "tables": [(None, "a", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_3_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ; select * from c',
'select * from a; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from a; select * from ; select * from c", "select * from a; select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
suggestions = suggest_type('select * from a; select from b; select * from c',
'select * from a; select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['b']},
{'type': 'column', 'tables': [(None, 'b', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["b"]},
{"type": "column", "tables": [(None, "b", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_create_db_with_template():
suggestions = suggest_type('create database foo with template ',
'create database foo with template ')
suggestions = suggest_type("create database foo with template ", "create database foo with template ")
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t'])
@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"])
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
assert sorted_dicts(suggestions) == \
sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}])
def test_specials_not_included_after_initial_token():
suggestions = suggest_type('create table foo (dt d',
'create table foo (dt d')
suggestions = suggest_type("create table foo (dt d", "create table foo (dt d")
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}])
def test_drop_schema_qualified_table_suggests_only_tables():
text = 'DROP TABLE schema_name.table_name'
text = "DROP TABLE schema_name.table_name"
suggestions = suggest_type(text, text)
assert suggestions == [{'type': 'table', 'schema': 'schema_name'}]
assert suggestions == [{"type": "table", "schema": "schema_name"}]
@pytest.mark.parametrize('text', [',', ' ,', 'sel ,'])
@pytest.mark.parametrize("text", [",", " ,", "sel ,"])
def test_handle_pre_completion_comma_gracefully(text):
suggestions = suggest_type(text, text)
@ -503,53 +537,59 @@ def test_handle_pre_completion_comma_gracefully(text):
def test_cross_join():
text = 'select * from v1 cross join v2 JOIN v1.id, '
text = "select * from v1 cross join v2 JOIN v1.id, "
suggestions = suggest_type(text, text)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
@pytest.mark.parametrize('expression', [
'SELECT 1 AS ',
'SELECT 1 FROM tabl AS ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT 1 AS ",
"SELECT 1 FROM tabl AS ",
],
)
def test_after_as(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set()
@pytest.mark.parametrize('expression', [
'\\. ',
'select 1; \\. ',
'select 1;\\. ',
'select 1 ; \\. ',
'source ',
'truncate table test; source ',
'truncate table test ; source ',
'truncate table test;source ',
])
@pytest.mark.parametrize(
"expression",
[
"\\. ",
"select 1; \\. ",
"select 1;\\. ",
"select 1 ; \\. ",
"source ",
"truncate table test; source ",
"truncate table test ; source ",
"truncate table test;source ",
],
)
def test_source_is_file(expression):
suggestions = suggest_type(expression, expression)
assert suggestions == [{'type': 'file_name'}]
assert suggestions == [{"type": "file_name"}]
@pytest.mark.parametrize("expression", [
"\\f ",
])
@pytest.mark.parametrize(
"expression",
[
"\\f ",
],
)
def test_favorite_name_suggestion(expression):
suggestions = suggest_type(expression, expression)
assert suggestions == [{'type': 'favoritequery'}]
assert suggestions == [{"type": "favoritequery"}]
def test_order_by():
text = 'select * from foo order by '
text = "select * from foo order by "
suggestions = suggest_type(text, text)
assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}]
assert suggestions == [{"tables": [(None, "foo", None)], "type": "column"}]
def test_quoted_where():
text = "'where i=';"
suggestions = suggest_type(text, text)
assert suggestions == [{'type': 'keyword'}]
assert suggestions == [{"type": "keyword"}]

View file

@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
@pytest.fixture
def refresher():
from mycli.completion_refresher import CompletionRefresher
return CompletionRefresher()
@ -18,8 +19,7 @@ def test_ctor(refresher):
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions',
'special_commands', 'show_commands', 'keywords']
expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"]
assert expected_handlers == actual_handlers
@ -32,12 +32,12 @@ def test_refresh_called_once(refresher):
callbacks = Mock()
sqlexecute = Mock()
with patch.object(refresher, '_bg_refresh') as bg_refresh:
with patch.object(refresher, "_bg_refresh") as bg_refresh:
actual = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual) == 1
assert len(actual[0]) == 4
assert actual[0][3] == 'Auto-completion refresh started in the background.'
assert actual[0][3] == "Auto-completion refresh started in the background."
bg_refresh.assert_called_with(sqlexecute, callbacks, {})
@ -61,13 +61,13 @@ def test_refresh_called_twice(refresher):
time.sleep(1) # Wait for the thread to work.
assert len(actual1) == 1
assert len(actual1[0]) == 4
assert actual1[0][3] == 'Auto-completion refresh started in the background.'
assert actual1[0][3] == "Auto-completion refresh started in the background."
actual2 = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual2) == 1
assert len(actual2[0]) == 4
assert actual2[0][3] == 'Auto-completion refresh restarted.'
assert actual2[0][3] == "Auto-completion refresh restarted."
def test_refresh_with_callbacks(refresher):
@ -80,9 +80,9 @@ def test_refresh_with_callbacks(refresher):
sqlexecute_class = Mock()
sqlexecute = Mock()
with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class):
with patch("mycli.completion_refresher.SQLExecute", sqlexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert (callbacks[0].call_count == 1)
assert callbacks[0].call_count == 1

View file

@ -1,4 +1,5 @@
"""Unit tests for the mycli.config module."""
from io import BytesIO, StringIO, TextIOWrapper
import os
import struct
@ -6,21 +7,26 @@ import sys
import tempfile
import pytest
from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf,
read_and_decrypt_mylogin_cnf, read_config_file,
str_to_bool, strip_matching_quotes)
from mycli.config import (
get_mylogin_cnf_path,
open_mylogin_cnf,
read_and_decrypt_mylogin_cnf,
read_config_file,
str_to_bool,
strip_matching_quotes,
)
LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__),
'mylogin.cnf'))
LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), "mylogin.cnf"))
def open_bmylogin_cnf(name):
"""Open contents of *name* in a BytesIO buffer."""
with open(name, 'rb') as f:
with open(name, "rb") as f:
buf = BytesIO()
buf.write(f.read())
return buf
def test_read_mylogin_cnf():
"""Tests that a login path file can be read and decrypted."""
mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE)
@ -28,7 +34,7 @@ def test_read_mylogin_cnf():
assert isinstance(mylogin_cnf, TextIOWrapper)
contents = mylogin_cnf.read()
for word in ('[test]', 'user', 'password', 'host', 'port'):
for word in ("[test]", "user", "password", "host", "port"):
assert word in contents
@ -46,7 +52,7 @@ def test_corrupted_login_key():
buf.seek(4)
# Write null bytes over half the login key
buf.write(b'\0\0\0\0\0\0\0\0\0\0')
buf.write(b"\0\0\0\0\0\0\0\0\0\0")
buf.seek(0)
mylogin_cnf = read_and_decrypt_mylogin_cnf(buf)
@ -63,58 +69,58 @@ def test_corrupted_pad():
# Skip option group
len_buf = buf.read(4)
cipher_len, = struct.unpack("<i", len_buf)
(cipher_len,) = struct.unpack("<i", len_buf)
buf.read(cipher_len)
# Corrupt the pad for the user line
len_buf = buf.read(4)
cipher_len, = struct.unpack("<i", len_buf)
(cipher_len,) = struct.unpack("<i", len_buf)
buf.read(cipher_len - 1)
buf.write(b'\0')
buf.write(b"\0")
buf.seek(0)
mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf))
contents = mylogin_cnf.read()
for word in ('[test]', 'password', 'host', 'port'):
for word in ("[test]", "password", "host", "port"):
assert word in contents
assert 'user' not in contents
assert "user" not in contents
def test_get_mylogin_cnf_path():
"""Tests that the path for .mylogin.cnf is detected."""
original_env = None
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
is_windows = sys.platform == 'win32'
if "MYSQL_TEST_LOGIN_FILE" in os.environ:
original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
is_windows = sys.platform == "win32"
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
if login_cnf_path is not None:
assert login_cnf_path.endswith('.mylogin.cnf')
assert login_cnf_path.endswith(".mylogin.cnf")
if is_windows is True:
assert 'MySQL' in login_cnf_path
assert "MySQL" in login_cnf_path
else:
home_dir = os.path.expanduser('~')
home_dir = os.path.expanduser("~")
assert login_cnf_path.startswith(home_dir)
def test_alternate_get_mylogin_cnf_path():
"""Tests that the alternate path for .mylogin.cnf is detected."""
original_env = None
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
if "MYSQL_TEST_LOGIN_FILE" in os.environ:
original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
_, temp_path = tempfile.mkstemp()
os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path
os.environ["MYSQL_TEST_LOGIN_FILE"] = temp_path
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
assert temp_path == login_cnf_path
@ -124,17 +130,17 @@ def test_str_to_bool():
assert str_to_bool(False) is False
assert str_to_bool(True) is True
assert str_to_bool('False') is False
assert str_to_bool('True') is True
assert str_to_bool('TRUE') is True
assert str_to_bool('1') is True
assert str_to_bool('0') is False
assert str_to_bool('on') is True
assert str_to_bool('off') is False
assert str_to_bool('off') is False
assert str_to_bool("False") is False
assert str_to_bool("True") is True
assert str_to_bool("TRUE") is True
assert str_to_bool("1") is True
assert str_to_bool("0") is False
assert str_to_bool("on") is True
assert str_to_bool("off") is False
assert str_to_bool("off") is False
with pytest.raises(ValueError):
str_to_bool('foo')
str_to_bool("foo")
with pytest.raises(TypeError):
str_to_bool(None)
@ -143,19 +149,19 @@ def test_str_to_bool():
def test_read_config_file_list_values_default():
"""Test that reading a config file uses list_values by default."""
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f)
assert config['main']['weather'] == u"cloudy with a chance of meatballs"
assert config["main"]["weather"] == "cloudy with a chance of meatballs"
def test_read_config_file_list_values_off():
"""Test that you can disable list_values when reading a config file."""
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f, list_values=False)
assert config['main']['weather'] == u"'cloudy with a chance of meatballs'"
assert config["main"]["weather"] == "'cloudy with a chance of meatballs'"
def test_strip_quotes_with_matching_quotes():
@ -177,7 +183,7 @@ def test_strip_quotes_with_unmatching_quotes():
def test_strip_quotes_with_empty_string():
"""Test that an empty string is handled during unquoting."""
assert '' == strip_matching_quotes('')
assert "" == strip_matching_quotes("")
def test_strip_quotes_with_none():

View file

@ -4,39 +4,32 @@ from mycli.packages.special.utils import format_uptime
def test_u_suggests_databases():
suggestions = suggest_type('\\u ', '\\u ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'database'}])
suggestions = suggest_type("\\u ", "\\u ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
def test_describe_table():
suggestions = suggest_type('\\dt', '\\dt ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("\\dt", "\\dt ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_list_or_show_create_tables():
suggestions = suggest_type('\\dt+', '\\dt+ ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("\\dt+", "\\dt+ ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_format_uptime():
seconds = 59
assert '59 sec' == format_uptime(seconds)
assert "59 sec" == format_uptime(seconds)
seconds = 120
assert '2 min 0 sec' == format_uptime(seconds)
assert "2 min 0 sec" == format_uptime(seconds)
seconds = 54890
assert '15 hours 14 min 50 sec' == format_uptime(seconds)
assert "15 hours 14 min 50 sec" == format_uptime(seconds)
seconds = 598244
assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds)
assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds)
seconds = 522600
assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds)
assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds)

View file

@ -13,52 +13,62 @@ from textwrap import dedent
from collections import namedtuple
from tempfile import NamedTemporaryFile
from textwrap import dedent
test_dir = os.path.abspath(os.path.dirname(__file__))
project_dir = os.path.dirname(test_dir)
default_config_file = os.path.join(project_dir, 'test', 'myclirc')
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
default_config_file = os.path.join(project_dir, "test", "myclirc")
login_path_file = os.path.join(test_dir, "mylogin.cnf")
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
'--password', PASSWORD, '--myclirc', default_config_file,
'--defaults-file', default_config_file,
'mycli_test_db']
os.environ["MYSQL_TEST_LOGIN_FILE"] = login_path_file
CLI_ARGS = [
"--user",
USER,
"--host",
HOST,
"--port",
PORT,
"--password",
PASSWORD,
"--myclirc",
default_config_file,
"--defaults-file",
default_config_file,
"mycli_test_db",
]
@dbtest
def test_execute_arg(executor):
run(executor, 'create table test (a text)')
run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql])
assert result.exit_code == 0
assert 'abc' in result.output
assert "abc" in result.output
result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql])
assert result.exit_code == 0
assert 'abc' in result.output
assert "abc" in result.output
expected = 'a\nabc\n'
expected = "a\nabc\n"
assert expected in result.output
@dbtest
def test_execute_arg_with_table(executor):
run(executor, 'create table test (a text)')
run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n'
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"])
expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n"
assert result.exit_code == 0
assert expected in result.output
@ -66,12 +76,12 @@ def test_execute_arg_with_table(executor):
@dbtest
def test_execute_arg_with_csv(executor):
run(executor, 'create table test (a text)')
run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"])
expected = '"a"\n"abc"\n'
assert result.exit_code == 0
@ -80,35 +90,29 @@ def test_execute_arg_with_csv(executor):
@dbtest
def test_batch_mode(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
sql = (
'select count(*) from test;\n'
'select * from test limit 1;'
)
sql = "select count(*) from test;\n" "select * from test limit 1;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
assert result.exit_code == 0
assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
assert "count(*)\n3\na\nabc\n" in "".join(result.output)
@dbtest
def test_batch_mode_table(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
sql = (
'select count(*) from test;\n'
'select * from test limit 1;'
)
sql = "select count(*) from test;\n" "select * from test limit 1;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql)
expected = (dedent("""\
expected = dedent("""\
+----------+
| count(*) |
+----------+
@ -118,7 +122,7 @@ def test_batch_mode_table(executor):
| a |
+-----+
| abc |
+-----+"""))
+-----+""")
assert result.exit_code == 0
assert expected in result.output
@ -126,14 +130,13 @@ def test_batch_mode_table(executor):
@dbtest
def test_batch_mode_csv(executor):
run(executor, '''create table test(a text, b text)''')
run(executor,
'''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
run(executor, """create table test(a text, b text)""")
run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""")
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
@ -150,15 +153,15 @@ def test_help_strings_end_with_periods():
"""Make sure click options have help text that end with a period."""
for param in cli.params:
if isinstance(param, click.core.Option):
assert hasattr(param, 'help')
assert param.help.endswith('.')
assert hasattr(param, "help")
assert param.help.endswith(".")
def test_command_descriptions_end_with_periods():
"""Make sure that mycli commands' descriptions end with a period."""
MyCli()
for _, command in SPECIAL_COMMANDS.items():
assert command[3].endswith('.')
assert command[3].endswith(".")
def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
@ -166,23 +169,23 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
clickoutput = ""
m = MyCli(myclirc=default_config_file)
class TestOutput():
class TestOutput:
def get_size(self):
size = namedtuple('Size', 'rows columns')
size = namedtuple("Size", "rows columns")
size.columns, size.rows = terminal_size
return size
class TestExecute():
host = 'test'
user = 'test'
dbname = 'test'
server_info = ServerInfo.from_version_string('unknown')
class TestExecute:
host = "test"
user = "test"
dbname = "test"
server_info = ServerInfo.from_version_string("unknown")
port = 0
def server_type(self):
return ['test']
return ["test"]
class PromptBuffer():
class PromptBuffer:
output = TestOutput()
m.prompt_app = PromptBuffer()
@ -199,8 +202,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
global clickoutput
clickoutput += s + "\n"
monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
monkeypatch.setattr(click, 'secho', secho)
monkeypatch.setattr(click, "echo_via_pager", echo_via_pager)
monkeypatch.setattr(click, "secho", secho)
m.output(testdata)
if clickoutput.endswith("\n"):
clickoutput = clickoutput[:-1]
@ -208,59 +211,29 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
def test_conditional_pager(monkeypatch):
testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
" ")
testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(" ")
# User didn't set pager, output doesn't fit screen -> pager
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=False,
expect_pager=True
)
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=True)
# User didn't set pager, output fits screen -> no pager
output(
monkeypatch,
terminal_size=(20, 20),
testdata=testdata,
explicit_pager=False,
expect_pager=False
)
output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=False, expect_pager=False)
# User manually configured pager, output doesn't fit screen -> pager
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=True,
expect_pager=True
)
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=True, expect_pager=True)
# User manually configured pager, output fit screen -> pager
output(
monkeypatch,
terminal_size=(20, 20),
testdata=testdata,
explicit_pager=True,
expect_pager=True
)
output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=True, expect_pager=True)
SPECIAL_COMMANDS['nopager'].handler()
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=False,
expect_pager=False
)
SPECIAL_COMMANDS['pager'].handler('')
SPECIAL_COMMANDS["nopager"].handler()
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=False)
SPECIAL_COMMANDS["pager"].handler("")
def test_reserved_space_is_integer(monkeypatch):
"""Make sure that reserved space is returned as an integer."""
def stub_terminal_size():
return (5, 5)
with monkeypatch.context() as m:
m.setattr(shutil, 'get_terminal_size', stub_terminal_size)
m.setattr(shutil, "get_terminal_size", stub_terminal_size)
mycli = MyCli()
assert isinstance(mycli.get_reserved_space(), int)
@ -268,18 +241,20 @@ def test_reserved_space_is_integer(monkeypatch):
def test_list_dsn():
runner = CliRunner()
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w",delete=False) as myclirc:
myclirc.write(dedent("""\
with NamedTemporaryFile(mode="w", delete=False) as myclirc:
myclirc.write(
dedent("""\
[alias_dsn]
test = mysql://test/test
"""))
""")
)
myclirc.flush()
args = ['--list-dsn', '--myclirc', myclirc.name]
args = ["--list-dsn", "--myclirc", myclirc.name]
result = runner.invoke(cli, args=args)
assert result.output == "test\n"
result = runner.invoke(cli, args=args + ['--verbose'])
result = runner.invoke(cli, args=args + ["--verbose"])
assert result.output == "test : mysql://test/test\n"
# delete=False means we should try to clean up
try:
if os.path.exists(myclirc.name):
@ -287,41 +262,41 @@ def test_list_dsn():
except Exception as e:
print(f"An error occurred while attempting to delete the file: {e}")
def test_prettify_statement():
statement = 'SELECT 1'
statement = "SELECT 1"
m = MyCli()
pretty_statement = m.handle_prettify_binding(statement)
assert pretty_statement == 'SELECT\n 1;'
assert pretty_statement == "SELECT\n 1;"
def test_unprettify_statement():
statement = 'SELECT\n 1'
statement = "SELECT\n 1"
m = MyCli()
unpretty_statement = m.handle_unprettify_binding(statement)
assert unpretty_statement == 'SELECT 1;'
assert unpretty_statement == "SELECT 1;"
def test_list_ssh_config():
runner = CliRunner()
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
ssh_config.write(dedent("""\
with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
ssh_config.write(
dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
"""))
""")
)
ssh_config.flush()
args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
args = ["--list-ssh-config", "--ssh-config-path", ssh_config.name]
result = runner.invoke(cli, args=args)
assert "test\n" in result.output
result = runner.invoke(cli, args=args + ['--verbose'])
result = runner.invoke(cli, args=args + ["--verbose"])
assert "test : test.example.com\n" in result.output
# delete=False means we should try to clean up
try:
if os.path.exists(ssh_config.name):
@ -343,7 +318,7 @@ def test_dsn(monkeypatch):
pass
class MockMyCli:
config = {'alias_dsn': {}}
config = {"alias_dsn": {}}
def __init__(self, **args):
self.logger = Logger()
@ -357,97 +332,109 @@ def test_dsn(monkeypatch):
pass
import mycli.main
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
runner = CliRunner()
# When a user supplies a DSN as database argument to mycli,
# use these values.
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
)
result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "dsn_user" and \
MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
MockMyCli.connect_args["host"] == "dsn_host" and \
MockMyCli.connect_args["port"] == 1 and \
MockMyCli.connect_args["database"] == "dsn_database"
assert (
MockMyCli.connect_args["user"] == "dsn_user"
and MockMyCli.connect_args["passwd"] == "dsn_passwd"
and MockMyCli.connect_args["host"] == "dsn_host"
and MockMyCli.connect_args["port"] == 1
and MockMyCli.connect_args["database"] == "dsn_database"
)
MockMyCli.connect_args = None
# When a use supplies a DSN as database argument to mycli,
# and used command line arguments, use the command line
# arguments.
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
"--user", "arg_user",
"--password", "arg_password",
"--host", "arg_host",
"--port", "3",
"--database", "arg_database",
])
result = runner.invoke(
mycli.main.cli,
args=[
"mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
"--user",
"arg_user",
"--password",
"arg_password",
"--host",
"arg_host",
"--port",
"3",
"--database",
"arg_database",
],
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "arg_user" and \
MockMyCli.connect_args["passwd"] == "arg_password" and \
MockMyCli.connect_args["host"] == "arg_host" and \
MockMyCli.connect_args["port"] == 3 and \
MockMyCli.connect_args["database"] == "arg_database"
assert (
MockMyCli.connect_args["user"] == "arg_user"
and MockMyCli.connect_args["passwd"] == "arg_password"
and MockMyCli.connect_args["host"] == "arg_host"
and MockMyCli.connect_args["port"] == 3
and MockMyCli.connect_args["database"] == "arg_database"
)
MockMyCli.config = {
'alias_dsn': {
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
}
}
MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn),
# use these values.
result = runner.invoke(cli, args=['--dsn', 'test'])
result = runner.invoke(cli, args=["--dsn", "test"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "alias_dsn_user" and \
MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
MockMyCli.connect_args["host"] == "alias_dsn_host" and \
MockMyCli.connect_args["port"] == 4 and \
MockMyCli.connect_args["database"] == "alias_dsn_database"
assert (
MockMyCli.connect_args["user"] == "alias_dsn_user"
and MockMyCli.connect_args["passwd"] == "alias_dsn_passwd"
and MockMyCli.connect_args["host"] == "alias_dsn_host"
and MockMyCli.connect_args["port"] == 4
and MockMyCli.connect_args["database"] == "alias_dsn_database"
)
MockMyCli.config = {
'alias_dsn': {
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
}
}
MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn)
# and used command line arguments, use the command line arguments.
result = runner.invoke(cli, args=[
'--dsn', 'test', '',
"--user", "arg_user",
"--password", "arg_password",
"--host", "arg_host",
"--port", "5",
"--database", "arg_database",
])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "arg_user" and \
MockMyCli.connect_args["passwd"] == "arg_password" and \
MockMyCli.connect_args["host"] == "arg_host" and \
MockMyCli.connect_args["port"] == 5 and \
MockMyCli.connect_args["database"] == "arg_database"
# Use a DSN without password
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user@dsn_host:6/dsn_database"]
result = runner.invoke(
cli,
args=[
"--dsn",
"test",
"",
"--user",
"arg_user",
"--password",
"arg_password",
"--host",
"arg_host",
"--port",
"5",
"--database",
"arg_database",
],
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "dsn_user" and \
MockMyCli.connect_args["passwd"] is None and \
MockMyCli.connect_args["host"] == "dsn_host" and \
MockMyCli.connect_args["port"] == 6 and \
MockMyCli.connect_args["database"] == "dsn_database"
assert (
MockMyCli.connect_args["user"] == "arg_user"
and MockMyCli.connect_args["passwd"] == "arg_password"
and MockMyCli.connect_args["host"] == "arg_host"
and MockMyCli.connect_args["port"] == 5
and MockMyCli.connect_args["database"] == "arg_database"
)
# Use a DSN without password
result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user@dsn_host:6/dsn_database"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["user"] == "dsn_user"
and MockMyCli.connect_args["passwd"] is None
and MockMyCli.connect_args["host"] == "dsn_host"
and MockMyCli.connect_args["port"] == 6
and MockMyCli.connect_args["database"] == "dsn_database"
)
def test_ssh_config(monkeypatch):
@ -463,7 +450,7 @@ def test_ssh_config(monkeypatch):
pass
class MockMyCli:
config = {'alias_dsn': {}}
config = {"alias_dsn": {}}
def __init__(self, **args):
self.logger = Logger()
@ -477,58 +464,62 @@ def test_ssh_config(monkeypatch):
pass
import mycli.main
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
runner = CliRunner()
# Setup temporary configuration
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
ssh_config.write(dedent("""\
with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
ssh_config.write(
dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
"""))
""")
)
ssh_config.flush()
# When a user supplies a ssh config.
result = runner.invoke(mycli.main.cli, args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test"
])
assert result.exit_code == 0, result.output + \
" " + str(result.exception)
assert \
MockMyCli.connect_args["ssh_user"] == "joe" and \
MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
MockMyCli.connect_args["ssh_port"] == 22222 and \
MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser(
"~") + "/.ssh/gateway"
result = runner.invoke(mycli.main.cli, args=["--ssh-config-path", ssh_config.name, "--ssh-config-host", "test"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["ssh_user"] == "joe"
and MockMyCli.connect_args["ssh_host"] == "test.example.com"
and MockMyCli.connect_args["ssh_port"] == 22222
and MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser("~") + "/.ssh/gateway"
)
# When a user supplies a ssh config host as argument to mycli,
# and used command line arguments, use the command line
# arguments.
result = runner.invoke(mycli.main.cli, args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test",
"--ssh-user", "arg_user",
"--ssh-host", "arg_host",
"--ssh-port", "3",
"--ssh-key-filename", "/path/to/key"
])
assert result.exit_code == 0, result.output + \
" " + str(result.exception)
assert \
MockMyCli.connect_args["ssh_user"] == "arg_user" and \
MockMyCli.connect_args["ssh_host"] == "arg_host" and \
MockMyCli.connect_args["ssh_port"] == 3 and \
MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
result = runner.invoke(
mycli.main.cli,
args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test",
"--ssh-user",
"arg_user",
"--ssh-host",
"arg_host",
"--ssh-port",
"3",
"--ssh-key-filename",
"/path/to/key",
],
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["ssh_user"] == "arg_user"
and MockMyCli.connect_args["ssh_host"] == "arg_host"
and MockMyCli.connect_args["ssh_port"] == 3
and MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
)
# delete=False means we should try to clean up
try:
if os.path.exists(ssh_config.name):
@ -542,9 +533,7 @@ def test_init_command_arg(executor):
init_command = "set sql_select_limit=1000"
sql = 'show variables like "sql_select_limit";'
runner = CliRunner()
result = runner.invoke(
cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
)
result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
expected = "sql_select_limit\t1000\n"
assert result.exit_code == 0
@ -553,18 +542,13 @@ def test_init_command_arg(executor):
@dbtest
def test_init_command_multiple_arg(executor):
init_command = 'set sql_select_limit=2000; set max_join_size=20000'
sql = (
'show variables like "sql_select_limit";\n'
'show variables like "max_join_size"'
)
init_command = "set sql_select_limit=2000; set max_join_size=20000"
sql = 'show variables like "sql_select_limit";\n' 'show variables like "max_join_size"'
runner = CliRunner()
result = runner.invoke(
cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
)
result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
expected_sql_select_limit = 'sql_select_limit\t2000\n'
expected_max_join_size = 'max_join_size\t20000\n'
expected_sql_select_limit = "sql_select_limit\t2000\n"
expected_max_join_size = "max_join_size\t20000\n"
assert result.exit_code == 0
assert expected_sql_select_limit in result.output

View file

@ -6,56 +6,48 @@ from prompt_toolkit.document import Document
@pytest.fixture
def completer():
import mycli.sqlcompleter as sqlcompleter
return sqlcompleter.SQLCompleter(smart_completion=False)
@pytest.fixture
def complete_event():
from unittest.mock import Mock
return Mock()
def test_empty_string_completion(completer, complete_event):
text = ''
text = ""
position = 0
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(map(Completion, completer.all_completions))
def test_select_keyword_completion(completer, complete_event):
text = 'SEL'
position = len('SEL')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([Completion(text='SELECT', start_position=-3)])
text = "SEL"
position = len("SEL")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list([Completion(text="SELECT", start_position=-3)])
def test_function_name_completion(completer, complete_event):
text = 'SELECT MA'
position = len('SELECT MA')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
text = "SELECT MA"
position = len("SELECT MA")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert sorted(x.text for x in result) == ["MASTER", "MAX"]
def test_column_name_completion(completer, complete_event):
text = 'SELECT FROM users'
position = len('SELECT ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
text = "SELECT FROM users"
position = len("SELECT ")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(map(Completion, completer.all_completions))
def test_special_name_completion(completer, complete_event):
text = '\\'
position = len('\\')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
text = "\\"
position = len("\\")
result = set(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
# Special commands will NOT be suggested during naive completion mode.
assert result == set()

View file

@ -1,67 +1,72 @@
import pytest
from mycli.packages.parseutils import (
extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause,
is_dropping_database)
extract_tables,
query_starts_with,
queries_start_with,
is_destructive,
query_has_where_clause,
is_dropping_database,
)
def test_empty_string():
tables = extract_tables('')
tables = extract_tables("")
assert tables == []
def test_simple_select_single_table():
tables = extract_tables('select * from abc')
assert tables == [(None, 'abc', None)]
tables = extract_tables("select * from abc")
assert tables == [(None, "abc", None)]
def test_simple_select_single_table_schema_qualified():
tables = extract_tables('select * from abc.def')
assert tables == [('abc', 'def', None)]
tables = extract_tables("select * from abc.def")
assert tables == [("abc", "def", None)]
def test_simple_select_multiple_tables():
tables = extract_tables('select * from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
tables = extract_tables("select * from abc, def")
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables('select * from abc.def, ghi.jkl')
assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
tables = extract_tables("select * from abc.def, ghi.jkl")
assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)]
def test_simple_select_with_cols_single_table():
tables = extract_tables('select a,b from abc')
assert tables == [(None, 'abc', None)]
tables = extract_tables("select a,b from abc")
assert tables == [(None, "abc", None)]
def test_simple_select_with_cols_single_table_schema_qualified():
tables = extract_tables('select a,b from abc.def')
assert tables == [('abc', 'def', None)]
tables = extract_tables("select a,b from abc.def")
assert tables == [("abc", "def", None)]
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables('select a,b from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
tables = extract_tables("select a,b from abc, def")
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_simple_select_with_cols_multiple_tables_with_schema():
tables = extract_tables('select a,b from abc.def, def.ghi')
assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
tables = extract_tables("select a,b from abc.def, def.ghi")
assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)]
def test_select_with_hanging_comma_single_table():
tables = extract_tables('select a, from abc')
assert tables == [(None, 'abc', None)]
tables = extract_tables("select a, from abc")
assert tables == [(None, "abc", None)]
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables('select a, from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
tables = extract_tables("select a, from abc, def")
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")]
def test_simple_insert_single_table():
@ -69,97 +74,80 @@ def test_simple_insert_single_table():
# sqlparse mistakenly assigns an alias to the table
# assert tables == [(None, 'abc', None)]
assert tables == [(None, 'abc', 'abc')]
assert tables == [(None, "abc", "abc")]
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
assert tables == [('abc', 'def', None)]
assert tables == [("abc", "def", None)]
def test_simple_update_table():
tables = extract_tables('update abc set id = 1')
assert tables == [(None, 'abc', None)]
tables = extract_tables("update abc set id = 1")
assert tables == [(None, "abc", None)]
def test_simple_update_table_with_schema():
tables = extract_tables('update abc.def set id = 1')
assert tables == [('abc', 'def', None)]
tables = extract_tables("update abc.def set id = 1")
assert tables == [("abc", "def", None)]
def test_join_table():
tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num")
assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")]
def test_join_table_schema_qualified():
tables = extract_tables(
'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")]
def test_join_as_table():
tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
assert tables == [(None, 'my_table', 'm')]
tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
assert tables == [(None, "my_table", "m")]
def test_query_starts_with():
query = 'USE test;'
assert query_starts_with(query, ('use', )) is True
query = "USE test;"
assert query_starts_with(query, ("use",)) is True
query = 'DROP DATABASE test;'
assert query_starts_with(query, ('use', )) is False
query = "DROP DATABASE test;"
assert query_starts_with(query, ("use",)) is False
def test_query_starts_with_comment():
query = '# comment\nUSE test;'
assert query_starts_with(query, ('use', )) is True
query = "# comment\nUSE test;"
assert query_starts_with(query, ("use",)) is True
def test_queries_start_with():
sql = (
'# comment\n'
'show databases;'
'use foo;'
)
assert queries_start_with(sql, ('show', 'select')) is True
assert queries_start_with(sql, ('use', 'drop')) is True
assert queries_start_with(sql, ('delete', 'update')) is False
sql = "# comment\n" "show databases;" "use foo;"
assert queries_start_with(sql, ("show", "select")) is True
assert queries_start_with(sql, ("use", "drop")) is True
assert queries_start_with(sql, ("delete", "update")) is False
def test_is_destructive():
sql = (
'use test;\n'
'show databases;\n'
'drop database foo;'
)
sql = "use test;\n" "show databases;\n" "drop database foo;"
assert is_destructive(sql) is True
def test_is_destructive_update_with_where_clause():
sql = (
'use test;\n'
'show databases;\n'
'UPDATE test SET x = 1 WHERE id = 1;'
)
sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1 WHERE id = 1;"
assert is_destructive(sql) is False
def test_is_destructive_update_without_where_clause():
sql = (
'use test;\n'
'show databases;\n'
'UPDATE test SET x = 1;'
)
sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1;"
assert is_destructive(sql) is True
@pytest.mark.parametrize(
('sql', 'has_where_clause'),
("sql", "has_where_clause"),
[
('update test set dummy = 1;', False),
('update test set dummy = 1 where id = 1);', True),
("update test set dummy = 1;", False),
("update test set dummy = 1 where id = 1);", True),
],
)
def test_query_has_where_clause(sql, has_where_clause):
@ -167,24 +155,20 @@ def test_query_has_where_clause(sql, has_where_clause):
@pytest.mark.parametrize(
('sql', 'dbname', 'is_dropping'),
("sql", "dbname", "is_dropping"),
[
('select bar from foo', 'foo', False),
('drop database "foo";', '`foo`', True),
('drop schema foo', 'foo', True),
('drop schema foo', 'bar', False),
('drop database bar', 'foo', False),
('drop database foo', None, False),
('drop database foo; create database foo', 'foo', False),
('drop database foo; create database bar', 'foo', True),
('select bar from foo; drop database bazz', 'foo', False),
('select bar from foo; drop database bazz', 'bazz', True),
('-- dropping database \n '
'drop -- really dropping \n '
'schema abc -- now it is dropped',
'abc',
True)
]
("select bar from foo", "foo", False),
('drop database "foo";', "`foo`", True),
("drop schema foo", "foo", True),
("drop schema foo", "bar", False),
("drop database bar", "foo", False),
("drop database foo", None, False),
("drop database foo; create database foo", "foo", False),
("drop database foo; create database bar", "foo", True),
("select bar from foo; drop database bazz", "foo", False),
("select bar from foo; drop database bazz", "bazz", True),
("-- dropping database \n " "drop -- really dropping \n " "schema abc -- now it is dropped", "abc", True),
],
)
def test_is_dropping_database(sql, dbname, is_dropping):
assert is_dropping_database(sql, dbname) == is_dropping

View file

@ -4,8 +4,8 @@ from mycli.packages.prompt_utils import confirm_destructive_query
def test_confirm_destructive_query_notty():
stdin = click.get_text_stream('stdin')
stdin = click.get_text_stream("stdin")
assert stdin.isatty() is False
sql = 'drop database foo;'
sql = "drop database foo;"
assert confirm_destructive_query(sql) is None

View file

@ -43,49 +43,35 @@ def complete_event():
def test_special_name_completion(completer, complete_event):
text = "\\d"
position = len("\\d")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert result == [Completion(text="\\dt", start_position=-2)]
def test_empty_string_completion(completer, complete_event):
text = ""
position = 0
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert (
list(map(Completion, completer.keywords + completer.special_commands)) == result
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert list(map(Completion, completer.keywords + completer.special_commands)) == result
def test_select_keyword_completion(completer, complete_event):
text = "SEL"
position = len("SEL")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list([Completion(text="SELECT", start_position=-3)])
def test_select_star(completer, complete_event):
text = "SELECT * "
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(map(Completion, completer.keywords))
def test_table_completion(completer, complete_event):
text = "SELECT * FROM "
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="users", start_position=0),
@ -99,9 +85,7 @@ def test_table_completion(completer, complete_event):
def test_function_name_completion(completer, complete_event):
text = "SELECT MA"
position = len("SELECT MA")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="MAX", start_position=-2),
@ -127,11 +111,7 @@ def test_suggested_column_names(completer, complete_event):
"""
text = "SELECT from users"
position = len("SELECT ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -157,9 +137,7 @@ def test_suggested_column_names_in_function(completer, complete_event):
"""
text = "SELECT MAX( from users"
position = len("SELECT MAX(")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="*", start_position=0),
@ -181,11 +159,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event):
"""
text = "SELECT users. from users"
position = len("SELECT users.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -207,11 +181,7 @@ def test_suggested_column_names_with_alias(completer, complete_event):
"""
text = "SELECT u. from users u"
position = len("SELECT u.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -234,11 +204,7 @@ def test_suggested_multiple_column_names(completer, complete_event):
"""
text = "SELECT id, from users u"
position = len("SELECT id, ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -264,11 +230,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event):
"""
text = "SELECT u.id, u. from users u"
position = len("SELECT u.id, u.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -291,11 +253,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
"""
text = "SELECT users.id, users. from users u"
position = len("SELECT users.id, users.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -310,11 +268,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
def test_suggested_aliases_after_on(completer, complete_event):
text = "SELECT u.name, o.id FROM users u JOIN orders o ON "
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="u", start_position=0),
@ -326,11 +280,7 @@ def test_suggested_aliases_after_on(completer, complete_event):
def test_suggested_aliases_after_on_right_side(completer, complete_event):
text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = "
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="u", start_position=0),
@ -342,11 +292,7 @@ def test_suggested_aliases_after_on_right_side(completer, complete_event):
def test_suggested_tables_after_on(completer, complete_event):
text = "SELECT users.name, orders.id FROM users JOIN orders ON "
position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@ -357,14 +303,8 @@ def test_suggested_tables_after_on(completer, complete_event):
def test_suggested_tables_after_on_right_side(completer, complete_event):
text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
position = len(
"SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
)
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@ -376,11 +316,7 @@ def test_suggested_tables_after_on_right_side(completer, complete_event):
def test_table_names_after_from(completer, complete_event):
text = "SELECT * FROM "
position = len("SELECT * FROM ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@ -394,29 +330,21 @@ def test_table_names_after_from(completer, complete_event):
def test_auto_escaped_col_names(completer, complete_event):
text = "SELECT from `select`"
position = len("SELECT ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == [
Completion(text="*", start_position=0),
Completion(text="id", start_position=0),
Completion(text="`insert`", start_position=0),
Completion(text="`ABC`", start_position=0),
] + list(map(Completion, completer.functions)) + [
Completion(text="select", start_position=0)
] + list(map(Completion, completer.keywords))
] + list(map(Completion, completer.functions)) + [Completion(text="select", start_position=0)] + list(
map(Completion, completer.keywords)
)
def test_un_escaped_table_names(completer, complete_event):
text = "SELECT from réveillé"
position = len("SELECT ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -464,10 +392,6 @@ def dummy_list_path(dir_name):
)
def test_file_name_completion(completer, complete_event, text, expected):
position = len(text)
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
expected = list((Completion(txt, pos) for txt, pos in expected))
assert result == expected

View file

@ -17,11 +17,11 @@ def test_set_get_pager():
assert mycli.packages.special.is_pager_enabled()
mycli.packages.special.set_pager_enabled(False)
assert not mycli.packages.special.is_pager_enabled()
mycli.packages.special.set_pager('less')
assert os.environ['PAGER'] == "less"
mycli.packages.special.set_pager("less")
assert os.environ["PAGER"] == "less"
mycli.packages.special.set_pager(False)
assert os.environ['PAGER'] == "less"
del os.environ['PAGER']
assert os.environ["PAGER"] == "less"
del os.environ["PAGER"]
mycli.packages.special.set_pager(False)
mycli.packages.special.disable_pager()
assert not mycli.packages.special.is_pager_enabled()
@ -42,45 +42,44 @@ def test_set_get_expanded_output():
def test_editor_command():
assert mycli.packages.special.editor_command(r'hello\e')
assert mycli.packages.special.editor_command(r'\ehello')
assert not mycli.packages.special.editor_command(r'hello')
assert mycli.packages.special.editor_command(r"hello\e")
assert mycli.packages.special.editor_command(r"\ehello")
assert not mycli.packages.special.editor_command(r"hello")
assert mycli.packages.special.get_filename(r'\e filename') == "filename"
assert mycli.packages.special.get_filename(r"\e filename") == "filename"
os.environ['EDITOR'] = 'true'
os.environ['VISUAL'] = 'true'
os.environ["EDITOR"] = "true"
os.environ["VISUAL"] = "true"
# Set the editor to Notepad on Windows
if os.name != 'nt':
mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1"
if os.name != "nt":
mycli.packages.special.open_external_editor(sql=r"select 1") == "select 1"
else:
pytest.skip('Skipping on Windows platform.')
pytest.skip("Skipping on Windows platform.")
def test_tee_command():
mycli.packages.special.write_tee(u"hello world") # write without file set
mycli.packages.special.write_tee("hello world") # write without file set
# keep Windows from locking the file with delete=False
with tempfile.NamedTemporaryFile(delete=False) as f:
mycli.packages.special.execute(None, u"tee " + f.name)
mycli.packages.special.write_tee(u"hello world")
if os.name=='nt':
mycli.packages.special.execute(None, "tee " + f.name)
mycli.packages.special.write_tee("hello world")
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"tee -o " + f.name)
mycli.packages.special.write_tee(u"hello world")
mycli.packages.special.execute(None, "tee -o " + f.name)
mycli.packages.special.write_tee("hello world")
f.seek(0)
if os.name=='nt':
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"notee")
mycli.packages.special.write_tee(u"hello world")
mycli.packages.special.execute(None, "notee")
mycli.packages.special.write_tee("hello world")
f.seek(0)
if os.name=='nt':
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
@ -92,52 +91,49 @@ def test_tee_command():
os.remove(f.name)
except Exception as e:
print(f"An error occurred while attempting to delete the file: {e}")
def test_tee_command_error():
with pytest.raises(TypeError):
mycli.packages.special.execute(None, 'tee')
mycli.packages.special.execute(None, "tee")
with pytest.raises(OSError):
with tempfile.NamedTemporaryFile() as f:
os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
mycli.packages.special.execute(None, 'tee {}'.format(f.name))
mycli.packages.special.execute(None, "tee {}".format(f.name))
@dbtest
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
def test_favorite_query():
with db_connection().cursor() as cur:
query = u'select ""'
mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query))
assert next(mycli.packages.special.execute(
cur, u'\\f check'))[0] == "> " + query
query = 'select ""'
mycli.packages.special.execute(cur, "\\fs check {0}".format(query))
assert next(mycli.packages.special.execute(cur, "\\f check"))[0] == "> " + query
def test_once_command():
with pytest.raises(TypeError):
mycli.packages.special.execute(None, u"\\once")
mycli.packages.special.execute(None, "\\once")
with pytest.raises(OSError):
mycli.packages.special.execute(None, u"\\once /proc/access-denied")
mycli.packages.special.execute(None, "\\once /proc/access-denied")
mycli.packages.special.write_once(u"hello world") # write without file set
mycli.packages.special.write_once("hello world") # write without file set
# keep Windows from locking the file with delete=False
with tempfile.NamedTemporaryFile(delete=False) as f:
mycli.packages.special.execute(None, u"\\once " + f.name)
mycli.packages.special.write_once(u"hello world")
if os.name=='nt':
mycli.packages.special.execute(None, "\\once " + f.name)
mycli.packages.special.write_once("hello world")
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"\\once -o " + f.name)
mycli.packages.special.write_once(u"hello world line 1")
mycli.packages.special.write_once(u"hello world line 2")
mycli.packages.special.execute(None, "\\once -o " + f.name)
mycli.packages.special.write_once("hello world line 1")
mycli.packages.special.write_once("hello world line 2")
f.seek(0)
if os.name=='nt':
if os.name == "nt":
assert f.read() == b"hello world line 1\r\nhello world line 2\r\n"
else:
assert f.read() == b"hello world line 1\nhello world line 2\n"
@ -151,52 +147,47 @@ def test_once_command():
def test_pipe_once_command():
with pytest.raises(IOError):
mycli.packages.special.execute(None, u"\\pipe_once")
mycli.packages.special.execute(None, "\\pipe_once")
with pytest.raises(OSError):
mycli.packages.special.execute(
None, u"\\pipe_once /proc/access-denied")
mycli.packages.special.execute(None, "\\pipe_once /proc/access-denied")
if os.name == 'nt':
mycli.packages.special.execute(None, u'\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
mycli.packages.special.write_once(u"hello world")
if os.name == "nt":
mycli.packages.special.execute(None, '\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
mycli.packages.special.write_once("hello world")
mycli.packages.special.unset_pipe_once_if_written()
else:
mycli.packages.special.execute(None, u"\\pipe_once wc")
mycli.packages.special.write_once(u"hello world")
mycli.packages.special.unset_pipe_once_if_written()
# how to assert on wc output?
with tempfile.NamedTemporaryFile() as f:
mycli.packages.special.execute(None, "\\pipe_once tee " + f.name)
mycli.packages.special.write_pipe_once("hello world")
mycli.packages.special.unset_pipe_once_if_written()
f.seek(0)
assert f.read() == b"hello world\n"
def test_parseargfile():
"""Test that parseargfile expands the user directory."""
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
'mode': 'a'}
if os.name=='nt':
assert expected == mycli.packages.special.iocommands.parseargfile(
'~\\filename')
else:
assert expected == mycli.packages.special.iocommands.parseargfile(
'~/filename')
expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "a"}
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
'mode': 'w'}
if os.name=='nt':
assert expected == mycli.packages.special.iocommands.parseargfile(
'-o ~\\filename')
if os.name == "nt":
assert expected == mycli.packages.special.iocommands.parseargfile("~\\filename")
else:
assert expected == mycli.packages.special.iocommands.parseargfile(
'-o ~/filename')
assert expected == mycli.packages.special.iocommands.parseargfile("~/filename")
expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "w"}
if os.name == "nt":
assert expected == mycli.packages.special.iocommands.parseargfile("-o ~\\filename")
else:
assert expected == mycli.packages.special.iocommands.parseargfile("-o ~/filename")
def test_parseargfile_no_file():
"""Test that parseargfile raises a TypeError if there is no filename."""
with pytest.raises(TypeError):
mycli.packages.special.iocommands.parseargfile('')
mycli.packages.special.iocommands.parseargfile("")
with pytest.raises(TypeError):
mycli.packages.special.iocommands.parseargfile('-o ')
mycli.packages.special.iocommands.parseargfile("-o ")
@dbtest
@ -205,11 +196,9 @@ def test_watch_query_iteration():
the desired query and returns the given results."""
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
expected_title = '> {0!s}'.format(query)
expected_title = "> {0!s}".format(query)
with db_connection().cursor() as cur:
result = next(mycli.packages.special.iocommands.watch_query(
arg=query, cur=cur
))
result = next(mycli.packages.special.iocommands.watch_query(arg=query, cur=cur))
assert result[0] == expected_title
assert result[2][0] == expected_value
@ -230,14 +219,12 @@ def test_watch_query_full():
wait_interval = 1
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
expected_title = '> {0!s}'.format(query)
expected_title = "> {0!s}".format(query)
expected_results = 4
ctrl_c_process = send_ctrl_c(wait_interval)
with db_connection().cursor() as cur:
results = list(
result for result in mycli.packages.special.iocommands.watch_query(
arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur
)
result for result in mycli.packages.special.iocommands.watch_query(arg="{0!s} {1!s}".format(watch_seconds, query), cur=cur)
)
ctrl_c_process.join(1)
assert len(results) == expected_results
@ -247,14 +234,12 @@ def test_watch_query_full():
@dbtest
@patch('click.clear')
@patch("click.clear")
def test_watch_query_clear(clear_mock):
"""Test that the screen is cleared with the -c flag of `watch` command
before execute the query."""
with db_connection().cursor() as cur:
watch_gen = mycli.packages.special.iocommands.watch_query(
arg='0.1 -c select 1;', cur=cur
)
watch_gen = mycli.packages.special.iocommands.watch_query(arg="0.1 -c select 1;", cur=cur)
assert not clear_mock.called
next(watch_gen)
assert clear_mock.called
@ -271,19 +256,20 @@ def test_watch_query_bad_arguments():
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
with pytest.raises(ProgrammingError):
next(watch_query('a select 1;', cur=cur))
next(watch_query("a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('-a select 1;', cur=cur))
next(watch_query("-a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('1 -a select 1;', cur=cur))
next(watch_query("1 -a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('-c -a select 1;', cur=cur))
next(watch_query("-c -a select 1;", cur=cur))
@dbtest
@patch('click.clear')
@patch("click.clear")
def test_watch_query_interval_clear(clear_mock):
"""Test `watch` command with interval and clear flag."""
def test_asserts(gen):
clear_mock.reset_mock()
start = time()
@ -296,46 +282,32 @@ def test_watch_query_interval_clear(clear_mock):
seconds = 1.0
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
test_asserts(watch_query('{0!s} -c select 1;'.format(seconds),
cur=cur))
test_asserts(watch_query('-c {0!s} select 1;'.format(seconds),
cur=cur))
test_asserts(watch_query("{0!s} -c select 1;".format(seconds), cur=cur))
test_asserts(watch_query("-c {0!s} select 1;".format(seconds), cur=cur))
def test_split_sql_by_delimiter():
for delimiter_str in (';', '$', '😀'):
for delimiter_str in (";", "$", "😀"):
mycli.packages.special.set_delimiter(delimiter_str)
sql_input = "select 1{} select \ufffc2".format(delimiter_str)
queries = (
"select 1",
"select \ufffc2"
)
for query, parsed_query in zip(
queries, mycli.packages.special.split_queries(sql_input)):
assert(query == parsed_query)
queries = ("select 1", "select \ufffc2")
for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
assert query == parsed_query
def test_switch_delimiter_within_query():
mycli.packages.special.set_delimiter(';')
mycli.packages.special.set_delimiter(";")
sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$"
queries = (
"select 1",
"delimiter $$ select 2 $$ select 3 $$",
"select 2",
"select 3"
)
for query, parsed_query in zip(
queries,
mycli.packages.special.split_queries(sql_input)):
assert(query == parsed_query)
queries = ("select 1", "delimiter $$ select 2 $$ select 3 $$", "select 2", "select 3")
for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
assert query == parsed_query
def test_set_delimiter():
for delim in ('foo', 'bar'):
for delim in ("foo", "bar"):
mycli.packages.special.set_delimiter(delim)
assert mycli.packages.special.get_current_delimiter() == delim
def teardown_function():
mycli.packages.special.set_delimiter(';')
mycli.packages.special.set_delimiter(";")

View file

@ -7,14 +7,11 @@ from mycli.sqlexecute import ServerInfo, ServerSpecies
from .utils import run, dbtest, set_expanded_output, is_expanded_output
def assert_result_equal(result, title=None, rows=None, headers=None,
status=None, auto_status=True, assert_contains=False):
def assert_result_equal(result, title=None, rows=None, headers=None, status=None, auto_status=True, assert_contains=False):
"""Assert that an sqlexecute.run() result matches the expected values."""
if status is None and auto_status and rows:
status = '{} row{} in set'.format(
len(rows), 's' if len(rows) > 1 else '')
fields = {'title': title, 'rows': rows, 'headers': headers,
'status': status}
status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "")
fields = {"title": title, "rows": rows, "headers": headers, "status": status}
if assert_contains:
# Do a loose match on the results using the *in* operator.
@ -28,34 +25,35 @@ def assert_result_equal(result, title=None, rows=None, headers=None,
@dbtest
def test_conn(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc')''')
results = run(executor, '''select * from test''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc')""")
results = run(executor, """select * from test""")
assert_result_equal(results, headers=['a'], rows=[('abc',)])
assert_result_equal(results, headers=["a"], rows=[("abc",)])
@dbtest
def test_bools(executor):
run(executor, '''create table test(a boolean)''')
run(executor, '''insert into test values(True)''')
results = run(executor, '''select * from test''')
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
results = run(executor, """select * from test""")
assert_result_equal(results, headers=['a'], rows=[(1,)])
assert_result_equal(results, headers=["a"], rows=[(1,)])
@dbtest
def test_binary(executor):
run(executor, '''create table bt(geom linestring NOT NULL)''')
run(executor, "INSERT INTO bt VALUES "
"(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
results = run(executor, '''select * from bt''')
run(executor, """create table bt(geom linestring NOT NULL)""")
run(executor, "INSERT INTO bt VALUES " "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
results = run(executor, """select * from bt""")
geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n'
b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9'
b'\xac\xdeC@')
geom = (
b"\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n"
b"\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9"
b"\xac\xdeC@"
)
assert_result_equal(results, headers=['geom'], rows=[(geom,)])
assert_result_equal(results, headers=["geom"], rows=[(geom,)])
@dbtest
@ -63,49 +61,48 @@ def test_table_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
assert set(executor.tables()) == set([('a',), ('b',)])
assert set(executor.table_columns()) == set(
[('a', 'x'), ('a', 'y'), ('b', 'z')])
assert set(executor.tables()) == set([("a",), ("b",)])
assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")])
@dbtest
def test_database_list(executor):
databases = executor.databases()
assert 'mycli_test_db' in databases
assert "mycli_test_db" in databases
@dbtest
def test_invalid_syntax(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
run(executor, 'invalid syntax!')
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
run(executor, "invalid syntax!")
assert "You have an error in your SQL syntax;" in str(excinfo.value)
@dbtest
def test_invalid_column_name(executor):
with pytest.raises(pymysql.err.OperationalError) as excinfo:
run(executor, 'select invalid command')
run(executor, "select invalid command")
assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value)
@dbtest
def test_unicode_support_in_output(executor):
run(executor, "create table unicodechars(t text)")
run(executor, u"insert into unicodechars (t) values ('é')")
run(executor, "insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
results = run(executor, u"select * from unicodechars")
assert_result_equal(results, headers=['t'], rows=[(u'é',)])
results = run(executor, "select * from unicodechars")
assert_result_equal(results, headers=["t"], rows=[("é",)])
@dbtest
def test_multiple_queries_same_line(executor):
results = run(executor, "select 'foo'; select 'bar'")
expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)],
'status': '1 row in set'},
{'title': None, 'headers': ['bar'], 'rows': [('bar',)],
'status': '1 row in set'}]
expected = [
{"title": None, "headers": ["foo"], "rows": [("foo",)], "status": "1 row in set"},
{"title": None, "headers": ["bar"], "rows": [("bar",)], "status": "1 row in set"},
]
assert expected == results
@ -113,7 +110,7 @@ def test_multiple_queries_same_line(executor):
def test_multiple_queries_same_line_syntaxerror(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
run(executor, "select 'foo'; invalid syntax")
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
assert "You have an error in your SQL syntax;" in str(excinfo.value)
@dbtest
@ -125,15 +122,13 @@ def test_favorite_query(executor):
run(executor, "insert into test values('def')")
results = run(executor, "\\fs test-a select * from test where a like 'a%'")
assert_result_equal(results, status='Saved.')
assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-a")
assert_result_equal(results,
title="> select * from test where a like 'a%'",
headers=['a'], rows=[('abc',)], auto_status=False)
assert_result_equal(results, title="> select * from test where a like 'a%'", headers=["a"], rows=[("abc",)], auto_status=False)
results = run(executor, "\\fd test-a")
assert_result_equal(results, status='test-a: Deleted')
assert_result_equal(results, status="test-a: Deleted")
@dbtest
@ -144,158 +139,147 @@ def test_favorite_query_multiple_statement(executor):
run(executor, "insert into test values('abc')")
run(executor, "insert into test values('def')")
results = run(executor,
"\\fs test-ad select * from test where a like 'a%'; "
"select * from test where a like 'd%'")
assert_result_equal(results, status='Saved.')
results = run(executor, "\\fs test-ad select * from test where a like 'a%'; " "select * from test where a like 'd%'")
assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-ad")
expected = [{'title': "> select * from test where a like 'a%'",
'headers': ['a'], 'rows': [('abc',)], 'status': None},
{'title': "> select * from test where a like 'd%'",
'headers': ['a'], 'rows': [('def',)], 'status': None}]
expected = [
{"title": "> select * from test where a like 'a%'", "headers": ["a"], "rows": [("abc",)], "status": None},
{"title": "> select * from test where a like 'd%'", "headers": ["a"], "rows": [("def",)], "status": None},
]
assert expected == results
results = run(executor, "\\fd test-ad")
assert_result_equal(results, status='test-ad: Deleted')
assert_result_equal(results, status="test-ad: Deleted")
@dbtest
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
def test_favorite_query_expanded_output(executor):
set_expanded_output(False)
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc')''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc')""")
results = run(executor, "\\fs test-ae select * from test")
assert_result_equal(results, status='Saved.')
assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-ae \\G")
assert is_expanded_output() is True
assert_result_equal(results, title='> select * from test',
headers=['a'], rows=[('abc',)], auto_status=False)
assert_result_equal(results, title="> select * from test", headers=["a"], rows=[("abc",)], auto_status=False)
set_expanded_output(False)
results = run(executor, "\\fd test-ae")
assert_result_equal(results, status='test-ae: Deleted')
assert_result_equal(results, status="test-ae: Deleted")
@dbtest
def test_special_command(executor):
results = run(executor, '\\?')
assert_result_equal(results, rows=('quit', '\\q', 'Quit.'),
headers='Command', assert_contains=True,
auto_status=False)
results = run(executor, "\\?")
assert_result_equal(results, rows=("quit", "\\q", "Quit."), headers="Command", assert_contains=True, auto_status=False)
@dbtest
def test_cd_command_without_a_folder_name(executor):
results = run(executor, 'system cd')
assert_result_equal(results, status='No folder name was provided.')
results = run(executor, "system cd")
assert_result_equal(results, status="No folder name was provided.")
@dbtest
def test_system_command_not_found(executor):
results = run(executor, 'system xyz')
if os.name=='nt':
assert_result_equal(results, status='OSError: The system cannot find the file specified',
assert_contains=True)
results = run(executor, "system xyz")
if os.name == "nt":
assert_result_equal(results, status="OSError: The system cannot find the file specified", assert_contains=True)
else:
assert_result_equal(results, status='OSError: No such file or directory',
assert_contains=True)
assert_result_equal(results, status="OSError: No such file or directory", assert_contains=True)
@dbtest
def test_system_command_output(executor):
eol = os.linesep
test_dir = os.path.abspath(os.path.dirname(__file__))
test_file_path = os.path.join(test_dir, 'test.txt')
results = run(executor, 'system cat {0}'.format(test_file_path))
assert_result_equal(results, status=f'mycli rocks!{eol}')
test_file_path = os.path.join(test_dir, "test.txt")
results = run(executor, "system cat {0}".format(test_file_path))
assert_result_equal(results, status=f"mycli rocks!{eol}")
@dbtest
def test_cd_command_current_dir(executor):
test_path = os.path.abspath(os.path.dirname(__file__))
run(executor, 'system cd {0}'.format(test_path))
run(executor, "system cd {0}".format(test_path))
assert os.getcwd() == test_path
@dbtest
def test_unicode_support(executor):
results = run(executor, u"SELECT '日本語' AS japanese;")
assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)])
results = run(executor, "SELECT '日本語' AS japanese;")
assert_result_equal(results, headers=["japanese"], rows=[("日本語",)])
@dbtest
def test_timestamp_null(executor):
run(executor, '''create table ts_null(a timestamp null)''')
run(executor, '''insert into ts_null values(null)''')
results = run(executor, '''select * from ts_null''')
assert_result_equal(results, headers=['a'],
rows=[(None,)])
run(executor, """create table ts_null(a timestamp null)""")
run(executor, """insert into ts_null values(null)""")
results = run(executor, """select * from ts_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_datetime_null(executor):
run(executor, '''create table dt_null(a datetime null)''')
run(executor, '''insert into dt_null values(null)''')
results = run(executor, '''select * from dt_null''')
assert_result_equal(results, headers=['a'],
rows=[(None,)])
run(executor, """create table dt_null(a datetime null)""")
run(executor, """insert into dt_null values(null)""")
results = run(executor, """select * from dt_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_date_null(executor):
run(executor, '''create table date_null(a date null)''')
run(executor, '''insert into date_null values(null)''')
results = run(executor, '''select * from date_null''')
assert_result_equal(results, headers=['a'], rows=[(None,)])
run(executor, """create table date_null(a date null)""")
run(executor, """insert into date_null values(null)""")
results = run(executor, """select * from date_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_time_null(executor):
run(executor, '''create table time_null(a time null)''')
run(executor, '''insert into time_null values(null)''')
results = run(executor, '''select * from time_null''')
assert_result_equal(results, headers=['a'], rows=[(None,)])
run(executor, """create table time_null(a time null)""")
run(executor, """insert into time_null values(null)""")
results = run(executor, """select * from time_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_multiple_results(executor):
query = '''CREATE PROCEDURE dmtest()
query = """CREATE PROCEDURE dmtest()
BEGIN
SELECT 1;
SELECT 2;
END'''
END"""
executor.conn.cursor().execute(query)
results = run(executor, 'call dmtest;')
results = run(executor, "call dmtest;")
expected = [
{'title': None, 'rows': [(1,)], 'headers': ['1'],
'status': '1 row in set'},
{'title': None, 'rows': [(2,)], 'headers': ['2'],
'status': '1 row in set'}
{"title": None, "rows": [(1,)], "headers": ["1"], "status": "1 row in set"},
{"title": None, "rows": [(2,)], "headers": ["2"], "status": "1 row in set"},
]
assert results == expected
@pytest.mark.parametrize(
'version_string, species, parsed_version_string, version',
"version_string, species, parsed_version_string, version",
(
('5.7.25-TiDB-v6.1.0','TiDB', '6.1.0', 60100),
('8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa', 'TiDB', '7.2.0', 70200),
('5.7.32-35', 'Percona', '5.7.32', 50732),
('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
('unexpected version string', None, '', 0),
('', None, '', 0),
(None, None, '', 0),
)
("5.7.25-TiDB-v6.1.0", "TiDB", "6.1.0", 60100),
("8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa", "TiDB", "7.2.0", 70200),
("5.7.32-35", "Percona", "5.7.32", 50732),
("5.7.32-0ubuntu0.18.04.1", "MySQL", "5.7.32", 50732),
("10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
("5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
("5.0.16-pro-nt-log", "MySQL", "5.0.16", 50016),
("5.1.5a-alpha", "MySQL", "5.1.5", 50105),
("unexpected version string", None, "", 0),
("", None, "", 0),
(None, None, "", 0),
),
)
def test_version_parsing(version_string, species, parsed_version_string, version):
server_info = ServerInfo.from_version_string(version_string)

View file

@ -2,8 +2,6 @@
from textwrap import dedent
from mycli.packages.tabular_output import sql_format
from cli_helpers.tabular_output import TabularOutputFormatter
from .utils import USER, PASSWORD, HOST, PORT, dbtest
@ -23,20 +21,17 @@ def mycli():
@dbtest
def test_sql_output(mycli):
"""Test the sql output adapter."""
headers = ['letters', 'number', 'optional', 'float', 'binary']
headers = ["letters", "number", "optional", "float", "binary"]
class FakeCursor(object):
def __init__(self):
self.data = [
('abc', 1, None, 10.0, b'\xAA'),
('d', 456, '1', 0.5, b'\xAA\xBB')
]
self.data = [("abc", 1, None, 10.0, b"\xaa"), ("d", 456, "1", 0.5, b"\xaa\xbb")]
self.description = [
(None, FIELD_TYPE.VARCHAR),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.FLOAT),
(None, FIELD_TYPE.BLOB)
(None, FIELD_TYPE.BLOB),
]
def __iter__(self):
@ -52,12 +47,11 @@ def test_sql_output(mycli):
return self.description
# Test sql-update output format
assert list(mycli.change_table_format("sql-update")) == \
[(None, None, None, 'Changed table format to sql-update')]
assert list(mycli.change_table_format("sql-update")) == [(None, None, None, "Changed table format to sql-update")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
actual = "\n".join(output)
assert actual == dedent('''\
assert actual == dedent("""\
UPDATE `DUAL` SET
`number` = 1
, `optional` = NULL
@ -69,13 +63,12 @@ def test_sql_output(mycli):
, `optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
WHERE `letters` = 'd';''')
WHERE `letters` = 'd';""")
# Test sql-update-2 output format
assert list(mycli.change_table_format("sql-update-2")) == \
[(None, None, None, 'Changed table format to sql-update-2')]
assert list(mycli.change_table_format("sql-update-2")) == [(None, None, None, "Changed table format to sql-update-2")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
UPDATE `DUAL` SET
`optional` = NULL
, `float` = 10.0e0
@ -85,34 +78,31 @@ def test_sql_output(mycli):
`optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
WHERE `letters` = 'd' AND `number` = 456;''')
WHERE `letters` = 'd' AND `number` = 456;""")
# Test sql-insert output format (without table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
;""")
# Test sql-insert output format (with table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = "SELECT * FROM `table`"
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
;""")
# Test sql-insert output format (with database + table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = "SELECT * FROM `database`.`table`"
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
;""")

View file

@ -9,20 +9,18 @@ import pytest
from mycli.main import special
PASSWORD = os.getenv('PYTEST_PASSWORD')
USER = os.getenv('PYTEST_USER', 'root')
HOST = os.getenv('PYTEST_HOST', 'localhost')
PORT = int(os.getenv('PYTEST_PORT', 3306))
CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
SSH_USER = os.getenv('PYTEST_SSH_USER', None)
SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)
PASSWORD = os.getenv("PYTEST_PASSWORD")
USER = os.getenv("PYTEST_USER", "root")
HOST = os.getenv("PYTEST_HOST", "localhost")
PORT = int(os.getenv("PYTEST_PORT", 3306))
CHARSET = os.getenv("PYTEST_CHARSET", "utf8")
SSH_USER = os.getenv("PYTEST_SSH_USER", None)
SSH_HOST = os.getenv("PYTEST_SSH_HOST", None)
SSH_PORT = os.getenv("PYTEST_SSH_PORT", 22)
def db_connection(dbname=None):
conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname,
password=PASSWORD, charset=CHARSET,
local_infile=False)
conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARSET, local_infile=False)
conn.autocommit = True
return conn
@ -30,20 +28,18 @@ def db_connection(dbname=None):
try:
db_connection()
CAN_CONNECT_TO_DB = True
except:
except Exception:
CAN_CONNECT_TO_DB = False
dbtest = pytest.mark.skipif(
not CAN_CONNECT_TO_DB,
reason="Need a mysql instance at localhost accessible by user 'root'")
dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason="Need a mysql instance at localhost accessible by user 'root'")
def create_db(dbname):
with db_connection().cursor() as cur:
try:
cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''')
cur.execute('''CREATE DATABASE mycli_test_db''')
except:
cur.execute("""DROP DATABASE IF EXISTS mycli_test_db""")
cur.execute("""CREATE DATABASE mycli_test_db""")
except Exception:
pass
@ -53,8 +49,7 @@ def run(executor, sql, rows_as_list=True):
for title, rows, headers, status in executor.run(sql):
rows = list(rows) if (rows_as_list and rows) else rows
result.append({'title': title, 'rows': rows, 'headers': headers,
'status': status})
result.append({"title": title, "rows": rows, "headers": headers, "status": status})
return result
@ -87,8 +82,6 @@ def send_ctrl_c(wait_seconds):
Returns the `multiprocessing.Process` created.
"""
ctrl_c_process = multiprocessing.Process(
target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
)
ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds))
ctrl_c_process.start()
return ctrl_c_process

20
tox.ini
View file

@ -1,15 +1,21 @@
[tox]
envlist = py36, py37, py38
envlist = py
[testenv]
deps = pytest
mock
pexpect
behave
coverage
commands = python setup.py test
skip_install = true
deps = uv
passenv = PYTEST_HOST
PYTEST_USER
PYTEST_PASSWORD
PYTEST_PORT
PYTEST_CHARSET
commands = uv pip install -e .[dev,ssh]
coverage run -m pytest -v test
coverage report -m
behave test/features
[testenv:style]
skip_install = true
deps = ruff
commands = ruff check --fix
ruff format