Adding upstream version 1.29.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
5bd6a68e8c
commit
f9065f1bef
68 changed files with 3723 additions and 3336 deletions
|
@ -1,3 +1,2 @@
|
|||
[run]
|
||||
parallel = True
|
||||
source = mycli
|
||||
|
|
46
.github/workflows/ci.yml
vendored
46
.github/workflows/ci.yml
vendored
|
@ -4,34 +4,21 @@ on:
|
|||
pull_request:
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'AUTHORS'
|
||||
|
||||
jobs:
|
||||
linux:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [
|
||||
'3.8',
|
||||
'3.9',
|
||||
'3.10',
|
||||
'3.11',
|
||||
'3.12',
|
||||
]
|
||||
include:
|
||||
- python-version: '3.8'
|
||||
os: ubuntu-20.04 # MySQL 8.0.36
|
||||
- python-version: '3.9'
|
||||
os: ubuntu-20.04 # MySQL 8.0.36
|
||||
- python-version: '3.10'
|
||||
os: ubuntu-22.04 # MySQL 8.0.36
|
||||
- python-version: '3.11'
|
||||
os: ubuntu-22.04 # MySQL 8.0.36
|
||||
- python-version: '3.12'
|
||||
os: ubuntu-22.04 # MySQL 8.0.36
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v1
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
|
@ -43,10 +30,7 @@ jobs:
|
|||
sudo /etc/init.d/mysql start
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-dev.txt
|
||||
pip install --no-cache-dir -e .
|
||||
run: uv sync --all-extras -p ${{ matrix.python-version }}
|
||||
|
||||
- name: Wait for MySQL connection
|
||||
run: |
|
||||
|
@ -59,13 +43,7 @@ jobs:
|
|||
PYTEST_PASSWORD: root
|
||||
PYTEST_HOST: 127.0.0.1
|
||||
run: |
|
||||
./setup.py test --pytest-args="--cov-report= --cov=mycli"
|
||||
uv run tox -e py${{ matrix.python-version }}
|
||||
|
||||
- name: Lint
|
||||
run: |
|
||||
./setup.py lint --branch=HEAD
|
||||
|
||||
- name: Coverage
|
||||
run: |
|
||||
coverage combine
|
||||
coverage report
|
||||
- name: Run Style Checks
|
||||
run: uv run tox -e style
|
||||
|
|
94
.github/workflows/publish.yml
vendored
Normal file
94
.github/workflows/publish.yml
vendored
Normal file
|
@ -0,0 +1,94 @@
|
|||
name: Publish Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v1
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Start MySQL
|
||||
run: |
|
||||
sudo /etc/init.d/mysql start
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-extras -p ${{ matrix.python-version }}
|
||||
|
||||
- name: Wait for MySQL connection
|
||||
run: |
|
||||
while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Pytest / behave
|
||||
env:
|
||||
PYTEST_PASSWORD: root
|
||||
PYTEST_HOST: 127.0.0.1
|
||||
run: |
|
||||
uv run tox -e py${{ matrix.python-version }}
|
||||
|
||||
- name: Run Style Checks
|
||||
run: uv run tox -e style
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v1
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-extras -p 3.13
|
||||
|
||||
- name: Build
|
||||
run: uv build
|
||||
|
||||
- name: Store the distribution packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: python-packages
|
||||
path: dist/
|
||||
|
||||
publish:
|
||||
name: Publish to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
needs: [build]
|
||||
environment: release
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Download distribution packages
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: python-packages
|
||||
path: dist/
|
||||
- name: Publish to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
31
changelog.md
31
changelog.md
|
@ -1,3 +1,34 @@
|
|||
1.29.2 (2024/12/11)
|
||||
===================
|
||||
|
||||
Internal
|
||||
--------
|
||||
|
||||
* Exclude tests from the python package.
|
||||
|
||||
1.29.1 (2024/12/11)
|
||||
===================
|
||||
|
||||
Internal
|
||||
--------
|
||||
|
||||
* Fix the GH actions to publish a new version.
|
||||
|
||||
1.29.0 (NEVER RELEASED)
|
||||
=======================
|
||||
|
||||
Bug Fixes
|
||||
----------
|
||||
|
||||
* fix SSL through SSH jump host by using a true python socket for a tunnel
|
||||
* Fix mycli crash when connecting to Vitess
|
||||
|
||||
Internal
|
||||
---------
|
||||
|
||||
* Modernize to use PEP-621. Use `uv` instead of `pip` in GH actions.
|
||||
* Remove Python 3.8 and add Python 3.13 in test matrix.
|
||||
|
||||
1.28.0 (2024/11/10)
|
||||
======================
|
||||
|
||||
|
|
|
@ -98,6 +98,7 @@ Contributors:
|
|||
* Houston Wong
|
||||
* Mohamed Rezk
|
||||
* Ryosuke Kazami
|
||||
* Cornel Cruceru
|
||||
|
||||
|
||||
Created by:
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
__version__ = "1.28.0"
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("mycli")
|
||||
|
|
|
@ -13,6 +13,7 @@ def cli_is_multiline(mycli):
|
|||
return False
|
||||
else:
|
||||
return not _multiline_exception(doc.text)
|
||||
|
||||
return cond
|
||||
|
||||
|
||||
|
@ -23,33 +24,32 @@ def _multiline_exception(text):
|
|||
# Multi-statement favorite query is a special case. Because there will
|
||||
# be a semicolon separating statements, we can't consider semicolon an
|
||||
# EOL. Let's consider an empty line an EOL instead.
|
||||
if text.startswith('\\fs'):
|
||||
return orig.endswith('\n')
|
||||
if text.startswith("\\fs"):
|
||||
return orig.endswith("\n")
|
||||
|
||||
return (
|
||||
# Special Command
|
||||
text.startswith('\\') or
|
||||
|
||||
text.startswith("\\")
|
||||
or
|
||||
# Delimiter declaration
|
||||
text.lower().startswith('delimiter') or
|
||||
|
||||
text.lower().startswith("delimiter")
|
||||
or
|
||||
# Ended with the current delimiter (usually a semi-column)
|
||||
text.endswith(special.get_current_delimiter()) or
|
||||
|
||||
text.endswith('\\g') or
|
||||
text.endswith('\\G') or
|
||||
text.endswith(r'\e') or
|
||||
text.endswith(r'\clip') or
|
||||
|
||||
text.endswith(special.get_current_delimiter())
|
||||
or text.endswith("\\g")
|
||||
or text.endswith("\\G")
|
||||
or text.endswith(r"\e")
|
||||
or text.endswith(r"\clip")
|
||||
or
|
||||
# Exit doesn't need semi-column`
|
||||
(text == 'exit') or
|
||||
|
||||
(text == "exit")
|
||||
or
|
||||
# Quit doesn't need semi-column
|
||||
(text == 'quit') or
|
||||
|
||||
(text == "quit")
|
||||
or
|
||||
# To all teh vim fans out there
|
||||
(text == ':q') or
|
||||
|
||||
(text == ":q")
|
||||
or
|
||||
# just a plain enter without any text
|
||||
(text == '')
|
||||
(text == "")
|
||||
)
|
||||
|
|
|
@ -11,70 +11,69 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
|
||||
TOKEN_TO_PROMPT_STYLE = {
|
||||
Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current',
|
||||
Token.Menu.Completions.Completion: 'completion-menu.completion',
|
||||
Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current',
|
||||
Token.Menu.Completions.Meta: 'completion-menu.meta.completion',
|
||||
Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta',
|
||||
Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess
|
||||
Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess
|
||||
Token.SelectedText: 'selected',
|
||||
Token.SearchMatch: 'search',
|
||||
Token.SearchMatch.Current: 'search.current',
|
||||
Token.Toolbar: 'bottom-toolbar',
|
||||
Token.Toolbar.Off: 'bottom-toolbar.off',
|
||||
Token.Toolbar.On: 'bottom-toolbar.on',
|
||||
Token.Toolbar.Search: 'search-toolbar',
|
||||
Token.Toolbar.Search.Text: 'search-toolbar.text',
|
||||
Token.Toolbar.System: 'system-toolbar',
|
||||
Token.Toolbar.Arg: 'arg-toolbar',
|
||||
Token.Toolbar.Arg.Text: 'arg-toolbar.text',
|
||||
Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid',
|
||||
Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed',
|
||||
Token.Output.Header: 'output.header',
|
||||
Token.Output.OddRow: 'output.odd-row',
|
||||
Token.Output.EvenRow: 'output.even-row',
|
||||
Token.Output.Null: 'output.null',
|
||||
Token.Prompt: 'prompt',
|
||||
Token.Continuation: 'continuation',
|
||||
Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
|
||||
Token.Menu.Completions.Completion: "completion-menu.completion",
|
||||
Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
|
||||
Token.Menu.Completions.Meta: "completion-menu.meta.completion",
|
||||
Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
|
||||
Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
|
||||
Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
|
||||
Token.SelectedText: "selected",
|
||||
Token.SearchMatch: "search",
|
||||
Token.SearchMatch.Current: "search.current",
|
||||
Token.Toolbar: "bottom-toolbar",
|
||||
Token.Toolbar.Off: "bottom-toolbar.off",
|
||||
Token.Toolbar.On: "bottom-toolbar.on",
|
||||
Token.Toolbar.Search: "search-toolbar",
|
||||
Token.Toolbar.Search.Text: "search-toolbar.text",
|
||||
Token.Toolbar.System: "system-toolbar",
|
||||
Token.Toolbar.Arg: "arg-toolbar",
|
||||
Token.Toolbar.Arg.Text: "arg-toolbar.text",
|
||||
Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
|
||||
Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
|
||||
Token.Output.Header: "output.header",
|
||||
Token.Output.OddRow: "output.odd-row",
|
||||
Token.Output.EvenRow: "output.even-row",
|
||||
Token.Output.Null: "output.null",
|
||||
Token.Prompt: "prompt",
|
||||
Token.Continuation: "continuation",
|
||||
}
|
||||
|
||||
# reverse dict for cli_helpers, because they still expect Pygments tokens.
|
||||
PROMPT_STYLE_TO_TOKEN = {
|
||||
v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()
|
||||
}
|
||||
PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
|
||||
|
||||
# all tokens that the Pygments MySQL lexer can produce
|
||||
OVERRIDE_STYLE_TO_TOKEN = {
|
||||
'sql.comment': Token.Comment,
|
||||
'sql.comment.multi-line': Token.Comment.Multiline,
|
||||
'sql.comment.single-line': Token.Comment.Single,
|
||||
'sql.comment.optimizer-hint': Token.Comment.Special,
|
||||
'sql.escape': Token.Error,
|
||||
'sql.keyword': Token.Keyword,
|
||||
'sql.datatype': Token.Keyword.Type,
|
||||
'sql.literal': Token.Literal,
|
||||
'sql.literal.date': Token.Literal.Date,
|
||||
'sql.symbol': Token.Name,
|
||||
'sql.quoted-schema-object': Token.Name.Quoted,
|
||||
'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape,
|
||||
'sql.constant': Token.Name.Constant,
|
||||
'sql.function': Token.Name.Function,
|
||||
'sql.variable': Token.Name.Variable,
|
||||
'sql.number': Token.Number,
|
||||
'sql.number.binary': Token.Number.Bin,
|
||||
'sql.number.float': Token.Number.Float,
|
||||
'sql.number.hex': Token.Number.Hex,
|
||||
'sql.number.integer': Token.Number.Integer,
|
||||
'sql.operator': Token.Operator,
|
||||
'sql.punctuation': Token.Punctuation,
|
||||
'sql.string': Token.String,
|
||||
'sql.string.double-quouted': Token.String.Double,
|
||||
'sql.string.escape': Token.String.Escape,
|
||||
'sql.string.single-quoted': Token.String.Single,
|
||||
'sql.whitespace': Token.Text,
|
||||
"sql.comment": Token.Comment,
|
||||
"sql.comment.multi-line": Token.Comment.Multiline,
|
||||
"sql.comment.single-line": Token.Comment.Single,
|
||||
"sql.comment.optimizer-hint": Token.Comment.Special,
|
||||
"sql.escape": Token.Error,
|
||||
"sql.keyword": Token.Keyword,
|
||||
"sql.datatype": Token.Keyword.Type,
|
||||
"sql.literal": Token.Literal,
|
||||
"sql.literal.date": Token.Literal.Date,
|
||||
"sql.symbol": Token.Name,
|
||||
"sql.quoted-schema-object": Token.Name.Quoted,
|
||||
"sql.quoted-schema-object.escape": Token.Name.Quoted.Escape,
|
||||
"sql.constant": Token.Name.Constant,
|
||||
"sql.function": Token.Name.Function,
|
||||
"sql.variable": Token.Name.Variable,
|
||||
"sql.number": Token.Number,
|
||||
"sql.number.binary": Token.Number.Bin,
|
||||
"sql.number.float": Token.Number.Float,
|
||||
"sql.number.hex": Token.Number.Hex,
|
||||
"sql.number.integer": Token.Number.Integer,
|
||||
"sql.operator": Token.Operator,
|
||||
"sql.punctuation": Token.Punctuation,
|
||||
"sql.string": Token.String,
|
||||
"sql.string.double-quouted": Token.String.Double,
|
||||
"sql.string.escape": Token.String.Escape,
|
||||
"sql.string.single-quoted": Token.String.Single,
|
||||
"sql.whitespace": Token.Text,
|
||||
}
|
||||
|
||||
|
||||
def parse_pygments_style(token_name, style_object, style_dict):
|
||||
"""Parse token type and style string.
|
||||
|
||||
|
@ -87,7 +86,7 @@ def parse_pygments_style(token_name, style_object, style_dict):
|
|||
try:
|
||||
other_token_type = string_to_tokentype(style_dict[token_name])
|
||||
return token_type, style_object.styles[other_token_type]
|
||||
except AttributeError as err:
|
||||
except AttributeError:
|
||||
return token_type, style_dict[token_name]
|
||||
|
||||
|
||||
|
@ -95,45 +94,39 @@ def style_factory(name, cli_style):
|
|||
try:
|
||||
style = pygments.styles.get_style_by_name(name)
|
||||
except ClassNotFound:
|
||||
style = pygments.styles.get_style_by_name('native')
|
||||
style = pygments.styles.get_style_by_name("native")
|
||||
|
||||
prompt_styles = []
|
||||
# prompt-toolkit used pygments tokens for styling before, switched to style
|
||||
# names in 2.0. Convert old token types to new style names, for backwards compatibility.
|
||||
for token in cli_style:
|
||||
if token.startswith('Token.'):
|
||||
if token.startswith("Token."):
|
||||
# treat as pygments token (1.0)
|
||||
token_type, style_value = parse_pygments_style(
|
||||
token, style, cli_style)
|
||||
token_type, style_value = parse_pygments_style(token, style, cli_style)
|
||||
if token_type in TOKEN_TO_PROMPT_STYLE:
|
||||
prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
|
||||
prompt_styles.append((prompt_style, style_value))
|
||||
else:
|
||||
# we don't want to support tokens anymore
|
||||
logger.error('Unhandled style / class name: %s', token)
|
||||
logger.error("Unhandled style / class name: %s", token)
|
||||
else:
|
||||
# treat as prompt style name (2.0). See default style names here:
|
||||
# https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
|
||||
prompt_styles.append((token, cli_style[token]))
|
||||
|
||||
override_style = Style([('bottom-toolbar', 'noreverse')])
|
||||
return merge_styles([
|
||||
style_from_pygments_cls(style),
|
||||
override_style,
|
||||
Style(prompt_styles)
|
||||
])
|
||||
override_style = Style([("bottom-toolbar", "noreverse")])
|
||||
return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)])
|
||||
|
||||
|
||||
def style_factory_output(name, cli_style):
|
||||
try:
|
||||
style = pygments.styles.get_style_by_name(name).styles
|
||||
except ClassNotFound:
|
||||
style = pygments.styles.get_style_by_name('native').styles
|
||||
style = pygments.styles.get_style_by_name("native").styles
|
||||
|
||||
for token in cli_style:
|
||||
if token.startswith('Token.'):
|
||||
token_type, style_value = parse_pygments_style(
|
||||
token, style, cli_style)
|
||||
if token.startswith("Token."):
|
||||
token_type, style_value = parse_pygments_style(token, style, cli_style)
|
||||
style.update({token_type: style_value})
|
||||
elif token in PROMPT_STYLE_TO_TOKEN:
|
||||
token_type = PROMPT_STYLE_TO_TOKEN[token]
|
||||
|
@ -143,7 +136,7 @@ def style_factory_output(name, cli_style):
|
|||
style.update({token_type: cli_style[token]})
|
||||
else:
|
||||
# TODO: cli helpers will have to switch to ptk.Style
|
||||
logger.error('Unhandled style / class name: %s', token)
|
||||
logger.error("Unhandled style / class name: %s", token)
|
||||
|
||||
class OutputStyle(PygmentsStyle):
|
||||
default_style = ""
|
||||
|
|
|
@ -6,52 +6,47 @@ from .packages import special
|
|||
|
||||
def create_toolbar_tokens_func(mycli, show_fish_help):
|
||||
"""Return a function that generates the toolbar tokens."""
|
||||
|
||||
def get_toolbar_tokens():
|
||||
result = [('class:bottom-toolbar', ' ')]
|
||||
result = [("class:bottom-toolbar", " ")]
|
||||
|
||||
if mycli.multi_line:
|
||||
delimiter = special.get_current_delimiter()
|
||||
result.append(
|
||||
(
|
||||
'class:bottom-toolbar',
|
||||
' ({} [{}] will end the line) '.format(
|
||||
'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter)
|
||||
))
|
||||
"class:bottom-toolbar",
|
||||
" ({} [{}] will end the line) ".format("Semi-colon" if delimiter == ";" else "Delimiter", delimiter),
|
||||
)
|
||||
)
|
||||
|
||||
if mycli.multi_line:
|
||||
result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON '))
|
||||
result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
|
||||
else:
|
||||
result.append(('class:bottom-toolbar.off',
|
||||
'[F3] Multiline: OFF '))
|
||||
result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
|
||||
if mycli.prompt_app.editing_mode == EditingMode.VI:
|
||||
result.append((
|
||||
'class:bottom-toolbar.on',
|
||||
'Vi-mode ({})'.format(_get_vi_mode())
|
||||
))
|
||||
result.append(("class:bottom-toolbar.on", "Vi-mode ({})".format(_get_vi_mode())))
|
||||
|
||||
if mycli.toolbar_error_message:
|
||||
result.append(
|
||||
('class:bottom-toolbar', ' ' + mycli.toolbar_error_message))
|
||||
result.append(("class:bottom-toolbar", " " + mycli.toolbar_error_message))
|
||||
mycli.toolbar_error_message = None
|
||||
|
||||
if show_fish_help():
|
||||
result.append(
|
||||
('class:bottom-toolbar', ' Right-arrow to complete suggestion'))
|
||||
result.append(("class:bottom-toolbar", " Right-arrow to complete suggestion"))
|
||||
|
||||
if mycli.completion_refresher.is_refreshing():
|
||||
result.append(
|
||||
('class:bottom-toolbar', ' Refreshing completions...'))
|
||||
result.append(("class:bottom-toolbar", " Refreshing completions..."))
|
||||
|
||||
return result
|
||||
|
||||
return get_toolbar_tokens
|
||||
|
||||
|
||||
def _get_vi_mode():
|
||||
"""Get the current vi mode for display."""
|
||||
return {
|
||||
InputMode.INSERT: 'I',
|
||||
InputMode.NAVIGATION: 'N',
|
||||
InputMode.REPLACE: 'R',
|
||||
InputMode.REPLACE_SINGLE: 'R',
|
||||
InputMode.INSERT_MULTIPLE: 'M',
|
||||
InputMode.INSERT: "I",
|
||||
InputMode.NAVIGATION: "N",
|
||||
InputMode.REPLACE: "R",
|
||||
InputMode.REPLACE_SINGLE: "R",
|
||||
InputMode.INSERT_MULTIPLE: "M",
|
||||
}[get_app().vi_state.input_mode]
|
||||
|
|
|
@ -3,4 +3,4 @@
|
|||
import sys
|
||||
|
||||
|
||||
WIN = sys.platform in ('win32', 'cygwin')
|
||||
WIN = sys.platform in ("win32", "cygwin")
|
||||
|
|
|
@ -5,8 +5,8 @@ from collections import OrderedDict
|
|||
from .sqlcompleter import SQLCompleter
|
||||
from .sqlexecute import SQLExecute, ServerSpecies
|
||||
|
||||
class CompletionRefresher(object):
|
||||
|
||||
class CompletionRefresher(object):
|
||||
refreshers = OrderedDict()
|
||||
|
||||
def __init__(self):
|
||||
|
@ -30,16 +30,14 @@ class CompletionRefresher(object):
|
|||
|
||||
if self.is_refreshing():
|
||||
self._restart_refresh.set()
|
||||
return [(None, None, None, 'Auto-completion refresh restarted.')]
|
||||
return [(None, None, None, "Auto-completion refresh restarted.")]
|
||||
else:
|
||||
self._completer_thread = threading.Thread(
|
||||
target=self._bg_refresh,
|
||||
args=(executor, callbacks, completer_options),
|
||||
name='completion_refresh')
|
||||
target=self._bg_refresh, args=(executor, callbacks, completer_options), name="completion_refresh"
|
||||
)
|
||||
self._completer_thread.daemon = True
|
||||
self._completer_thread.start()
|
||||
return [(None, None, None,
|
||||
'Auto-completion refresh started in the background.')]
|
||||
return [(None, None, None, "Auto-completion refresh started in the background.")]
|
||||
|
||||
def is_refreshing(self):
|
||||
return self._completer_thread and self._completer_thread.is_alive()
|
||||
|
@ -49,10 +47,22 @@ class CompletionRefresher(object):
|
|||
|
||||
# Create a new pgexecute method to populate the completions.
|
||||
e = sqlexecute
|
||||
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
|
||||
e.socket, e.charset, e.local_infile, e.ssl,
|
||||
e.ssh_user, e.ssh_host, e.ssh_port,
|
||||
e.ssh_password, e.ssh_key_filename)
|
||||
executor = SQLExecute(
|
||||
e.dbname,
|
||||
e.user,
|
||||
e.password,
|
||||
e.host,
|
||||
e.port,
|
||||
e.socket,
|
||||
e.charset,
|
||||
e.local_infile,
|
||||
e.ssl,
|
||||
e.ssh_user,
|
||||
e.ssh_host,
|
||||
e.ssh_port,
|
||||
e.ssh_password,
|
||||
e.ssh_key_filename,
|
||||
)
|
||||
|
||||
# If callbacks is a single function then push it into a list.
|
||||
if callable(callbacks):
|
||||
|
@ -76,55 +86,68 @@ class CompletionRefresher(object):
|
|||
for callback in callbacks:
|
||||
callback(completer)
|
||||
|
||||
|
||||
def refresher(name, refreshers=CompletionRefresher.refreshers):
|
||||
"""Decorator to add the decorated function to the dictionary of
|
||||
refreshers. Any function decorated with a @refresher will be executed as
|
||||
part of the completion refresh routine."""
|
||||
|
||||
def wrapper(wrapped):
|
||||
refreshers[name] = wrapped
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
@refresher('databases')
|
||||
|
||||
@refresher("databases")
|
||||
def refresh_databases(completer, executor):
|
||||
completer.extend_database_names(executor.databases())
|
||||
|
||||
@refresher('schemata')
|
||||
|
||||
@refresher("schemata")
|
||||
def refresh_schemata(completer, executor):
|
||||
# schemata - In MySQL Schema is the same as database. But for mycli
|
||||
# schemata will be the name of the current database.
|
||||
completer.extend_schemata(executor.dbname)
|
||||
completer.set_dbname(executor.dbname)
|
||||
|
||||
@refresher('tables')
|
||||
def refresh_tables(completer, executor):
|
||||
completer.extend_relations(executor.tables(), kind='tables')
|
||||
completer.extend_columns(executor.table_columns(), kind='tables')
|
||||
|
||||
@refresher('users')
|
||||
@refresher("tables")
|
||||
def refresh_tables(completer, executor):
|
||||
table_columns_dbresult = list(executor.table_columns())
|
||||
completer.extend_relations(table_columns_dbresult, kind="tables")
|
||||
completer.extend_columns(table_columns_dbresult, kind="tables")
|
||||
|
||||
|
||||
@refresher("users")
|
||||
def refresh_users(completer, executor):
|
||||
completer.extend_users(executor.users())
|
||||
|
||||
|
||||
# @refresher('views')
|
||||
# def refresh_views(completer, executor):
|
||||
# completer.extend_relations(executor.views(), kind='views')
|
||||
# completer.extend_columns(executor.view_columns(), kind='views')
|
||||
|
||||
@refresher('functions')
|
||||
|
||||
@refresher("functions")
|
||||
def refresh_functions(completer, executor):
|
||||
completer.extend_functions(executor.functions())
|
||||
if executor.server_info.species == ServerSpecies.TiDB:
|
||||
completer.extend_functions(completer.tidb_functions, builtin=True)
|
||||
|
||||
@refresher('special_commands')
|
||||
|
||||
@refresher("special_commands")
|
||||
def refresh_special(completer, executor):
|
||||
completer.extend_special_commands(COMMANDS.keys())
|
||||
|
||||
@refresher('show_commands')
|
||||
|
||||
@refresher("show_commands")
|
||||
def refresh_show_commands(completer, executor):
|
||||
completer.extend_show_items(executor.show_candidates())
|
||||
|
||||
@refresher('keywords')
|
||||
|
||||
@refresher("keywords")
|
||||
def refresh_keywords(completer, executor):
|
||||
if executor.server_info.species == ServerSpecies.TiDB:
|
||||
completer.extend_keywords(completer.tidb_keywords, replace=True)
|
||||
|
|
|
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
def log(logger, level, message):
|
||||
"""Logs message to stderr if logging isn't initialized."""
|
||||
|
||||
if logger.parent.name != 'root':
|
||||
if logger.parent.name != "root":
|
||||
logger.log(level, message)
|
||||
else:
|
||||
print(message, file=sys.stderr)
|
||||
|
@ -49,16 +49,13 @@ def read_config_file(f, list_values=True):
|
|||
f = os.path.expanduser(f)
|
||||
|
||||
try:
|
||||
config = ConfigObj(f, interpolation=False, encoding='utf8',
|
||||
list_values=list_values)
|
||||
config = ConfigObj(f, interpolation=False, encoding="utf8", list_values=list_values)
|
||||
except ConfigObjError as e:
|
||||
log(logger, logging.WARNING, "Unable to parse line {0} of config file "
|
||||
"'{1}'.".format(e.line_number, f))
|
||||
log(logger, logging.WARNING, "Unable to parse line {0} of config file " "'{1}'.".format(e.line_number, f))
|
||||
log(logger, logging.WARNING, "Using successfully parsed config values.")
|
||||
return e.config
|
||||
except (IOError, OSError) as e:
|
||||
log(logger, logging.WARNING, "You don't have permission to read "
|
||||
"config file '{0}'.".format(e.filename))
|
||||
log(logger, logging.WARNING, "You don't have permission to read " "config file '{0}'.".format(e.filename))
|
||||
return None
|
||||
|
||||
return config
|
||||
|
@ -80,15 +77,12 @@ def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
|
|||
|
||||
try:
|
||||
with open(config_file) as f:
|
||||
include_directives = filter(
|
||||
lambda s: s.startswith('!includedir'),
|
||||
f
|
||||
)
|
||||
include_directives = filter(lambda s: s.startswith("!includedir"), f)
|
||||
dirs = map(lambda s: s.strip().split()[-1], include_directives)
|
||||
dirs = filter(os.path.isdir, dirs)
|
||||
for dir in dirs:
|
||||
for filename in os.listdir(dir):
|
||||
if filename.endswith('.cnf'):
|
||||
if filename.endswith(".cnf"):
|
||||
included_configs.append(os.path.join(dir, filename))
|
||||
except (PermissionError, UnicodeDecodeError):
|
||||
pass
|
||||
|
@ -117,29 +111,31 @@ def read_config_files(files, list_values=True):
|
|||
|
||||
def create_default_config(list_values=True):
|
||||
import mycli
|
||||
default_config_file = resources.open_text(mycli, 'myclirc')
|
||||
|
||||
default_config_file = resources.open_text(mycli, "myclirc")
|
||||
return read_config_file(default_config_file, list_values=list_values)
|
||||
|
||||
|
||||
def write_default_config(destination, overwrite=False):
|
||||
import mycli
|
||||
default_config = resources.read_text(mycli, 'myclirc')
|
||||
|
||||
default_config = resources.read_text(mycli, "myclirc")
|
||||
destination = os.path.expanduser(destination)
|
||||
if not overwrite and exists(destination):
|
||||
return
|
||||
|
||||
with open(destination, 'w') as f:
|
||||
with open(destination, "w") as f:
|
||||
f.write(default_config)
|
||||
|
||||
|
||||
def get_mylogin_cnf_path():
|
||||
"""Return the path to the login path file or None if it doesn't exist."""
|
||||
mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE')
|
||||
mylogin_cnf_path = os.getenv("MYSQL_TEST_LOGIN_FILE")
|
||||
|
||||
if mylogin_cnf_path is None:
|
||||
app_data = os.getenv('APPDATA')
|
||||
default_dir = os.path.join(app_data, 'MySQL') if app_data else '~'
|
||||
mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf')
|
||||
app_data = os.getenv("APPDATA")
|
||||
default_dir = os.path.join(app_data, "MySQL") if app_data else "~"
|
||||
mylogin_cnf_path = os.path.join(default_dir, ".mylogin.cnf")
|
||||
|
||||
mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path)
|
||||
|
||||
|
@ -159,14 +155,14 @@ def open_mylogin_cnf(name):
|
|||
"""
|
||||
|
||||
try:
|
||||
with open(name, 'rb') as f:
|
||||
with open(name, "rb") as f:
|
||||
plaintext = read_and_decrypt_mylogin_cnf(f)
|
||||
except (OSError, IOError, ValueError):
|
||||
logger.error('Unable to open login path file.')
|
||||
logger.error("Unable to open login path file.")
|
||||
return None
|
||||
|
||||
if not isinstance(plaintext, BytesIO):
|
||||
logger.error('Unable to read login path file.')
|
||||
logger.error("Unable to read login path file.")
|
||||
return None
|
||||
|
||||
return TextIOWrapper(plaintext)
|
||||
|
@ -181,6 +177,7 @@ def encrypt_mylogin_cnf(plaintext: IO[str]):
|
|||
https://github.com/isotopp/mysql-config-coder
|
||||
|
||||
"""
|
||||
|
||||
def realkey(key):
|
||||
"""Create the AES key from the login key."""
|
||||
rkey = bytearray(16)
|
||||
|
@ -194,10 +191,7 @@ def encrypt_mylogin_cnf(plaintext: IO[str]):
|
|||
pad_len = buf_len - text_len
|
||||
pad_chr = bytes(chr(pad_len), "utf8")
|
||||
plaintext = plaintext.encode() + pad_chr * pad_len
|
||||
encrypted_text = b''.join(
|
||||
[aes.encrypt(plaintext[i: i + 16])
|
||||
for i in range(0, len(plaintext), 16)]
|
||||
)
|
||||
encrypted_text = b"".join([aes.encrypt(plaintext[i : i + 16]) for i in range(0, len(plaintext), 16)])
|
||||
return encrypted_text
|
||||
|
||||
LOGIN_KEY_LENGTH = 20
|
||||
|
@ -248,7 +242,7 @@ def read_and_decrypt_mylogin_cnf(f):
|
|||
buf = f.read(4)
|
||||
|
||||
if not buf or len(buf) != 4:
|
||||
logger.error('Login path file is blank or incomplete.')
|
||||
logger.error("Login path file is blank or incomplete.")
|
||||
return None
|
||||
|
||||
# Read the login key.
|
||||
|
@ -258,12 +252,12 @@ def read_and_decrypt_mylogin_cnf(f):
|
|||
rkey = [0] * 16
|
||||
for i in range(LOGIN_KEY_LEN):
|
||||
try:
|
||||
rkey[i % 16] ^= ord(key[i:i+1])
|
||||
rkey[i % 16] ^= ord(key[i : i + 1])
|
||||
except TypeError:
|
||||
# ord() was unable to get the value of the byte.
|
||||
logger.error('Unable to generate login path AES key.')
|
||||
logger.error("Unable to generate login path AES key.")
|
||||
return None
|
||||
rkey = struct.pack('16B', *rkey)
|
||||
rkey = struct.pack("16B", *rkey)
|
||||
|
||||
# Create a bytes buffer to hold the plaintext.
|
||||
plaintext = BytesIO()
|
||||
|
@ -274,20 +268,17 @@ def read_and_decrypt_mylogin_cnf(f):
|
|||
len_buf = f.read(MAX_CIPHER_STORE_LEN)
|
||||
if len(len_buf) < MAX_CIPHER_STORE_LEN:
|
||||
break
|
||||
cipher_len, = struct.unpack("<i", len_buf)
|
||||
(cipher_len,) = struct.unpack("<i", len_buf)
|
||||
|
||||
# Read cipher_len bytes from the file and decrypt.
|
||||
cipher = f.read(cipher_len)
|
||||
plain = _remove_pad(
|
||||
b''.join([aes.decrypt(cipher[i: i + 16])
|
||||
for i in range(0, cipher_len, 16)])
|
||||
)
|
||||
plain = _remove_pad(b"".join([aes.decrypt(cipher[i : i + 16]) for i in range(0, cipher_len, 16)]))
|
||||
if plain is False:
|
||||
continue
|
||||
plaintext.write(plain)
|
||||
|
||||
if plaintext.tell() == 0:
|
||||
logger.error('No data successfully decrypted from login path file.')
|
||||
logger.error("No data successfully decrypted from login path file.")
|
||||
return None
|
||||
|
||||
plaintext.seek(0)
|
||||
|
@ -299,17 +290,17 @@ def str_to_bool(s):
|
|||
if isinstance(s, bool):
|
||||
return s
|
||||
elif not isinstance(s, basestring):
|
||||
raise TypeError('argument must be a string')
|
||||
raise TypeError("argument must be a string")
|
||||
|
||||
true_values = ('true', 'on', '1')
|
||||
false_values = ('false', 'off', '0')
|
||||
true_values = ("true", "on", "1")
|
||||
false_values = ("false", "off", "0")
|
||||
|
||||
if s.lower() in true_values:
|
||||
return True
|
||||
elif s.lower() in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError('not a recognized boolean value: {0}'.format(s))
|
||||
raise ValueError("not a recognized boolean value: {0}".format(s))
|
||||
|
||||
|
||||
def strip_matching_quotes(s):
|
||||
|
@ -319,8 +310,7 @@ def strip_matching_quotes(s):
|
|||
values.
|
||||
|
||||
"""
|
||||
if (isinstance(s, basestring) and len(s) >= 2 and
|
||||
s[0] == s[-1] and s[0] in ('"', "'")):
|
||||
if isinstance(s, basestring) and len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"):
|
||||
s = s[1:-1]
|
||||
return s
|
||||
|
||||
|
@ -332,13 +322,13 @@ def _remove_pad(line):
|
|||
pad_length = ord(line[-1:])
|
||||
except TypeError:
|
||||
# ord() was unable to get the value of the byte.
|
||||
logger.warning('Unable to remove pad.')
|
||||
logger.warning("Unable to remove pad.")
|
||||
return False
|
||||
|
||||
if pad_length > len(line) or len(set(line[-pad_length:])) != 1:
|
||||
# Pad length should be less than or equal to the length of the
|
||||
# plaintext. The pad should have a single unique byte.
|
||||
logger.warning('Invalid pad found in login path file.')
|
||||
logger.warning("Invalid pad found in login path file.")
|
||||
return False
|
||||
|
||||
return line[:-pad_length]
|
||||
|
|
|
@ -12,22 +12,22 @@ def mycli_bindings(mycli):
|
|||
"""Custom key bindings for mycli."""
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add('f2')
|
||||
@kb.add("f2")
|
||||
def _(event):
|
||||
"""Enable/Disable SmartCompletion Mode."""
|
||||
_logger.debug('Detected F2 key.')
|
||||
_logger.debug("Detected F2 key.")
|
||||
mycli.completer.smart_completion = not mycli.completer.smart_completion
|
||||
|
||||
@kb.add('f3')
|
||||
@kb.add("f3")
|
||||
def _(event):
|
||||
"""Enable/Disable Multiline Mode."""
|
||||
_logger.debug('Detected F3 key.')
|
||||
_logger.debug("Detected F3 key.")
|
||||
mycli.multi_line = not mycli.multi_line
|
||||
|
||||
@kb.add('f4')
|
||||
@kb.add("f4")
|
||||
def _(event):
|
||||
"""Toggle between Vi and Emacs mode."""
|
||||
_logger.debug('Detected F4 key.')
|
||||
_logger.debug("Detected F4 key.")
|
||||
if mycli.key_bindings == "vi":
|
||||
event.app.editing_mode = EditingMode.EMACS
|
||||
mycli.key_bindings = "emacs"
|
||||
|
@ -35,17 +35,17 @@ def mycli_bindings(mycli):
|
|||
event.app.editing_mode = EditingMode.VI
|
||||
mycli.key_bindings = "vi"
|
||||
|
||||
@kb.add('tab')
|
||||
@kb.add("tab")
|
||||
def _(event):
|
||||
"""Force autocompletion at cursor."""
|
||||
_logger.debug('Detected <Tab> key.')
|
||||
_logger.debug("Detected <Tab> key.")
|
||||
b = event.app.current_buffer
|
||||
if b.complete_state:
|
||||
b.complete_next()
|
||||
else:
|
||||
b.start_completion(select_first=True)
|
||||
|
||||
@kb.add('c-space')
|
||||
@kb.add("c-space")
|
||||
def _(event):
|
||||
"""
|
||||
Initialize autocompletion at cursor.
|
||||
|
@ -55,7 +55,7 @@ def mycli_bindings(mycli):
|
|||
|
||||
If the menu is showing, select the next completion.
|
||||
"""
|
||||
_logger.debug('Detected <C-Space> key.')
|
||||
_logger.debug("Detected <C-Space> key.")
|
||||
|
||||
b = event.app.current_buffer
|
||||
if b.complete_state:
|
||||
|
@ -63,14 +63,14 @@ def mycli_bindings(mycli):
|
|||
else:
|
||||
b.start_completion(select_first=False)
|
||||
|
||||
@kb.add('c-x', 'p', filter=emacs_mode)
|
||||
@kb.add("c-x", "p", filter=emacs_mode)
|
||||
def _(event):
|
||||
"""
|
||||
Prettify and indent current statement, usually into multiple lines.
|
||||
|
||||
Only accepts buffers containing single SQL statements.
|
||||
"""
|
||||
_logger.debug('Detected <C-x p>/> key.')
|
||||
_logger.debug("Detected <C-x p>/> key.")
|
||||
|
||||
b = event.app.current_buffer
|
||||
cursorpos_relative = b.cursor_position / max(1, len(b.text))
|
||||
|
@ -78,19 +78,18 @@ def mycli_bindings(mycli):
|
|||
if len(pretty_text) > 0:
|
||||
b.text = pretty_text
|
||||
cursorpos_abs = int(round(cursorpos_relative * len(b.text)))
|
||||
while 0 < cursorpos_abs < len(b.text) \
|
||||
and b.text[cursorpos_abs] in (' ', '\n'):
|
||||
while 0 < cursorpos_abs < len(b.text) and b.text[cursorpos_abs] in (" ", "\n"):
|
||||
cursorpos_abs -= 1
|
||||
b.cursor_position = min(cursorpos_abs, len(b.text))
|
||||
|
||||
@kb.add('c-x', 'u', filter=emacs_mode)
|
||||
@kb.add("c-x", "u", filter=emacs_mode)
|
||||
def _(event):
|
||||
"""
|
||||
Unprettify and dedent current statement, usually into one line.
|
||||
|
||||
Only accepts buffers containing single SQL statements.
|
||||
"""
|
||||
_logger.debug('Detected <C-x u>/< key.')
|
||||
_logger.debug("Detected <C-x u>/< key.")
|
||||
|
||||
b = event.app.current_buffer
|
||||
cursorpos_relative = b.cursor_position / max(1, len(b.text))
|
||||
|
@ -98,18 +97,17 @@ def mycli_bindings(mycli):
|
|||
if len(unpretty_text) > 0:
|
||||
b.text = unpretty_text
|
||||
cursorpos_abs = int(round(cursorpos_relative * len(b.text)))
|
||||
while 0 < cursorpos_abs < len(b.text) \
|
||||
and b.text[cursorpos_abs] in (' ', '\n'):
|
||||
while 0 < cursorpos_abs < len(b.text) and b.text[cursorpos_abs] in (" ", "\n"):
|
||||
cursorpos_abs -= 1
|
||||
b.cursor_position = min(cursorpos_abs, len(b.text))
|
||||
|
||||
@kb.add('c-r', filter=emacs_mode)
|
||||
@kb.add("c-r", filter=emacs_mode)
|
||||
def _(event):
|
||||
"""Search history using fzf or default reverse incremental search."""
|
||||
_logger.debug('Detected <C-r> key.')
|
||||
_logger.debug("Detected <C-r> key.")
|
||||
search_history(event)
|
||||
|
||||
@kb.add('enter', filter=completion_is_selected)
|
||||
@kb.add("enter", filter=completion_is_selected)
|
||||
def _(event):
|
||||
"""Makes the enter key work as the tab key only when showing the menu.
|
||||
|
||||
|
@ -118,20 +116,20 @@ def mycli_bindings(mycli):
|
|||
(accept current selection).
|
||||
|
||||
"""
|
||||
_logger.debug('Detected enter key.')
|
||||
_logger.debug("Detected enter key.")
|
||||
|
||||
event.current_buffer.complete_state = None
|
||||
b = event.app.current_buffer
|
||||
b.complete_state = None
|
||||
|
||||
@kb.add('escape', 'enter')
|
||||
@kb.add("escape", "enter")
|
||||
def _(event):
|
||||
"""Introduces a line break in multi-line mode, or dispatches the
|
||||
command in single-line mode."""
|
||||
_logger.debug('Detected alt-enter key.')
|
||||
_logger.debug("Detected alt-enter key.")
|
||||
if mycli.multi_line:
|
||||
event.app.current_buffer.validate_and_handle()
|
||||
else:
|
||||
event.app.current_buffer.insert_text('\n')
|
||||
event.app.current_buffer.insert_text("\n")
|
||||
|
||||
return kb
|
||||
|
|
|
@ -7,6 +7,5 @@ class MyCliLexer(MySqlLexer):
|
|||
"""Extends MySQL lexer to add keywords."""
|
||||
|
||||
tokens = {
|
||||
'root': [(r'\brepair\b', Keyword),
|
||||
(r'\boffset\b', Keyword), inherit],
|
||||
"root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit],
|
||||
}
|
||||
|
|
|
@ -5,19 +5,20 @@ import logging
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def load_ipython_extension(ipython):
|
||||
|
||||
def load_ipython_extension(ipython):
|
||||
# This is called via the ipython command '%load_ext mycli.magic'.
|
||||
|
||||
# First, load the sql magic if it isn't already loaded.
|
||||
if not ipython.find_line_magic('sql'):
|
||||
ipython.run_line_magic('load_ext', 'sql')
|
||||
if not ipython.find_line_magic("sql"):
|
||||
ipython.run_line_magic("load_ext", "sql")
|
||||
|
||||
# Register our own magic.
|
||||
ipython.register_magic_function(mycli_line_magic, 'line', 'mycli')
|
||||
ipython.register_magic_function(mycli_line_magic, "line", "mycli")
|
||||
|
||||
|
||||
def mycli_line_magic(line):
|
||||
_logger.debug('mycli magic called: %r', line)
|
||||
_logger.debug("mycli magic called: %r", line)
|
||||
parsed = sql.parse.parse(line, {})
|
||||
# "get" was renamed to "set" in ipython-sql:
|
||||
# https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43
|
||||
|
@ -32,17 +33,17 @@ def mycli_line_magic(line):
|
|||
try:
|
||||
# A corresponding mycli object already exists
|
||||
mycli = conn._mycli
|
||||
_logger.debug('Reusing existing mycli')
|
||||
_logger.debug("Reusing existing mycli")
|
||||
except AttributeError:
|
||||
mycli = MyCli()
|
||||
u = conn.session.engine.url
|
||||
_logger.debug('New mycli: %r', str(u))
|
||||
_logger.debug("New mycli: %r", str(u))
|
||||
|
||||
mycli.connect(host=u.host, port=u.port, passwd=u.password, database=u.database, user=u.username, init_command=None)
|
||||
conn._mycli = mycli
|
||||
|
||||
# For convenience, print the connection alias
|
||||
print('Connected: {}'.format(conn.name))
|
||||
print("Connected: {}".format(conn.name))
|
||||
|
||||
try:
|
||||
mycli.run_cli()
|
||||
|
@ -54,9 +55,9 @@ def mycli_line_magic(line):
|
|||
|
||||
q = mycli.query_history[-1]
|
||||
if q.mutating:
|
||||
_logger.debug('Mutating query detected -- ignoring')
|
||||
_logger.debug("Mutating query detected -- ignoring")
|
||||
return
|
||||
|
||||
if q.successful:
|
||||
ipython = get_ipython()
|
||||
return ipython.run_cell_magic('sql', line, q.query)
|
||||
ipython = get_ipython() # noqa: F821
|
||||
return ipython.run_cell_magic("sql", line, q.query)
|
||||
|
|
971
mycli/main.py
971
mycli/main.py
File diff suppressed because it is too large
Load diff
|
@ -12,8 +12,7 @@ def suggest_type(full_text, text_before_cursor):
|
|||
A scope for a column category will be a list of tables.
|
||||
"""
|
||||
|
||||
word_before_cursor = last_word(text_before_cursor,
|
||||
include='many_punctuations')
|
||||
word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
|
||||
|
||||
identifier = None
|
||||
|
||||
|
@ -25,12 +24,10 @@ def suggest_type(full_text, text_before_cursor):
|
|||
# partially typed string which renders the smart completion useless because
|
||||
# it will always return the list of keywords as completion.
|
||||
if word_before_cursor:
|
||||
if word_before_cursor.endswith(
|
||||
'(') or word_before_cursor.startswith('\\'):
|
||||
if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
else:
|
||||
parsed = sqlparse.parse(
|
||||
text_before_cursor[:-len(word_before_cursor)])
|
||||
parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])
|
||||
|
||||
# word_before_cursor may include a schema qualification, like
|
||||
# "schema_name.partial_name" or "schema_name.", so parse it
|
||||
|
@ -42,7 +39,7 @@ def suggest_type(full_text, text_before_cursor):
|
|||
else:
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
except (TypeError, AttributeError):
|
||||
return [{'type': 'keyword'}]
|
||||
return [{"type": "keyword"}]
|
||||
|
||||
if len(parsed) > 1:
|
||||
# Multiple statements being edited -- isolate the current one by
|
||||
|
@ -72,13 +69,12 @@ def suggest_type(full_text, text_before_cursor):
|
|||
# Be careful here because trivial whitespace is parsed as a statement,
|
||||
# but the statement won't have a first token
|
||||
tok1 = statement.token_first()
|
||||
if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')):
|
||||
if tok1 and (tok1.value == "source" or tok1.value.startswith("\\")):
|
||||
return suggest_special(text_before_cursor)
|
||||
|
||||
last_token = statement and statement.token_prev(len(statement.tokens))[1] or ''
|
||||
last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""
|
||||
|
||||
return suggest_based_on_last_token(last_token, text_before_cursor,
|
||||
full_text, identifier)
|
||||
return suggest_based_on_last_token(last_token, text_before_cursor, full_text, identifier)
|
||||
|
||||
|
||||
def suggest_special(text):
|
||||
|
@ -87,27 +83,27 @@ def suggest_special(text):
|
|||
|
||||
if cmd == text:
|
||||
# Trying to complete the special command itself
|
||||
return [{'type': 'special'}]
|
||||
return [{"type": "special"}]
|
||||
|
||||
if cmd in ('\\u', '\\r'):
|
||||
return [{'type': 'database'}]
|
||||
if cmd in ("\\u", "\\r"):
|
||||
return [{"type": "database"}]
|
||||
|
||||
if cmd in ('\\T'):
|
||||
return [{'type': 'table_format'}]
|
||||
if cmd in ("\\T"):
|
||||
return [{"type": "table_format"}]
|
||||
|
||||
if cmd in ['\\f', '\\fs', '\\fd']:
|
||||
return [{'type': 'favoritequery'}]
|
||||
if cmd in ["\\f", "\\fs", "\\fd"]:
|
||||
return [{"type": "favoritequery"}]
|
||||
|
||||
if cmd in ['\\dt', '\\dt+']:
|
||||
if cmd in ["\\dt", "\\dt+"]:
|
||||
return [
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'},
|
||||
{"type": "table", "schema": []},
|
||||
{"type": "view", "schema": []},
|
||||
{"type": "schema"},
|
||||
]
|
||||
elif cmd in ['\\.', 'source']:
|
||||
return[{'type': 'file_name'}]
|
||||
elif cmd in ["\\.", "source"]:
|
||||
return [{"type": "file_name"}]
|
||||
|
||||
return [{'type': 'keyword'}, {'type': 'special'}]
|
||||
return [{"type": "keyword"}, {"type": "special"}]
|
||||
|
||||
|
||||
def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
|
||||
|
@ -127,20 +123,19 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier
|
|||
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
|
||||
# suggestions in complicated where clauses correctly
|
||||
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
|
||||
return suggest_based_on_last_token(prev_keyword, text_before_cursor,
|
||||
full_text, identifier)
|
||||
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
|
||||
elif token is None:
|
||||
return [{'type': 'keyword'}]
|
||||
return [{"type": "keyword"}]
|
||||
else:
|
||||
token_v = token.value.lower()
|
||||
|
||||
is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']])
|
||||
is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]]) # noqa: E731
|
||||
|
||||
if not token:
|
||||
return [{'type': 'keyword'}, {'type': 'special'}]
|
||||
return [{"type": "keyword"}, {"type": "special"}]
|
||||
elif token_v == "*":
|
||||
return [{'type': 'keyword'}]
|
||||
elif token_v.endswith('('):
|
||||
return [{"type": "keyword"}]
|
||||
elif token_v.endswith("("):
|
||||
p = sqlparse.parse(text_before_cursor)[0]
|
||||
|
||||
if p.tokens and isinstance(p.tokens[-1], Where):
|
||||
|
@ -155,8 +150,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier
|
|||
# Suggest columns/functions AND keywords. (If we wanted to be
|
||||
# really fancy, we could suggest only array-typed columns)
|
||||
|
||||
column_suggestions = suggest_based_on_last_token('where',
|
||||
text_before_cursor, full_text, identifier)
|
||||
column_suggestions = suggest_based_on_last_token("where", text_before_cursor, full_text, identifier)
|
||||
|
||||
# Check for a subquery expression (cases 3 & 4)
|
||||
where = p.tokens[-1]
|
||||
|
@ -167,130 +161,133 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier
|
|||
prev_tok = prev_tok.tokens[-1]
|
||||
|
||||
prev_tok = prev_tok.value.lower()
|
||||
if prev_tok == 'exists':
|
||||
return [{'type': 'keyword'}]
|
||||
if prev_tok == "exists":
|
||||
return [{"type": "keyword"}]
|
||||
else:
|
||||
return column_suggestions
|
||||
|
||||
# Get the token before the parens
|
||||
idx, prev_tok = p.token_prev(len(p.tokens) - 1)
|
||||
if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
|
||||
if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
|
||||
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
||||
tables = extract_tables(full_text)
|
||||
|
||||
# suggest columns that are present in more than one table
|
||||
return [{'type': 'column', 'tables': tables, 'drop_unique': True}]
|
||||
elif p.token_first().value.lower() == 'select':
|
||||
return [{"type": "column", "tables": tables, "drop_unique": True}]
|
||||
elif p.token_first().value.lower() == "select":
|
||||
# If the lparen is preceeded by a space chances are we're about to
|
||||
# do a sub-select.
|
||||
if last_word(text_before_cursor,
|
||||
'all_punctuations').startswith('('):
|
||||
return [{'type': 'keyword'}]
|
||||
elif p.token_first().value.lower() == 'show':
|
||||
return [{'type': 'show'}]
|
||||
if last_word(text_before_cursor, "all_punctuations").startswith("("):
|
||||
return [{"type": "keyword"}]
|
||||
elif p.token_first().value.lower() == "show":
|
||||
return [{"type": "show"}]
|
||||
|
||||
# We're probably in a function argument list
|
||||
return [{'type': 'column', 'tables': extract_tables(full_text)}]
|
||||
elif token_v in ('set', 'order by', 'distinct'):
|
||||
return [{'type': 'column', 'tables': extract_tables(full_text)}]
|
||||
elif token_v == 'as':
|
||||
return [{"type": "column", "tables": extract_tables(full_text)}]
|
||||
elif token_v in ("set", "order by", "distinct"):
|
||||
return [{"type": "column", "tables": extract_tables(full_text)}]
|
||||
elif token_v == "as":
|
||||
# Don't suggest anything for an alias
|
||||
return []
|
||||
elif token_v in ('show'):
|
||||
return [{'type': 'show'}]
|
||||
elif token_v in ('to',):
|
||||
elif token_v in ("show"):
|
||||
return [{"type": "show"}]
|
||||
elif token_v in ("to",):
|
||||
p = sqlparse.parse(text_before_cursor)[0]
|
||||
if p.token_first().value.lower() == 'change':
|
||||
return [{'type': 'change'}]
|
||||
if p.token_first().value.lower() == "change":
|
||||
return [{"type": "change"}]
|
||||
else:
|
||||
return [{'type': 'user'}]
|
||||
elif token_v in ('user', 'for'):
|
||||
return [{'type': 'user'}]
|
||||
elif token_v in ('select', 'where', 'having'):
|
||||
return [{"type": "user"}]
|
||||
elif token_v in ("user", "for"):
|
||||
return [{"type": "user"}]
|
||||
elif token_v in ("select", "where", "having"):
|
||||
# Check for a table alias or schema qualification
|
||||
parent = (identifier and identifier.get_parent_name()) or []
|
||||
|
||||
tables = extract_tables(full_text)
|
||||
if parent:
|
||||
tables = [t for t in tables if identifies(parent, *t)]
|
||||
return [{'type': 'column', 'tables': tables},
|
||||
{'type': 'table', 'schema': parent},
|
||||
{'type': 'view', 'schema': parent},
|
||||
{'type': 'function', 'schema': parent}]
|
||||
return [
|
||||
{"type": "column", "tables": tables},
|
||||
{"type": "table", "schema": parent},
|
||||
{"type": "view", "schema": parent},
|
||||
{"type": "function", "schema": parent},
|
||||
]
|
||||
else:
|
||||
aliases = [alias or table for (schema, table, alias) in tables]
|
||||
return [{'type': 'column', 'tables': tables},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'alias', 'aliases': aliases},
|
||||
{'type': 'keyword'}]
|
||||
elif (token_v.endswith('join') and token.is_keyword) or (token_v in
|
||||
('copy', 'from', 'update', 'into', 'describe', 'truncate',
|
||||
'desc', 'explain')):
|
||||
return [
|
||||
{"type": "column", "tables": tables},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "alias", "aliases": aliases},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
elif (token_v.endswith("join") and token.is_keyword) or (
|
||||
token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
|
||||
):
|
||||
schema = (identifier and identifier.get_parent_name()) or []
|
||||
|
||||
# Suggest tables from either the currently-selected schema or the
|
||||
# public schema if no schema has been specified
|
||||
suggest = [{'type': 'table', 'schema': schema}]
|
||||
suggest = [{"type": "table", "schema": schema}]
|
||||
|
||||
if not schema:
|
||||
# Suggest schemas
|
||||
suggest.insert(0, {'type': 'schema'})
|
||||
suggest.insert(0, {"type": "schema"})
|
||||
|
||||
# Only tables can be TRUNCATED, otherwise suggest views
|
||||
if token_v != 'truncate':
|
||||
suggest.append({'type': 'view', 'schema': schema})
|
||||
if token_v != "truncate":
|
||||
suggest.append({"type": "view", "schema": schema})
|
||||
|
||||
return suggest
|
||||
|
||||
elif token_v in ('table', 'view', 'function'):
|
||||
elif token_v in ("table", "view", "function"):
|
||||
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
|
||||
rel_type = token_v
|
||||
schema = (identifier and identifier.get_parent_name()) or []
|
||||
if schema:
|
||||
return [{'type': rel_type, 'schema': schema}]
|
||||
return [{"type": rel_type, "schema": schema}]
|
||||
else:
|
||||
return [{'type': 'schema'}, {'type': rel_type, 'schema': []}]
|
||||
elif token_v == 'on':
|
||||
return [{"type": "schema"}, {"type": rel_type, "schema": []}]
|
||||
elif token_v == "on":
|
||||
tables = extract_tables(full_text) # [(schema, table, alias), ...]
|
||||
parent = (identifier and identifier.get_parent_name()) or []
|
||||
if parent:
|
||||
# "ON parent.<suggestion>"
|
||||
# parent can be either a schema name or table alias
|
||||
tables = [t for t in tables if identifies(parent, *t)]
|
||||
return [{'type': 'column', 'tables': tables},
|
||||
{'type': 'table', 'schema': parent},
|
||||
{'type': 'view', 'schema': parent},
|
||||
{'type': 'function', 'schema': parent}]
|
||||
return [
|
||||
{"type": "column", "tables": tables},
|
||||
{"type": "table", "schema": parent},
|
||||
{"type": "view", "schema": parent},
|
||||
{"type": "function", "schema": parent},
|
||||
]
|
||||
else:
|
||||
# ON <suggestion>
|
||||
# Use table alias if there is one, otherwise the table name
|
||||
aliases = [alias or table for (schema, table, alias) in tables]
|
||||
suggest = [{'type': 'alias', 'aliases': aliases}]
|
||||
suggest = [{"type": "alias", "aliases": aliases}]
|
||||
|
||||
# The lists of 'aliases' could be empty if we're trying to complete
|
||||
# a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
|
||||
# In that case we just suggest all tables.
|
||||
if not aliases:
|
||||
suggest.append({'type': 'table', 'schema': parent})
|
||||
suggest.append({"type": "table", "schema": parent})
|
||||
return suggest
|
||||
|
||||
elif token_v in ('use', 'database', 'template', 'connect'):
|
||||
elif token_v in ("use", "database", "template", "connect"):
|
||||
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
||||
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
||||
return [{'type': 'database'}]
|
||||
elif token_v == 'tableformat':
|
||||
return [{'type': 'table_format'}]
|
||||
elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']:
|
||||
return [{"type": "database"}]
|
||||
elif token_v == "tableformat":
|
||||
return [{"type": "table_format"}]
|
||||
elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
|
||||
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
|
||||
if prev_keyword:
|
||||
return suggest_based_on_last_token(
|
||||
prev_keyword, text_before_cursor, full_text, identifier)
|
||||
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
return [{'type': 'keyword'}]
|
||||
return [{"type": "keyword"}]
|
||||
|
||||
|
||||
def identifies(id, schema, table, alias):
|
||||
return id == alias or id == table or (
|
||||
schema and (id == schema + '.' + table))
|
||||
return id == alias or id == table or (schema and (id == schema + "." + table))
|
||||
|
|
|
@ -38,7 +38,7 @@ def complete_path(curr_dir, last_dir):
|
|||
"""
|
||||
if not last_dir or curr_dir.startswith(last_dir):
|
||||
return curr_dir
|
||||
elif last_dir == '~':
|
||||
elif last_dir == "~":
|
||||
return os.path.join(last_dir, curr_dir)
|
||||
|
||||
|
||||
|
@ -51,7 +51,7 @@ def parse_path(root_dir):
|
|||
:return: tuple of (string, string, int)
|
||||
|
||||
"""
|
||||
base_dir, last_dir, position = '', '', 0
|
||||
base_dir, last_dir, position = "", "", 0
|
||||
if root_dir:
|
||||
base_dir, last_dir = os.path.split(root_dir)
|
||||
position = -len(last_dir) if last_dir else 0
|
||||
|
@ -69,9 +69,9 @@ def suggest_path(root_dir):
|
|||
|
||||
"""
|
||||
if not root_dir:
|
||||
return [os.path.abspath(os.sep), '~', os.curdir, os.pardir]
|
||||
return [os.path.abspath(os.sep), "~", os.curdir, os.pardir]
|
||||
|
||||
if '~' in root_dir:
|
||||
if "~" in root_dir:
|
||||
root_dir = os.path.expanduser(root_dir)
|
||||
|
||||
if not os.path.exists(root_dir):
|
||||
|
@ -100,7 +100,7 @@ def guess_socket_location():
|
|||
for r, dirs, files in os.walk(directory, topdown=True):
|
||||
for filename in files:
|
||||
name, ext = os.path.splitext(filename)
|
||||
if name.startswith("mysql") and name != "mysqlx" and ext in ('.socket', '.sock'):
|
||||
if name.startswith("mysql") and name != "mysqlx" and ext in (".socket", ".sock"):
|
||||
return os.path.join(r, filename)
|
||||
dirs[:] = [d for d in dirs if d.startswith("mysql")]
|
||||
return None
|
||||
|
|
|
@ -12,16 +12,19 @@ class Paramiko:
|
|||
def __getattr__(self, name):
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
print(dedent("""
|
||||
To enable certain SSH features you need to install paramiko:
|
||||
|
||||
print(
|
||||
dedent("""
|
||||
To enable certain SSH features you need to install paramiko and sshtunnel:
|
||||
|
||||
pip install paramiko
|
||||
pip install paramiko sshtunnel
|
||||
|
||||
It is required for the following configuration options:
|
||||
--list-ssh-config
|
||||
--ssh-config-host
|
||||
--ssh-host
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
|
|
@ -4,18 +4,18 @@ from sqlparse.sql import IdentifierList, Identifier, Function
|
|||
from sqlparse.tokens import Keyword, DML, Punctuation
|
||||
|
||||
cleanup_regex = {
|
||||
# This matches only alphanumerics and underscores.
|
||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
||||
# This matches everything except spaces, parens, colon, and comma
|
||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
||||
# This matches everything except spaces, parens, colon, comma, and period
|
||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||
# This matches only alphanumerics and underscores.
|
||||
"alphanum_underscore": re.compile(r"(\w+)$"),
|
||||
# This matches everything except spaces, parens, colon, and comma
|
||||
"many_punctuations": re.compile(r"([^():,\s]+)$"),
|
||||
# This matches everything except spaces, parens, colon, comma, and period
|
||||
"most_punctuations": re.compile(r"([^\.():,\s]+)$"),
|
||||
# This matches everything except a space.
|
||||
"all_punctuations": re.compile(r"([^\s]+)$"),
|
||||
}
|
||||
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
def last_word(text, include="alphanum_underscore"):
|
||||
r"""
|
||||
Find the last word in a sentence.
|
||||
|
||||
|
@ -47,18 +47,18 @@ def last_word(text, include='alphanum_underscore'):
|
|||
'def'
|
||||
"""
|
||||
|
||||
if not text: # Empty string
|
||||
return ''
|
||||
if not text: # Empty string
|
||||
return ""
|
||||
|
||||
if text[-1].isspace():
|
||||
return ''
|
||||
return ""
|
||||
else:
|
||||
regex = cleanup_regex[include]
|
||||
matches = regex.search(text)
|
||||
if matches:
|
||||
return matches.group(0)
|
||||
else:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
# This code is borrowed from sqlparse example script.
|
||||
|
@ -67,11 +67,11 @@ def is_subselect(parsed):
|
|||
if not parsed.is_group:
|
||||
return False
|
||||
for item in parsed.tokens:
|
||||
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
|
||||
'UPDATE', 'CREATE', 'DELETE'):
|
||||
if item.ttype is DML and item.value.upper() in ("SELECT", "INSERT", "UPDATE", "CREATE", "DELETE"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_from_part(parsed, stop_at_punctuation=True):
|
||||
tbl_prefix_seen = False
|
||||
for item in parsed.tokens:
|
||||
|
@ -85,7 +85,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
|||
# "ON" is a keyword and will trigger the next elif condition.
|
||||
# So instead of stooping the loop when finding an "ON" skip it
|
||||
# eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi'
|
||||
elif item.ttype is Keyword and item.value.upper() == 'ON':
|
||||
elif item.ttype is Keyword and item.value.upper() == "ON":
|
||||
tbl_prefix_seen = False
|
||||
continue
|
||||
# An incomplete nested select won't be recognized correctly as a
|
||||
|
@ -96,24 +96,28 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
|||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||
# condition. So we need to ignore the keyword JOIN and its variants
|
||||
# INNER JOIN, FULL OUTER JOIN, etc.
|
||||
elif item.ttype is Keyword and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper().endswith('JOIN')):
|
||||
elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")):
|
||||
return
|
||||
else:
|
||||
yield item
|
||||
elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
|
||||
item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
|
||||
elif (item.ttype is Keyword or item.ttype is Keyword.DML) and item.value.upper() in (
|
||||
"COPY",
|
||||
"FROM",
|
||||
"INTO",
|
||||
"UPDATE",
|
||||
"TABLE",
|
||||
"JOIN",
|
||||
):
|
||||
tbl_prefix_seen = True
|
||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||
# So this check here is necessary.
|
||||
elif isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
if (identifier.ttype is Keyword and
|
||||
identifier.value.upper() == 'FROM'):
|
||||
if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
|
||||
tbl_prefix_seen = True
|
||||
break
|
||||
|
||||
|
||||
def extract_table_identifiers(token_stream):
|
||||
"""yields tuples of (schema_name, table_name, table_alias)"""
|
||||
|
||||
|
@ -141,6 +145,7 @@ def extract_table_identifiers(token_stream):
|
|||
elif isinstance(item, Function):
|
||||
yield (None, item.get_name(), item.get_name())
|
||||
|
||||
|
||||
# extract_tables is inspired from examples in the sqlparse lib.
|
||||
def extract_tables(sql):
|
||||
"""Extract the table names from an SQL statement.
|
||||
|
@ -156,27 +161,27 @@ def extract_tables(sql):
|
|||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||
# abc is the table name, but if we don't stop at the first lparen, then
|
||||
# we'll identify abc, col1 and col2 as table names.
|
||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
||||
insert_stmt = parsed[0].token_first().value.lower() == "insert"
|
||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||
return list(extract_table_identifiers(stream))
|
||||
|
||||
|
||||
def find_prev_keyword(sql):
|
||||
""" Find the last sql keyword in an SQL statement
|
||||
"""Find the last sql keyword in an SQL statement
|
||||
|
||||
Returns the value of the last keyword, and the text of the query with
|
||||
everything after the last keyword stripped
|
||||
"""
|
||||
if not sql.strip():
|
||||
return None, ''
|
||||
return None, ""
|
||||
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
flattened = list(parsed.flatten())
|
||||
|
||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
||||
logical_operators = ("AND", "OR", "NOT", "BETWEEN")
|
||||
|
||||
for t in reversed(flattened):
|
||||
if t.value == '(' or (t.is_keyword and (
|
||||
t.value.upper() not in logical_operators)):
|
||||
if t.value == "(" or (t.is_keyword and (t.value.upper() not in logical_operators)):
|
||||
# Find the location of token t in the original parsed statement
|
||||
# We can't use parsed.token_index(t) because t may be a child token
|
||||
# inside a TokenList, in which case token_index thows an error
|
||||
|
@ -189,10 +194,10 @@ def find_prev_keyword(sql):
|
|||
# Combine the string values of all tokens in the original list
|
||||
# up to and including the target keyword token t, to produce a
|
||||
# query string with everything after the keyword token removed
|
||||
text = ''.join(tok.value for tok in flattened[:idx+1])
|
||||
text = "".join(tok.value for tok in flattened[: idx + 1])
|
||||
return t, text
|
||||
|
||||
return None, ''
|
||||
return None, ""
|
||||
|
||||
|
||||
def query_starts_with(query, prefixes):
|
||||
|
@ -212,31 +217,25 @@ def queries_start_with(queries, prefixes):
|
|||
|
||||
def query_has_where_clause(query):
|
||||
"""Check if the query contains a where-clause."""
|
||||
return any(
|
||||
isinstance(token, sqlparse.sql.Where)
|
||||
for token_list in sqlparse.parse(query)
|
||||
for token in token_list
|
||||
)
|
||||
return any(isinstance(token, sqlparse.sql.Where) for token_list in sqlparse.parse(query) for token in token_list)
|
||||
|
||||
|
||||
def is_destructive(queries):
|
||||
"""Returns if any of the queries in *queries* is destructive."""
|
||||
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
|
||||
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
|
||||
for query in sqlparse.split(queries):
|
||||
if query:
|
||||
if query_starts_with(query, keywords) is True:
|
||||
return True
|
||||
elif query_starts_with(
|
||||
query, ['update']
|
||||
) is True and not query_has_where_clause(query):
|
||||
elif query_starts_with(query, ["update"]) is True and not query_has_where_clause(query):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sql = 'select * from (select t. from tabl t'
|
||||
print (extract_tables(sql))
|
||||
if __name__ == "__main__":
|
||||
sql = "select * from (select t. from tabl t"
|
||||
print(extract_tables(sql))
|
||||
|
||||
|
||||
def is_dropping_database(queries, dbname):
|
||||
|
@ -258,9 +257,7 @@ def is_dropping_database(queries, dbname):
|
|||
"database",
|
||||
"schema",
|
||||
):
|
||||
database_token = next(
|
||||
(t for t in query.tokens if isinstance(t, Identifier)), None
|
||||
)
|
||||
database_token = next((t for t in query.tokens if isinstance(t, Identifier)), None)
|
||||
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
|
||||
result = keywords[0].normalized == "DROP"
|
||||
return result
|
||||
|
|
|
@ -4,20 +4,20 @@ from .parseutils import is_destructive
|
|||
|
||||
|
||||
class ConfirmBoolParamType(click.ParamType):
|
||||
name = 'confirmation'
|
||||
name = "confirmation"
|
||||
|
||||
def convert(self, value, param, ctx):
|
||||
if isinstance(value, bool):
|
||||
return bool(value)
|
||||
value = value.lower()
|
||||
if value in ('yes', 'y'):
|
||||
if value in ("yes", "y"):
|
||||
return True
|
||||
elif value in ('no', 'n'):
|
||||
elif value in ("no", "n"):
|
||||
return False
|
||||
self.fail('%s is not a valid boolean' % value, param, ctx)
|
||||
self.fail("%s is not a valid boolean" % value, param, ctx)
|
||||
|
||||
def __repr__(self):
|
||||
return 'BOOL'
|
||||
return "BOOL"
|
||||
|
||||
|
||||
BOOLEAN_TYPE = ConfirmBoolParamType()
|
||||
|
@ -32,8 +32,7 @@ def confirm_destructive_query(queries):
|
|||
* False if the query is destructive and the user doesn't want to proceed.
|
||||
|
||||
"""
|
||||
prompt_text = ("You're about to run a destructive command.\n"
|
||||
"Do you want to proceed? (y/n)")
|
||||
prompt_text = "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
|
||||
if is_destructive(queries) and sys.stdin.isatty():
|
||||
return prompt(prompt_text, type=BOOLEAN_TYPE)
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
__all__ = []
|
||||
|
||||
|
||||
def export(defn):
|
||||
"""Decorator to explicitly mark functions that are exposed in a lib."""
|
||||
globals()[defn.__name__] = defn
|
||||
__all__.append(defn.__name__)
|
||||
return defn
|
||||
|
||||
from . import dbcommands
|
||||
from . import iocommands
|
||||
|
||||
from . import dbcommands # noqa: E402 F401
|
||||
from . import iocommands # noqa: E402 F401
|
||||
|
|
|
@ -10,24 +10,23 @@ from pymysql import ProgrammingError
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.',
|
||||
arg_type=PARSED_QUERY, case_sensitive=True)
|
||||
@special_command("\\dt", "\\dt[+] [table]", "List or describe tables.", arg_type=PARSED_QUERY, case_sensitive=True)
|
||||
def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
|
||||
if arg:
|
||||
query = 'SHOW FIELDS FROM {0}'.format(arg)
|
||||
query = "SHOW FIELDS FROM {0}".format(arg)
|
||||
else:
|
||||
query = 'SHOW TABLES'
|
||||
query = "SHOW TABLES"
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
tables = cur.fetchall()
|
||||
status = ''
|
||||
status = ""
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
else:
|
||||
return [(None, None, None, '')]
|
||||
return [(None, None, None, "")]
|
||||
|
||||
if verbose and arg:
|
||||
query = 'SHOW CREATE TABLE {0}'.format(arg)
|
||||
query = "SHOW CREATE TABLE {0}".format(arg)
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
status = cur.fetchone()[1]
|
||||
|
@ -35,128 +34,121 @@ def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
|
|||
return [(None, tables, headers, status)]
|
||||
|
||||
|
||||
@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True)
|
||||
@special_command("\\l", "\\l", "List databases.", arg_type=RAW_QUERY, case_sensitive=True)
|
||||
def list_databases(cur, **_):
|
||||
query = 'SHOW DATABASES'
|
||||
query = "SHOW DATABASES"
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
return [(None, cur, headers, '')]
|
||||
return [(None, cur, headers, "")]
|
||||
else:
|
||||
return [(None, None, None, '')]
|
||||
return [(None, None, None, "")]
|
||||
|
||||
|
||||
@special_command('status', '\\s', 'Get status information from the server.',
|
||||
arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True)
|
||||
@special_command("status", "\\s", "Get status information from the server.", arg_type=RAW_QUERY, aliases=("\\s",), case_sensitive=True)
|
||||
def status(cur, **_):
|
||||
query = 'SHOW GLOBAL STATUS;'
|
||||
query = "SHOW GLOBAL STATUS;"
|
||||
log.debug(query)
|
||||
try:
|
||||
cur.execute(query)
|
||||
except ProgrammingError:
|
||||
# Fallback in case query fail, as it does with Mysql 4
|
||||
query = 'SHOW STATUS;'
|
||||
query = "SHOW STATUS;"
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
status = dict(cur.fetchall())
|
||||
|
||||
query = 'SHOW GLOBAL VARIABLES;'
|
||||
query = "SHOW GLOBAL VARIABLES;"
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
variables = dict(cur.fetchall())
|
||||
|
||||
# prepare in case keys are bytes, as with Python 3 and Mysql 4
|
||||
if (isinstance(list(variables)[0], bytes) and
|
||||
isinstance(list(status)[0], bytes)):
|
||||
variables = {k.decode('utf-8'): v.decode('utf-8') for k, v
|
||||
in variables.items()}
|
||||
status = {k.decode('utf-8'): v.decode('utf-8') for k, v
|
||||
in status.items()}
|
||||
if isinstance(list(variables)[0], bytes) and isinstance(list(status)[0], bytes):
|
||||
variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in variables.items()}
|
||||
status = {k.decode("utf-8"): v.decode("utf-8") for k, v in status.items()}
|
||||
|
||||
# Create output buffers.
|
||||
title = []
|
||||
output = []
|
||||
footer = []
|
||||
|
||||
title.append('--------------')
|
||||
title.append("--------------")
|
||||
|
||||
# Output the mycli client information.
|
||||
implementation = platform.python_implementation()
|
||||
version = platform.python_version()
|
||||
client_info = []
|
||||
client_info.append('mycli {0},'.format(__version__))
|
||||
client_info.append('running on {0} {1}'.format(implementation, version))
|
||||
title.append(' '.join(client_info) + '\n')
|
||||
client_info.append("mycli {0},".format(__version__))
|
||||
client_info.append("running on {0} {1}".format(implementation, version))
|
||||
title.append(" ".join(client_info) + "\n")
|
||||
|
||||
# Build the output that will be displayed as a table.
|
||||
output.append(('Connection id:', cur.connection.thread_id()))
|
||||
output.append(("Connection id:", cur.connection.thread_id()))
|
||||
|
||||
query = 'SELECT DATABASE(), USER();'
|
||||
query = "SELECT DATABASE(), USER();"
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
db, user = cur.fetchone()
|
||||
if db is None:
|
||||
db = ''
|
||||
db = ""
|
||||
|
||||
output.append(('Current database:', db))
|
||||
output.append(('Current user:', user))
|
||||
output.append(("Current database:", db))
|
||||
output.append(("Current user:", user))
|
||||
|
||||
if iocommands.is_pager_enabled():
|
||||
if 'PAGER' in os.environ:
|
||||
pager = os.environ['PAGER']
|
||||
if "PAGER" in os.environ:
|
||||
pager = os.environ["PAGER"]
|
||||
else:
|
||||
pager = 'System default'
|
||||
pager = "System default"
|
||||
else:
|
||||
pager = 'stdout'
|
||||
output.append(('Current pager:', pager))
|
||||
pager = "stdout"
|
||||
output.append(("Current pager:", pager))
|
||||
|
||||
output.append(('Server version:', '{0} {1}'.format(
|
||||
variables['version'], variables['version_comment'])))
|
||||
output.append(('Protocol version:', variables['protocol_version']))
|
||||
output.append(("Server version:", "{0} {1}".format(variables["version"], variables["version_comment"])))
|
||||
output.append(("Protocol version:", variables["protocol_version"]))
|
||||
|
||||
if 'unix' in cur.connection.host_info.lower():
|
||||
if "unix" in cur.connection.host_info.lower():
|
||||
host_info = cur.connection.host_info
|
||||
else:
|
||||
host_info = '{0} via TCP/IP'.format(cur.connection.host)
|
||||
host_info = "{0} via TCP/IP".format(cur.connection.host)
|
||||
|
||||
output.append(('Connection:', host_info))
|
||||
output.append(("Connection:", host_info))
|
||||
|
||||
query = ('SELECT @@character_set_server, @@character_set_database, '
|
||||
'@@character_set_client, @@character_set_connection LIMIT 1;')
|
||||
query = "SELECT @@character_set_server, @@character_set_database, " "@@character_set_client, @@character_set_connection LIMIT 1;"
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
charset = cur.fetchone()
|
||||
output.append(('Server characterset:', charset[0]))
|
||||
output.append(('Db characterset:', charset[1]))
|
||||
output.append(('Client characterset:', charset[2]))
|
||||
output.append(('Conn. characterset:', charset[3]))
|
||||
output.append(("Server characterset:", charset[0]))
|
||||
output.append(("Db characterset:", charset[1]))
|
||||
output.append(("Client characterset:", charset[2]))
|
||||
output.append(("Conn. characterset:", charset[3]))
|
||||
|
||||
if 'TCP/IP' in host_info:
|
||||
output.append(('TCP port:', cur.connection.port))
|
||||
if "TCP/IP" in host_info:
|
||||
output.append(("TCP port:", cur.connection.port))
|
||||
else:
|
||||
output.append(('UNIX socket:', variables['socket']))
|
||||
output.append(("UNIX socket:", variables["socket"]))
|
||||
|
||||
if 'Uptime' in status:
|
||||
output.append(('Uptime:', format_uptime(status['Uptime'])))
|
||||
if "Uptime" in status:
|
||||
output.append(("Uptime:", format_uptime(status["Uptime"])))
|
||||
|
||||
if 'Threads_connected' in status:
|
||||
if "Threads_connected" in status:
|
||||
# Print the current server statistics.
|
||||
stats = []
|
||||
stats.append('Connections: {0}'.format(status['Threads_connected']))
|
||||
if 'Queries' in status:
|
||||
stats.append('Queries: {0}'.format(status['Queries']))
|
||||
stats.append('Slow queries: {0}'.format(status['Slow_queries']))
|
||||
stats.append('Opens: {0}'.format(status['Opened_tables']))
|
||||
if 'Flush_commands' in status:
|
||||
stats.append('Flush tables: {0}'.format(status['Flush_commands']))
|
||||
stats.append('Open tables: {0}'.format(status['Open_tables']))
|
||||
if 'Queries' in status:
|
||||
queries_per_second = int(status['Queries']) / int(status['Uptime'])
|
||||
stats.append('Queries per second avg: {:.3f}'.format(
|
||||
queries_per_second))
|
||||
stats = ' '.join(stats)
|
||||
footer.append('\n' + stats)
|
||||
stats.append("Connections: {0}".format(status["Threads_connected"]))
|
||||
if "Queries" in status:
|
||||
stats.append("Queries: {0}".format(status["Queries"]))
|
||||
stats.append("Slow queries: {0}".format(status["Slow_queries"]))
|
||||
stats.append("Opens: {0}".format(status["Opened_tables"]))
|
||||
if "Flush_commands" in status:
|
||||
stats.append("Flush tables: {0}".format(status["Flush_commands"]))
|
||||
stats.append("Open tables: {0}".format(status["Open_tables"]))
|
||||
if "Queries" in status:
|
||||
queries_per_second = int(status["Queries"]) / int(status["Uptime"])
|
||||
stats.append("Queries per second avg: {:.3f}".format(queries_per_second))
|
||||
stats = " ".join(stats)
|
||||
footer.append("\n" + stats)
|
||||
|
||||
footer.append('--------------')
|
||||
return [('\n'.join(title), output, '', '\n'.join(footer))]
|
||||
footer.append("--------------")
|
||||
return [("\n".join(title), output, "", "\n".join(footer))]
|
||||
|
|
|
@ -4,7 +4,7 @@ import sqlparse
|
|||
|
||||
class DelimiterCommand(object):
|
||||
def __init__(self):
|
||||
self._delimiter = ';'
|
||||
self._delimiter = ";"
|
||||
|
||||
def _split(self, sql):
|
||||
"""Temporary workaround until sqlparse.split() learns about custom
|
||||
|
@ -12,22 +12,19 @@ class DelimiterCommand(object):
|
|||
|
||||
placeholder = "\ufffc" # unicode object replacement character
|
||||
|
||||
if self._delimiter == ';':
|
||||
if self._delimiter == ";":
|
||||
return sqlparse.split(sql)
|
||||
|
||||
# We must find a string that original sql does not contain.
|
||||
# Most likely, our placeholder is enough, but if not, keep looking
|
||||
while placeholder in sql:
|
||||
placeholder += placeholder[0]
|
||||
sql = sql.replace(';', placeholder)
|
||||
sql = sql.replace(self._delimiter, ';')
|
||||
sql = sql.replace(";", placeholder)
|
||||
sql = sql.replace(self._delimiter, ";")
|
||||
|
||||
split = sqlparse.split(sql)
|
||||
|
||||
return [
|
||||
stmt.replace(';', self._delimiter).replace(placeholder, ';')
|
||||
for stmt in split
|
||||
]
|
||||
return [stmt.replace(";", self._delimiter).replace(placeholder, ";") for stmt in split]
|
||||
|
||||
def queries_iter(self, input):
|
||||
"""Iterate over queries in the input string."""
|
||||
|
@ -49,7 +46,7 @@ class DelimiterCommand(object):
|
|||
# re-split everything, and if we previously stripped
|
||||
# the delimiter, append it to the end
|
||||
if self._delimiter != delimiter:
|
||||
combined_statement = ' '.join([sql] + queries)
|
||||
combined_statement = " ".join([sql] + queries)
|
||||
if trailing_delimiter:
|
||||
combined_statement += delimiter
|
||||
queries = self._split(combined_statement)[1:]
|
||||
|
@ -63,13 +60,13 @@ class DelimiterCommand(object):
|
|||
word of it.
|
||||
|
||||
"""
|
||||
match = arg and re.search(r'[^\s]+', arg)
|
||||
match = arg and re.search(r"[^\s]+", arg)
|
||||
if not match:
|
||||
message = 'Missing required argument, delimiter'
|
||||
message = "Missing required argument, delimiter"
|
||||
return [(None, None, None, message)]
|
||||
|
||||
delimiter = match.group()
|
||||
if delimiter.lower() == 'delimiter':
|
||||
if delimiter.lower() == "delimiter":
|
||||
return [(None, None, None, 'Invalid delimiter "delimiter"')]
|
||||
|
||||
self._delimiter = delimiter
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
class FavoriteQueries(object):
|
||||
section_name = "favorite_queries"
|
||||
|
||||
section_name = 'favorite_queries'
|
||||
|
||||
usage = '''
|
||||
usage = """
|
||||
Favorite Queries are a way to save frequently used queries
|
||||
with a short name.
|
||||
Examples:
|
||||
|
@ -29,7 +28,7 @@ Examples:
|
|||
# Delete a favorite query.
|
||||
> \\fd simple
|
||||
simple: Deleted
|
||||
'''
|
||||
"""
|
||||
|
||||
# Class-level variable, for convenience to use as a singleton.
|
||||
instance = None
|
||||
|
@ -48,7 +47,7 @@ Examples:
|
|||
return self.config.get(self.section_name, {}).get(name, None)
|
||||
|
||||
def save(self, name, query):
|
||||
self.config.encoding = 'utf-8'
|
||||
self.config.encoding = "utf-8"
|
||||
if self.section_name not in self.config:
|
||||
self.config[self.section_name] = {}
|
||||
self.config[self.section_name][name] = query
|
||||
|
@ -58,6 +57,6 @@ Examples:
|
|||
try:
|
||||
del self.config[self.section_name][name]
|
||||
except KeyError:
|
||||
return '%s: Not Found.' % name
|
||||
return "%s: Not Found." % name
|
||||
self.config.write()
|
||||
return '%s: Deleted' % name
|
||||
return "%s: Deleted" % name
|
||||
|
|
|
@ -34,6 +34,7 @@ def set_timing_enabled(val):
|
|||
global TIMING_ENABLED
|
||||
TIMING_ENABLED = val
|
||||
|
||||
|
||||
@export
|
||||
def set_pager_enabled(val):
|
||||
global PAGER_ENABLED
|
||||
|
@ -44,33 +45,35 @@ def set_pager_enabled(val):
|
|||
def is_pager_enabled():
|
||||
return PAGER_ENABLED
|
||||
|
||||
|
||||
@export
|
||||
@special_command('pager', '\\P [command]',
|
||||
'Set PAGER. Print the query results via PAGER.',
|
||||
arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
|
||||
@special_command(
|
||||
"pager", "\\P [command]", "Set PAGER. Print the query results via PAGER.", arg_type=PARSED_QUERY, aliases=("\\P",), case_sensitive=True
|
||||
)
|
||||
def set_pager(arg, **_):
|
||||
if arg:
|
||||
os.environ['PAGER'] = arg
|
||||
msg = 'PAGER set to %s.' % arg
|
||||
os.environ["PAGER"] = arg
|
||||
msg = "PAGER set to %s." % arg
|
||||
set_pager_enabled(True)
|
||||
else:
|
||||
if 'PAGER' in os.environ:
|
||||
msg = 'PAGER set to %s.' % os.environ['PAGER']
|
||||
if "PAGER" in os.environ:
|
||||
msg = "PAGER set to %s." % os.environ["PAGER"]
|
||||
else:
|
||||
# This uses click's default per echo_via_pager.
|
||||
msg = 'Pager enabled.'
|
||||
msg = "Pager enabled."
|
||||
set_pager_enabled(True)
|
||||
|
||||
return [(None, None, None, msg)]
|
||||
|
||||
|
||||
@export
|
||||
@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
|
||||
arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
|
||||
@special_command("nopager", "\\n", "Disable pager, print to stdout.", arg_type=NO_QUERY, aliases=("\\n",), case_sensitive=True)
|
||||
def disable_pager():
|
||||
set_pager_enabled(False)
|
||||
return [(None, None, None, 'Pager disabled.')]
|
||||
return [(None, None, None, "Pager disabled.")]
|
||||
|
||||
@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
|
||||
|
||||
@special_command("\\timing", "\\t", "Toggle timing of commands.", arg_type=NO_QUERY, aliases=("\\t",), case_sensitive=True)
|
||||
def toggle_timing():
|
||||
global TIMING_ENABLED
|
||||
TIMING_ENABLED = not TIMING_ENABLED
|
||||
|
@ -78,21 +81,26 @@ def toggle_timing():
|
|||
message += "on." if TIMING_ENABLED else "off."
|
||||
return [(None, None, None, message)]
|
||||
|
||||
|
||||
@export
|
||||
def is_timing_enabled():
|
||||
return TIMING_ENABLED
|
||||
|
||||
|
||||
@export
|
||||
def set_expanded_output(val):
|
||||
global use_expanded_output
|
||||
use_expanded_output = val
|
||||
|
||||
|
||||
@export
|
||||
def is_expanded_output():
|
||||
return use_expanded_output
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@export
|
||||
def editor_command(command):
|
||||
"""
|
||||
|
@ -101,12 +109,13 @@ def editor_command(command):
|
|||
"""
|
||||
# It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
|
||||
# for both conditions.
|
||||
return command.strip().endswith('\\e') or command.strip().startswith('\\e')
|
||||
return command.strip().endswith("\\e") or command.strip().startswith("\\e")
|
||||
|
||||
|
||||
@export
|
||||
def get_filename(sql):
|
||||
if sql.strip().startswith('\\e'):
|
||||
command, _, filename = sql.partition(' ')
|
||||
if sql.strip().startswith("\\e"):
|
||||
command, _, filename = sql.partition(" ")
|
||||
return filename.strip() or None
|
||||
|
||||
|
||||
|
@ -118,9 +127,9 @@ def get_editor_query(sql):
|
|||
# The reason we can't simply do .strip('\e') is that it strips characters,
|
||||
# not a substring. So it'll strip "e" in the end of the sql also!
|
||||
# Ex: "select * from style\e" -> "select * from styl".
|
||||
pattern = re.compile(r'(^\\e|\\e$)')
|
||||
pattern = re.compile(r"(^\\e|\\e$)")
|
||||
while pattern.search(sql):
|
||||
sql = pattern.sub('', sql)
|
||||
sql = pattern.sub("", sql)
|
||||
|
||||
return sql
|
||||
|
||||
|
@ -135,25 +144,24 @@ def open_external_editor(filename=None, sql=None):
|
|||
"""
|
||||
|
||||
message = None
|
||||
filename = filename.strip().split(' ', 1)[0] if filename else None
|
||||
filename = filename.strip().split(" ", 1)[0] if filename else None
|
||||
|
||||
sql = sql or ''
|
||||
MARKER = '# Type your query above this line.\n'
|
||||
sql = sql or ""
|
||||
MARKER = "# Type your query above this line.\n"
|
||||
|
||||
# Populate the editor buffer with the partial sql (if available) and a
|
||||
# placeholder comment.
|
||||
query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
|
||||
filename=filename, extension='.sql')
|
||||
query = click.edit("{sql}\n\n{marker}".format(sql=sql, marker=MARKER), filename=filename, extension=".sql")
|
||||
|
||||
if filename:
|
||||
try:
|
||||
with open(filename) as f:
|
||||
query = f.read()
|
||||
except IOError:
|
||||
message = 'Error reading file: %s.' % filename
|
||||
message = "Error reading file: %s." % filename
|
||||
|
||||
if query is not None:
|
||||
query = query.split(MARKER, 1)[0].rstrip('\n')
|
||||
query = query.split(MARKER, 1)[0].rstrip("\n")
|
||||
else:
|
||||
# Don't return None for the caller to deal with.
|
||||
# Empty string is ok.
|
||||
|
@ -171,7 +179,7 @@ def clip_command(command):
|
|||
"""
|
||||
# It is possible to have `\clip` or `SELECT * FROM \clip`. So we check
|
||||
# for both conditions.
|
||||
return command.strip().endswith('\\clip') or command.strip().startswith('\\clip')
|
||||
return command.strip().endswith("\\clip") or command.strip().startswith("\\clip")
|
||||
|
||||
|
||||
@export
|
||||
|
@ -181,9 +189,9 @@ def get_clip_query(sql):
|
|||
|
||||
# The reason we can't simply do .strip('\clip') is that it strips characters,
|
||||
# not a substring. So it'll strip "c" in the end of the sql also!
|
||||
pattern = re.compile(r'(^\\clip|\\clip$)')
|
||||
pattern = re.compile(r"(^\\clip|\\clip$)")
|
||||
while pattern.search(sql):
|
||||
sql = pattern.sub('', sql)
|
||||
sql = pattern.sub("", sql)
|
||||
|
||||
return sql
|
||||
|
||||
|
@ -192,26 +200,26 @@ def get_clip_query(sql):
|
|||
def copy_query_to_clipboard(sql=None):
|
||||
"""Send query to the clipboard."""
|
||||
|
||||
sql = sql or ''
|
||||
sql = sql or ""
|
||||
message = None
|
||||
|
||||
try:
|
||||
pyperclip.copy(u'{sql}'.format(sql=sql))
|
||||
pyperclip.copy("{sql}".format(sql=sql))
|
||||
except RuntimeError as e:
|
||||
message = 'Error clipping query: %s.' % e.strerror
|
||||
message = "Error clipping query: %s." % e.strerror
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
|
||||
@special_command("\\f", "\\f [name [args..]]", "List or execute favorite queries.", arg_type=PARSED_QUERY, case_sensitive=True)
|
||||
def execute_favorite_query(cur, arg, **_):
|
||||
"""Returns (title, rows, headers, status)"""
|
||||
if arg == '':
|
||||
if arg == "":
|
||||
for result in list_favorite_queries():
|
||||
yield result
|
||||
|
||||
"""Parse out favorite name and optional substitution parameters"""
|
||||
name, _, arg_str = arg.partition(' ')
|
||||
name, _, arg_str = arg.partition(" ")
|
||||
args = shlex.split(arg_str)
|
||||
|
||||
query = FavoriteQueries.instance.get(name)
|
||||
|
@ -224,8 +232,8 @@ def execute_favorite_query(cur, arg, **_):
|
|||
yield (None, None, None, arg_error)
|
||||
else:
|
||||
for sql in sqlparse.split(query):
|
||||
sql = sql.rstrip(';')
|
||||
title = '> %s' % (sql)
|
||||
sql = sql.rstrip(";")
|
||||
title = "> %s" % (sql)
|
||||
cur.execute(sql)
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
|
@ -233,60 +241,60 @@ def execute_favorite_query(cur, arg, **_):
|
|||
else:
|
||||
yield (title, None, None, None)
|
||||
|
||||
|
||||
def list_favorite_queries():
|
||||
"""List of all favorite queries.
|
||||
Returns (title, rows, headers, status)"""
|
||||
|
||||
headers = ["Name", "Query"]
|
||||
rows = [(r, FavoriteQueries.instance.get(r))
|
||||
for r in FavoriteQueries.instance.list()]
|
||||
rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()]
|
||||
|
||||
if not rows:
|
||||
status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
|
||||
status = "\nNo favorite queries found." + FavoriteQueries.instance.usage
|
||||
else:
|
||||
status = ''
|
||||
return [('', rows, headers, status)]
|
||||
status = ""
|
||||
return [("", rows, headers, status)]
|
||||
|
||||
|
||||
def subst_favorite_query_args(query, args):
|
||||
"""replace positional parameters ($1...$N) in query."""
|
||||
for idx, val in enumerate(args):
|
||||
subst_var = '$' + str(idx + 1)
|
||||
subst_var = "$" + str(idx + 1)
|
||||
if subst_var not in query:
|
||||
return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query]
|
||||
return [None, "query does not have substitution parameter " + subst_var + ":\n " + query]
|
||||
|
||||
query = query.replace(subst_var, val)
|
||||
|
||||
match = re.search(r'\$\d+', query)
|
||||
match = re.search(r"\$\d+", query)
|
||||
if match:
|
||||
return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query]
|
||||
return [None, "missing substitution for " + match.group(0) + " in query:\n " + query]
|
||||
|
||||
return [query, None]
|
||||
|
||||
@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
|
||||
|
||||
@special_command("\\fs", "\\fs name query", "Save a favorite query.")
|
||||
def save_favorite_query(arg, **_):
|
||||
"""Save a new favorite query.
|
||||
Returns (title, rows, headers, status)"""
|
||||
|
||||
usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
|
||||
usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage
|
||||
if not arg:
|
||||
return [(None, None, None, usage)]
|
||||
|
||||
name, _, query = arg.partition(' ')
|
||||
name, _, query = arg.partition(" ")
|
||||
|
||||
# If either name or query is missing then print the usage and complain.
|
||||
if (not name) or (not query):
|
||||
return [(None, None, None,
|
||||
usage + 'Err: Both name and query are required.')]
|
||||
return [(None, None, None, usage + "Err: Both name and query are required.")]
|
||||
|
||||
FavoriteQueries.instance.save(name, query)
|
||||
return [(None, None, None, "Saved.")]
|
||||
|
||||
|
||||
@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
|
||||
@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
|
||||
def delete_favorite_query(arg, **_):
|
||||
"""Delete an existing favorite query."""
|
||||
usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
|
||||
usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage
|
||||
if not arg:
|
||||
return [(None, None, None, usage)]
|
||||
|
||||
|
@ -295,8 +303,7 @@ def delete_favorite_query(arg, **_):
|
|||
return [(None, None, None, status)]
|
||||
|
||||
|
||||
@special_command('system', 'system [command]',
|
||||
'Execute a system shell commmand.')
|
||||
@special_command("system", "system [command]", "Execute a system shell commmand.")
|
||||
def execute_system_command(arg, **_):
|
||||
"""Execute a system shell command."""
|
||||
usage = "Syntax: system [command].\n"
|
||||
|
@ -306,13 +313,13 @@ def execute_system_command(arg, **_):
|
|||
|
||||
try:
|
||||
command = arg.strip()
|
||||
if command.startswith('cd'):
|
||||
if command.startswith("cd"):
|
||||
ok, error_message = handle_cd_command(arg)
|
||||
if not ok:
|
||||
return [(None, None, None, error_message)]
|
||||
return [(None, None, None, '')]
|
||||
return [(None, None, None, "")]
|
||||
|
||||
args = arg.split(' ')
|
||||
args = arg.split(" ")
|
||||
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
response = output if not error else error
|
||||
|
@ -324,25 +331,24 @@ def execute_system_command(arg, **_):
|
|||
|
||||
return [(None, None, None, response)]
|
||||
except OSError as e:
|
||||
return [(None, None, None, 'OSError: %s' % e.strerror)]
|
||||
return [(None, None, None, "OSError: %s" % e.strerror)]
|
||||
|
||||
|
||||
def parseargfile(arg):
|
||||
if arg.startswith('-o '):
|
||||
if arg.startswith("-o "):
|
||||
mode = "w"
|
||||
filename = arg[3:]
|
||||
else:
|
||||
mode = 'a'
|
||||
mode = "a"
|
||||
filename = arg
|
||||
|
||||
if not filename:
|
||||
raise TypeError('You must provide a filename.')
|
||||
raise TypeError("You must provide a filename.")
|
||||
|
||||
return {'file': os.path.expanduser(filename), 'mode': mode}
|
||||
return {"file": os.path.expanduser(filename), "mode": mode}
|
||||
|
||||
|
||||
@special_command('tee', 'tee [-o] filename',
|
||||
'Append all results to an output file (overwrite using -o).')
|
||||
@special_command("tee", "tee [-o] filename", "Append all results to an output file (overwrite using -o).")
|
||||
def set_tee(arg, **_):
|
||||
global tee_file
|
||||
|
||||
|
@ -353,6 +359,7 @@ def set_tee(arg, **_):
|
|||
|
||||
return [(None, None, None, "")]
|
||||
|
||||
|
||||
@export
|
||||
def close_tee():
|
||||
global tee_file
|
||||
|
@ -361,31 +368,29 @@ def close_tee():
|
|||
tee_file = None
|
||||
|
||||
|
||||
@special_command('notee', 'notee', 'Stop writing results to an output file.')
|
||||
@special_command("notee", "notee", "Stop writing results to an output file.")
|
||||
def no_tee(arg, **_):
|
||||
close_tee()
|
||||
return [(None, None, None, "")]
|
||||
|
||||
|
||||
@export
|
||||
def write_tee(output):
|
||||
global tee_file
|
||||
if tee_file:
|
||||
click.echo(output, file=tee_file, nl=False)
|
||||
click.echo(u'\n', file=tee_file, nl=False)
|
||||
click.echo("\n", file=tee_file, nl=False)
|
||||
tee_file.flush()
|
||||
|
||||
|
||||
@special_command('\\once', '\\o [-o] filename',
|
||||
'Append next result to an output file (overwrite using -o).',
|
||||
aliases=('\\o', ))
|
||||
@special_command("\\once", "\\o [-o] filename", "Append next result to an output file (overwrite using -o).", aliases=("\\o",))
|
||||
def set_once(arg, **_):
|
||||
global once_file, written_to_once_file
|
||||
|
||||
try:
|
||||
once_file = open(**parseargfile(arg))
|
||||
except (IOError, OSError) as e:
|
||||
raise OSError("Cannot write to file '{}': {}".format(
|
||||
e.filename, e.strerror))
|
||||
raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
|
||||
written_to_once_file = False
|
||||
|
||||
return [(None, None, None, "")]
|
||||
|
@ -396,7 +401,7 @@ def write_once(output):
|
|||
global once_file, written_to_once_file
|
||||
if output and once_file:
|
||||
click.echo(output, file=once_file, nl=False)
|
||||
click.echo(u"\n", file=once_file, nl=False)
|
||||
click.echo("\n", file=once_file, nl=False)
|
||||
once_file.flush()
|
||||
written_to_once_file = True
|
||||
|
||||
|
@ -410,22 +415,22 @@ def unset_once_if_written():
|
|||
once_file = None
|
||||
|
||||
|
||||
@special_command('\\pipe_once', '\\| command',
|
||||
'Send next result to a subprocess.',
|
||||
aliases=('\\|', ))
|
||||
@special_command("\\pipe_once", "\\| command", "Send next result to a subprocess.", aliases=("\\|",))
|
||||
def set_pipe_once(arg, **_):
|
||||
global pipe_once_process, written_to_pipe_once_process
|
||||
pipe_once_cmd = shlex.split(arg)
|
||||
if len(pipe_once_cmd) == 0:
|
||||
raise OSError("pipe_once requires a command")
|
||||
written_to_pipe_once_process = False
|
||||
pipe_once_process = subprocess.Popen(pipe_once_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=1,
|
||||
encoding='UTF-8',
|
||||
universal_newlines=True)
|
||||
pipe_once_process = subprocess.Popen(
|
||||
pipe_once_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=1,
|
||||
encoding="UTF-8",
|
||||
universal_newlines=True,
|
||||
)
|
||||
return [(None, None, None, "")]
|
||||
|
||||
|
||||
|
@ -435,11 +440,10 @@ def write_pipe_once(output):
|
|||
if output and pipe_once_process:
|
||||
try:
|
||||
click.echo(output, file=pipe_once_process.stdin, nl=False)
|
||||
click.echo(u"\n", file=pipe_once_process.stdin, nl=False)
|
||||
click.echo("\n", file=pipe_once_process.stdin, nl=False)
|
||||
except (IOError, OSError) as e:
|
||||
pipe_once_process.terminate()
|
||||
raise OSError(
|
||||
"Failed writing to pipe_once subprocess: {}".format(e.strerror))
|
||||
raise OSError("Failed writing to pipe_once subprocess: {}".format(e.strerror))
|
||||
written_to_pipe_once_process = True
|
||||
|
||||
|
||||
|
@ -450,18 +454,14 @@ def unset_pipe_once_if_written():
|
|||
if written_to_pipe_once_process:
|
||||
(stdout_data, stderr_data) = pipe_once_process.communicate()
|
||||
if len(stdout_data) > 0:
|
||||
print(stdout_data.rstrip(u"\n"))
|
||||
print(stdout_data.rstrip("\n"))
|
||||
if len(stderr_data) > 0:
|
||||
print(stderr_data.rstrip(u"\n"))
|
||||
print(stderr_data.rstrip("\n"))
|
||||
pipe_once_process = None
|
||||
written_to_pipe_once_process = False
|
||||
|
||||
|
||||
@special_command(
|
||||
'watch',
|
||||
'watch [seconds] [-c] query',
|
||||
'Executes the query every [seconds] seconds (by default 5).'
|
||||
)
|
||||
@special_command("watch", "watch [seconds] [-c] query", "Executes the query every [seconds] seconds (by default 5).")
|
||||
def watch_query(arg, **kwargs):
|
||||
usage = """Syntax: watch [seconds] [-c] query.
|
||||
* seconds: The interval at the query will be repeated, in seconds.
|
||||
|
@ -480,27 +480,24 @@ def watch_query(arg, **kwargs):
|
|||
# Oops, we parsed all the arguments without finding a statement
|
||||
yield (None, None, None, usage)
|
||||
return
|
||||
(current_arg, _, arg) = arg.partition(' ')
|
||||
(current_arg, _, arg) = arg.partition(" ")
|
||||
try:
|
||||
seconds = float(current_arg)
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
if current_arg == '-c':
|
||||
if current_arg == "-c":
|
||||
clear_screen = True
|
||||
continue
|
||||
statement = '{0!s} {1!s}'.format(current_arg, arg)
|
||||
statement = "{0!s} {1!s}".format(current_arg, arg)
|
||||
destructive_prompt = confirm_destructive_query(statement)
|
||||
if destructive_prompt is False:
|
||||
click.secho("Wise choice!")
|
||||
return
|
||||
elif destructive_prompt is True:
|
||||
click.secho("Your call!")
|
||||
cur = kwargs['cur']
|
||||
sql_list = [
|
||||
(sql.rstrip(';'), "> {0!s}".format(sql))
|
||||
for sql in sqlparse.split(statement)
|
||||
]
|
||||
cur = kwargs["cur"]
|
||||
sql_list = [(sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)]
|
||||
old_pager_enabled = is_pager_enabled()
|
||||
while True:
|
||||
if clear_screen:
|
||||
|
@ -509,7 +506,7 @@ def watch_query(arg, **kwargs):
|
|||
# Somewhere in the code the pager its activated after every yield,
|
||||
# so we disable it in every iteration
|
||||
set_pager_enabled(False)
|
||||
for (sql, title) in sql_list:
|
||||
for sql, title in sql_list:
|
||||
cur.execute(sql)
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
|
@ -527,7 +524,7 @@ def watch_query(arg, **kwargs):
|
|||
|
||||
|
||||
@export
|
||||
@special_command('delimiter', None, 'Change SQL delimiter.')
|
||||
@special_command("delimiter", None, "Change SQL delimiter.")
|
||||
def set_delimiter(arg, **_):
|
||||
return delimiter_command.set(arg)
|
||||
|
||||
|
|
|
@ -9,43 +9,43 @@ NO_QUERY = 0
|
|||
PARSED_QUERY = 1
|
||||
RAW_QUERY = 2
|
||||
|
||||
SpecialCommand = namedtuple('SpecialCommand',
|
||||
['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden',
|
||||
'case_sensitive'])
|
||||
SpecialCommand = namedtuple("SpecialCommand", ["handler", "command", "shortcut", "description", "arg_type", "hidden", "case_sensitive"])
|
||||
|
||||
COMMANDS = {}
|
||||
|
||||
|
||||
@export
|
||||
class CommandNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@export
|
||||
def parse_special_command(sql):
|
||||
command, _, arg = sql.partition(' ')
|
||||
verbose = '+' in command
|
||||
command = command.strip().replace('+', '')
|
||||
command, _, arg = sql.partition(" ")
|
||||
verbose = "+" in command
|
||||
command = command.strip().replace("+", "")
|
||||
return (command, verbose, arg.strip())
|
||||
|
||||
@export
|
||||
def special_command(command, shortcut, description, arg_type=PARSED_QUERY,
|
||||
hidden=False, case_sensitive=False, aliases=()):
|
||||
def wrapper(wrapped):
|
||||
register_special_command(wrapped, command, shortcut, description,
|
||||
arg_type, hidden, case_sensitive, aliases)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
@export
|
||||
def register_special_command(handler, command, shortcut, description,
|
||||
arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
|
||||
def special_command(command, shortcut, description, arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
|
||||
def wrapper(wrapped):
|
||||
register_special_command(wrapped, command, shortcut, description, arg_type, hidden, case_sensitive, aliases)
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@export
|
||||
def register_special_command(
|
||||
handler, command, shortcut, description, arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()
|
||||
):
|
||||
cmd = command.lower() if not case_sensitive else command
|
||||
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
|
||||
arg_type, hidden, case_sensitive)
|
||||
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, arg_type, hidden, case_sensitive)
|
||||
for alias in aliases:
|
||||
cmd = alias.lower() if not case_sensitive else alias
|
||||
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
|
||||
arg_type, case_sensitive=case_sensitive,
|
||||
hidden=True)
|
||||
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, arg_type, case_sensitive=case_sensitive, hidden=True)
|
||||
|
||||
|
||||
@export
|
||||
def execute(cur, sql):
|
||||
|
@ -62,11 +62,11 @@ def execute(cur, sql):
|
|||
except KeyError:
|
||||
special_cmd = COMMANDS[command.lower()]
|
||||
if special_cmd.case_sensitive:
|
||||
raise CommandNotFound('Command not found: %s' % command)
|
||||
raise CommandNotFound("Command not found: %s" % command)
|
||||
|
||||
# "help <SQL KEYWORD> is a special case. We want built-in help, not
|
||||
# mycli help here.
|
||||
if command == 'help' and arg:
|
||||
if command == "help" and arg:
|
||||
return show_keyword_help(cur=cur, arg=arg)
|
||||
|
||||
if special_cmd.arg_type == NO_QUERY:
|
||||
|
@ -76,9 +76,10 @@ def execute(cur, sql):
|
|||
elif special_cmd.arg_type == RAW_QUERY:
|
||||
return special_cmd.handler(cur=cur, query=sql)
|
||||
|
||||
@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?'))
|
||||
|
||||
@special_command("help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?"))
|
||||
def show_help(): # All the parameters are ignored.
|
||||
headers = ['Command', 'Shortcut', 'Description']
|
||||
headers = ["Command", "Shortcut", "Description"]
|
||||
result = []
|
||||
|
||||
for _, value in sorted(COMMANDS.items()):
|
||||
|
@ -86,6 +87,7 @@ def show_help(): # All the parameters are ignored.
|
|||
result.append((value.command, value.shortcut, value.description))
|
||||
return [(None, result, headers, None)]
|
||||
|
||||
|
||||
def show_keyword_help(cur, arg):
|
||||
"""
|
||||
Call the built-in "show <command>", to display help for an SQL keyword.
|
||||
|
@ -99,22 +101,19 @@ def show_keyword_help(cur, arg):
|
|||
cur.execute(query)
|
||||
if cur.description and cur.rowcount > 0:
|
||||
headers = [x[0] for x in cur.description]
|
||||
return [(None, cur, headers, '')]
|
||||
return [(None, cur, headers, "")]
|
||||
else:
|
||||
return [(None, None, None, 'No help found for {0}.'.format(keyword))]
|
||||
return [(None, None, None, "No help found for {0}.".format(keyword))]
|
||||
|
||||
|
||||
@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', ))
|
||||
@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY)
|
||||
@special_command("exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q",))
|
||||
@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY)
|
||||
def quit(*_args):
|
||||
raise EOFError
|
||||
|
||||
|
||||
@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).',
|
||||
arg_type=NO_QUERY, case_sensitive=True)
|
||||
@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.',
|
||||
arg_type=NO_QUERY, case_sensitive=True)
|
||||
@special_command('\\G', '\\G', 'Display current query results vertically.',
|
||||
arg_type=NO_QUERY, case_sensitive=True)
|
||||
@special_command("\\e", "\\e", "Edit command with editor (uses $EDITOR).", arg_type=NO_QUERY, case_sensitive=True)
|
||||
@special_command("\\clip", "\\clip", "Copy query to the system clipboard.", arg_type=NO_QUERY, case_sensitive=True)
|
||||
@special_command("\\G", "\\G", "Display current query results vertically.", arg_type=NO_QUERY, case_sensitive=True)
|
||||
def stub():
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def handle_cd_command(arg):
|
||||
"""Handles a `cd` shell command by calling python's os.chdir."""
|
||||
CD_CMD = 'cd'
|
||||
tokens = arg.split(CD_CMD + ' ')
|
||||
CD_CMD = "cd"
|
||||
tokens = arg.split(CD_CMD + " ")
|
||||
directory = tokens[-1] if len(tokens) > 1 else None
|
||||
if not directory:
|
||||
return False, "No folder name was provided."
|
||||
try:
|
||||
os.chdir(directory)
|
||||
subprocess.call(['pwd'])
|
||||
subprocess.call(["pwd"])
|
||||
return True, None
|
||||
except OSError as e:
|
||||
return False, e.strerror
|
||||
|
||||
|
||||
def format_uptime(uptime_in_seconds):
|
||||
"""Format number of seconds into human-readable string.
|
||||
|
||||
|
@ -32,15 +34,15 @@ def format_uptime(uptime_in_seconds):
|
|||
|
||||
uptime_values = []
|
||||
|
||||
for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')):
|
||||
for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")):
|
||||
if value == 0 and not uptime_values:
|
||||
# Don't include a value/unit if the unit isn't applicable to
|
||||
# the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
|
||||
continue
|
||||
elif value == 1 and unit.endswith('s'):
|
||||
elif value == 1 and unit.endswith("s"):
|
||||
# Remove the "s" if the unit is singular.
|
||||
unit = unit[:-1]
|
||||
uptime_values.append('{0} {1}'.format(value, unit))
|
||||
uptime_values.append("{0} {1}".format(value, unit))
|
||||
|
||||
uptime = ' '.join(uptime_values)
|
||||
uptime = " ".join(uptime_values)
|
||||
return uptime
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
|
||||
from mycli.packages.parseutils import extract_tables
|
||||
|
||||
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
|
||||
'sql-update-2', )
|
||||
supported_formats = (
|
||||
"sql-insert",
|
||||
"sql-update",
|
||||
"sql-update-1",
|
||||
"sql-update-2",
|
||||
)
|
||||
|
||||
preprocessors = ()
|
||||
|
||||
|
@ -25,19 +29,18 @@ def adapter(data, headers, table_format=None, **kwargs):
|
|||
table_name = table[1]
|
||||
else:
|
||||
table_name = "`DUAL`"
|
||||
if table_format == 'sql-insert':
|
||||
if table_format == "sql-insert":
|
||||
h = "`, `".join(headers)
|
||||
yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h)
|
||||
prefix = " "
|
||||
for d in data:
|
||||
values = ", ".join(escape_for_sql_statement(v)
|
||||
for i, v in enumerate(d))
|
||||
values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d))
|
||||
yield "{}({})".format(prefix, values)
|
||||
if prefix == " ":
|
||||
prefix = ", "
|
||||
yield ";"
|
||||
if table_format.startswith('sql-update'):
|
||||
s = table_format.split('-')
|
||||
if table_format.startswith("sql-update"):
|
||||
s = table_format.split("-")
|
||||
keys = 1
|
||||
if len(s) > 2:
|
||||
keys = int(s[-1])
|
||||
|
@ -49,8 +52,7 @@ def adapter(data, headers, table_format=None, **kwargs):
|
|||
if prefix == " ":
|
||||
prefix = ", "
|
||||
f = "`{}` = {}"
|
||||
where = (f.format(headers[i], escape_for_sql_statement(
|
||||
d[i])) for i in range(keys))
|
||||
where = (f.format(headers[i], escape_for_sql_statement(d[i])) for i in range(keys))
|
||||
yield "WHERE {};".format(" AND ".join(where))
|
||||
|
||||
|
||||
|
@ -58,5 +60,4 @@ def register_new_formatter(TabularOutputFormatter):
|
|||
global formatter
|
||||
formatter = TabularOutputFormatter
|
||||
for sql_format in supported_formats:
|
||||
TabularOutputFormatter.register_new_formatter(
|
||||
sql_format, adapter, preprocessors, {'table_format': sql_format})
|
||||
TabularOutputFormatter.register_new_formatter(sql_format, adapter, preprocessors, {"table_format": sql_format})
|
||||
|
|
|
@ -29,7 +29,7 @@ def search_history(event: KeyPressEvent):
|
|||
formatted_history_items = []
|
||||
original_history_items = []
|
||||
for item, timestamp in history_items_with_timestamp:
|
||||
formatted_item = item.replace('\n', ' ')
|
||||
formatted_item = item.replace("\n", " ")
|
||||
timestamp = timestamp.split(".")[0] if "." in timestamp else timestamp
|
||||
formatted_history_items.append(f"{timestamp} {formatted_item}")
|
||||
original_history_items.append(item)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Iterable, Union, List, Tuple
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from prompt_toolkit.history import FileHistory
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -5,31 +5,29 @@ import re
|
|||
import pymysql
|
||||
from .packages import special
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
from pymysql.converters import (convert_datetime,
|
||||
convert_timedelta, convert_date, conversions,
|
||||
decoders)
|
||||
from pymysql.converters import convert_datetime, convert_timedelta, convert_date, conversions, decoders
|
||||
|
||||
try:
|
||||
import paramiko
|
||||
import paramiko # noqa: F401
|
||||
import sshtunnel
|
||||
except ImportError:
|
||||
from mycli.packages.paramiko_stub import paramiko
|
||||
pass
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
FIELD_TYPES = decoders.copy()
|
||||
FIELD_TYPES.update({
|
||||
FIELD_TYPE.NULL: type(None)
|
||||
})
|
||||
FIELD_TYPES.update({FIELD_TYPE.NULL: type(None)})
|
||||
|
||||
|
||||
ERROR_CODE_ACCESS_DENIED = 1045
|
||||
|
||||
|
||||
class ServerSpecies(enum.Enum):
|
||||
MySQL = 'MySQL'
|
||||
MariaDB = 'MariaDB'
|
||||
Percona = 'Percona'
|
||||
TiDB = 'TiDB'
|
||||
Unknown = 'MySQL'
|
||||
MySQL = "MySQL"
|
||||
MariaDB = "MariaDB"
|
||||
Percona = "Percona"
|
||||
TiDB = "TiDB"
|
||||
Unknown = "MySQL"
|
||||
|
||||
|
||||
class ServerInfo:
|
||||
|
@ -43,7 +41,7 @@ class ServerInfo:
|
|||
if not version_str or not isinstance(version_str, str):
|
||||
return 0
|
||||
try:
|
||||
major, minor, patch = version_str.split('.')
|
||||
major, minor, patch = version_str.split(".")
|
||||
except ValueError:
|
||||
return 0
|
||||
else:
|
||||
|
@ -52,55 +50,67 @@ class ServerInfo:
|
|||
@classmethod
|
||||
def from_version_string(cls, version_string):
|
||||
if not version_string:
|
||||
return cls(ServerSpecies.Unknown, '')
|
||||
return cls(ServerSpecies.Unknown, "")
|
||||
|
||||
re_species = (
|
||||
(r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
|
||||
(r'[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)', ServerSpecies.TiDB),
|
||||
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
|
||||
ServerSpecies.Percona),
|
||||
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
|
||||
ServerSpecies.MySQL),
|
||||
(r"(?P<version>[0-9\.]+)-MariaDB", ServerSpecies.MariaDB),
|
||||
(r"[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)", ServerSpecies.TiDB),
|
||||
(r"(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)", ServerSpecies.Percona),
|
||||
(r"(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)", ServerSpecies.MySQL),
|
||||
)
|
||||
for regexp, species in re_species:
|
||||
match = re.search(regexp, version_string)
|
||||
if match is not None:
|
||||
parsed_version = match.group('version')
|
||||
parsed_version = match.group("version")
|
||||
detected_species = species
|
||||
break
|
||||
else:
|
||||
detected_species = ServerSpecies.Unknown
|
||||
parsed_version = ''
|
||||
parsed_version = ""
|
||||
|
||||
return cls(detected_species, parsed_version)
|
||||
|
||||
def __str__(self):
|
||||
if self.species:
|
||||
return f'{self.species.value} {self.version_str}'
|
||||
return f"{self.species.value} {self.version_str}"
|
||||
else:
|
||||
return self.version_str
|
||||
|
||||
|
||||
class SQLExecute(object):
|
||||
databases_query = """SHOW DATABASES"""
|
||||
|
||||
databases_query = '''SHOW DATABASES'''
|
||||
|
||||
tables_query = '''SHOW TABLES'''
|
||||
tables_query = """SHOW TABLES"""
|
||||
|
||||
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
|
||||
|
||||
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
|
||||
users_query = """SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user"""
|
||||
|
||||
functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
|
||||
WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
|
||||
|
||||
table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns
|
||||
table_columns_query = """select TABLE_NAME, COLUMN_NAME from information_schema.columns
|
||||
where table_schema = '%s'
|
||||
order by table_name,ordinal_position'''
|
||||
order by table_name,ordinal_position"""
|
||||
|
||||
def __init__(self, database, user, password, host, port, socket, charset,
|
||||
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
|
||||
ssh_key_filename, init_command=None):
|
||||
def __init__(
|
||||
self,
|
||||
database,
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
port,
|
||||
socket,
|
||||
charset,
|
||||
local_infile,
|
||||
ssl,
|
||||
ssh_user,
|
||||
ssh_host,
|
||||
ssh_port,
|
||||
ssh_password,
|
||||
ssh_key_filename,
|
||||
init_command=None,
|
||||
):
|
||||
self.dbname = database
|
||||
self.user = user
|
||||
self.password = password
|
||||
|
@ -120,52 +130,79 @@ class SQLExecute(object):
|
|||
self.init_command = init_command
|
||||
self.connect()
|
||||
|
||||
def connect(self, database=None, user=None, password=None, host=None,
|
||||
port=None, socket=None, charset=None, local_infile=None,
|
||||
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
|
||||
ssh_password=None, ssh_key_filename=None, init_command=None):
|
||||
db = (database or self.dbname)
|
||||
user = (user or self.user)
|
||||
password = (password or self.password)
|
||||
host = (host or self.host)
|
||||
port = (port or self.port)
|
||||
socket = (socket or self.socket)
|
||||
charset = (charset or self.charset)
|
||||
local_infile = (local_infile or self.local_infile)
|
||||
ssl = (ssl or self.ssl)
|
||||
ssh_user = (ssh_user or self.ssh_user)
|
||||
ssh_host = (ssh_host or self.ssh_host)
|
||||
ssh_port = (ssh_port or self.ssh_port)
|
||||
ssh_password = (ssh_password or self.ssh_password)
|
||||
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
|
||||
init_command = (init_command or self.init_command)
|
||||
def connect(
|
||||
self,
|
||||
database=None,
|
||||
user=None,
|
||||
password=None,
|
||||
host=None,
|
||||
port=None,
|
||||
socket=None,
|
||||
charset=None,
|
||||
local_infile=None,
|
||||
ssl=None,
|
||||
ssh_host=None,
|
||||
ssh_port=None,
|
||||
ssh_user=None,
|
||||
ssh_password=None,
|
||||
ssh_key_filename=None,
|
||||
init_command=None,
|
||||
):
|
||||
db = database or self.dbname
|
||||
user = user or self.user
|
||||
password = password or self.password
|
||||
host = host or self.host
|
||||
port = port or self.port
|
||||
socket = socket or self.socket
|
||||
charset = charset or self.charset
|
||||
local_infile = local_infile or self.local_infile
|
||||
ssl = ssl or self.ssl
|
||||
ssh_user = ssh_user or self.ssh_user
|
||||
ssh_host = ssh_host or self.ssh_host
|
||||
ssh_port = ssh_port or self.ssh_port
|
||||
ssh_password = ssh_password or self.ssh_password
|
||||
ssh_key_filename = ssh_key_filename or self.ssh_key_filename
|
||||
init_command = init_command or self.init_command
|
||||
_logger.debug(
|
||||
'Connection DB Params: \n'
|
||||
'\tdatabase: %r'
|
||||
'\tuser: %r'
|
||||
'\thost: %r'
|
||||
'\tport: %r'
|
||||
'\tsocket: %r'
|
||||
'\tcharset: %r'
|
||||
'\tlocal_infile: %r'
|
||||
'\tssl: %r'
|
||||
'\tssh_user: %r'
|
||||
'\tssh_host: %r'
|
||||
'\tssh_port: %r'
|
||||
'\tssh_password: %r'
|
||||
'\tssh_key_filename: %r'
|
||||
'\tinit_command: %r',
|
||||
db, user, host, port, socket, charset, local_infile, ssl,
|
||||
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
|
||||
init_command
|
||||
"Connection DB Params: \n"
|
||||
"\tdatabase: %r"
|
||||
"\tuser: %r"
|
||||
"\thost: %r"
|
||||
"\tport: %r"
|
||||
"\tsocket: %r"
|
||||
"\tcharset: %r"
|
||||
"\tlocal_infile: %r"
|
||||
"\tssl: %r"
|
||||
"\tssh_user: %r"
|
||||
"\tssh_host: %r"
|
||||
"\tssh_port: %r"
|
||||
"\tssh_password: %r"
|
||||
"\tssh_key_filename: %r"
|
||||
"\tinit_command: %r",
|
||||
db,
|
||||
user,
|
||||
host,
|
||||
port,
|
||||
socket,
|
||||
charset,
|
||||
local_infile,
|
||||
ssl,
|
||||
ssh_user,
|
||||
ssh_host,
|
||||
ssh_port,
|
||||
ssh_password,
|
||||
ssh_key_filename,
|
||||
init_command,
|
||||
)
|
||||
conv = conversions.copy()
|
||||
conv.update({
|
||||
FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
|
||||
FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
|
||||
FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
|
||||
FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
|
||||
})
|
||||
conv.update(
|
||||
{
|
||||
FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
|
||||
FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
|
||||
FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
|
||||
FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
|
||||
}
|
||||
)
|
||||
|
||||
defer_connect = False
|
||||
|
||||
|
@ -181,29 +218,45 @@ class SQLExecute(object):
|
|||
ssl_context = self._create_ssl_ctx(ssl)
|
||||
|
||||
conn = pymysql.connect(
|
||||
database=db, user=user, password=password, host=host, port=port,
|
||||
unix_socket=socket, use_unicode=True, charset=charset,
|
||||
autocommit=True, client_flag=client_flag,
|
||||
local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli",
|
||||
defer_connect=defer_connect, init_command=init_command
|
||||
database=db,
|
||||
user=user,
|
||||
password=password,
|
||||
host=host,
|
||||
port=port,
|
||||
unix_socket=socket,
|
||||
use_unicode=True,
|
||||
charset=charset,
|
||||
autocommit=True,
|
||||
client_flag=client_flag,
|
||||
local_infile=local_infile,
|
||||
conv=conv,
|
||||
ssl=ssl_context,
|
||||
program_name="mycli",
|
||||
defer_connect=defer_connect,
|
||||
init_command=init_command,
|
||||
)
|
||||
|
||||
if ssh_host:
|
||||
client = paramiko.SSHClient()
|
||||
client.load_system_host_keys()
|
||||
client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
client.connect(
|
||||
ssh_host, ssh_port, ssh_user, ssh_password,
|
||||
key_filename=ssh_key_filename
|
||||
)
|
||||
chan = client.get_transport().open_channel(
|
||||
'direct-tcpip',
|
||||
(host, port),
|
||||
('0.0.0.0', 0),
|
||||
)
|
||||
conn.connect(chan)
|
||||
##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel
|
||||
#####
|
||||
# instead let's open a tunnel and rewrite host:port to local bind
|
||||
try:
|
||||
chan = sshtunnel.SSHTunnelForwarder(
|
||||
(ssh_host, ssh_port),
|
||||
ssh_username=ssh_user,
|
||||
ssh_pkey=ssh_key_filename,
|
||||
ssh_password=ssh_password,
|
||||
remote_bind_address=(host, port),
|
||||
)
|
||||
chan.start()
|
||||
|
||||
if hasattr(self, 'conn'):
|
||||
conn.host = chan.local_bind_host
|
||||
conn.port = chan.local_bind_port
|
||||
conn.connect()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if hasattr(self, "conn"):
|
||||
self.conn.close()
|
||||
self.conn = conn
|
||||
# Update them after the connection is made to ensure that it was a
|
||||
|
@ -235,24 +288,24 @@ class SQLExecute(object):
|
|||
# Split the sql into separate queries and run each one.
|
||||
# Unless it's saving a favorite query, in which case we
|
||||
# want to save them all together.
|
||||
if statement.startswith('\\fs'):
|
||||
if statement.startswith("\\fs"):
|
||||
components = [statement]
|
||||
else:
|
||||
components = special.split_queries(statement)
|
||||
|
||||
for sql in components:
|
||||
# \G is treated specially since we have to set the expanded output.
|
||||
if sql.endswith('\\G'):
|
||||
if sql.endswith("\\G"):
|
||||
special.set_expanded_output(True)
|
||||
sql = sql[:-2].strip()
|
||||
|
||||
cur = self.conn.cursor()
|
||||
try: # Special command
|
||||
_logger.debug('Trying a dbspecial command. sql: %r', sql)
|
||||
try: # Special command
|
||||
_logger.debug("Trying a dbspecial command. sql: %r", sql)
|
||||
for result in special.execute(cur, sql):
|
||||
yield result
|
||||
except special.CommandNotFound: # Regular SQL
|
||||
_logger.debug('Regular sql statement. sql: %r', sql)
|
||||
_logger.debug("Regular sql statement. sql: %r", sql)
|
||||
cur.execute(sql)
|
||||
while True:
|
||||
yield self.get_result(cur)
|
||||
|
@ -271,12 +324,11 @@ class SQLExecute(object):
|
|||
# e.g. SELECT or SHOW.
|
||||
if cursor.description is not None:
|
||||
headers = [x[0] for x in cursor.description]
|
||||
status = '{0} row{1} in set'
|
||||
status = "{0} row{1} in set"
|
||||
else:
|
||||
_logger.debug('No rows in result.')
|
||||
status = 'Query OK, {0} row{1} affected'
|
||||
status = status.format(cursor.rowcount,
|
||||
'' if cursor.rowcount == 1 else 's')
|
||||
_logger.debug("No rows in result.")
|
||||
status = "Query OK, {0} row{1} affected"
|
||||
status = status.format(cursor.rowcount, "" if cursor.rowcount == 1 else "s")
|
||||
|
||||
return (title, cursor if cursor.description else None, headers, status)
|
||||
|
||||
|
@ -284,7 +336,7 @@ class SQLExecute(object):
|
|||
"""Yields table names"""
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Tables Query. sql: %r', self.tables_query)
|
||||
_logger.debug("Tables Query. sql: %r", self.tables_query)
|
||||
cur.execute(self.tables_query)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
@ -292,14 +344,14 @@ class SQLExecute(object):
|
|||
def table_columns(self):
|
||||
"""Yields (table name, column name) pairs"""
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Columns Query. sql: %r', self.table_columns_query)
|
||||
_logger.debug("Columns Query. sql: %r", self.table_columns_query)
|
||||
cur.execute(self.table_columns_query % self.dbname)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def databases(self):
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Databases Query. sql: %r', self.databases_query)
|
||||
_logger.debug("Databases Query. sql: %r", self.databases_query)
|
||||
cur.execute(self.databases_query)
|
||||
return [x[0] for x in cur.fetchall()]
|
||||
|
||||
|
@ -307,31 +359,31 @@ class SQLExecute(object):
|
|||
"""Yields tuples of (schema_name, function_name)"""
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Functions Query. sql: %r', self.functions_query)
|
||||
_logger.debug("Functions Query. sql: %r", self.functions_query)
|
||||
cur.execute(self.functions_query % self.dbname)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def show_candidates(self):
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Show Query. sql: %r', self.show_candidates_query)
|
||||
_logger.debug("Show Query. sql: %r", self.show_candidates_query)
|
||||
try:
|
||||
cur.execute(self.show_candidates_query)
|
||||
except pymysql.DatabaseError as e:
|
||||
_logger.error('No show completions due to %r', e)
|
||||
yield ''
|
||||
_logger.error("No show completions due to %r", e)
|
||||
yield ""
|
||||
else:
|
||||
for row in cur:
|
||||
yield (row[0].split(None, 1)[-1], )
|
||||
yield (row[0].split(None, 1)[-1],)
|
||||
|
||||
def users(self):
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Users Query. sql: %r', self.users_query)
|
||||
_logger.debug("Users Query. sql: %r", self.users_query)
|
||||
try:
|
||||
cur.execute(self.users_query)
|
||||
except pymysql.DatabaseError as e:
|
||||
_logger.error('No user completions due to %r', e)
|
||||
yield ''
|
||||
_logger.error("No user completions due to %r", e)
|
||||
yield ""
|
||||
else:
|
||||
for row in cur:
|
||||
yield row
|
||||
|
@ -343,17 +395,17 @@ class SQLExecute(object):
|
|||
|
||||
def reset_connection_id(self):
|
||||
# Remember current connection id
|
||||
_logger.debug('Get current connection id')
|
||||
_logger.debug("Get current connection id")
|
||||
try:
|
||||
res = self.run('select connection_id()')
|
||||
res = self.run("select connection_id()")
|
||||
for title, cur, headers, status in res:
|
||||
self.connection_id = cur.fetchone()[0]
|
||||
except Exception as e:
|
||||
# See #1054
|
||||
self.connection_id = -1
|
||||
_logger.error('Failed to get connection id: %s', e)
|
||||
_logger.error("Failed to get connection id: %s", e)
|
||||
else:
|
||||
_logger.debug('Current connection id: %s', self.connection_id)
|
||||
_logger.debug("Current connection id: %s", self.connection_id)
|
||||
|
||||
def change_db(self, db):
|
||||
self.conn.select_db(db)
|
||||
|
@ -392,6 +444,6 @@ class SQLExecute(object):
|
|||
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
ctx.maximum_version = ssl.TLSVersion.TLSv1_3
|
||||
else:
|
||||
_logger.error('Invalid tls version: %s', tls_version)
|
||||
_logger.error("Invalid tls version: %s", tls_version)
|
||||
|
||||
return ctx
|
||||
|
|
59
pyproject.toml
Normal file
59
pyproject.toml
Normal file
|
@ -0,0 +1,59 @@
|
|||
[project]
|
||||
name = "mycli"
|
||||
dynamic = ["version"]
|
||||
description = "CLI for MySQL Database. With auto-completion and syntax highlighting."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.7"
|
||||
license = { text = "BSD" }
|
||||
authors = [{ name = "Mycli Core Team", email = "mycli-dev@googlegroups.com" }]
|
||||
urls = { homepage = "http://mycli.net" }
|
||||
|
||||
dependencies = [
|
||||
"click >= 7.0",
|
||||
"cryptography >= 1.0.0",
|
||||
"Pygments>=1.6",
|
||||
"prompt_toolkit>=3.0.6,<4.0.0",
|
||||
"PyMySQL >= 0.9.2",
|
||||
"sqlparse>=0.3.0,<0.5.0",
|
||||
"sqlglot>=5.1.3",
|
||||
"configobj >= 5.0.5",
|
||||
"cli_helpers[styles] >= 2.2.1",
|
||||
"pyperclip >= 1.8.1",
|
||||
"pyaes >= 1.6.1",
|
||||
"pyfzf >= 0.3.1",
|
||||
"importlib_resources >= 5.0.0; python_version<'3.9'",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=64.0",
|
||||
"setuptools-scm>=8;python_version>='3.8'",
|
||||
"setuptools-scm<8;python_version<'3.8'",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
|
||||
[project.optional-dependencies]
|
||||
ssh = ["paramiko", "sshtunnel"]
|
||||
dev = [
|
||||
"behave>=1.2.6",
|
||||
"coverage>=7.2.7",
|
||||
"pexpect>=4.9.0",
|
||||
"pytest>=7.4.4",
|
||||
"pytest-cov>=4.1.0",
|
||||
"tox>=4.8.0",
|
||||
"pdbpp>=0.10.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
mycli = "mycli.main:cli"
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
mycli = ["myclirc", "AUTHORS", "SPONSORS"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["mycli*"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 140
|
119
release.py
119
release.py
|
@ -1,119 +0,0 @@
|
|||
"""A script to publish a release of mycli to PyPI."""
|
||||
|
||||
from optparse import OptionParser
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
DEBUG = False
|
||||
CONFIRM_STEPS = False
|
||||
DRY_RUN = False
|
||||
|
||||
|
||||
def skip_step():
|
||||
"""
|
||||
Asks for user's response whether to run a step. Default is yes.
|
||||
:return: boolean
|
||||
"""
|
||||
global CONFIRM_STEPS
|
||||
|
||||
if CONFIRM_STEPS:
|
||||
return not click.confirm('--- Run this step?', default=True)
|
||||
return False
|
||||
|
||||
|
||||
def run_step(*args):
|
||||
"""
|
||||
Prints out the command and asks if it should be run.
|
||||
If yes (default), runs it.
|
||||
:param args: list of strings (command and args)
|
||||
"""
|
||||
global DRY_RUN
|
||||
|
||||
cmd = args
|
||||
print(' '.join(cmd))
|
||||
if skip_step():
|
||||
print('--- Skipping...')
|
||||
elif DRY_RUN:
|
||||
print('--- Pretending to run...')
|
||||
else:
|
||||
subprocess.check_output(cmd)
|
||||
|
||||
|
||||
def version(version_file):
|
||||
_version_re = re.compile(
|
||||
r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)')
|
||||
|
||||
with open(version_file) as f:
|
||||
ver = _version_re.search(f.read()).group('version')
|
||||
|
||||
return ver
|
||||
|
||||
|
||||
def commit_for_release(version_file, ver):
|
||||
run_step('git', 'reset')
|
||||
run_step('git', 'add', version_file)
|
||||
run_step('git', 'commit', '--message',
|
||||
'Releasing version {}'.format(ver))
|
||||
|
||||
|
||||
def create_git_tag(tag_name):
|
||||
run_step('git', 'tag', tag_name)
|
||||
|
||||
|
||||
def create_distribution_files():
|
||||
run_step('python', 'setup.py', 'sdist', 'bdist_wheel')
|
||||
|
||||
|
||||
def upload_distribution_files():
|
||||
run_step('twine', 'upload', 'dist/*')
|
||||
|
||||
|
||||
def push_to_github():
|
||||
run_step('git', 'push', 'origin', 'main')
|
||||
|
||||
|
||||
def push_tags_to_github():
|
||||
run_step('git', 'push', '--tags', 'origin')
|
||||
|
||||
|
||||
def checklist(questions):
|
||||
for question in questions:
|
||||
if not click.confirm('--- {}'.format(question), default=False):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if DEBUG:
|
||||
subprocess.check_output = lambda x: x
|
||||
|
||||
ver = version('mycli/__init__.py')
|
||||
|
||||
parser = OptionParser()
|
||||
parser.add_option(
|
||||
"-c", "--confirm-steps", action="store_true", dest="confirm_steps",
|
||||
default=False, help=("Confirm every step. If the step is not "
|
||||
"confirmed, it will be skipped.")
|
||||
)
|
||||
parser.add_option(
|
||||
"-d", "--dry-run", action="store_true", dest="dry_run",
|
||||
default=False, help="Print out, but not actually run any steps."
|
||||
)
|
||||
|
||||
popts, pargs = parser.parse_args()
|
||||
CONFIRM_STEPS = popts.confirm_steps
|
||||
DRY_RUN = popts.dry_run
|
||||
|
||||
print('Releasing Version:', ver)
|
||||
|
||||
if not click.confirm('Are you sure?', default=False):
|
||||
sys.exit(1)
|
||||
|
||||
commit_for_release('mycli/__init__.py', ver)
|
||||
create_git_tag('v{}'.format(ver))
|
||||
create_distribution_files()
|
||||
push_to_github()
|
||||
push_tags_to_github()
|
||||
upload_distribution_files()
|
|
@ -10,6 +10,7 @@ colorama>=0.4.1
|
|||
git+https://github.com/hayd/pep8radius.git # --error-status option not released
|
||||
click>=7.0
|
||||
paramiko==2.11.0
|
||||
sshtunnel==0.4.0
|
||||
pyperclip>=1.8.1
|
||||
importlib_resources>=5.0.0
|
||||
pyaes>=1.6.1
|
||||
|
|
18
setup.cfg
18
setup.cfg
|
@ -1,18 +0,0 @@
|
|||
[bdist_wheel]
|
||||
universal = 1
|
||||
|
||||
[tool:pytest]
|
||||
addopts = --capture=sys
|
||||
--showlocals
|
||||
--doctest-modules
|
||||
--doctest-ignore-import-errors
|
||||
--ignore=setup.py
|
||||
--ignore=mycli/magic.py
|
||||
--ignore=mycli/packages/parseutils.py
|
||||
--ignore=test/features
|
||||
|
||||
[pep8]
|
||||
rev = master
|
||||
docformatter = True
|
||||
diff = True
|
||||
error-status = True
|
127
setup.py
127
setup.py
|
@ -1,127 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import ast
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from setuptools import Command, find_packages, setup
|
||||
from setuptools.command.test import test as TestCommand
|
||||
|
||||
_version_re = re.compile(r'__version__\s+=\s+(.*)')
|
||||
|
||||
with open('mycli/__init__.py') as f:
|
||||
version = ast.literal_eval(_version_re.search(
|
||||
f.read()).group(1))
|
||||
|
||||
description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.'
|
||||
|
||||
install_requirements = [
|
||||
'click >= 7.0',
|
||||
# Pinning cryptography is not needed after paramiko 2.11.0. Correct it
|
||||
'cryptography >= 1.0.0',
|
||||
# 'Pygments>=1.6,<=2.11.1',
|
||||
'Pygments>=1.6',
|
||||
'prompt_toolkit>=3.0.6,<4.0.0',
|
||||
'PyMySQL >= 0.9.2',
|
||||
'sqlparse>=0.3.0,<0.5.0',
|
||||
'sqlglot>=5.1.3',
|
||||
'configobj >= 5.0.5',
|
||||
'cli_helpers[styles] >= 2.2.1',
|
||||
'pyperclip >= 1.8.1',
|
||||
'pyaes >= 1.6.1',
|
||||
'pyfzf >= 0.3.1',
|
||||
]
|
||||
|
||||
if sys.version_info.minor < 9:
|
||||
install_requirements.append('importlib_resources >= 5.0.0')
|
||||
|
||||
|
||||
class lint(Command):
|
||||
description = 'check code against PEP 8 (and fix violations)'
|
||||
|
||||
user_options = [
|
||||
('branch=', 'b', 'branch/revision to compare against (e.g. main)'),
|
||||
('fix', 'f', 'fix the violations in place'),
|
||||
('error-status', 'e', 'return an error code on failed PEP check'),
|
||||
]
|
||||
|
||||
def initialize_options(self):
|
||||
"""Set the default options."""
|
||||
self.branch = 'main'
|
||||
self.fix = False
|
||||
self.error_status = True
|
||||
|
||||
def finalize_options(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
cmd = 'pep8radius {}'.format(self.branch)
|
||||
if self.fix:
|
||||
cmd += ' --in-place'
|
||||
if self.error_status:
|
||||
cmd += ' --error-status'
|
||||
sys.exit(subprocess.call(cmd, shell=True))
|
||||
|
||||
|
||||
class test(TestCommand):
|
||||
|
||||
user_options = [
|
||||
('pytest-args=', 'a', 'Arguments to pass to pytest'),
|
||||
('behave-args=', 'b', 'Arguments to pass to pytest')
|
||||
]
|
||||
|
||||
def initialize_options(self):
|
||||
TestCommand.initialize_options(self)
|
||||
self.pytest_args = ''
|
||||
self.behave_args = '--no-capture'
|
||||
|
||||
def run_tests(self):
|
||||
unit_test_errno = subprocess.call(
|
||||
'pytest test/ ' + self.pytest_args,
|
||||
shell=True
|
||||
)
|
||||
cli_errno = subprocess.call(
|
||||
'behave test/features ' + self.behave_args,
|
||||
shell=True
|
||||
)
|
||||
subprocess.run(['git', 'checkout', '--', 'test/myclirc'], check=False)
|
||||
sys.exit(unit_test_errno or cli_errno)
|
||||
|
||||
|
||||
setup(
|
||||
name='mycli',
|
||||
author='Mycli Core Team',
|
||||
author_email='mycli-dev@googlegroups.com',
|
||||
version=version,
|
||||
url='http://mycli.net',
|
||||
packages=find_packages(exclude=['test*']),
|
||||
package_data={'mycli': ['myclirc', 'AUTHORS', 'SPONSORS']},
|
||||
description=description,
|
||||
long_description=description,
|
||||
install_requires=install_requirements,
|
||||
entry_points={
|
||||
'console_scripts': ['mycli = mycli.main:cli'],
|
||||
},
|
||||
cmdclass={'lint': lint, 'test': test},
|
||||
python_requires=">=3.7",
|
||||
classifiers=[
|
||||
'Intended Audience :: Developers',
|
||||
'License :: OSI Approved :: BSD License',
|
||||
'Operating System :: Unix',
|
||||
'Programming Language :: Python',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Programming Language :: SQL',
|
||||
'Topic :: Database',
|
||||
'Topic :: Database :: Front-Ends',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
extras_require={
|
||||
'ssh': ['paramiko'],
|
||||
},
|
||||
)
|
|
@ -1,13 +1,12 @@
|
|||
import pytest
|
||||
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
|
||||
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
|
||||
from .utils import HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT
|
||||
import mycli.sqlexecute
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def connection():
|
||||
create_db('mycli_test_db')
|
||||
connection = db_connection('mycli_test_db')
|
||||
create_db("mycli_test_db")
|
||||
connection = db_connection("mycli_test_db")
|
||||
yield connection
|
||||
|
||||
connection.close()
|
||||
|
@ -22,8 +21,18 @@ def cursor(connection):
|
|||
@pytest.fixture
|
||||
def executor(connection):
|
||||
return mycli.sqlexecute.SQLExecute(
|
||||
database='mycli_test_db', user=USER,
|
||||
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
|
||||
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
|
||||
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
|
||||
database="mycli_test_db",
|
||||
user=USER,
|
||||
host=HOST,
|
||||
password=PASSWORD,
|
||||
port=PORT,
|
||||
socket=None,
|
||||
charset=CHARSET,
|
||||
local_infile=False,
|
||||
ssl=None,
|
||||
ssh_user=SSH_USER,
|
||||
ssh_host=SSH_HOST,
|
||||
ssh_port=SSH_PORT,
|
||||
ssh_password=None,
|
||||
ssh_key_filename=None,
|
||||
)
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import pymysql
|
||||
|
||||
|
||||
def create_db(hostname='localhost', port=3306, username=None,
|
||||
password=None, dbname=None):
|
||||
def create_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
|
||||
"""Create test database.
|
||||
|
||||
:param hostname: string
|
||||
|
@ -14,17 +13,12 @@ def create_db(hostname='localhost', port=3306, username=None,
|
|||
|
||||
"""
|
||||
cn = pymysql.connect(
|
||||
host=hostname,
|
||||
port=port,
|
||||
user=username,
|
||||
password=password,
|
||||
charset='utf8mb4',
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
|
||||
with cn.cursor() as cr:
|
||||
cr.execute('drop database if exists ' + dbname)
|
||||
cr.execute('create database ' + dbname)
|
||||
cr.execute("drop database if exists " + dbname)
|
||||
cr.execute("create database " + dbname)
|
||||
|
||||
cn.close()
|
||||
|
||||
|
@ -44,20 +38,13 @@ def create_cn(hostname, port, password, username, dbname):
|
|||
|
||||
"""
|
||||
cn = pymysql.connect(
|
||||
host=hostname,
|
||||
port=port,
|
||||
user=username,
|
||||
password=password,
|
||||
db=dbname,
|
||||
charset='utf8mb4',
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
|
||||
return cn
|
||||
|
||||
|
||||
def drop_db(hostname='localhost', port=3306, username=None,
|
||||
password=None, dbname=None):
|
||||
def drop_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
|
||||
"""Drop database.
|
||||
|
||||
:param hostname: string
|
||||
|
@ -68,17 +55,11 @@ def drop_db(hostname='localhost', port=3306, username=None,
|
|||
|
||||
"""
|
||||
cn = pymysql.connect(
|
||||
host=hostname,
|
||||
port=port,
|
||||
user=username,
|
||||
password=password,
|
||||
db=dbname,
|
||||
charset='utf8mb4',
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
|
||||
with cn.cursor() as cr:
|
||||
cr.execute('drop database if exists ' + dbname)
|
||||
cr.execute("drop database if exists " + dbname)
|
||||
|
||||
close_cn(cn)
|
||||
|
||||
|
|
|
@ -9,96 +9,72 @@ import pexpect
|
|||
|
||||
from steps.wrappers import run_cli, wait_prompt
|
||||
|
||||
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
|
||||
test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log")
|
||||
|
||||
|
||||
SELF_CONNECTING_FEATURES = (
|
||||
'test/features/connection.feature',
|
||||
)
|
||||
SELF_CONNECTING_FEATURES = ("test/features/connection.feature",)
|
||||
|
||||
|
||||
MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
|
||||
MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
|
||||
MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
|
||||
MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
|
||||
MY_CNF_PATH = os.path.expanduser("~/.my.cnf")
|
||||
MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup"
|
||||
MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf")
|
||||
MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup"
|
||||
|
||||
|
||||
def get_db_name_from_context(context):
|
||||
return context.config.userdata.get(
|
||||
'my_test_db', None
|
||||
) or "mycli_behave_tests"
|
||||
|
||||
return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests"
|
||||
|
||||
|
||||
def before_all(context):
|
||||
"""Set env parameters."""
|
||||
os.environ['LINES'] = "100"
|
||||
os.environ['COLUMNS'] = "100"
|
||||
os.environ['EDITOR'] = 'ex'
|
||||
os.environ['LC_ALL'] = 'en_US.UTF-8'
|
||||
os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
|
||||
os.environ['MYCLI_HISTFILE'] = os.devnull
|
||||
os.environ["LINES"] = "100"
|
||||
os.environ["COLUMNS"] = "100"
|
||||
os.environ["EDITOR"] = "ex"
|
||||
os.environ["LC_ALL"] = "en_US.UTF-8"
|
||||
os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
|
||||
os.environ["MYCLI_HISTFILE"] = os.devnull
|
||||
|
||||
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
|
||||
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
|
||||
# test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
# login_path_file = os.path.join(test_dir, "mylogin.cnf")
|
||||
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
|
||||
|
||||
context.package_root = os.path.abspath(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
|
||||
'.coveragerc')
|
||||
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc")
|
||||
|
||||
context.exit_sent = False
|
||||
|
||||
vi = '_'.join([str(x) for x in sys.version_info[:3]])
|
||||
vi = "_".join([str(x) for x in sys.version_info[:3]])
|
||||
db_name = get_db_name_from_context(context)
|
||||
db_name_full = '{0}_{1}'.format(db_name, vi)
|
||||
db_name_full = "{0}_{1}".format(db_name, vi)
|
||||
|
||||
# Store get params from config/environment variables
|
||||
context.conf = {
|
||||
'host': context.config.userdata.get(
|
||||
'my_test_host',
|
||||
os.getenv('PYTEST_HOST', 'localhost')
|
||||
),
|
||||
'port': context.config.userdata.get(
|
||||
'my_test_port',
|
||||
int(os.getenv('PYTEST_PORT', '3306'))
|
||||
),
|
||||
'user': context.config.userdata.get(
|
||||
'my_test_user',
|
||||
os.getenv('PYTEST_USER', 'root')
|
||||
),
|
||||
'pass': context.config.userdata.get(
|
||||
'my_test_pass',
|
||||
os.getenv('PYTEST_PASSWORD', None)
|
||||
),
|
||||
'cli_command': context.config.userdata.get(
|
||||
'my_cli_command', None) or
|
||||
sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
|
||||
'dbname': db_name,
|
||||
'dbname_tmp': db_name_full + '_tmp',
|
||||
'vi': vi,
|
||||
'pager_boundary': '---boundary---',
|
||||
"host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", "localhost")),
|
||||
"port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", "3306"))),
|
||||
"user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", "root")),
|
||||
"pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)),
|
||||
"cli_command": context.config.userdata.get("my_cli_command", None)
|
||||
or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
|
||||
"dbname": db_name,
|
||||
"dbname_tmp": db_name_full + "_tmp",
|
||||
"vi": vi,
|
||||
"pager_boundary": "---boundary---",
|
||||
}
|
||||
|
||||
_, my_cnf = mkstemp()
|
||||
with open(my_cnf, 'w') as f:
|
||||
with open(my_cnf, "w") as f:
|
||||
f.write(
|
||||
'[client]\n'
|
||||
'pager={0} {1} {2}\n'.format(
|
||||
sys.executable, os.path.join(context.package_root,
|
||||
'test/features/wrappager.py'),
|
||||
context.conf['pager_boundary'])
|
||||
"[client]\n" "pager={0} {1} {2}\n".format(
|
||||
sys.executable, os.path.join(context.package_root, "test/features/wrappager.py"), context.conf["pager_boundary"]
|
||||
)
|
||||
)
|
||||
context.conf['defaults-file'] = my_cnf
|
||||
context.conf['myclirc'] = os.path.join(context.package_root, 'test',
|
||||
'myclirc')
|
||||
context.conf["defaults-file"] = my_cnf
|
||||
context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc")
|
||||
|
||||
context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
|
||||
context.conf['user'],
|
||||
context.conf['pass'],
|
||||
context.conf['dbname'])
|
||||
context.cn = dbutils.create_db(
|
||||
context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]
|
||||
)
|
||||
|
||||
context.fixture_data = fixutils.read_fixture_files()
|
||||
|
||||
|
@ -106,12 +82,10 @@ def before_all(context):
|
|||
def after_all(context):
|
||||
"""Unset env parameters."""
|
||||
dbutils.close_cn(context.cn)
|
||||
dbutils.drop_db(context.conf['host'], context.conf['port'],
|
||||
context.conf['user'], context.conf['pass'],
|
||||
context.conf['dbname'])
|
||||
dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"])
|
||||
|
||||
# Restore env vars.
|
||||
#for k, v in context.pgenv.items():
|
||||
# for k, v in context.pgenv.items():
|
||||
# if k in os.environ and v is None:
|
||||
# del os.environ[k]
|
||||
# elif v:
|
||||
|
@ -123,8 +97,8 @@ def before_step(context, _):
|
|||
|
||||
|
||||
def before_scenario(context, arg):
|
||||
with open(test_log_file, 'w') as f:
|
||||
f.write('')
|
||||
with open(test_log_file, "w") as f:
|
||||
f.write("")
|
||||
if arg.location.filename not in SELF_CONNECTING_FEATURES:
|
||||
run_cli(context)
|
||||
wait_prompt(context)
|
||||
|
@ -140,23 +114,18 @@ def after_scenario(context, _):
|
|||
"""Cleans up after each test complete."""
|
||||
with open(test_log_file) as f:
|
||||
for line in f:
|
||||
if 'error' in line.lower():
|
||||
raise RuntimeError(f'Error in log file: {line}')
|
||||
if "error" in line.lower():
|
||||
raise RuntimeError(f"Error in log file: {line}")
|
||||
|
||||
if hasattr(context, 'cli') and not context.exit_sent:
|
||||
if hasattr(context, "cli") and not context.exit_sent:
|
||||
# Quit nicely.
|
||||
if not context.atprompt:
|
||||
user = context.conf['user']
|
||||
host = context.conf['host']
|
||||
user = context.conf["user"]
|
||||
host = context.conf["host"]
|
||||
dbname = context.currentdb
|
||||
context.cli.expect_exact(
|
||||
'{0}@{1}:{2}>'.format(
|
||||
user, host, dbname
|
||||
),
|
||||
timeout=5
|
||||
)
|
||||
context.cli.sendcontrol('c')
|
||||
context.cli.sendcontrol('d')
|
||||
context.cli.expect_exact("{0}@{1}:{2}>".format(user, host, dbname), timeout=5)
|
||||
context.cli.sendcontrol("c")
|
||||
context.cli.sendcontrol("d")
|
||||
context.cli.expect_exact(pexpect.EOF, timeout=5)
|
||||
|
||||
if os.path.exists(MY_CNF_BACKUP_PATH):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import io
|
||||
|
||||
|
||||
def read_fixture_lines(filename):
|
||||
|
@ -20,9 +19,9 @@ def read_fixture_files():
|
|||
fixture_dict = {}
|
||||
|
||||
current_dir = os.path.dirname(__file__)
|
||||
fixture_dir = os.path.join(current_dir, 'fixture_data/')
|
||||
fixture_dir = os.path.join(current_dir, "fixture_data/")
|
||||
for filename in os.listdir(fixture_dir):
|
||||
if filename not in ['.', '..']:
|
||||
if filename not in [".", ".."]:
|
||||
fullname = os.path.join(fixture_dir, filename)
|
||||
fixture_dict[filename] = read_fixture_lines(fullname)
|
||||
|
||||
|
|
|
@ -6,41 +6,42 @@ import wrappers
|
|||
from utils import parse_cli_args_to_dict
|
||||
|
||||
|
||||
@when('we run dbcli with {arg}')
|
||||
@when("we run dbcli with {arg}")
|
||||
def step_run_cli_with_arg(context, arg):
|
||||
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
|
||||
|
||||
|
||||
@when('we execute a small query')
|
||||
@when("we execute a small query")
|
||||
def step_execute_small_query(context):
|
||||
context.cli.sendline('select 1')
|
||||
context.cli.sendline("select 1")
|
||||
|
||||
|
||||
@when('we execute a large query')
|
||||
@when("we execute a large query")
|
||||
def step_execute_large_query(context):
|
||||
context.cli.sendline(
|
||||
'select {}'.format(','.join([str(n) for n in range(1, 50)])))
|
||||
context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
|
||||
|
||||
|
||||
@then('we see small results in horizontal format')
|
||||
@then("we see small results in horizontal format")
|
||||
def step_see_small_results(context):
|
||||
wrappers.expect_pager(context, dedent("""\
|
||||
wrappers.expect_pager(
|
||||
context,
|
||||
dedent("""\
|
||||
+---+\r
|
||||
| 1 |\r
|
||||
+---+\r
|
||||
| 1 |\r
|
||||
+---+\r
|
||||
\r
|
||||
"""), timeout=5)
|
||||
wrappers.expect_exact(context, '1 row in set', timeout=2)
|
||||
"""),
|
||||
timeout=5,
|
||||
)
|
||||
wrappers.expect_exact(context, "1 row in set", timeout=2)
|
||||
|
||||
|
||||
@then('we see large results in vertical format')
|
||||
@then("we see large results in vertical format")
|
||||
def step_see_large_results(context):
|
||||
rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)]
|
||||
expected = ('***************************[ 1. row ]'
|
||||
'***************************\r\n' +
|
||||
'{}\r\n'.format('\r\n'.join(rows) + '\r\n'))
|
||||
rows = ["{n:3}| {n}".format(n=str(n)) for n in range(1, 50)]
|
||||
expected = "***************************[ 1. row ]" "***************************\r\n" + "{}\r\n".format("\r\n".join(rows) + "\r\n")
|
||||
|
||||
wrappers.expect_pager(context, expected, timeout=10)
|
||||
wrappers.expect_exact(context, '1 row in set', timeout=2)
|
||||
wrappers.expect_exact(context, "1 row in set", timeout=2)
|
||||
|
|
|
@ -5,18 +5,18 @@ to call the step in "*.feature" file.
|
|||
|
||||
"""
|
||||
|
||||
from behave import when
|
||||
from behave import when, then
|
||||
from textwrap import dedent
|
||||
import tempfile
|
||||
import wrappers
|
||||
|
||||
|
||||
@when('we run dbcli')
|
||||
@when("we run dbcli")
|
||||
def step_run_cli(context):
|
||||
wrappers.run_cli(context)
|
||||
|
||||
|
||||
@when('we wait for prompt')
|
||||
@when("we wait for prompt")
|
||||
def step_wait_prompt(context):
|
||||
wrappers.wait_prompt(context)
|
||||
|
||||
|
@ -24,77 +24,75 @@ def step_wait_prompt(context):
|
|||
@when('we send "ctrl + d"')
|
||||
def step_ctrl_d(context):
|
||||
"""Send Ctrl + D to hopefully exit."""
|
||||
context.cli.sendcontrol('d')
|
||||
context.cli.sendcontrol("d")
|
||||
context.exit_sent = True
|
||||
|
||||
|
||||
@when('we send "\?" command')
|
||||
@when(r'we send "\?" command')
|
||||
def step_send_help(context):
|
||||
"""Send \?
|
||||
r"""Send \?
|
||||
|
||||
to see help.
|
||||
|
||||
"""
|
||||
context.cli.sendline('\\?')
|
||||
wrappers.expect_exact(
|
||||
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
|
||||
context.cli.sendline("\\?")
|
||||
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
||||
|
||||
|
||||
@when(u'we send source command')
|
||||
@when("we send source command")
|
||||
def step_send_source_command(context):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(b'\?')
|
||||
f.write(b"\\?")
|
||||
f.flush()
|
||||
context.cli.sendline('\. {0}'.format(f.name))
|
||||
wrappers.expect_exact(
|
||||
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
|
||||
context.cli.sendline("\\. {0}".format(f.name))
|
||||
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
||||
|
||||
|
||||
@when(u'we run query to check application_name')
|
||||
@when("we run query to check application_name")
|
||||
def step_check_application_name(context):
|
||||
context.cli.sendline(
|
||||
"SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'"
|
||||
)
|
||||
|
||||
|
||||
@then(u'we see found')
|
||||
@then("we see found")
|
||||
def step_see_found(context):
|
||||
wrappers.expect_exact(
|
||||
context,
|
||||
context.conf['pager_boundary'] + '\r' + dedent('''
|
||||
context.conf["pager_boundary"]
|
||||
+ "\r"
|
||||
+ dedent("""
|
||||
+-------+\r
|
||||
| found |\r
|
||||
+-------+\r
|
||||
| found |\r
|
||||
+-------+\r
|
||||
\r
|
||||
''') + context.conf['pager_boundary'],
|
||||
timeout=5
|
||||
""")
|
||||
+ context.conf["pager_boundary"],
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
|
||||
@then(u'we confirm the destructive warning')
|
||||
def step_confirm_destructive_command(context):
|
||||
@then("we confirm the destructive warning")
|
||||
def step_confirm_destructive_command(context): # noqa
|
||||
"""Confirm destructive command."""
|
||||
wrappers.expect_exact(
|
||||
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
|
||||
context.cli.sendline('y')
|
||||
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
|
||||
context.cli.sendline("y")
|
||||
|
||||
|
||||
@when(u'we answer the destructive warning with "{confirmation}"')
|
||||
def step_confirm_destructive_command(context, confirmation):
|
||||
@when('we answer the destructive warning with "{confirmation}"')
|
||||
def step_confirm_destructive_command(context, confirmation): # noqa
|
||||
"""Confirm destructive command."""
|
||||
wrappers.expect_exact(
|
||||
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
|
||||
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
|
||||
context.cli.sendline(confirmation)
|
||||
|
||||
|
||||
@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
|
||||
def step_confirm_destructive_command(context, confirmation, text):
|
||||
@then('we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
|
||||
def step_confirm_destructive_command(context, confirmation, text): # noqa
|
||||
"""Confirm destructive command."""
|
||||
wrappers.expect_exact(
|
||||
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
|
||||
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
|
||||
context.cli.sendline(confirmation)
|
||||
wrappers.expect_exact(context, text, timeout=2)
|
||||
# we must exit the Click loop, or the feature will hang
|
||||
context.cli.sendline('n')
|
||||
context.cli.sendline("n")
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import io
|
||||
import os
|
||||
import shlex
|
||||
|
||||
from behave import when, then
|
||||
import pexpect
|
||||
|
||||
import wrappers
|
||||
from test.features.steps.utils import parse_cli_args_to_dict
|
||||
|
@ -12,60 +10,44 @@ from test.utils import HOST, PORT, USER, PASSWORD
|
|||
from mycli.config import encrypt_mylogin_cnf
|
||||
|
||||
|
||||
TEST_LOGIN_PATH = 'test_login_path'
|
||||
TEST_LOGIN_PATH = "test_login_path"
|
||||
|
||||
|
||||
@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
|
||||
@when('we run mycli without arguments "{excluded_args}"')
|
||||
def step_run_cli_without_args(context, excluded_args, exact_args=''):
|
||||
wrappers.run_cli(
|
||||
context,
|
||||
run_args=parse_cli_args_to_dict(exact_args),
|
||||
exclude_args=parse_cli_args_to_dict(excluded_args).keys()
|
||||
)
|
||||
def step_run_cli_without_args(context, excluded_args, exact_args=""):
|
||||
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(exact_args), exclude_args=parse_cli_args_to_dict(excluded_args).keys())
|
||||
|
||||
|
||||
@then('status contains "{expression}"')
|
||||
def status_contains(context, expression):
|
||||
wrappers.expect_exact(context, f'{expression}', timeout=5)
|
||||
wrappers.expect_exact(context, f"{expression}", timeout=5)
|
||||
|
||||
# Normally, the shutdown after scenario waits for the prompt.
|
||||
# But we may have changed the prompt, depending on parameters,
|
||||
# so let's wait for its last character
|
||||
context.cli.expect_exact('>')
|
||||
context.cli.expect_exact(">")
|
||||
context.atprompt = True
|
||||
|
||||
|
||||
@when('we create my.cnf file')
|
||||
@when("we create my.cnf file")
|
||||
def step_create_my_cnf_file(context):
|
||||
my_cnf = (
|
||||
'[client]\n'
|
||||
f'host = {HOST}\n'
|
||||
f'port = {PORT}\n'
|
||||
f'user = {USER}\n'
|
||||
f'password = {PASSWORD}\n'
|
||||
)
|
||||
with open(MY_CNF_PATH, 'w') as f:
|
||||
my_cnf = "[client]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
|
||||
with open(MY_CNF_PATH, "w") as f:
|
||||
f.write(my_cnf)
|
||||
|
||||
|
||||
@when('we create mylogin.cnf file')
|
||||
@when("we create mylogin.cnf file")
|
||||
def step_create_mylogin_cnf_file(context):
|
||||
os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
|
||||
mylogin_cnf = (
|
||||
f'[{TEST_LOGIN_PATH}]\n'
|
||||
f'host = {HOST}\n'
|
||||
f'port = {PORT}\n'
|
||||
f'user = {USER}\n'
|
||||
f'password = {PASSWORD}\n'
|
||||
)
|
||||
with open(MYLOGIN_CNF_PATH, 'wb') as f:
|
||||
os.environ.pop("MYSQL_TEST_LOGIN_FILE", None)
|
||||
mylogin_cnf = f"[{TEST_LOGIN_PATH}]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
|
||||
with open(MYLOGIN_CNF_PATH, "wb") as f:
|
||||
input_file = io.StringIO(mylogin_cnf)
|
||||
f.write(encrypt_mylogin_cnf(input_file).read())
|
||||
|
||||
|
||||
@then('we are logged in')
|
||||
@then("we are logged in")
|
||||
def we_are_logged_in(context):
|
||||
db_name = get_db_name_from_context(context)
|
||||
context.cli.expect_exact(f'{db_name}>', timeout=5)
|
||||
context.cli.expect_exact(f"{db_name}>", timeout=5)
|
||||
context.atprompt = True
|
||||
|
|
|
@ -11,105 +11,99 @@ import wrappers
|
|||
from behave import when, then
|
||||
|
||||
|
||||
@when('we create database')
|
||||
@when("we create database")
|
||||
def step_db_create(context):
|
||||
"""Send create database."""
|
||||
context.cli.sendline('create database {0};'.format(
|
||||
context.conf['dbname_tmp']))
|
||||
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
|
||||
|
||||
context.response = {
|
||||
'database_name': context.conf['dbname_tmp']
|
||||
}
|
||||
context.response = {"database_name": context.conf["dbname_tmp"]}
|
||||
|
||||
|
||||
@when('we drop database')
|
||||
@when("we drop database")
|
||||
def step_db_drop(context):
|
||||
"""Send drop database."""
|
||||
context.cli.sendline('drop database {0};'.format(
|
||||
context.conf['dbname_tmp']))
|
||||
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
|
||||
|
||||
|
||||
@when('we connect to test database')
|
||||
@when("we connect to test database")
|
||||
def step_db_connect_test(context):
|
||||
"""Send connect to database."""
|
||||
db_name = context.conf['dbname']
|
||||
db_name = context.conf["dbname"]
|
||||
context.currentdb = db_name
|
||||
context.cli.sendline('use {0};'.format(db_name))
|
||||
context.cli.sendline("use {0};".format(db_name))
|
||||
|
||||
|
||||
@when('we connect to quoted test database')
|
||||
@when("we connect to quoted test database")
|
||||
def step_db_connect_quoted_tmp(context):
|
||||
"""Send connect to database."""
|
||||
db_name = context.conf['dbname']
|
||||
db_name = context.conf["dbname"]
|
||||
context.currentdb = db_name
|
||||
context.cli.sendline('use `{0}`;'.format(db_name))
|
||||
context.cli.sendline("use `{0}`;".format(db_name))
|
||||
|
||||
|
||||
@when('we connect to tmp database')
|
||||
@when("we connect to tmp database")
|
||||
def step_db_connect_tmp(context):
|
||||
"""Send connect to database."""
|
||||
db_name = context.conf['dbname_tmp']
|
||||
db_name = context.conf["dbname_tmp"]
|
||||
context.currentdb = db_name
|
||||
context.cli.sendline('use {0}'.format(db_name))
|
||||
context.cli.sendline("use {0}".format(db_name))
|
||||
|
||||
|
||||
@when('we connect to dbserver')
|
||||
@when("we connect to dbserver")
|
||||
def step_db_connect_dbserver(context):
|
||||
"""Send connect to database."""
|
||||
context.currentdb = 'mysql'
|
||||
context.cli.sendline('use mysql')
|
||||
context.currentdb = "mysql"
|
||||
context.cli.sendline("use mysql")
|
||||
|
||||
|
||||
@then('dbcli exits')
|
||||
@then("dbcli exits")
|
||||
def step_wait_exit(context):
|
||||
"""Make sure the cli exits."""
|
||||
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
|
||||
|
||||
|
||||
@then('we see dbcli prompt')
|
||||
@then("we see dbcli prompt")
|
||||
def step_see_prompt(context):
|
||||
"""Wait to see the prompt."""
|
||||
user = context.conf['user']
|
||||
host = context.conf['host']
|
||||
user = context.conf["user"]
|
||||
host = context.conf["host"]
|
||||
dbname = context.currentdb
|
||||
wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname))
|
||||
wrappers.wait_prompt(context, "{0}@{1}:{2}> ".format(user, host, dbname))
|
||||
|
||||
|
||||
@then('we see help output')
|
||||
@then("we see help output")
|
||||
def step_see_help(context):
|
||||
for expected_line in context.fixture_data['help_commands.txt']:
|
||||
for expected_line in context.fixture_data["help_commands.txt"]:
|
||||
wrappers.expect_exact(context, expected_line, timeout=1)
|
||||
|
||||
|
||||
@then('we see database created')
|
||||
@then("we see database created")
|
||||
def step_see_db_created(context):
|
||||
"""Wait to see create database output."""
|
||||
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
|
||||
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
|
||||
|
||||
|
||||
@then('we see database dropped')
|
||||
@then("we see database dropped")
|
||||
def step_see_db_dropped(context):
|
||||
"""Wait to see drop database output."""
|
||||
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
|
||||
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
|
||||
|
||||
|
||||
@then('we see database dropped and no default database')
|
||||
@then("we see database dropped and no default database")
|
||||
def step_see_db_dropped_no_default(context):
|
||||
"""Wait to see drop database output."""
|
||||
user = context.conf['user']
|
||||
host = context.conf['host']
|
||||
database = '(none)'
|
||||
user = context.conf["user"]
|
||||
host = context.conf["host"]
|
||||
database = "(none)"
|
||||
context.currentdb = None
|
||||
|
||||
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
|
||||
wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database))
|
||||
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
|
||||
wrappers.wait_prompt(context, "{0}@{1}:{2}>".format(user, host, database))
|
||||
|
||||
|
||||
@then('we see database connected')
|
||||
@then("we see database connected")
|
||||
def step_see_db_connected(context):
|
||||
"""Wait to see drop database output."""
|
||||
wrappers.expect_exact(
|
||||
context, 'You are now connected to database "', timeout=2)
|
||||
wrappers.expect_exact(context, 'You are now connected to database "', timeout=2)
|
||||
wrappers.expect_exact(context, '"', timeout=2)
|
||||
wrappers.expect_exact(context, ' as user "{0}"'.format(
|
||||
context.conf['user']), timeout=2)
|
||||
wrappers.expect_exact(context, ' as user "{0}"'.format(context.conf["user"]), timeout=2)
|
||||
|
|
|
@ -10,103 +10,109 @@ from behave import when, then
|
|||
from textwrap import dedent
|
||||
|
||||
|
||||
@when('we create table')
|
||||
@when("we create table")
|
||||
def step_create_table(context):
|
||||
"""Send create table."""
|
||||
context.cli.sendline('create table a(x text);')
|
||||
context.cli.sendline("create table a(x text);")
|
||||
|
||||
|
||||
@when('we insert into table')
|
||||
@when("we insert into table")
|
||||
def step_insert_into_table(context):
|
||||
"""Send insert into table."""
|
||||
context.cli.sendline('''insert into a(x) values('xxx');''')
|
||||
context.cli.sendline("""insert into a(x) values('xxx');""")
|
||||
|
||||
|
||||
@when('we update table')
|
||||
@when("we update table")
|
||||
def step_update_table(context):
|
||||
"""Send insert into table."""
|
||||
context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
|
||||
context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""")
|
||||
|
||||
|
||||
@when('we select from table')
|
||||
@when("we select from table")
|
||||
def step_select_from_table(context):
|
||||
"""Send select from table."""
|
||||
context.cli.sendline('select * from a;')
|
||||
context.cli.sendline("select * from a;")
|
||||
|
||||
|
||||
@when('we delete from table')
|
||||
@when("we delete from table")
|
||||
def step_delete_from_table(context):
|
||||
"""Send deete from table."""
|
||||
context.cli.sendline('''delete from a where x = 'yyy';''')
|
||||
context.cli.sendline("""delete from a where x = 'yyy';""")
|
||||
|
||||
|
||||
@when('we drop table')
|
||||
@when("we drop table")
|
||||
def step_drop_table(context):
|
||||
"""Send drop table."""
|
||||
context.cli.sendline('drop table a;')
|
||||
context.cli.sendline("drop table a;")
|
||||
|
||||
|
||||
@then('we see table created')
|
||||
@then("we see table created")
|
||||
def step_see_table_created(context):
|
||||
"""Wait to see create table output."""
|
||||
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
|
||||
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
|
||||
|
||||
|
||||
@then('we see record inserted')
|
||||
@then("we see record inserted")
|
||||
def step_see_record_inserted(context):
|
||||
"""Wait to see insert output."""
|
||||
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
|
||||
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
|
||||
|
||||
|
||||
@then('we see record updated')
|
||||
@then("we see record updated")
|
||||
def step_see_record_updated(context):
|
||||
"""Wait to see update output."""
|
||||
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
|
||||
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
|
||||
|
||||
|
||||
@then('we see data selected')
|
||||
@then("we see data selected")
|
||||
def step_see_data_selected(context):
|
||||
"""Wait to see select output."""
|
||||
wrappers.expect_pager(
|
||||
context, dedent("""\
|
||||
context,
|
||||
dedent("""\
|
||||
+-----+\r
|
||||
| x |\r
|
||||
+-----+\r
|
||||
| yyy |\r
|
||||
+-----+\r
|
||||
\r
|
||||
"""), timeout=2)
|
||||
wrappers.expect_exact(context, '1 row in set', timeout=2)
|
||||
"""),
|
||||
timeout=2,
|
||||
)
|
||||
wrappers.expect_exact(context, "1 row in set", timeout=2)
|
||||
|
||||
|
||||
@then('we see record deleted')
|
||||
@then("we see record deleted")
|
||||
def step_see_data_deleted(context):
|
||||
"""Wait to see delete output."""
|
||||
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
|
||||
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
|
||||
|
||||
|
||||
@then('we see table dropped')
|
||||
@then("we see table dropped")
|
||||
def step_see_table_dropped(context):
|
||||
"""Wait to see drop output."""
|
||||
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
|
||||
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
|
||||
|
||||
|
||||
@when('we select null')
|
||||
@when("we select null")
|
||||
def step_select_null(context):
|
||||
"""Send select null."""
|
||||
context.cli.sendline('select null;')
|
||||
context.cli.sendline("select null;")
|
||||
|
||||
|
||||
@then('we see null selected')
|
||||
@then("we see null selected")
|
||||
def step_see_null_selected(context):
|
||||
"""Wait to see null output."""
|
||||
wrappers.expect_pager(
|
||||
context, dedent("""\
|
||||
context,
|
||||
dedent("""\
|
||||
+--------+\r
|
||||
| NULL |\r
|
||||
+--------+\r
|
||||
| <null> |\r
|
||||
+--------+\r
|
||||
\r
|
||||
"""), timeout=2)
|
||||
wrappers.expect_exact(context, '1 row in set', timeout=2)
|
||||
"""),
|
||||
timeout=2,
|
||||
)
|
||||
wrappers.expect_exact(context, "1 row in set", timeout=2)
|
||||
|
|
|
@ -5,101 +5,93 @@ from behave import when, then
|
|||
from textwrap import dedent
|
||||
|
||||
|
||||
@when('we start external editor providing a file name')
|
||||
@when("we start external editor providing a file name")
|
||||
def step_edit_file(context):
|
||||
"""Edit file with external editor."""
|
||||
context.editor_file_name = os.path.join(
|
||||
context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
|
||||
context.editor_file_name = os.path.join(context.package_root, "test_file_{0}.sql".format(context.conf["vi"]))
|
||||
if os.path.exists(context.editor_file_name):
|
||||
os.remove(context.editor_file_name)
|
||||
context.cli.sendline('\e {0}'.format(
|
||||
os.path.basename(context.editor_file_name)))
|
||||
wrappers.expect_exact(
|
||||
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
|
||||
wrappers.expect_exact(context, '\r\n:', timeout=2)
|
||||
context.cli.sendline("\\e {0}".format(os.path.basename(context.editor_file_name)))
|
||||
wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
|
||||
wrappers.expect_exact(context, "\r\n:", timeout=2)
|
||||
|
||||
|
||||
@when('we type "{query}" in the editor')
|
||||
def step_edit_type_sql(context, query):
|
||||
context.cli.sendline('i')
|
||||
context.cli.sendline("i")
|
||||
context.cli.sendline(query)
|
||||
context.cli.sendline('.')
|
||||
wrappers.expect_exact(context, '\r\n:', timeout=2)
|
||||
context.cli.sendline(".")
|
||||
wrappers.expect_exact(context, "\r\n:", timeout=2)
|
||||
|
||||
|
||||
@when('we exit the editor')
|
||||
@when("we exit the editor")
|
||||
def step_edit_quit(context):
|
||||
context.cli.sendline('x')
|
||||
context.cli.sendline("x")
|
||||
wrappers.expect_exact(context, "written", timeout=2)
|
||||
|
||||
|
||||
@then('we see "{query}" in prompt')
|
||||
def step_edit_done_sql(context, query):
|
||||
for match in query.split(' '):
|
||||
for match in query.split(" "):
|
||||
wrappers.expect_exact(context, match, timeout=5)
|
||||
# Cleanup the command line.
|
||||
context.cli.sendcontrol('c')
|
||||
context.cli.sendcontrol("c")
|
||||
# Cleanup the edited file.
|
||||
if context.editor_file_name and os.path.exists(context.editor_file_name):
|
||||
os.remove(context.editor_file_name)
|
||||
|
||||
|
||||
@when(u'we tee output')
|
||||
@when("we tee output")
|
||||
def step_tee_ouptut(context):
|
||||
context.tee_file_name = os.path.join(
|
||||
context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
|
||||
context.tee_file_name = os.path.join(context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]))
|
||||
if os.path.exists(context.tee_file_name):
|
||||
os.remove(context.tee_file_name)
|
||||
context.cli.sendline('tee {0}'.format(
|
||||
os.path.basename(context.tee_file_name)))
|
||||
context.cli.sendline("tee {0}".format(os.path.basename(context.tee_file_name)))
|
||||
|
||||
|
||||
@when(u'we select "select {param}"')
|
||||
@when('we select "select {param}"')
|
||||
def step_query_select_number(context, param):
|
||||
context.cli.sendline(u'select {}'.format(param))
|
||||
wrappers.expect_pager(context, dedent(u"""\
|
||||
context.cli.sendline("select {}".format(param))
|
||||
wrappers.expect_pager(
|
||||
context,
|
||||
dedent(
|
||||
"""\
|
||||
+{dashes}+\r
|
||||
| {param} |\r
|
||||
+{dashes}+\r
|
||||
| {param} |\r
|
||||
+{dashes}+\r
|
||||
\r
|
||||
""".format(param=param, dashes='-' * (len(param) + 2))
|
||||
), timeout=5)
|
||||
wrappers.expect_exact(context, '1 row in set', timeout=2)
|
||||
|
||||
|
||||
@then(u'we see result "{result}"')
|
||||
def step_see_result(context, result):
|
||||
wrappers.expect_exact(
|
||||
context,
|
||||
u"| {} |".format(result),
|
||||
timeout=2
|
||||
""".format(param=param, dashes="-" * (len(param) + 2))
|
||||
),
|
||||
timeout=5,
|
||||
)
|
||||
wrappers.expect_exact(context, "1 row in set", timeout=2)
|
||||
|
||||
|
||||
@when(u'we query "{query}"')
|
||||
@then('we see result "{result}"')
|
||||
def step_see_result(context, result):
|
||||
wrappers.expect_exact(context, "| {} |".format(result), timeout=2)
|
||||
|
||||
|
||||
@when('we query "{query}"')
|
||||
def step_query(context, query):
|
||||
context.cli.sendline(query)
|
||||
|
||||
|
||||
@when(u'we notee output')
|
||||
@when("we notee output")
|
||||
def step_notee_output(context):
|
||||
context.cli.sendline('notee')
|
||||
context.cli.sendline("notee")
|
||||
|
||||
|
||||
@then(u'we see 123456 in tee output')
|
||||
@then("we see 123456 in tee output")
|
||||
def step_see_123456_in_ouput(context):
|
||||
with open(context.tee_file_name) as f:
|
||||
assert '123456' in f.read()
|
||||
assert "123456" in f.read()
|
||||
if os.path.exists(context.tee_file_name):
|
||||
os.remove(context.tee_file_name)
|
||||
|
||||
|
||||
@then(u'delimiter is set to "{delimiter}"')
|
||||
@then('delimiter is set to "{delimiter}"')
|
||||
def delimiter_is_set(context, delimiter):
|
||||
wrappers.expect_exact(
|
||||
context,
|
||||
u'Changed delimiter to {}'.format(delimiter),
|
||||
timeout=2
|
||||
)
|
||||
wrappers.expect_exact(context, "Changed delimiter to {}".format(delimiter), timeout=2)
|
||||
|
|
|
@ -9,82 +9,79 @@ import wrappers
|
|||
from behave import when, then
|
||||
|
||||
|
||||
@when('we save a named query')
|
||||
@when("we save a named query")
|
||||
def step_save_named_query(context):
|
||||
"""Send \fs command."""
|
||||
context.cli.sendline('\\fs foo SELECT 12345')
|
||||
context.cli.sendline("\\fs foo SELECT 12345")
|
||||
|
||||
|
||||
@when('we use a named query')
|
||||
@when("we use a named query")
|
||||
def step_use_named_query(context):
|
||||
"""Send \f command."""
|
||||
context.cli.sendline('\\f foo')
|
||||
context.cli.sendline("\\f foo")
|
||||
|
||||
|
||||
@when('we delete a named query')
|
||||
@when("we delete a named query")
|
||||
def step_delete_named_query(context):
|
||||
"""Send \fd command."""
|
||||
context.cli.sendline('\\fd foo')
|
||||
context.cli.sendline("\\fd foo")
|
||||
|
||||
|
||||
@then('we see the named query saved')
|
||||
@then("we see the named query saved")
|
||||
def step_see_named_query_saved(context):
|
||||
"""Wait to see query saved."""
|
||||
wrappers.expect_exact(context, 'Saved.', timeout=2)
|
||||
wrappers.expect_exact(context, "Saved.", timeout=2)
|
||||
|
||||
|
||||
@then('we see the named query executed')
|
||||
@then("we see the named query executed")
|
||||
def step_see_named_query_executed(context):
|
||||
"""Wait to see select output."""
|
||||
wrappers.expect_exact(context, 'SELECT 12345', timeout=2)
|
||||
wrappers.expect_exact(context, "SELECT 12345", timeout=2)
|
||||
|
||||
|
||||
@then('we see the named query deleted')
|
||||
@then("we see the named query deleted")
|
||||
def step_see_named_query_deleted(context):
|
||||
"""Wait to see query deleted."""
|
||||
wrappers.expect_exact(context, 'foo: Deleted', timeout=2)
|
||||
wrappers.expect_exact(context, "foo: Deleted", timeout=2)
|
||||
|
||||
|
||||
@when('we save a named query with parameters')
|
||||
@when("we save a named query with parameters")
|
||||
def step_save_named_query_with_parameters(context):
|
||||
"""Send \fs command for query with parameters."""
|
||||
context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"')
|
||||
|
||||
|
||||
@when('we use named query with parameters')
|
||||
@when("we use named query with parameters")
|
||||
def step_use_named_query_with_parameters(context):
|
||||
"""Send \f command with parameters."""
|
||||
context.cli.sendline('\\f foo_args 101 second "third value"')
|
||||
|
||||
|
||||
@then('we see the named query with parameters executed')
|
||||
@then("we see the named query with parameters executed")
|
||||
def step_see_named_query_with_parameters_executed(context):
|
||||
"""Wait to see select output."""
|
||||
wrappers.expect_exact(
|
||||
context, 'SELECT 101, "second", "third value"', timeout=2)
|
||||
wrappers.expect_exact(context, 'SELECT 101, "second", "third value"', timeout=2)
|
||||
|
||||
|
||||
@when('we use named query with too few parameters')
|
||||
@when("we use named query with too few parameters")
|
||||
def step_use_named_query_with_too_few_parameters(context):
|
||||
"""Send \f command with missing parameters."""
|
||||
context.cli.sendline('\\f foo_args 101')
|
||||
context.cli.sendline("\\f foo_args 101")
|
||||
|
||||
|
||||
@then('we see the named query with parameters fail with missing parameters')
|
||||
@then("we see the named query with parameters fail with missing parameters")
|
||||
def step_see_named_query_with_parameters_fail_with_missing_parameters(context):
|
||||
"""Wait to see select output."""
|
||||
wrappers.expect_exact(
|
||||
context, 'missing substitution for $2 in query:', timeout=2)
|
||||
wrappers.expect_exact(context, "missing substitution for $2 in query:", timeout=2)
|
||||
|
||||
|
||||
@when('we use named query with too many parameters')
|
||||
@when("we use named query with too many parameters")
|
||||
def step_use_named_query_with_too_many_parameters(context):
|
||||
"""Send \f command with extra parameters."""
|
||||
context.cli.sendline('\\f foo_args 101 102 103 104')
|
||||
context.cli.sendline("\\f foo_args 101 102 103 104")
|
||||
|
||||
|
||||
@then('we see the named query with parameters fail with extra parameters')
|
||||
@then("we see the named query with parameters fail with extra parameters")
|
||||
def step_see_named_query_with_parameters_fail_with_extra_parameters(context):
|
||||
"""Wait to see select output."""
|
||||
wrappers.expect_exact(
|
||||
context, 'query does not have substitution parameter $4:', timeout=2)
|
||||
wrappers.expect_exact(context, "query does not have substitution parameter $4:", timeout=2)
|
||||
|
|
|
@ -9,10 +9,10 @@ import wrappers
|
|||
from behave import when, then
|
||||
|
||||
|
||||
@when('we refresh completions')
|
||||
@when("we refresh completions")
|
||||
def step_refresh_completions(context):
|
||||
"""Send refresh command."""
|
||||
context.cli.sendline('rehash')
|
||||
context.cli.sendline("rehash")
|
||||
|
||||
|
||||
@then('we see text "{text}"')
|
||||
|
@ -20,8 +20,8 @@ def step_see_text(context, text):
|
|||
"""Wait to see given text message."""
|
||||
wrappers.expect_exact(context, text, timeout=2)
|
||||
|
||||
@then('we see completions refresh started')
|
||||
|
||||
@then("we see completions refresh started")
|
||||
def step_see_refresh_started(context):
|
||||
"""Wait to see refresh output."""
|
||||
wrappers.expect_exact(
|
||||
context, 'Auto-completion refresh started in the background.', timeout=2)
|
||||
wrappers.expect_exact(context, "Auto-completion refresh started in the background.", timeout=2)
|
||||
|
|
|
@ -4,8 +4,8 @@ import shlex
|
|||
def parse_cli_args_to_dict(cli_args: str):
|
||||
args_dict = {}
|
||||
for arg in shlex.split(cli_args):
|
||||
if '=' in arg:
|
||||
key, value = arg.split('=')
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=")
|
||||
args_dict[key] = value
|
||||
else:
|
||||
args_dict[arg] = None
|
||||
|
|
|
@ -18,10 +18,9 @@ def expect_exact(context, expected, timeout):
|
|||
timedout = True
|
||||
if timedout:
|
||||
# Strip color codes out of the output.
|
||||
actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
|
||||
'', context.cli.before)
|
||||
actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
|
||||
raise Exception(
|
||||
textwrap.dedent('''\
|
||||
textwrap.dedent("""\
|
||||
Expected:
|
||||
---
|
||||
{0!r}
|
||||
|
@ -34,17 +33,12 @@ def expect_exact(context, expected, timeout):
|
|||
---
|
||||
{2!r}
|
||||
---
|
||||
''').format(
|
||||
expected,
|
||||
actual,
|
||||
context.logfile.getvalue()
|
||||
)
|
||||
""").format(expected, actual, context.logfile.getvalue())
|
||||
)
|
||||
|
||||
|
||||
def expect_pager(context, expected, timeout):
|
||||
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
|
||||
context.conf['pager_boundary'], expected), timeout=timeout)
|
||||
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), timeout=timeout)
|
||||
|
||||
|
||||
def run_cli(context, run_args=None, exclude_args=None):
|
||||
|
@ -63,55 +57,49 @@ def run_cli(context, run_args=None, exclude_args=None):
|
|||
else:
|
||||
rendered_args.append(key)
|
||||
|
||||
if conf.get('host', None):
|
||||
add_arg('host', '-h', conf['host'])
|
||||
if conf.get('user', None):
|
||||
add_arg('user', '-u', conf['user'])
|
||||
if conf.get('pass', None):
|
||||
add_arg('pass', '-p', conf['pass'])
|
||||
if conf.get('port', None):
|
||||
add_arg('port', '-P', str(conf['port']))
|
||||
if conf.get('dbname', None):
|
||||
add_arg('dbname', '-D', conf['dbname'])
|
||||
if conf.get('defaults-file', None):
|
||||
add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
|
||||
if conf.get('myclirc', None):
|
||||
add_arg('myclirc', '--myclirc', conf['myclirc'])
|
||||
if conf.get('login_path'):
|
||||
add_arg('login_path', '--login-path', conf['login_path'])
|
||||
if conf.get("host", None):
|
||||
add_arg("host", "-h", conf["host"])
|
||||
if conf.get("user", None):
|
||||
add_arg("user", "-u", conf["user"])
|
||||
if conf.get("pass", None):
|
||||
add_arg("pass", "-p", conf["pass"])
|
||||
if conf.get("port", None):
|
||||
add_arg("port", "-P", str(conf["port"]))
|
||||
if conf.get("dbname", None):
|
||||
add_arg("dbname", "-D", conf["dbname"])
|
||||
if conf.get("defaults-file", None):
|
||||
add_arg("defaults_file", "--defaults-file", conf["defaults-file"])
|
||||
if conf.get("myclirc", None):
|
||||
add_arg("myclirc", "--myclirc", conf["myclirc"])
|
||||
if conf.get("login_path"):
|
||||
add_arg("login_path", "--login-path", conf["login_path"])
|
||||
|
||||
for arg_name, arg_value in conf.items():
|
||||
if arg_name.startswith('-'):
|
||||
if arg_name.startswith("-"):
|
||||
add_arg(arg_name, arg_name, arg_value)
|
||||
|
||||
try:
|
||||
cli_cmd = context.conf['cli_command']
|
||||
cli_cmd = context.conf["cli_command"]
|
||||
except KeyError:
|
||||
cli_cmd = (
|
||||
'{0!s} -c "'
|
||||
'import coverage ; '
|
||||
'coverage.process_startup(); '
|
||||
'import mycli.main; '
|
||||
'mycli.main.cli()'
|
||||
'"'
|
||||
).format(sys.executable)
|
||||
cli_cmd = ('{0!s} -c "' "import coverage ; " "coverage.process_startup(); " "import mycli.main; " "mycli.main.cli()" '"').format(
|
||||
sys.executable
|
||||
)
|
||||
|
||||
cmd_parts = [cli_cmd] + rendered_args
|
||||
cmd = ' '.join(cmd_parts)
|
||||
cmd = " ".join(cmd_parts)
|
||||
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
|
||||
context.logfile = StringIO()
|
||||
context.cli.logfile = context.logfile
|
||||
context.exit_sent = False
|
||||
context.currentdb = context.conf['dbname']
|
||||
context.currentdb = context.conf["dbname"]
|
||||
|
||||
|
||||
def wait_prompt(context, prompt=None):
|
||||
"""Make sure prompt is displayed."""
|
||||
if prompt is None:
|
||||
user = context.conf['user']
|
||||
host = context.conf['host']
|
||||
user = context.conf["user"]
|
||||
host = context.conf["host"]
|
||||
dbname = context.currentdb
|
||||
prompt = '{0}@{1}:{2}>'.format(
|
||||
user, host, dbname),
|
||||
prompt = ("{0}@{1}:{2}>".format(user, host, dbname),)
|
||||
expect_exact(context, prompt, timeout=5)
|
||||
context.atprompt = True
|
||||
|
|
|
@ -153,6 +153,7 @@ output.null = "#808080"
|
|||
# Favorite queries.
|
||||
[favorite_queries]
|
||||
check = 'select "✔"'
|
||||
foo_args = 'SELECT $1, "$2", "$3"'
|
||||
|
||||
# Use the -d option to reference a DSN.
|
||||
# Special characters in passwords and other strings can be escaped with URL encoding.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Test the mycli.clistyle module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from pygments.style import Style
|
||||
|
@ -10,9 +11,9 @@ from mycli.clistyle import style_factory
|
|||
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
|
||||
def test_style_factory():
|
||||
"""Test that a Pygments Style class is created."""
|
||||
header = 'bold underline #ansired'
|
||||
cli_style = {'Token.Output.Header': header}
|
||||
style = style_factory('default', cli_style)
|
||||
header = "bold underline #ansired"
|
||||
cli_style = {"Token.Output.Header": header}
|
||||
style = style_factory("default", cli_style)
|
||||
|
||||
assert isinstance(style(), Style)
|
||||
assert Token.Output.Header in style.styles
|
||||
|
@ -22,6 +23,6 @@ def test_style_factory():
|
|||
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
|
||||
def test_style_factory_unknown_name():
|
||||
"""Test that an unrecognized name will not throw an error."""
|
||||
style = style_factory('foobar', {})
|
||||
style = style_factory("foobar", {})
|
||||
|
||||
assert isinstance(style(), Style)
|
||||
|
|
|
@ -8,494 +8,528 @@ def sorted_dicts(dicts):
|
|||
|
||||
|
||||
def test_select_suggests_cols_with_visible_table_scope():
|
||||
suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['tabl']},
|
||||
{'type': 'column', 'tables': [(None, 'tabl', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["tabl"]},
|
||||
{"type": "column", "tables": [(None, "tabl", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_select_suggests_cols_with_qualified_table_scope():
|
||||
suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['tabl']},
|
||||
{'type': 'column', 'tables': [('sch', 'tabl', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["tabl"]},
|
||||
{"type": "column", "tables": [("sch", "tabl", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT * FROM tabl WHERE ',
|
||||
'SELECT * FROM tabl WHERE (',
|
||||
'SELECT * FROM tabl WHERE foo = ',
|
||||
'SELECT * FROM tabl WHERE bar OR ',
|
||||
'SELECT * FROM tabl WHERE foo = 1 AND ',
|
||||
'SELECT * FROM tabl WHERE (bar > 10 AND ',
|
||||
'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (',
|
||||
'SELECT * FROM tabl WHERE 10 < ',
|
||||
'SELECT * FROM tabl WHERE foo BETWEEN ',
|
||||
'SELECT * FROM tabl WHERE foo BETWEEN foo AND ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT * FROM tabl WHERE ",
|
||||
"SELECT * FROM tabl WHERE (",
|
||||
"SELECT * FROM tabl WHERE foo = ",
|
||||
"SELECT * FROM tabl WHERE bar OR ",
|
||||
"SELECT * FROM tabl WHERE foo = 1 AND ",
|
||||
"SELECT * FROM tabl WHERE (bar > 10 AND ",
|
||||
"SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
|
||||
"SELECT * FROM tabl WHERE 10 < ",
|
||||
"SELECT * FROM tabl WHERE foo BETWEEN ",
|
||||
"SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
|
||||
],
|
||||
)
|
||||
def test_where_suggests_columns_functions(expression):
|
||||
suggestions = suggest_type(expression, expression)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['tabl']},
|
||||
{'type': 'column', 'tables': [(None, 'tabl', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["tabl"]},
|
||||
{"type": "column", "tables": [(None, "tabl", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT * FROM tabl WHERE foo IN (',
|
||||
'SELECT * FROM tabl WHERE foo IN (bar, ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT * FROM tabl WHERE foo IN (",
|
||||
"SELECT * FROM tabl WHERE foo IN (bar, ",
|
||||
],
|
||||
)
|
||||
def test_where_in_suggests_columns(expression):
|
||||
suggestions = suggest_type(expression, expression)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['tabl']},
|
||||
{'type': 'column', 'tables': [(None, 'tabl', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["tabl"]},
|
||||
{"type": "column", "tables": [(None, "tabl", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_where_equals_any_suggests_columns_or_keywords():
|
||||
text = 'SELECT * FROM tabl WHERE foo = ANY('
|
||||
text = "SELECT * FROM tabl WHERE foo = ANY("
|
||||
suggestions = suggest_type(text, text)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['tabl']},
|
||||
{'type': 'column', 'tables': [(None, 'tabl', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["tabl"]},
|
||||
{"type": "column", "tables": [(None, "tabl", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_lparen_suggests_cols():
|
||||
suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
|
||||
assert suggestion == [
|
||||
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
|
||||
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
|
||||
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
|
||||
|
||||
|
||||
def test_operand_inside_function_suggests_cols1():
|
||||
suggestion = suggest_type(
|
||||
'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ')
|
||||
assert suggestion == [
|
||||
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
|
||||
suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ")
|
||||
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
|
||||
|
||||
|
||||
def test_operand_inside_function_suggests_cols2():
|
||||
suggestion = suggest_type(
|
||||
'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ')
|
||||
assert suggestion == [
|
||||
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
|
||||
suggestion = suggest_type("SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + ")
|
||||
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
|
||||
|
||||
|
||||
def test_select_suggests_cols_and_funcs():
|
||||
suggestions = suggest_type('SELECT ', 'SELECT ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': []},
|
||||
{'type': 'column', 'tables': []},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("SELECT ", "SELECT ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": []},
|
||||
{"type": "column", "tables": []},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT * FROM ',
|
||||
'INSERT INTO ',
|
||||
'COPY ',
|
||||
'UPDATE ',
|
||||
'DESCRIBE ',
|
||||
'DESC ',
|
||||
'EXPLAIN ',
|
||||
'SELECT * FROM foo JOIN ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT * FROM ",
|
||||
"INSERT INTO ",
|
||||
"COPY ",
|
||||
"UPDATE ",
|
||||
"DESCRIBE ",
|
||||
"DESC ",
|
||||
"EXPLAIN ",
|
||||
"SELECT * FROM foo JOIN ",
|
||||
],
|
||||
)
|
||||
def test_expression_suggests_tables_views_and_schemas(expression):
|
||||
suggestions = suggest_type(expression, expression)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT * FROM sch.',
|
||||
'INSERT INTO sch.',
|
||||
'COPY sch.',
|
||||
'UPDATE sch.',
|
||||
'DESCRIBE sch.',
|
||||
'DESC sch.',
|
||||
'EXPLAIN sch.',
|
||||
'SELECT * FROM foo JOIN sch.',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT * FROM sch.",
|
||||
"INSERT INTO sch.",
|
||||
"COPY sch.",
|
||||
"UPDATE sch.",
|
||||
"DESCRIBE sch.",
|
||||
"DESC sch.",
|
||||
"EXPLAIN sch.",
|
||||
"SELECT * FROM foo JOIN sch.",
|
||||
],
|
||||
)
|
||||
def test_expression_suggests_qualified_tables_views_and_schemas(expression):
|
||||
suggestions = suggest_type(expression, expression)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': 'sch'},
|
||||
{'type': 'view', 'schema': 'sch'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}])
|
||||
|
||||
|
||||
def test_truncate_suggests_tables_and_schemas():
|
||||
suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
def test_truncate_suggests_qualified_tables():
|
||||
suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': 'sch'}])
|
||||
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}])
|
||||
|
||||
|
||||
def test_distinct_suggests_cols():
|
||||
suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
|
||||
assert suggestions == [{'type': 'column', 'tables': []}]
|
||||
suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ")
|
||||
assert suggestions == [{"type": "column", "tables": []}]
|
||||
|
||||
|
||||
def test_col_comma_suggests_cols():
|
||||
suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['tbl']},
|
||||
{'type': 'column', 'tables': [(None, 'tbl', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["tbl"]},
|
||||
{"type": "column", "tables": [(None, "tbl", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_table_comma_suggests_tables_and_schemas():
|
||||
suggestions = suggest_type('SELECT a, b FROM tbl1, ',
|
||||
'SELECT a, b FROM tbl1, ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
def test_into_suggests_tables_and_schemas():
|
||||
suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
|
||||
assert sorted_dicts(suggestion) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
|
||||
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
def test_insert_into_lparen_suggests_cols():
|
||||
suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
|
||||
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
|
||||
suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (")
|
||||
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
|
||||
|
||||
|
||||
def test_insert_into_lparen_partial_text_suggests_cols():
|
||||
suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
|
||||
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
|
||||
suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
|
||||
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
|
||||
|
||||
|
||||
def test_insert_into_lparen_comma_suggests_cols():
|
||||
suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
|
||||
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
|
||||
suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
|
||||
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
|
||||
|
||||
|
||||
def test_partially_typed_col_name_suggests_col_names():
|
||||
suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
|
||||
'SELECT * FROM tabl WHERE col_n')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['tabl']},
|
||||
{'type': 'column', 'tables': [(None, 'tabl', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["tabl"]},
|
||||
{"type": "column", "tables": [(None, "tabl", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
|
||||
suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'column', 'tables': [(None, 'tabl', None)]},
|
||||
{'type': 'table', 'schema': 'tabl'},
|
||||
{'type': 'view', 'schema': 'tabl'},
|
||||
{'type': 'function', 'schema': 'tabl'}])
|
||||
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "column", "tables": [(None, "tabl", None)]},
|
||||
{"type": "table", "schema": "tabl"},
|
||||
{"type": "view", "schema": "tabl"},
|
||||
{"type": "function", "schema": "tabl"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_dot_suggests_cols_of_an_alias():
|
||||
suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
|
||||
'SELECT t1.')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': 't1'},
|
||||
{'type': 'view', 'schema': 't1'},
|
||||
{'type': 'column', 'tables': [(None, 'tabl1', 't1')]},
|
||||
{'type': 'function', 'schema': 't1'}])
|
||||
suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "table", "schema": "t1"},
|
||||
{"type": "view", "schema": "t1"},
|
||||
{"type": "column", "tables": [(None, "tabl1", "t1")]},
|
||||
{"type": "function", "schema": "t1"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
|
||||
suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
|
||||
'SELECT t1.a, t2.')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
|
||||
{'type': 'table', 'schema': 't2'},
|
||||
{'type': 'view', 'schema': 't2'},
|
||||
{'type': 'function', 'schema': 't2'}])
|
||||
suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "column", "tables": [(None, "tabl2", "t2")]},
|
||||
{"type": "table", "schema": "t2"},
|
||||
{"type": "view", "schema": "t2"},
|
||||
{"type": "function", "schema": "t2"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT * FROM (',
|
||||
'SELECT * FROM foo WHERE EXISTS (',
|
||||
'SELECT * FROM foo WHERE bar AND NOT EXISTS (',
|
||||
'SELECT 1 AS',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT * FROM (",
|
||||
"SELECT * FROM foo WHERE EXISTS (",
|
||||
"SELECT * FROM foo WHERE bar AND NOT EXISTS (",
|
||||
"SELECT 1 AS",
|
||||
],
|
||||
)
|
||||
def test_sub_select_suggests_keyword(expression):
|
||||
suggestion = suggest_type(expression, expression)
|
||||
assert suggestion == [{'type': 'keyword'}]
|
||||
assert suggestion == [{"type": "keyword"}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT * FROM (S',
|
||||
'SELECT * FROM foo WHERE EXISTS (S',
|
||||
'SELECT * FROM foo WHERE bar AND NOT EXISTS (S',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT * FROM (S",
|
||||
"SELECT * FROM foo WHERE EXISTS (S",
|
||||
"SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
|
||||
],
|
||||
)
|
||||
def test_sub_select_partial_text_suggests_keyword(expression):
|
||||
suggestion = suggest_type(expression, expression)
|
||||
assert suggestion == [{'type': 'keyword'}]
|
||||
assert suggestion == [{"type": "keyword"}]
|
||||
|
||||
|
||||
def test_outer_table_reference_in_exists_subquery_suggests_columns():
|
||||
q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.'
|
||||
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
|
||||
suggestions = suggest_type(q, q)
|
||||
assert suggestions == [
|
||||
{'type': 'column', 'tables': [(None, 'foo', 'f')]},
|
||||
{'type': 'table', 'schema': 'f'},
|
||||
{'type': 'view', 'schema': 'f'},
|
||||
{'type': 'function', 'schema': 'f'}]
|
||||
{"type": "column", "tables": [(None, "foo", "f")]},
|
||||
{"type": "table", "schema": "f"},
|
||||
{"type": "view", "schema": "f"},
|
||||
{"type": "function", "schema": "f"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT * FROM (SELECT * FROM ',
|
||||
'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
|
||||
'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT * FROM (SELECT * FROM ",
|
||||
"SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
|
||||
"SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
|
||||
],
|
||||
)
|
||||
def test_sub_select_table_name_completion(expression):
|
||||
suggestion = suggest_type(expression, expression)
|
||||
assert sorted_dicts(suggestion) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
def test_sub_select_col_name_completion():
|
||||
suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
|
||||
'SELECT * FROM (SELECT ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['abc']},
|
||||
{'type': 'column', 'tables': [(None, 'abc', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["abc"]},
|
||||
{"type": "column", "tables": [(None, "abc", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_sub_select_multiple_col_name_completion():
|
||||
suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
|
||||
'SELECT * FROM (SELECT a, ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'column', 'tables': [(None, 'abc', None)]},
|
||||
{'type': 'function', 'schema': []}])
|
||||
suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[{"type": "column", "tables": [(None, "abc", None)]}, {"type": "function", "schema": []}]
|
||||
)
|
||||
|
||||
|
||||
def test_sub_select_dot_col_name_completion():
|
||||
suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
|
||||
'SELECT * FROM (SELECT t.')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'column', 'tables': [(None, 'tabl', 't')]},
|
||||
{'type': 'table', 'schema': 't'},
|
||||
{'type': 'view', 'schema': 't'},
|
||||
{'type': 'function', 'schema': 't'}])
|
||||
suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "column", "tables": [(None, "tabl", "t")]},
|
||||
{"type": "table", "schema": "t"},
|
||||
{"type": "view", "schema": "t"},
|
||||
{"type": "function", "schema": "t"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
|
||||
@pytest.mark.parametrize('tbl_alias', ['', 'foo'])
|
||||
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
|
||||
@pytest.mark.parametrize("tbl_alias", ["", "foo"])
|
||||
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
|
||||
text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
|
||||
text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
|
||||
suggestion = suggest_type(text, text)
|
||||
assert sorted_dicts(suggestion) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'SELECT * FROM abc a JOIN def d ON a.',
|
||||
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"sql",
|
||||
[
|
||||
"SELECT * FROM abc a JOIN def d ON a.",
|
||||
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
|
||||
],
|
||||
)
|
||||
def test_join_alias_dot_suggests_cols1(sql):
|
||||
suggestions = suggest_type(sql, sql)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'column', 'tables': [(None, 'abc', 'a')]},
|
||||
{'type': 'table', 'schema': 'a'},
|
||||
{'type': 'view', 'schema': 'a'},
|
||||
{'type': 'function', 'schema': 'a'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "column", "tables": [(None, "abc", "a")]},
|
||||
{"type": "table", "schema": "a"},
|
||||
{"type": "view", "schema": "a"},
|
||||
{"type": "function", "schema": "a"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'SELECT * FROM abc a JOIN def d ON a.id = d.',
|
||||
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"sql",
|
||||
[
|
||||
"SELECT * FROM abc a JOIN def d ON a.id = d.",
|
||||
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
|
||||
],
|
||||
)
|
||||
def test_join_alias_dot_suggests_cols2(sql):
|
||||
suggestions = suggest_type(sql, sql)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'column', 'tables': [(None, 'def', 'd')]},
|
||||
{'type': 'table', 'schema': 'd'},
|
||||
{'type': 'view', 'schema': 'd'},
|
||||
{'type': 'function', 'schema': 'd'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "column", "tables": [(None, "def", "d")]},
|
||||
{"type": "table", "schema": "d"},
|
||||
{"type": "view", "schema": "d"},
|
||||
{"type": "function", "schema": "d"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'select a.x, b.y from abc a join bcd b on ',
|
||||
'select a.x, b.y from abc a join bcd b on a.id = b.id OR ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"sql",
|
||||
[
|
||||
"select a.x, b.y from abc a join bcd b on ",
|
||||
"select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
|
||||
],
|
||||
)
|
||||
def test_on_suggests_aliases(sql):
|
||||
suggestions = suggest_type(sql, sql)
|
||||
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
|
||||
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'select abc.x, bcd.y from abc join bcd on ',
|
||||
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"sql",
|
||||
[
|
||||
"select abc.x, bcd.y from abc join bcd on ",
|
||||
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
|
||||
],
|
||||
)
|
||||
def test_on_suggests_tables(sql):
|
||||
suggestions = suggest_type(sql, sql)
|
||||
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
|
||||
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'select a.x, b.y from abc a join bcd b on a.id = ',
|
||||
'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"sql",
|
||||
[
|
||||
"select a.x, b.y from abc a join bcd b on a.id = ",
|
||||
"select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
|
||||
],
|
||||
)
|
||||
def test_on_suggests_aliases_right_side(sql):
|
||||
suggestions = suggest_type(sql, sql)
|
||||
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
|
||||
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'select abc.x, bcd.y from abc join bcd on ',
|
||||
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"sql",
|
||||
[
|
||||
"select abc.x, bcd.y from abc join bcd on ",
|
||||
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
|
||||
],
|
||||
)
|
||||
def test_on_suggests_tables_right_side(sql):
|
||||
suggestions = suggest_type(sql, sql)
|
||||
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
|
||||
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('col_list', ['', 'col1, '])
|
||||
@pytest.mark.parametrize("col_list", ["", "col1, "])
|
||||
def test_join_using_suggests_common_columns(col_list):
|
||||
text = 'select * from abc inner join def using (' + col_list
|
||||
assert suggest_type(text, text) == [
|
||||
{'type': 'column',
|
||||
'tables': [(None, 'abc', None), (None, 'def', None)],
|
||||
'drop_unique': True}]
|
||||
text = "select * from abc inner join def using (" + col_list
|
||||
assert suggest_type(text, text) == [{"type": "column", "tables": [(None, "abc", None), (None, "def", None)], "drop_unique": True}]
|
||||
|
||||
@pytest.mark.parametrize('sql', [
|
||||
'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.',
|
||||
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.',
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql",
|
||||
[
|
||||
"SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.",
|
||||
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.",
|
||||
],
|
||||
)
|
||||
def test_two_join_alias_dot_suggests_cols1(sql):
|
||||
suggestions = suggest_type(sql, sql)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'column', 'tables': [(None, 'ghi', 'g')]},
|
||||
{'type': 'table', 'schema': 'g'},
|
||||
{'type': 'view', 'schema': 'g'},
|
||||
{'type': 'function', 'schema': 'g'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "column", "tables": [(None, "ghi", "g")]},
|
||||
{"type": "table", "schema": "g"},
|
||||
{"type": "view", "schema": "g"},
|
||||
{"type": "function", "schema": "g"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_2_statements_2nd_current():
|
||||
suggestions = suggest_type('select * from a; select * from ',
|
||||
'select * from a; select * from ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
suggestions = suggest_type('select * from a; select from b',
|
||||
'select * from a; select ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['b']},
|
||||
{'type': 'column', 'tables': [(None, 'b', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("select * from a; select from b", "select * from a; select ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["b"]},
|
||||
{"type": "column", "tables": [(None, "b", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
# Should work even if first statement is invalid
|
||||
suggestions = suggest_type('select * from; select * from ',
|
||||
'select * from; select * from ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("select * from; select * from ", "select * from; select * from ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
def test_2_statements_1st_current():
|
||||
suggestions = suggest_type('select * from ; select * from b',
|
||||
'select * from ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("select * from ; select * from b", "select * from ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
suggestions = suggest_type('select from a; select * from b',
|
||||
'select ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['a']},
|
||||
{'type': 'column', 'tables': [(None, 'a', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("select from a; select * from b", "select ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["a"]},
|
||||
{"type": "column", "tables": [(None, "a", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_3_statements_2nd_current():
|
||||
suggestions = suggest_type('select * from a; select * from ; select * from c',
|
||||
'select * from a; select * from ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("select * from a; select * from ; select * from c", "select * from a; select * from ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
suggestions = suggest_type('select * from a; select from b; select * from c',
|
||||
'select * from a; select ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'alias', 'aliases': ['b']},
|
||||
{'type': 'column', 'tables': [(None, 'b', None)]},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'keyword'},
|
||||
])
|
||||
suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts(
|
||||
[
|
||||
{"type": "alias", "aliases": ["b"]},
|
||||
{"type": "column", "tables": [(None, "b", None)]},
|
||||
{"type": "function", "schema": []},
|
||||
{"type": "keyword"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_create_db_with_template():
|
||||
suggestions = suggest_type('create database foo with template ',
|
||||
'create database foo with template ')
|
||||
suggestions = suggest_type("create database foo with template ", "create database foo with template ")
|
||||
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t'])
|
||||
@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"])
|
||||
def test_specials_included_for_initial_completion(initial_text):
|
||||
suggestions = suggest_type(initial_text, initial_text)
|
||||
|
||||
assert sorted_dicts(suggestions) == \
|
||||
sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}])
|
||||
|
||||
|
||||
def test_specials_not_included_after_initial_token():
|
||||
suggestions = suggest_type('create table foo (dt d',
|
||||
'create table foo (dt d')
|
||||
suggestions = suggest_type("create table foo (dt d", "create table foo (dt d")
|
||||
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}])
|
||||
|
||||
|
||||
def test_drop_schema_qualified_table_suggests_only_tables():
|
||||
text = 'DROP TABLE schema_name.table_name'
|
||||
text = "DROP TABLE schema_name.table_name"
|
||||
suggestions = suggest_type(text, text)
|
||||
assert suggestions == [{'type': 'table', 'schema': 'schema_name'}]
|
||||
assert suggestions == [{"type": "table", "schema": "schema_name"}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('text', [',', ' ,', 'sel ,'])
|
||||
@pytest.mark.parametrize("text", [",", " ,", "sel ,"])
|
||||
def test_handle_pre_completion_comma_gracefully(text):
|
||||
suggestions = suggest_type(text, text)
|
||||
|
||||
|
@ -503,53 +537,59 @@ def test_handle_pre_completion_comma_gracefully(text):
|
|||
|
||||
|
||||
def test_cross_join():
|
||||
text = 'select * from v1 cross join v2 JOIN v1.id, '
|
||||
text = "select * from v1 cross join v2 JOIN v1.id, "
|
||||
suggestions = suggest_type(text, text)
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'SELECT 1 AS ',
|
||||
'SELECT 1 FROM tabl AS ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"SELECT 1 AS ",
|
||||
"SELECT 1 FROM tabl AS ",
|
||||
],
|
||||
)
|
||||
def test_after_as(expression):
|
||||
suggestions = suggest_type(expression, expression)
|
||||
assert set(suggestions) == set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expression', [
|
||||
'\\. ',
|
||||
'select 1; \\. ',
|
||||
'select 1;\\. ',
|
||||
'select 1 ; \\. ',
|
||||
'source ',
|
||||
'truncate table test; source ',
|
||||
'truncate table test ; source ',
|
||||
'truncate table test;source ',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"\\. ",
|
||||
"select 1; \\. ",
|
||||
"select 1;\\. ",
|
||||
"select 1 ; \\. ",
|
||||
"source ",
|
||||
"truncate table test; source ",
|
||||
"truncate table test ; source ",
|
||||
"truncate table test;source ",
|
||||
],
|
||||
)
|
||||
def test_source_is_file(expression):
|
||||
suggestions = suggest_type(expression, expression)
|
||||
assert suggestions == [{'type': 'file_name'}]
|
||||
assert suggestions == [{"type": "file_name"}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expression", [
|
||||
"\\f ",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"expression",
|
||||
[
|
||||
"\\f ",
|
||||
],
|
||||
)
|
||||
def test_favorite_name_suggestion(expression):
|
||||
suggestions = suggest_type(expression, expression)
|
||||
assert suggestions == [{'type': 'favoritequery'}]
|
||||
assert suggestions == [{"type": "favoritequery"}]
|
||||
|
||||
|
||||
def test_order_by():
|
||||
text = 'select * from foo order by '
|
||||
text = "select * from foo order by "
|
||||
suggestions = suggest_type(text, text)
|
||||
assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}]
|
||||
assert suggestions == [{"tables": [(None, "foo", None)], "type": "column"}]
|
||||
|
||||
|
||||
def test_quoted_where():
|
||||
text = "'where i=';"
|
||||
suggestions = suggest_type(text, text)
|
||||
assert suggestions == [{'type': 'keyword'}]
|
||||
assert suggestions == [{"type": "keyword"}]
|
||||
|
|
|
@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
|
|||
@pytest.fixture
|
||||
def refresher():
|
||||
from mycli.completion_refresher import CompletionRefresher
|
||||
|
||||
return CompletionRefresher()
|
||||
|
||||
|
||||
|
@ -18,8 +19,7 @@ def test_ctor(refresher):
|
|||
"""
|
||||
assert len(refresher.refreshers) > 0
|
||||
actual_handlers = list(refresher.refreshers.keys())
|
||||
expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions',
|
||||
'special_commands', 'show_commands', 'keywords']
|
||||
expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"]
|
||||
assert expected_handlers == actual_handlers
|
||||
|
||||
|
||||
|
@ -32,12 +32,12 @@ def test_refresh_called_once(refresher):
|
|||
callbacks = Mock()
|
||||
sqlexecute = Mock()
|
||||
|
||||
with patch.object(refresher, '_bg_refresh') as bg_refresh:
|
||||
with patch.object(refresher, "_bg_refresh") as bg_refresh:
|
||||
actual = refresher.refresh(sqlexecute, callbacks)
|
||||
time.sleep(1) # Wait for the thread to work.
|
||||
assert len(actual) == 1
|
||||
assert len(actual[0]) == 4
|
||||
assert actual[0][3] == 'Auto-completion refresh started in the background.'
|
||||
assert actual[0][3] == "Auto-completion refresh started in the background."
|
||||
bg_refresh.assert_called_with(sqlexecute, callbacks, {})
|
||||
|
||||
|
||||
|
@ -61,13 +61,13 @@ def test_refresh_called_twice(refresher):
|
|||
time.sleep(1) # Wait for the thread to work.
|
||||
assert len(actual1) == 1
|
||||
assert len(actual1[0]) == 4
|
||||
assert actual1[0][3] == 'Auto-completion refresh started in the background.'
|
||||
assert actual1[0][3] == "Auto-completion refresh started in the background."
|
||||
|
||||
actual2 = refresher.refresh(sqlexecute, callbacks)
|
||||
time.sleep(1) # Wait for the thread to work.
|
||||
assert len(actual2) == 1
|
||||
assert len(actual2[0]) == 4
|
||||
assert actual2[0][3] == 'Auto-completion refresh restarted.'
|
||||
assert actual2[0][3] == "Auto-completion refresh restarted."
|
||||
|
||||
|
||||
def test_refresh_with_callbacks(refresher):
|
||||
|
@ -80,9 +80,9 @@ def test_refresh_with_callbacks(refresher):
|
|||
sqlexecute_class = Mock()
|
||||
sqlexecute = Mock()
|
||||
|
||||
with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class):
|
||||
with patch("mycli.completion_refresher.SQLExecute", sqlexecute_class):
|
||||
# Set refreshers to 0: we're not testing refresh logic here
|
||||
refresher.refreshers = {}
|
||||
refresher.refresh(sqlexecute, callbacks)
|
||||
time.sleep(1) # Wait for the thread to work.
|
||||
assert (callbacks[0].call_count == 1)
|
||||
assert callbacks[0].call_count == 1
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Unit tests for the mycli.config module."""
|
||||
|
||||
from io import BytesIO, StringIO, TextIOWrapper
|
||||
import os
|
||||
import struct
|
||||
|
@ -6,21 +7,26 @@ import sys
|
|||
import tempfile
|
||||
import pytest
|
||||
|
||||
from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf,
|
||||
read_and_decrypt_mylogin_cnf, read_config_file,
|
||||
str_to_bool, strip_matching_quotes)
|
||||
from mycli.config import (
|
||||
get_mylogin_cnf_path,
|
||||
open_mylogin_cnf,
|
||||
read_and_decrypt_mylogin_cnf,
|
||||
read_config_file,
|
||||
str_to_bool,
|
||||
strip_matching_quotes,
|
||||
)
|
||||
|
||||
LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
'mylogin.cnf'))
|
||||
LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), "mylogin.cnf"))
|
||||
|
||||
|
||||
def open_bmylogin_cnf(name):
|
||||
"""Open contents of *name* in a BytesIO buffer."""
|
||||
with open(name, 'rb') as f:
|
||||
with open(name, "rb") as f:
|
||||
buf = BytesIO()
|
||||
buf.write(f.read())
|
||||
return buf
|
||||
|
||||
|
||||
def test_read_mylogin_cnf():
|
||||
"""Tests that a login path file can be read and decrypted."""
|
||||
mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE)
|
||||
|
@ -28,7 +34,7 @@ def test_read_mylogin_cnf():
|
|||
assert isinstance(mylogin_cnf, TextIOWrapper)
|
||||
|
||||
contents = mylogin_cnf.read()
|
||||
for word in ('[test]', 'user', 'password', 'host', 'port'):
|
||||
for word in ("[test]", "user", "password", "host", "port"):
|
||||
assert word in contents
|
||||
|
||||
|
||||
|
@ -46,7 +52,7 @@ def test_corrupted_login_key():
|
|||
buf.seek(4)
|
||||
|
||||
# Write null bytes over half the login key
|
||||
buf.write(b'\0\0\0\0\0\0\0\0\0\0')
|
||||
buf.write(b"\0\0\0\0\0\0\0\0\0\0")
|
||||
|
||||
buf.seek(0)
|
||||
mylogin_cnf = read_and_decrypt_mylogin_cnf(buf)
|
||||
|
@ -63,58 +69,58 @@ def test_corrupted_pad():
|
|||
|
||||
# Skip option group
|
||||
len_buf = buf.read(4)
|
||||
cipher_len, = struct.unpack("<i", len_buf)
|
||||
(cipher_len,) = struct.unpack("<i", len_buf)
|
||||
buf.read(cipher_len)
|
||||
|
||||
# Corrupt the pad for the user line
|
||||
len_buf = buf.read(4)
|
||||
cipher_len, = struct.unpack("<i", len_buf)
|
||||
(cipher_len,) = struct.unpack("<i", len_buf)
|
||||
buf.read(cipher_len - 1)
|
||||
buf.write(b'\0')
|
||||
buf.write(b"\0")
|
||||
|
||||
buf.seek(0)
|
||||
mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf))
|
||||
contents = mylogin_cnf.read()
|
||||
for word in ('[test]', 'password', 'host', 'port'):
|
||||
for word in ("[test]", "password", "host", "port"):
|
||||
assert word in contents
|
||||
assert 'user' not in contents
|
||||
assert "user" not in contents
|
||||
|
||||
|
||||
def test_get_mylogin_cnf_path():
|
||||
"""Tests that the path for .mylogin.cnf is detected."""
|
||||
original_env = None
|
||||
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
|
||||
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
|
||||
is_windows = sys.platform == 'win32'
|
||||
if "MYSQL_TEST_LOGIN_FILE" in os.environ:
|
||||
original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
|
||||
is_windows = sys.platform == "win32"
|
||||
|
||||
login_cnf_path = get_mylogin_cnf_path()
|
||||
|
||||
if original_env is not None:
|
||||
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
|
||||
os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
|
||||
|
||||
if login_cnf_path is not None:
|
||||
assert login_cnf_path.endswith('.mylogin.cnf')
|
||||
assert login_cnf_path.endswith(".mylogin.cnf")
|
||||
|
||||
if is_windows is True:
|
||||
assert 'MySQL' in login_cnf_path
|
||||
assert "MySQL" in login_cnf_path
|
||||
else:
|
||||
home_dir = os.path.expanduser('~')
|
||||
home_dir = os.path.expanduser("~")
|
||||
assert login_cnf_path.startswith(home_dir)
|
||||
|
||||
|
||||
def test_alternate_get_mylogin_cnf_path():
|
||||
"""Tests that the alternate path for .mylogin.cnf is detected."""
|
||||
original_env = None
|
||||
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
|
||||
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
|
||||
if "MYSQL_TEST_LOGIN_FILE" in os.environ:
|
||||
original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
|
||||
|
||||
_, temp_path = tempfile.mkstemp()
|
||||
os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path
|
||||
os.environ["MYSQL_TEST_LOGIN_FILE"] = temp_path
|
||||
|
||||
login_cnf_path = get_mylogin_cnf_path()
|
||||
|
||||
if original_env is not None:
|
||||
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
|
||||
os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
|
||||
|
||||
assert temp_path == login_cnf_path
|
||||
|
||||
|
@ -124,17 +130,17 @@ def test_str_to_bool():
|
|||
|
||||
assert str_to_bool(False) is False
|
||||
assert str_to_bool(True) is True
|
||||
assert str_to_bool('False') is False
|
||||
assert str_to_bool('True') is True
|
||||
assert str_to_bool('TRUE') is True
|
||||
assert str_to_bool('1') is True
|
||||
assert str_to_bool('0') is False
|
||||
assert str_to_bool('on') is True
|
||||
assert str_to_bool('off') is False
|
||||
assert str_to_bool('off') is False
|
||||
assert str_to_bool("False") is False
|
||||
assert str_to_bool("True") is True
|
||||
assert str_to_bool("TRUE") is True
|
||||
assert str_to_bool("1") is True
|
||||
assert str_to_bool("0") is False
|
||||
assert str_to_bool("on") is True
|
||||
assert str_to_bool("off") is False
|
||||
assert str_to_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
str_to_bool('foo')
|
||||
str_to_bool("foo")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
str_to_bool(None)
|
||||
|
@ -143,19 +149,19 @@ def test_str_to_bool():
|
|||
def test_read_config_file_list_values_default():
|
||||
"""Test that reading a config file uses list_values by default."""
|
||||
|
||||
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
|
||||
f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
|
||||
config = read_config_file(f)
|
||||
|
||||
assert config['main']['weather'] == u"cloudy with a chance of meatballs"
|
||||
assert config["main"]["weather"] == "cloudy with a chance of meatballs"
|
||||
|
||||
|
||||
def test_read_config_file_list_values_off():
|
||||
"""Test that you can disable list_values when reading a config file."""
|
||||
|
||||
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
|
||||
f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
|
||||
config = read_config_file(f, list_values=False)
|
||||
|
||||
assert config['main']['weather'] == u"'cloudy with a chance of meatballs'"
|
||||
assert config["main"]["weather"] == "'cloudy with a chance of meatballs'"
|
||||
|
||||
|
||||
def test_strip_quotes_with_matching_quotes():
|
||||
|
@ -177,7 +183,7 @@ def test_strip_quotes_with_unmatching_quotes():
|
|||
def test_strip_quotes_with_empty_string():
|
||||
"""Test that an empty string is handled during unquoting."""
|
||||
|
||||
assert '' == strip_matching_quotes('')
|
||||
assert "" == strip_matching_quotes("")
|
||||
|
||||
|
||||
def test_strip_quotes_with_none():
|
||||
|
|
|
@ -4,39 +4,32 @@ from mycli.packages.special.utils import format_uptime
|
|||
|
||||
|
||||
def test_u_suggests_databases():
|
||||
suggestions = suggest_type('\\u ', '\\u ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'database'}])
|
||||
suggestions = suggest_type("\\u ", "\\u ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
|
||||
|
||||
|
||||
def test_describe_table():
|
||||
suggestions = suggest_type('\\dt', '\\dt ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("\\dt", "\\dt ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
def test_list_or_show_create_tables():
|
||||
suggestions = suggest_type('\\dt+', '\\dt+ ')
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'}])
|
||||
suggestions = suggest_type("\\dt+", "\\dt+ ")
|
||||
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
|
||||
|
||||
|
||||
def test_format_uptime():
|
||||
seconds = 59
|
||||
assert '59 sec' == format_uptime(seconds)
|
||||
assert "59 sec" == format_uptime(seconds)
|
||||
|
||||
seconds = 120
|
||||
assert '2 min 0 sec' == format_uptime(seconds)
|
||||
assert "2 min 0 sec" == format_uptime(seconds)
|
||||
|
||||
seconds = 54890
|
||||
assert '15 hours 14 min 50 sec' == format_uptime(seconds)
|
||||
assert "15 hours 14 min 50 sec" == format_uptime(seconds)
|
||||
|
||||
seconds = 598244
|
||||
assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds)
|
||||
assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds)
|
||||
|
||||
seconds = 522600
|
||||
assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds)
|
||||
assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds)
|
||||
|
|
|
@ -13,52 +13,62 @@ from textwrap import dedent
|
|||
from collections import namedtuple
|
||||
|
||||
from tempfile import NamedTemporaryFile
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
test_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
project_dir = os.path.dirname(test_dir)
|
||||
default_config_file = os.path.join(project_dir, 'test', 'myclirc')
|
||||
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
|
||||
default_config_file = os.path.join(project_dir, "test", "myclirc")
|
||||
login_path_file = os.path.join(test_dir, "mylogin.cnf")
|
||||
|
||||
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
|
||||
CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
|
||||
'--password', PASSWORD, '--myclirc', default_config_file,
|
||||
'--defaults-file', default_config_file,
|
||||
'mycli_test_db']
|
||||
os.environ["MYSQL_TEST_LOGIN_FILE"] = login_path_file
|
||||
CLI_ARGS = [
|
||||
"--user",
|
||||
USER,
|
||||
"--host",
|
||||
HOST,
|
||||
"--port",
|
||||
PORT,
|
||||
"--password",
|
||||
PASSWORD,
|
||||
"--myclirc",
|
||||
default_config_file,
|
||||
"--defaults-file",
|
||||
default_config_file,
|
||||
"mycli_test_db",
|
||||
]
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_execute_arg(executor):
|
||||
run(executor, 'create table test (a text)')
|
||||
run(executor, "create table test (a text)")
|
||||
run(executor, 'insert into test values("abc")')
|
||||
|
||||
sql = 'select * from test;'
|
||||
sql = "select * from test;"
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert 'abc' in result.output
|
||||
assert "abc" in result.output
|
||||
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert 'abc' in result.output
|
||||
assert "abc" in result.output
|
||||
|
||||
expected = 'a\nabc\n'
|
||||
expected = "a\nabc\n"
|
||||
|
||||
assert expected in result.output
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_execute_arg_with_table(executor):
|
||||
run(executor, 'create table test (a text)')
|
||||
run(executor, "create table test (a text)")
|
||||
run(executor, 'insert into test values("abc")')
|
||||
|
||||
sql = 'select * from test;'
|
||||
sql = "select * from test;"
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
|
||||
expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n'
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"])
|
||||
expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n"
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert expected in result.output
|
||||
|
@ -66,12 +76,12 @@ def test_execute_arg_with_table(executor):
|
|||
|
||||
@dbtest
|
||||
def test_execute_arg_with_csv(executor):
|
||||
run(executor, 'create table test (a text)')
|
||||
run(executor, "create table test (a text)")
|
||||
run(executor, 'insert into test values("abc")')
|
||||
|
||||
sql = 'select * from test;'
|
||||
sql = "select * from test;"
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"])
|
||||
expected = '"a"\n"abc"\n'
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
@ -80,35 +90,29 @@ def test_execute_arg_with_csv(executor):
|
|||
|
||||
@dbtest
|
||||
def test_batch_mode(executor):
|
||||
run(executor, '''create table test(a text)''')
|
||||
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
|
||||
run(executor, """create table test(a text)""")
|
||||
run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
|
||||
|
||||
sql = (
|
||||
'select count(*) from test;\n'
|
||||
'select * from test limit 1;'
|
||||
)
|
||||
sql = "select count(*) from test;\n" "select * from test limit 1;"
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
|
||||
assert "count(*)\n3\na\nabc\n" in "".join(result.output)
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_batch_mode_table(executor):
|
||||
run(executor, '''create table test(a text)''')
|
||||
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
|
||||
run(executor, """create table test(a text)""")
|
||||
run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
|
||||
|
||||
sql = (
|
||||
'select count(*) from test;\n'
|
||||
'select * from test limit 1;'
|
||||
)
|
||||
sql = "select count(*) from test;\n" "select * from test limit 1;"
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql)
|
||||
|
||||
expected = (dedent("""\
|
||||
expected = dedent("""\
|
||||
+----------+
|
||||
| count(*) |
|
||||
+----------+
|
||||
|
@ -118,7 +122,7 @@ def test_batch_mode_table(executor):
|
|||
| a |
|
||||
+-----+
|
||||
| abc |
|
||||
+-----+"""))
|
||||
+-----+""")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert expected in result.output
|
||||
|
@ -126,14 +130,13 @@ def test_batch_mode_table(executor):
|
|||
|
||||
@dbtest
|
||||
def test_batch_mode_csv(executor):
|
||||
run(executor, '''create table test(a text, b text)''')
|
||||
run(executor,
|
||||
'''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
|
||||
run(executor, """create table test(a text, b text)""")
|
||||
run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""")
|
||||
|
||||
sql = 'select * from test;'
|
||||
sql = "select * from test;"
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
|
||||
|
||||
expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
|
||||
|
||||
|
@ -150,15 +153,15 @@ def test_help_strings_end_with_periods():
|
|||
"""Make sure click options have help text that end with a period."""
|
||||
for param in cli.params:
|
||||
if isinstance(param, click.core.Option):
|
||||
assert hasattr(param, 'help')
|
||||
assert param.help.endswith('.')
|
||||
assert hasattr(param, "help")
|
||||
assert param.help.endswith(".")
|
||||
|
||||
|
||||
def test_command_descriptions_end_with_periods():
|
||||
"""Make sure that mycli commands' descriptions end with a period."""
|
||||
MyCli()
|
||||
for _, command in SPECIAL_COMMANDS.items():
|
||||
assert command[3].endswith('.')
|
||||
assert command[3].endswith(".")
|
||||
|
||||
|
||||
def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
|
||||
|
@ -166,23 +169,23 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
|
|||
clickoutput = ""
|
||||
m = MyCli(myclirc=default_config_file)
|
||||
|
||||
class TestOutput():
|
||||
class TestOutput:
|
||||
def get_size(self):
|
||||
size = namedtuple('Size', 'rows columns')
|
||||
size = namedtuple("Size", "rows columns")
|
||||
size.columns, size.rows = terminal_size
|
||||
return size
|
||||
|
||||
class TestExecute():
|
||||
host = 'test'
|
||||
user = 'test'
|
||||
dbname = 'test'
|
||||
server_info = ServerInfo.from_version_string('unknown')
|
||||
class TestExecute:
|
||||
host = "test"
|
||||
user = "test"
|
||||
dbname = "test"
|
||||
server_info = ServerInfo.from_version_string("unknown")
|
||||
port = 0
|
||||
|
||||
def server_type(self):
|
||||
return ['test']
|
||||
return ["test"]
|
||||
|
||||
class PromptBuffer():
|
||||
class PromptBuffer:
|
||||
output = TestOutput()
|
||||
|
||||
m.prompt_app = PromptBuffer()
|
||||
|
@ -199,8 +202,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
|
|||
global clickoutput
|
||||
clickoutput += s + "\n"
|
||||
|
||||
monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
|
||||
monkeypatch.setattr(click, 'secho', secho)
|
||||
monkeypatch.setattr(click, "echo_via_pager", echo_via_pager)
|
||||
monkeypatch.setattr(click, "secho", secho)
|
||||
m.output(testdata)
|
||||
if clickoutput.endswith("\n"):
|
||||
clickoutput = clickoutput[:-1]
|
||||
|
@ -208,59 +211,29 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
|
|||
|
||||
|
||||
def test_conditional_pager(monkeypatch):
|
||||
testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
|
||||
" ")
|
||||
testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(" ")
|
||||
# User didn't set pager, output doesn't fit screen -> pager
|
||||
output(
|
||||
monkeypatch,
|
||||
terminal_size=(5, 10),
|
||||
testdata=testdata,
|
||||
explicit_pager=False,
|
||||
expect_pager=True
|
||||
)
|
||||
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=True)
|
||||
# User didn't set pager, output fits screen -> no pager
|
||||
output(
|
||||
monkeypatch,
|
||||
terminal_size=(20, 20),
|
||||
testdata=testdata,
|
||||
explicit_pager=False,
|
||||
expect_pager=False
|
||||
)
|
||||
output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=False, expect_pager=False)
|
||||
# User manually configured pager, output doesn't fit screen -> pager
|
||||
output(
|
||||
monkeypatch,
|
||||
terminal_size=(5, 10),
|
||||
testdata=testdata,
|
||||
explicit_pager=True,
|
||||
expect_pager=True
|
||||
)
|
||||
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=True, expect_pager=True)
|
||||
# User manually configured pager, output fit screen -> pager
|
||||
output(
|
||||
monkeypatch,
|
||||
terminal_size=(20, 20),
|
||||
testdata=testdata,
|
||||
explicit_pager=True,
|
||||
expect_pager=True
|
||||
)
|
||||
output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=True, expect_pager=True)
|
||||
|
||||
SPECIAL_COMMANDS['nopager'].handler()
|
||||
output(
|
||||
monkeypatch,
|
||||
terminal_size=(5, 10),
|
||||
testdata=testdata,
|
||||
explicit_pager=False,
|
||||
expect_pager=False
|
||||
)
|
||||
SPECIAL_COMMANDS['pager'].handler('')
|
||||
SPECIAL_COMMANDS["nopager"].handler()
|
||||
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=False)
|
||||
SPECIAL_COMMANDS["pager"].handler("")
|
||||
|
||||
|
||||
def test_reserved_space_is_integer(monkeypatch):
|
||||
"""Make sure that reserved space is returned as an integer."""
|
||||
|
||||
def stub_terminal_size():
|
||||
return (5, 5)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(shutil, 'get_terminal_size', stub_terminal_size)
|
||||
m.setattr(shutil, "get_terminal_size", stub_terminal_size)
|
||||
mycli = MyCli()
|
||||
assert isinstance(mycli.get_reserved_space(), int)
|
||||
|
||||
|
@ -268,18 +241,20 @@ def test_reserved_space_is_integer(monkeypatch):
|
|||
def test_list_dsn():
|
||||
runner = CliRunner()
|
||||
# keep Windows from locking the file with delete=False
|
||||
with NamedTemporaryFile(mode="w",delete=False) as myclirc:
|
||||
myclirc.write(dedent("""\
|
||||
with NamedTemporaryFile(mode="w", delete=False) as myclirc:
|
||||
myclirc.write(
|
||||
dedent("""\
|
||||
[alias_dsn]
|
||||
test = mysql://test/test
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
myclirc.flush()
|
||||
args = ['--list-dsn', '--myclirc', myclirc.name]
|
||||
args = ["--list-dsn", "--myclirc", myclirc.name]
|
||||
result = runner.invoke(cli, args=args)
|
||||
assert result.output == "test\n"
|
||||
result = runner.invoke(cli, args=args + ['--verbose'])
|
||||
result = runner.invoke(cli, args=args + ["--verbose"])
|
||||
assert result.output == "test : mysql://test/test\n"
|
||||
|
||||
|
||||
# delete=False means we should try to clean up
|
||||
try:
|
||||
if os.path.exists(myclirc.name):
|
||||
|
@ -287,41 +262,41 @@ def test_list_dsn():
|
|||
except Exception as e:
|
||||
print(f"An error occurred while attempting to delete the file: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
def test_prettify_statement():
|
||||
statement = 'SELECT 1'
|
||||
statement = "SELECT 1"
|
||||
m = MyCli()
|
||||
pretty_statement = m.handle_prettify_binding(statement)
|
||||
assert pretty_statement == 'SELECT\n 1;'
|
||||
assert pretty_statement == "SELECT\n 1;"
|
||||
|
||||
|
||||
def test_unprettify_statement():
|
||||
statement = 'SELECT\n 1'
|
||||
statement = "SELECT\n 1"
|
||||
m = MyCli()
|
||||
unpretty_statement = m.handle_unprettify_binding(statement)
|
||||
assert unpretty_statement == 'SELECT 1;'
|
||||
assert unpretty_statement == "SELECT 1;"
|
||||
|
||||
|
||||
def test_list_ssh_config():
|
||||
runner = CliRunner()
|
||||
# keep Windows from locking the file with delete=False
|
||||
with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
|
||||
ssh_config.write(dedent("""\
|
||||
with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
|
||||
ssh_config.write(
|
||||
dedent("""\
|
||||
Host test
|
||||
Hostname test.example.com
|
||||
User joe
|
||||
Port 22222
|
||||
IdentityFile ~/.ssh/gateway
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
ssh_config.flush()
|
||||
args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
|
||||
args = ["--list-ssh-config", "--ssh-config-path", ssh_config.name]
|
||||
result = runner.invoke(cli, args=args)
|
||||
assert "test\n" in result.output
|
||||
result = runner.invoke(cli, args=args + ['--verbose'])
|
||||
result = runner.invoke(cli, args=args + ["--verbose"])
|
||||
assert "test : test.example.com\n" in result.output
|
||||
|
||||
|
||||
# delete=False means we should try to clean up
|
||||
try:
|
||||
if os.path.exists(ssh_config.name):
|
||||
|
@ -343,7 +318,7 @@ def test_dsn(monkeypatch):
|
|||
pass
|
||||
|
||||
class MockMyCli:
|
||||
config = {'alias_dsn': {}}
|
||||
config = {"alias_dsn": {}}
|
||||
|
||||
def __init__(self, **args):
|
||||
self.logger = Logger()
|
||||
|
@ -357,97 +332,109 @@ def test_dsn(monkeypatch):
|
|||
pass
|
||||
|
||||
import mycli.main
|
||||
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
|
||||
|
||||
monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
|
||||
runner = CliRunner()
|
||||
|
||||
# When a user supplies a DSN as database argument to mycli,
|
||||
# use these values.
|
||||
result = runner.invoke(mycli.main.cli, args=[
|
||||
"mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
|
||||
)
|
||||
result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"])
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert \
|
||||
MockMyCli.connect_args["user"] == "dsn_user" and \
|
||||
MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
|
||||
MockMyCli.connect_args["host"] == "dsn_host" and \
|
||||
MockMyCli.connect_args["port"] == 1 and \
|
||||
MockMyCli.connect_args["database"] == "dsn_database"
|
||||
assert (
|
||||
MockMyCli.connect_args["user"] == "dsn_user"
|
||||
and MockMyCli.connect_args["passwd"] == "dsn_passwd"
|
||||
and MockMyCli.connect_args["host"] == "dsn_host"
|
||||
and MockMyCli.connect_args["port"] == 1
|
||||
and MockMyCli.connect_args["database"] == "dsn_database"
|
||||
)
|
||||
|
||||
MockMyCli.connect_args = None
|
||||
|
||||
# When a use supplies a DSN as database argument to mycli,
|
||||
# and used command line arguments, use the command line
|
||||
# arguments.
|
||||
result = runner.invoke(mycli.main.cli, args=[
|
||||
"mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
|
||||
"--user", "arg_user",
|
||||
"--password", "arg_password",
|
||||
"--host", "arg_host",
|
||||
"--port", "3",
|
||||
"--database", "arg_database",
|
||||
])
|
||||
result = runner.invoke(
|
||||
mycli.main.cli,
|
||||
args=[
|
||||
"mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
|
||||
"--user",
|
||||
"arg_user",
|
||||
"--password",
|
||||
"arg_password",
|
||||
"--host",
|
||||
"arg_host",
|
||||
"--port",
|
||||
"3",
|
||||
"--database",
|
||||
"arg_database",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert \
|
||||
MockMyCli.connect_args["user"] == "arg_user" and \
|
||||
MockMyCli.connect_args["passwd"] == "arg_password" and \
|
||||
MockMyCli.connect_args["host"] == "arg_host" and \
|
||||
MockMyCli.connect_args["port"] == 3 and \
|
||||
MockMyCli.connect_args["database"] == "arg_database"
|
||||
assert (
|
||||
MockMyCli.connect_args["user"] == "arg_user"
|
||||
and MockMyCli.connect_args["passwd"] == "arg_password"
|
||||
and MockMyCli.connect_args["host"] == "arg_host"
|
||||
and MockMyCli.connect_args["port"] == 3
|
||||
and MockMyCli.connect_args["database"] == "arg_database"
|
||||
)
|
||||
|
||||
MockMyCli.config = {
|
||||
'alias_dsn': {
|
||||
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
|
||||
}
|
||||
}
|
||||
MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
|
||||
MockMyCli.connect_args = None
|
||||
|
||||
# When a user uses a DSN from the configuration file (alias_dsn),
|
||||
# use these values.
|
||||
result = runner.invoke(cli, args=['--dsn', 'test'])
|
||||
result = runner.invoke(cli, args=["--dsn", "test"])
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert \
|
||||
MockMyCli.connect_args["user"] == "alias_dsn_user" and \
|
||||
MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
|
||||
MockMyCli.connect_args["host"] == "alias_dsn_host" and \
|
||||
MockMyCli.connect_args["port"] == 4 and \
|
||||
MockMyCli.connect_args["database"] == "alias_dsn_database"
|
||||
assert (
|
||||
MockMyCli.connect_args["user"] == "alias_dsn_user"
|
||||
and MockMyCli.connect_args["passwd"] == "alias_dsn_passwd"
|
||||
and MockMyCli.connect_args["host"] == "alias_dsn_host"
|
||||
and MockMyCli.connect_args["port"] == 4
|
||||
and MockMyCli.connect_args["database"] == "alias_dsn_database"
|
||||
)
|
||||
|
||||
MockMyCli.config = {
|
||||
'alias_dsn': {
|
||||
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
|
||||
}
|
||||
}
|
||||
MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
|
||||
MockMyCli.connect_args = None
|
||||
|
||||
# When a user uses a DSN from the configuration file (alias_dsn)
|
||||
# and used command line arguments, use the command line arguments.
|
||||
result = runner.invoke(cli, args=[
|
||||
'--dsn', 'test', '',
|
||||
"--user", "arg_user",
|
||||
"--password", "arg_password",
|
||||
"--host", "arg_host",
|
||||
"--port", "5",
|
||||
"--database", "arg_database",
|
||||
])
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert \
|
||||
MockMyCli.connect_args["user"] == "arg_user" and \
|
||||
MockMyCli.connect_args["passwd"] == "arg_password" and \
|
||||
MockMyCli.connect_args["host"] == "arg_host" and \
|
||||
MockMyCli.connect_args["port"] == 5 and \
|
||||
MockMyCli.connect_args["database"] == "arg_database"
|
||||
|
||||
# Use a DSN without password
|
||||
result = runner.invoke(mycli.main.cli, args=[
|
||||
"mysql://dsn_user@dsn_host:6/dsn_database"]
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
args=[
|
||||
"--dsn",
|
||||
"test",
|
||||
"",
|
||||
"--user",
|
||||
"arg_user",
|
||||
"--password",
|
||||
"arg_password",
|
||||
"--host",
|
||||
"arg_host",
|
||||
"--port",
|
||||
"5",
|
||||
"--database",
|
||||
"arg_database",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert \
|
||||
MockMyCli.connect_args["user"] == "dsn_user" and \
|
||||
MockMyCli.connect_args["passwd"] is None and \
|
||||
MockMyCli.connect_args["host"] == "dsn_host" and \
|
||||
MockMyCli.connect_args["port"] == 6 and \
|
||||
MockMyCli.connect_args["database"] == "dsn_database"
|
||||
assert (
|
||||
MockMyCli.connect_args["user"] == "arg_user"
|
||||
and MockMyCli.connect_args["passwd"] == "arg_password"
|
||||
and MockMyCli.connect_args["host"] == "arg_host"
|
||||
and MockMyCli.connect_args["port"] == 5
|
||||
and MockMyCli.connect_args["database"] == "arg_database"
|
||||
)
|
||||
|
||||
# Use a DSN without password
|
||||
result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user@dsn_host:6/dsn_database"])
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert (
|
||||
MockMyCli.connect_args["user"] == "dsn_user"
|
||||
and MockMyCli.connect_args["passwd"] is None
|
||||
and MockMyCli.connect_args["host"] == "dsn_host"
|
||||
and MockMyCli.connect_args["port"] == 6
|
||||
and MockMyCli.connect_args["database"] == "dsn_database"
|
||||
)
|
||||
|
||||
|
||||
def test_ssh_config(monkeypatch):
|
||||
|
@ -463,7 +450,7 @@ def test_ssh_config(monkeypatch):
|
|||
pass
|
||||
|
||||
class MockMyCli:
|
||||
config = {'alias_dsn': {}}
|
||||
config = {"alias_dsn": {}}
|
||||
|
||||
def __init__(self, **args):
|
||||
self.logger = Logger()
|
||||
|
@ -477,58 +464,62 @@ def test_ssh_config(monkeypatch):
|
|||
pass
|
||||
|
||||
import mycli.main
|
||||
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
|
||||
|
||||
monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
|
||||
runner = CliRunner()
|
||||
|
||||
# Setup temporary configuration
|
||||
# keep Windows from locking the file with delete=False
|
||||
with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
|
||||
ssh_config.write(dedent("""\
|
||||
with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
|
||||
ssh_config.write(
|
||||
dedent("""\
|
||||
Host test
|
||||
Hostname test.example.com
|
||||
User joe
|
||||
Port 22222
|
||||
IdentityFile ~/.ssh/gateway
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
ssh_config.flush()
|
||||
|
||||
# When a user supplies a ssh config.
|
||||
result = runner.invoke(mycli.main.cli, args=[
|
||||
"--ssh-config-path",
|
||||
ssh_config.name,
|
||||
"--ssh-config-host",
|
||||
"test"
|
||||
])
|
||||
assert result.exit_code == 0, result.output + \
|
||||
" " + str(result.exception)
|
||||
assert \
|
||||
MockMyCli.connect_args["ssh_user"] == "joe" and \
|
||||
MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
|
||||
MockMyCli.connect_args["ssh_port"] == 22222 and \
|
||||
MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser(
|
||||
"~") + "/.ssh/gateway"
|
||||
result = runner.invoke(mycli.main.cli, args=["--ssh-config-path", ssh_config.name, "--ssh-config-host", "test"])
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert (
|
||||
MockMyCli.connect_args["ssh_user"] == "joe"
|
||||
and MockMyCli.connect_args["ssh_host"] == "test.example.com"
|
||||
and MockMyCli.connect_args["ssh_port"] == 22222
|
||||
and MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser("~") + "/.ssh/gateway"
|
||||
)
|
||||
|
||||
# When a user supplies a ssh config host as argument to mycli,
|
||||
# and used command line arguments, use the command line
|
||||
# arguments.
|
||||
result = runner.invoke(mycli.main.cli, args=[
|
||||
"--ssh-config-path",
|
||||
ssh_config.name,
|
||||
"--ssh-config-host",
|
||||
"test",
|
||||
"--ssh-user", "arg_user",
|
||||
"--ssh-host", "arg_host",
|
||||
"--ssh-port", "3",
|
||||
"--ssh-key-filename", "/path/to/key"
|
||||
])
|
||||
assert result.exit_code == 0, result.output + \
|
||||
" " + str(result.exception)
|
||||
assert \
|
||||
MockMyCli.connect_args["ssh_user"] == "arg_user" and \
|
||||
MockMyCli.connect_args["ssh_host"] == "arg_host" and \
|
||||
MockMyCli.connect_args["ssh_port"] == 3 and \
|
||||
MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
|
||||
|
||||
result = runner.invoke(
|
||||
mycli.main.cli,
|
||||
args=[
|
||||
"--ssh-config-path",
|
||||
ssh_config.name,
|
||||
"--ssh-config-host",
|
||||
"test",
|
||||
"--ssh-user",
|
||||
"arg_user",
|
||||
"--ssh-host",
|
||||
"arg_host",
|
||||
"--ssh-port",
|
||||
"3",
|
||||
"--ssh-key-filename",
|
||||
"/path/to/key",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0, result.output + " " + str(result.exception)
|
||||
assert (
|
||||
MockMyCli.connect_args["ssh_user"] == "arg_user"
|
||||
and MockMyCli.connect_args["ssh_host"] == "arg_host"
|
||||
and MockMyCli.connect_args["ssh_port"] == 3
|
||||
and MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
|
||||
)
|
||||
|
||||
# delete=False means we should try to clean up
|
||||
try:
|
||||
if os.path.exists(ssh_config.name):
|
||||
|
@ -542,9 +533,7 @@ def test_init_command_arg(executor):
|
|||
init_command = "set sql_select_limit=1000"
|
||||
sql = 'show variables like "sql_select_limit";'
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
|
||||
)
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
|
||||
|
||||
expected = "sql_select_limit\t1000\n"
|
||||
assert result.exit_code == 0
|
||||
|
@ -553,18 +542,13 @@ def test_init_command_arg(executor):
|
|||
|
||||
@dbtest
|
||||
def test_init_command_multiple_arg(executor):
|
||||
init_command = 'set sql_select_limit=2000; set max_join_size=20000'
|
||||
sql = (
|
||||
'show variables like "sql_select_limit";\n'
|
||||
'show variables like "max_join_size"'
|
||||
)
|
||||
init_command = "set sql_select_limit=2000; set max_join_size=20000"
|
||||
sql = 'show variables like "sql_select_limit";\n' 'show variables like "max_join_size"'
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
|
||||
)
|
||||
result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
|
||||
|
||||
expected_sql_select_limit = 'sql_select_limit\t2000\n'
|
||||
expected_max_join_size = 'max_join_size\t20000\n'
|
||||
expected_sql_select_limit = "sql_select_limit\t2000\n"
|
||||
expected_max_join_size = "max_join_size\t20000\n"
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert expected_sql_select_limit in result.output
|
||||
|
|
|
@ -6,56 +6,48 @@ from prompt_toolkit.document import Document
|
|||
@pytest.fixture
|
||||
def completer():
|
||||
import mycli.sqlcompleter as sqlcompleter
|
||||
|
||||
return sqlcompleter.SQLCompleter(smart_completion=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complete_event():
|
||||
from unittest.mock import Mock
|
||||
|
||||
return Mock()
|
||||
|
||||
|
||||
def test_empty_string_completion(completer, complete_event):
|
||||
text = ''
|
||||
text = ""
|
||||
position = 0
|
||||
result = list(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event))
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(map(Completion, completer.all_completions))
|
||||
|
||||
|
||||
def test_select_keyword_completion(completer, complete_event):
|
||||
text = 'SEL'
|
||||
position = len('SEL')
|
||||
result = list(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event))
|
||||
assert result == list([Completion(text='SELECT', start_position=-3)])
|
||||
text = "SEL"
|
||||
position = len("SEL")
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list([Completion(text="SELECT", start_position=-3)])
|
||||
|
||||
|
||||
def test_function_name_completion(completer, complete_event):
|
||||
text = 'SELECT MA'
|
||||
position = len('SELECT MA')
|
||||
result = list(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event))
|
||||
text = "SELECT MA"
|
||||
position = len("SELECT MA")
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert sorted(x.text for x in result) == ["MASTER", "MAX"]
|
||||
|
||||
|
||||
def test_column_name_completion(completer, complete_event):
|
||||
text = 'SELECT FROM users'
|
||||
position = len('SELECT ')
|
||||
result = list(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event))
|
||||
text = "SELECT FROM users"
|
||||
position = len("SELECT ")
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(map(Completion, completer.all_completions))
|
||||
|
||||
|
||||
def test_special_name_completion(completer, complete_event):
|
||||
text = '\\'
|
||||
position = len('\\')
|
||||
result = set(completer.get_completions(
|
||||
Document(text=text, cursor_position=position),
|
||||
complete_event))
|
||||
text = "\\"
|
||||
position = len("\\")
|
||||
result = set(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
# Special commands will NOT be suggested during naive completion mode.
|
||||
assert result == set()
|
||||
|
|
|
@ -1,67 +1,72 @@
|
|||
import pytest
|
||||
from mycli.packages.parseutils import (
|
||||
extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause,
|
||||
is_dropping_database)
|
||||
extract_tables,
|
||||
query_starts_with,
|
||||
queries_start_with,
|
||||
is_destructive,
|
||||
query_has_where_clause,
|
||||
is_dropping_database,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_string():
|
||||
tables = extract_tables('')
|
||||
tables = extract_tables("")
|
||||
assert tables == []
|
||||
|
||||
|
||||
def test_simple_select_single_table():
|
||||
tables = extract_tables('select * from abc')
|
||||
assert tables == [(None, 'abc', None)]
|
||||
tables = extract_tables("select * from abc")
|
||||
assert tables == [(None, "abc", None)]
|
||||
|
||||
|
||||
def test_simple_select_single_table_schema_qualified():
|
||||
tables = extract_tables('select * from abc.def')
|
||||
assert tables == [('abc', 'def', None)]
|
||||
tables = extract_tables("select * from abc.def")
|
||||
assert tables == [("abc", "def", None)]
|
||||
|
||||
|
||||
def test_simple_select_multiple_tables():
|
||||
tables = extract_tables('select * from abc, def')
|
||||
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
|
||||
tables = extract_tables("select * from abc, def")
|
||||
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
|
||||
|
||||
|
||||
def test_simple_select_multiple_tables_schema_qualified():
|
||||
tables = extract_tables('select * from abc.def, ghi.jkl')
|
||||
assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
|
||||
tables = extract_tables("select * from abc.def, ghi.jkl")
|
||||
assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)]
|
||||
|
||||
|
||||
def test_simple_select_with_cols_single_table():
|
||||
tables = extract_tables('select a,b from abc')
|
||||
assert tables == [(None, 'abc', None)]
|
||||
tables = extract_tables("select a,b from abc")
|
||||
assert tables == [(None, "abc", None)]
|
||||
|
||||
|
||||
def test_simple_select_with_cols_single_table_schema_qualified():
|
||||
tables = extract_tables('select a,b from abc.def')
|
||||
assert tables == [('abc', 'def', None)]
|
||||
tables = extract_tables("select a,b from abc.def")
|
||||
assert tables == [("abc", "def", None)]
|
||||
|
||||
|
||||
def test_simple_select_with_cols_multiple_tables():
|
||||
tables = extract_tables('select a,b from abc, def')
|
||||
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
|
||||
tables = extract_tables("select a,b from abc, def")
|
||||
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
|
||||
|
||||
|
||||
def test_simple_select_with_cols_multiple_tables_with_schema():
|
||||
tables = extract_tables('select a,b from abc.def, def.ghi')
|
||||
assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
|
||||
tables = extract_tables("select a,b from abc.def, def.ghi")
|
||||
assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)]
|
||||
|
||||
|
||||
def test_select_with_hanging_comma_single_table():
|
||||
tables = extract_tables('select a, from abc')
|
||||
assert tables == [(None, 'abc', None)]
|
||||
tables = extract_tables("select a, from abc")
|
||||
assert tables == [(None, "abc", None)]
|
||||
|
||||
|
||||
def test_select_with_hanging_comma_multiple_tables():
|
||||
tables = extract_tables('select a, from abc, def')
|
||||
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
|
||||
tables = extract_tables("select a, from abc, def")
|
||||
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
|
||||
|
||||
|
||||
def test_select_with_hanging_period_multiple_tables():
|
||||
tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
|
||||
assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
|
||||
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
|
||||
assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")]
|
||||
|
||||
|
||||
def test_simple_insert_single_table():
|
||||
|
@ -69,97 +74,80 @@ def test_simple_insert_single_table():
|
|||
|
||||
# sqlparse mistakenly assigns an alias to the table
|
||||
# assert tables == [(None, 'abc', None)]
|
||||
assert tables == [(None, 'abc', 'abc')]
|
||||
assert tables == [(None, "abc", "abc")]
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_simple_insert_single_table_schema_qualified():
|
||||
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
|
||||
assert tables == [('abc', 'def', None)]
|
||||
assert tables == [("abc", "def", None)]
|
||||
|
||||
|
||||
def test_simple_update_table():
|
||||
tables = extract_tables('update abc set id = 1')
|
||||
assert tables == [(None, 'abc', None)]
|
||||
tables = extract_tables("update abc set id = 1")
|
||||
assert tables == [(None, "abc", None)]
|
||||
|
||||
|
||||
def test_simple_update_table_with_schema():
|
||||
tables = extract_tables('update abc.def set id = 1')
|
||||
assert tables == [('abc', 'def', None)]
|
||||
tables = extract_tables("update abc.def set id = 1")
|
||||
assert tables == [("abc", "def", None)]
|
||||
|
||||
|
||||
def test_join_table():
|
||||
tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
|
||||
assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
|
||||
tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num")
|
||||
assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")]
|
||||
|
||||
|
||||
def test_join_table_schema_qualified():
|
||||
tables = extract_tables(
|
||||
'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
|
||||
assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
|
||||
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
|
||||
assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")]
|
||||
|
||||
|
||||
def test_join_as_table():
|
||||
tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
|
||||
assert tables == [(None, 'my_table', 'm')]
|
||||
tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
|
||||
assert tables == [(None, "my_table", "m")]
|
||||
|
||||
|
||||
def test_query_starts_with():
|
||||
query = 'USE test;'
|
||||
assert query_starts_with(query, ('use', )) is True
|
||||
query = "USE test;"
|
||||
assert query_starts_with(query, ("use",)) is True
|
||||
|
||||
query = 'DROP DATABASE test;'
|
||||
assert query_starts_with(query, ('use', )) is False
|
||||
query = "DROP DATABASE test;"
|
||||
assert query_starts_with(query, ("use",)) is False
|
||||
|
||||
|
||||
def test_query_starts_with_comment():
|
||||
query = '# comment\nUSE test;'
|
||||
assert query_starts_with(query, ('use', )) is True
|
||||
query = "# comment\nUSE test;"
|
||||
assert query_starts_with(query, ("use",)) is True
|
||||
|
||||
|
||||
def test_queries_start_with():
|
||||
sql = (
|
||||
'# comment\n'
|
||||
'show databases;'
|
||||
'use foo;'
|
||||
)
|
||||
assert queries_start_with(sql, ('show', 'select')) is True
|
||||
assert queries_start_with(sql, ('use', 'drop')) is True
|
||||
assert queries_start_with(sql, ('delete', 'update')) is False
|
||||
sql = "# comment\n" "show databases;" "use foo;"
|
||||
assert queries_start_with(sql, ("show", "select")) is True
|
||||
assert queries_start_with(sql, ("use", "drop")) is True
|
||||
assert queries_start_with(sql, ("delete", "update")) is False
|
||||
|
||||
|
||||
def test_is_destructive():
|
||||
sql = (
|
||||
'use test;\n'
|
||||
'show databases;\n'
|
||||
'drop database foo;'
|
||||
)
|
||||
sql = "use test;\n" "show databases;\n" "drop database foo;"
|
||||
assert is_destructive(sql) is True
|
||||
|
||||
|
||||
def test_is_destructive_update_with_where_clause():
|
||||
sql = (
|
||||
'use test;\n'
|
||||
'show databases;\n'
|
||||
'UPDATE test SET x = 1 WHERE id = 1;'
|
||||
)
|
||||
sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1 WHERE id = 1;"
|
||||
assert is_destructive(sql) is False
|
||||
|
||||
|
||||
def test_is_destructive_update_without_where_clause():
|
||||
sql = (
|
||||
'use test;\n'
|
||||
'show databases;\n'
|
||||
'UPDATE test SET x = 1;'
|
||||
)
|
||||
sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1;"
|
||||
assert is_destructive(sql) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('sql', 'has_where_clause'),
|
||||
("sql", "has_where_clause"),
|
||||
[
|
||||
('update test set dummy = 1;', False),
|
||||
('update test set dummy = 1 where id = 1);', True),
|
||||
("update test set dummy = 1;", False),
|
||||
("update test set dummy = 1 where id = 1);", True),
|
||||
],
|
||||
)
|
||||
def test_query_has_where_clause(sql, has_where_clause):
|
||||
|
@ -167,24 +155,20 @@ def test_query_has_where_clause(sql, has_where_clause):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('sql', 'dbname', 'is_dropping'),
|
||||
("sql", "dbname", "is_dropping"),
|
||||
[
|
||||
('select bar from foo', 'foo', False),
|
||||
('drop database "foo";', '`foo`', True),
|
||||
('drop schema foo', 'foo', True),
|
||||
('drop schema foo', 'bar', False),
|
||||
('drop database bar', 'foo', False),
|
||||
('drop database foo', None, False),
|
||||
('drop database foo; create database foo', 'foo', False),
|
||||
('drop database foo; create database bar', 'foo', True),
|
||||
('select bar from foo; drop database bazz', 'foo', False),
|
||||
('select bar from foo; drop database bazz', 'bazz', True),
|
||||
('-- dropping database \n '
|
||||
'drop -- really dropping \n '
|
||||
'schema abc -- now it is dropped',
|
||||
'abc',
|
||||
True)
|
||||
]
|
||||
("select bar from foo", "foo", False),
|
||||
('drop database "foo";', "`foo`", True),
|
||||
("drop schema foo", "foo", True),
|
||||
("drop schema foo", "bar", False),
|
||||
("drop database bar", "foo", False),
|
||||
("drop database foo", None, False),
|
||||
("drop database foo; create database foo", "foo", False),
|
||||
("drop database foo; create database bar", "foo", True),
|
||||
("select bar from foo; drop database bazz", "foo", False),
|
||||
("select bar from foo; drop database bazz", "bazz", True),
|
||||
("-- dropping database \n " "drop -- really dropping \n " "schema abc -- now it is dropped", "abc", True),
|
||||
],
|
||||
)
|
||||
def test_is_dropping_database(sql, dbname, is_dropping):
|
||||
assert is_dropping_database(sql, dbname) == is_dropping
|
||||
|
|
|
@ -4,8 +4,8 @@ from mycli.packages.prompt_utils import confirm_destructive_query
|
|||
|
||||
|
||||
def test_confirm_destructive_query_notty():
|
||||
stdin = click.get_text_stream('stdin')
|
||||
stdin = click.get_text_stream("stdin")
|
||||
assert stdin.isatty() is False
|
||||
|
||||
sql = 'drop database foo;'
|
||||
sql = "drop database foo;"
|
||||
assert confirm_destructive_query(sql) is None
|
||||
|
|
|
@ -43,49 +43,35 @@ def complete_event():
|
|||
def test_special_name_completion(completer, complete_event):
|
||||
text = "\\d"
|
||||
position = len("\\d")
|
||||
result = completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
|
||||
assert result == [Completion(text="\\dt", start_position=-2)]
|
||||
|
||||
|
||||
def test_empty_string_completion(completer, complete_event):
|
||||
text = ""
|
||||
position = 0
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
assert (
|
||||
list(map(Completion, completer.keywords + completer.special_commands)) == result
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert list(map(Completion, completer.keywords + completer.special_commands)) == result
|
||||
|
||||
|
||||
def test_select_keyword_completion(completer, complete_event):
|
||||
text = "SEL"
|
||||
position = len("SEL")
|
||||
result = completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
|
||||
assert list(result) == list([Completion(text="SELECT", start_position=-3)])
|
||||
|
||||
|
||||
def test_select_star(completer, complete_event):
|
||||
text = "SELECT * "
|
||||
position = len(text)
|
||||
result = completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
|
||||
assert list(result) == list(map(Completion, completer.keywords))
|
||||
|
||||
|
||||
def test_table_completion(completer, complete_event):
|
||||
text = "SELECT * FROM "
|
||||
position = len(text)
|
||||
result = completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
|
||||
assert list(result) == list(
|
||||
[
|
||||
Completion(text="users", start_position=0),
|
||||
|
@ -99,9 +85,7 @@ def test_table_completion(completer, complete_event):
|
|||
def test_function_name_completion(completer, complete_event):
|
||||
text = "SELECT MA"
|
||||
position = len("SELECT MA")
|
||||
result = completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
|
||||
assert list(result) == list(
|
||||
[
|
||||
Completion(text="MAX", start_position=-2),
|
||||
|
@ -127,11 +111,7 @@ def test_suggested_column_names(completer, complete_event):
|
|||
"""
|
||||
text = "SELECT from users"
|
||||
position = len("SELECT ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -157,9 +137,7 @@ def test_suggested_column_names_in_function(completer, complete_event):
|
|||
"""
|
||||
text = "SELECT MAX( from users"
|
||||
position = len("SELECT MAX(")
|
||||
result = completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
|
||||
assert list(result) == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -181,11 +159,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event):
|
|||
"""
|
||||
text = "SELECT users. from users"
|
||||
position = len("SELECT users.")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -207,11 +181,7 @@ def test_suggested_column_names_with_alias(completer, complete_event):
|
|||
"""
|
||||
text = "SELECT u. from users u"
|
||||
position = len("SELECT u.")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -234,11 +204,7 @@ def test_suggested_multiple_column_names(completer, complete_event):
|
|||
"""
|
||||
text = "SELECT id, from users u"
|
||||
position = len("SELECT id, ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -264,11 +230,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event):
|
|||
"""
|
||||
text = "SELECT u.id, u. from users u"
|
||||
position = len("SELECT u.id, u.")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -291,11 +253,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
|
|||
"""
|
||||
text = "SELECT users.id, users. from users u"
|
||||
position = len("SELECT users.id, users.")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -310,11 +268,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
|
|||
def test_suggested_aliases_after_on(completer, complete_event):
|
||||
text = "SELECT u.name, o.id FROM users u JOIN orders o ON "
|
||||
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="u", start_position=0),
|
||||
|
@ -326,11 +280,7 @@ def test_suggested_aliases_after_on(completer, complete_event):
|
|||
def test_suggested_aliases_after_on_right_side(completer, complete_event):
|
||||
text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = "
|
||||
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="u", start_position=0),
|
||||
|
@ -342,11 +292,7 @@ def test_suggested_aliases_after_on_right_side(completer, complete_event):
|
|||
def test_suggested_tables_after_on(completer, complete_event):
|
||||
text = "SELECT users.name, orders.id FROM users JOIN orders ON "
|
||||
position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="users", start_position=0),
|
||||
|
@ -357,14 +303,8 @@ def test_suggested_tables_after_on(completer, complete_event):
|
|||
|
||||
def test_suggested_tables_after_on_right_side(completer, complete_event):
|
||||
text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
|
||||
position = len(
|
||||
"SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
|
||||
)
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ")
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="users", start_position=0),
|
||||
|
@ -376,11 +316,7 @@ def test_suggested_tables_after_on_right_side(completer, complete_event):
|
|||
def test_table_names_after_from(completer, complete_event):
|
||||
text = "SELECT * FROM "
|
||||
position = len("SELECT * FROM ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="users", start_position=0),
|
||||
|
@ -394,29 +330,21 @@ def test_table_names_after_from(completer, complete_event):
|
|||
def test_auto_escaped_col_names(completer, complete_event):
|
||||
text = "SELECT from `select`"
|
||||
position = len("SELECT ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == [
|
||||
Completion(text="*", start_position=0),
|
||||
Completion(text="id", start_position=0),
|
||||
Completion(text="`insert`", start_position=0),
|
||||
Completion(text="`ABC`", start_position=0),
|
||||
] + list(map(Completion, completer.functions)) + [
|
||||
Completion(text="select", start_position=0)
|
||||
] + list(map(Completion, completer.keywords))
|
||||
] + list(map(Completion, completer.functions)) + [Completion(text="select", start_position=0)] + list(
|
||||
map(Completion, completer.keywords)
|
||||
)
|
||||
|
||||
|
||||
def test_un_escaped_table_names(completer, complete_event):
|
||||
text = "SELECT from réveillé"
|
||||
position = len("SELECT ")
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
assert result == list(
|
||||
[
|
||||
Completion(text="*", start_position=0),
|
||||
|
@ -464,10 +392,6 @@ def dummy_list_path(dir_name):
|
|||
)
|
||||
def test_file_name_completion(completer, complete_event, text, expected):
|
||||
position = len(text)
|
||||
result = list(
|
||||
completer.get_completions(
|
||||
Document(text=text, cursor_position=position), complete_event
|
||||
)
|
||||
)
|
||||
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
|
||||
expected = list((Completion(txt, pos) for txt, pos in expected))
|
||||
assert result == expected
|
||||
|
|
|
@ -17,11 +17,11 @@ def test_set_get_pager():
|
|||
assert mycli.packages.special.is_pager_enabled()
|
||||
mycli.packages.special.set_pager_enabled(False)
|
||||
assert not mycli.packages.special.is_pager_enabled()
|
||||
mycli.packages.special.set_pager('less')
|
||||
assert os.environ['PAGER'] == "less"
|
||||
mycli.packages.special.set_pager("less")
|
||||
assert os.environ["PAGER"] == "less"
|
||||
mycli.packages.special.set_pager(False)
|
||||
assert os.environ['PAGER'] == "less"
|
||||
del os.environ['PAGER']
|
||||
assert os.environ["PAGER"] == "less"
|
||||
del os.environ["PAGER"]
|
||||
mycli.packages.special.set_pager(False)
|
||||
mycli.packages.special.disable_pager()
|
||||
assert not mycli.packages.special.is_pager_enabled()
|
||||
|
@ -42,45 +42,44 @@ def test_set_get_expanded_output():
|
|||
|
||||
|
||||
def test_editor_command():
|
||||
assert mycli.packages.special.editor_command(r'hello\e')
|
||||
assert mycli.packages.special.editor_command(r'\ehello')
|
||||
assert not mycli.packages.special.editor_command(r'hello')
|
||||
assert mycli.packages.special.editor_command(r"hello\e")
|
||||
assert mycli.packages.special.editor_command(r"\ehello")
|
||||
assert not mycli.packages.special.editor_command(r"hello")
|
||||
|
||||
assert mycli.packages.special.get_filename(r'\e filename') == "filename"
|
||||
assert mycli.packages.special.get_filename(r"\e filename") == "filename"
|
||||
|
||||
os.environ['EDITOR'] = 'true'
|
||||
os.environ['VISUAL'] = 'true'
|
||||
os.environ["EDITOR"] = "true"
|
||||
os.environ["VISUAL"] = "true"
|
||||
# Set the editor to Notepad on Windows
|
||||
if os.name != 'nt':
|
||||
mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1"
|
||||
if os.name != "nt":
|
||||
mycli.packages.special.open_external_editor(sql=r"select 1") == "select 1"
|
||||
else:
|
||||
pytest.skip('Skipping on Windows platform.')
|
||||
|
||||
pytest.skip("Skipping on Windows platform.")
|
||||
|
||||
|
||||
def test_tee_command():
|
||||
mycli.packages.special.write_tee(u"hello world") # write without file set
|
||||
mycli.packages.special.write_tee("hello world") # write without file set
|
||||
# keep Windows from locking the file with delete=False
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
mycli.packages.special.execute(None, u"tee " + f.name)
|
||||
mycli.packages.special.write_tee(u"hello world")
|
||||
if os.name=='nt':
|
||||
mycli.packages.special.execute(None, "tee " + f.name)
|
||||
mycli.packages.special.write_tee("hello world")
|
||||
if os.name == "nt":
|
||||
assert f.read() == b"hello world\r\n"
|
||||
else:
|
||||
assert f.read() == b"hello world\n"
|
||||
|
||||
mycli.packages.special.execute(None, u"tee -o " + f.name)
|
||||
mycli.packages.special.write_tee(u"hello world")
|
||||
mycli.packages.special.execute(None, "tee -o " + f.name)
|
||||
mycli.packages.special.write_tee("hello world")
|
||||
f.seek(0)
|
||||
if os.name=='nt':
|
||||
if os.name == "nt":
|
||||
assert f.read() == b"hello world\r\n"
|
||||
else:
|
||||
assert f.read() == b"hello world\n"
|
||||
|
||||
mycli.packages.special.execute(None, u"notee")
|
||||
mycli.packages.special.write_tee(u"hello world")
|
||||
mycli.packages.special.execute(None, "notee")
|
||||
mycli.packages.special.write_tee("hello world")
|
||||
f.seek(0)
|
||||
if os.name=='nt':
|
||||
if os.name == "nt":
|
||||
assert f.read() == b"hello world\r\n"
|
||||
else:
|
||||
assert f.read() == b"hello world\n"
|
||||
|
@ -92,52 +91,49 @@ def test_tee_command():
|
|||
os.remove(f.name)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while attempting to delete the file: {e}")
|
||||
|
||||
|
||||
|
||||
def test_tee_command_error():
|
||||
with pytest.raises(TypeError):
|
||||
mycli.packages.special.execute(None, 'tee')
|
||||
mycli.packages.special.execute(None, "tee")
|
||||
|
||||
with pytest.raises(OSError):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
|
||||
mycli.packages.special.execute(None, 'tee {}'.format(f.name))
|
||||
mycli.packages.special.execute(None, "tee {}".format(f.name))
|
||||
|
||||
|
||||
@dbtest
|
||||
|
||||
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
|
||||
def test_favorite_query():
|
||||
with db_connection().cursor() as cur:
|
||||
query = u'select "✔"'
|
||||
mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query))
|
||||
assert next(mycli.packages.special.execute(
|
||||
cur, u'\\f check'))[0] == "> " + query
|
||||
query = 'select "✔"'
|
||||
mycli.packages.special.execute(cur, "\\fs check {0}".format(query))
|
||||
assert next(mycli.packages.special.execute(cur, "\\f check"))[0] == "> " + query
|
||||
|
||||
|
||||
def test_once_command():
|
||||
with pytest.raises(TypeError):
|
||||
mycli.packages.special.execute(None, u"\\once")
|
||||
mycli.packages.special.execute(None, "\\once")
|
||||
|
||||
with pytest.raises(OSError):
|
||||
mycli.packages.special.execute(None, u"\\once /proc/access-denied")
|
||||
mycli.packages.special.execute(None, "\\once /proc/access-denied")
|
||||
|
||||
mycli.packages.special.write_once(u"hello world") # write without file set
|
||||
mycli.packages.special.write_once("hello world") # write without file set
|
||||
# keep Windows from locking the file with delete=False
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
mycli.packages.special.execute(None, u"\\once " + f.name)
|
||||
mycli.packages.special.write_once(u"hello world")
|
||||
if os.name=='nt':
|
||||
mycli.packages.special.execute(None, "\\once " + f.name)
|
||||
mycli.packages.special.write_once("hello world")
|
||||
if os.name == "nt":
|
||||
assert f.read() == b"hello world\r\n"
|
||||
else:
|
||||
assert f.read() == b"hello world\n"
|
||||
|
||||
mycli.packages.special.execute(None, u"\\once -o " + f.name)
|
||||
mycli.packages.special.write_once(u"hello world line 1")
|
||||
mycli.packages.special.write_once(u"hello world line 2")
|
||||
mycli.packages.special.execute(None, "\\once -o " + f.name)
|
||||
mycli.packages.special.write_once("hello world line 1")
|
||||
mycli.packages.special.write_once("hello world line 2")
|
||||
f.seek(0)
|
||||
if os.name=='nt':
|
||||
if os.name == "nt":
|
||||
assert f.read() == b"hello world line 1\r\nhello world line 2\r\n"
|
||||
else:
|
||||
assert f.read() == b"hello world line 1\nhello world line 2\n"
|
||||
|
@ -151,52 +147,47 @@ def test_once_command():
|
|||
|
||||
def test_pipe_once_command():
|
||||
with pytest.raises(IOError):
|
||||
mycli.packages.special.execute(None, u"\\pipe_once")
|
||||
mycli.packages.special.execute(None, "\\pipe_once")
|
||||
|
||||
with pytest.raises(OSError):
|
||||
mycli.packages.special.execute(
|
||||
None, u"\\pipe_once /proc/access-denied")
|
||||
mycli.packages.special.execute(None, "\\pipe_once /proc/access-denied")
|
||||
|
||||
if os.name == 'nt':
|
||||
mycli.packages.special.execute(None, u'\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
|
||||
mycli.packages.special.write_once(u"hello world")
|
||||
if os.name == "nt":
|
||||
mycli.packages.special.execute(None, '\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
|
||||
mycli.packages.special.write_once("hello world")
|
||||
mycli.packages.special.unset_pipe_once_if_written()
|
||||
else:
|
||||
mycli.packages.special.execute(None, u"\\pipe_once wc")
|
||||
mycli.packages.special.write_once(u"hello world")
|
||||
mycli.packages.special.unset_pipe_once_if_written()
|
||||
# how to assert on wc output?
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
mycli.packages.special.execute(None, "\\pipe_once tee " + f.name)
|
||||
mycli.packages.special.write_pipe_once("hello world")
|
||||
mycli.packages.special.unset_pipe_once_if_written()
|
||||
f.seek(0)
|
||||
assert f.read() == b"hello world\n"
|
||||
|
||||
|
||||
def test_parseargfile():
|
||||
"""Test that parseargfile expands the user directory."""
|
||||
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
|
||||
'mode': 'a'}
|
||||
|
||||
if os.name=='nt':
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile(
|
||||
'~\\filename')
|
||||
else:
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile(
|
||||
'~/filename')
|
||||
expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "a"}
|
||||
|
||||
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
|
||||
'mode': 'w'}
|
||||
if os.name=='nt':
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile(
|
||||
'-o ~\\filename')
|
||||
if os.name == "nt":
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile("~\\filename")
|
||||
else:
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile(
|
||||
'-o ~/filename')
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile("~/filename")
|
||||
|
||||
expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "w"}
|
||||
if os.name == "nt":
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile("-o ~\\filename")
|
||||
else:
|
||||
assert expected == mycli.packages.special.iocommands.parseargfile("-o ~/filename")
|
||||
|
||||
|
||||
def test_parseargfile_no_file():
|
||||
"""Test that parseargfile raises a TypeError if there is no filename."""
|
||||
with pytest.raises(TypeError):
|
||||
mycli.packages.special.iocommands.parseargfile('')
|
||||
mycli.packages.special.iocommands.parseargfile("")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
mycli.packages.special.iocommands.parseargfile('-o ')
|
||||
mycli.packages.special.iocommands.parseargfile("-o ")
|
||||
|
||||
|
||||
@dbtest
|
||||
|
@ -205,11 +196,9 @@ def test_watch_query_iteration():
|
|||
the desired query and returns the given results."""
|
||||
expected_value = "1"
|
||||
query = "SELECT {0!s}".format(expected_value)
|
||||
expected_title = '> {0!s}'.format(query)
|
||||
expected_title = "> {0!s}".format(query)
|
||||
with db_connection().cursor() as cur:
|
||||
result = next(mycli.packages.special.iocommands.watch_query(
|
||||
arg=query, cur=cur
|
||||
))
|
||||
result = next(mycli.packages.special.iocommands.watch_query(arg=query, cur=cur))
|
||||
assert result[0] == expected_title
|
||||
assert result[2][0] == expected_value
|
||||
|
||||
|
@ -230,14 +219,12 @@ def test_watch_query_full():
|
|||
wait_interval = 1
|
||||
expected_value = "1"
|
||||
query = "SELECT {0!s}".format(expected_value)
|
||||
expected_title = '> {0!s}'.format(query)
|
||||
expected_title = "> {0!s}".format(query)
|
||||
expected_results = 4
|
||||
ctrl_c_process = send_ctrl_c(wait_interval)
|
||||
with db_connection().cursor() as cur:
|
||||
results = list(
|
||||
result for result in mycli.packages.special.iocommands.watch_query(
|
||||
arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur
|
||||
)
|
||||
result for result in mycli.packages.special.iocommands.watch_query(arg="{0!s} {1!s}".format(watch_seconds, query), cur=cur)
|
||||
)
|
||||
ctrl_c_process.join(1)
|
||||
assert len(results) == expected_results
|
||||
|
@ -247,14 +234,12 @@ def test_watch_query_full():
|
|||
|
||||
|
||||
@dbtest
|
||||
@patch('click.clear')
|
||||
@patch("click.clear")
|
||||
def test_watch_query_clear(clear_mock):
|
||||
"""Test that the screen is cleared with the -c flag of `watch` command
|
||||
before execute the query."""
|
||||
with db_connection().cursor() as cur:
|
||||
watch_gen = mycli.packages.special.iocommands.watch_query(
|
||||
arg='0.1 -c select 1;', cur=cur
|
||||
)
|
||||
watch_gen = mycli.packages.special.iocommands.watch_query(arg="0.1 -c select 1;", cur=cur)
|
||||
assert not clear_mock.called
|
||||
next(watch_gen)
|
||||
assert clear_mock.called
|
||||
|
@ -271,19 +256,20 @@ def test_watch_query_bad_arguments():
|
|||
watch_query = mycli.packages.special.iocommands.watch_query
|
||||
with db_connection().cursor() as cur:
|
||||
with pytest.raises(ProgrammingError):
|
||||
next(watch_query('a select 1;', cur=cur))
|
||||
next(watch_query("a select 1;", cur=cur))
|
||||
with pytest.raises(ProgrammingError):
|
||||
next(watch_query('-a select 1;', cur=cur))
|
||||
next(watch_query("-a select 1;", cur=cur))
|
||||
with pytest.raises(ProgrammingError):
|
||||
next(watch_query('1 -a select 1;', cur=cur))
|
||||
next(watch_query("1 -a select 1;", cur=cur))
|
||||
with pytest.raises(ProgrammingError):
|
||||
next(watch_query('-c -a select 1;', cur=cur))
|
||||
next(watch_query("-c -a select 1;", cur=cur))
|
||||
|
||||
|
||||
@dbtest
|
||||
@patch('click.clear')
|
||||
@patch("click.clear")
|
||||
def test_watch_query_interval_clear(clear_mock):
|
||||
"""Test `watch` command with interval and clear flag."""
|
||||
|
||||
def test_asserts(gen):
|
||||
clear_mock.reset_mock()
|
||||
start = time()
|
||||
|
@ -296,46 +282,32 @@ def test_watch_query_interval_clear(clear_mock):
|
|||
seconds = 1.0
|
||||
watch_query = mycli.packages.special.iocommands.watch_query
|
||||
with db_connection().cursor() as cur:
|
||||
test_asserts(watch_query('{0!s} -c select 1;'.format(seconds),
|
||||
cur=cur))
|
||||
test_asserts(watch_query('-c {0!s} select 1;'.format(seconds),
|
||||
cur=cur))
|
||||
test_asserts(watch_query("{0!s} -c select 1;".format(seconds), cur=cur))
|
||||
test_asserts(watch_query("-c {0!s} select 1;".format(seconds), cur=cur))
|
||||
|
||||
|
||||
def test_split_sql_by_delimiter():
|
||||
for delimiter_str in (';', '$', '😀'):
|
||||
for delimiter_str in (";", "$", "😀"):
|
||||
mycli.packages.special.set_delimiter(delimiter_str)
|
||||
sql_input = "select 1{} select \ufffc2".format(delimiter_str)
|
||||
queries = (
|
||||
"select 1",
|
||||
"select \ufffc2"
|
||||
)
|
||||
for query, parsed_query in zip(
|
||||
queries, mycli.packages.special.split_queries(sql_input)):
|
||||
assert(query == parsed_query)
|
||||
queries = ("select 1", "select \ufffc2")
|
||||
for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
|
||||
assert query == parsed_query
|
||||
|
||||
|
||||
def test_switch_delimiter_within_query():
|
||||
mycli.packages.special.set_delimiter(';')
|
||||
mycli.packages.special.set_delimiter(";")
|
||||
sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$"
|
||||
queries = (
|
||||
"select 1",
|
||||
"delimiter $$ select 2 $$ select 3 $$",
|
||||
"select 2",
|
||||
"select 3"
|
||||
)
|
||||
for query, parsed_query in zip(
|
||||
queries,
|
||||
mycli.packages.special.split_queries(sql_input)):
|
||||
assert(query == parsed_query)
|
||||
queries = ("select 1", "delimiter $$ select 2 $$ select 3 $$", "select 2", "select 3")
|
||||
for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
|
||||
assert query == parsed_query
|
||||
|
||||
|
||||
def test_set_delimiter():
|
||||
|
||||
for delim in ('foo', 'bar'):
|
||||
for delim in ("foo", "bar"):
|
||||
mycli.packages.special.set_delimiter(delim)
|
||||
assert mycli.packages.special.get_current_delimiter() == delim
|
||||
|
||||
|
||||
def teardown_function():
|
||||
mycli.packages.special.set_delimiter(';')
|
||||
mycli.packages.special.set_delimiter(";")
|
||||
|
|
|
@ -7,14 +7,11 @@ from mycli.sqlexecute import ServerInfo, ServerSpecies
|
|||
from .utils import run, dbtest, set_expanded_output, is_expanded_output
|
||||
|
||||
|
||||
def assert_result_equal(result, title=None, rows=None, headers=None,
|
||||
status=None, auto_status=True, assert_contains=False):
|
||||
def assert_result_equal(result, title=None, rows=None, headers=None, status=None, auto_status=True, assert_contains=False):
|
||||
"""Assert that an sqlexecute.run() result matches the expected values."""
|
||||
if status is None and auto_status and rows:
|
||||
status = '{} row{} in set'.format(
|
||||
len(rows), 's' if len(rows) > 1 else '')
|
||||
fields = {'title': title, 'rows': rows, 'headers': headers,
|
||||
'status': status}
|
||||
status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "")
|
||||
fields = {"title": title, "rows": rows, "headers": headers, "status": status}
|
||||
|
||||
if assert_contains:
|
||||
# Do a loose match on the results using the *in* operator.
|
||||
|
@ -28,34 +25,35 @@ def assert_result_equal(result, title=None, rows=None, headers=None,
|
|||
|
||||
@dbtest
|
||||
def test_conn(executor):
|
||||
run(executor, '''create table test(a text)''')
|
||||
run(executor, '''insert into test values('abc')''')
|
||||
results = run(executor, '''select * from test''')
|
||||
run(executor, """create table test(a text)""")
|
||||
run(executor, """insert into test values('abc')""")
|
||||
results = run(executor, """select * from test""")
|
||||
|
||||
assert_result_equal(results, headers=['a'], rows=[('abc',)])
|
||||
assert_result_equal(results, headers=["a"], rows=[("abc",)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_bools(executor):
|
||||
run(executor, '''create table test(a boolean)''')
|
||||
run(executor, '''insert into test values(True)''')
|
||||
results = run(executor, '''select * from test''')
|
||||
run(executor, """create table test(a boolean)""")
|
||||
run(executor, """insert into test values(True)""")
|
||||
results = run(executor, """select * from test""")
|
||||
|
||||
assert_result_equal(results, headers=['a'], rows=[(1,)])
|
||||
assert_result_equal(results, headers=["a"], rows=[(1,)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_binary(executor):
|
||||
run(executor, '''create table bt(geom linestring NOT NULL)''')
|
||||
run(executor, "INSERT INTO bt VALUES "
|
||||
"(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
|
||||
results = run(executor, '''select * from bt''')
|
||||
run(executor, """create table bt(geom linestring NOT NULL)""")
|
||||
run(executor, "INSERT INTO bt VALUES " "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
|
||||
results = run(executor, """select * from bt""")
|
||||
|
||||
geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n'
|
||||
b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9'
|
||||
b'\xac\xdeC@')
|
||||
geom = (
|
||||
b"\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n"
|
||||
b"\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9"
|
||||
b"\xac\xdeC@"
|
||||
)
|
||||
|
||||
assert_result_equal(results, headers=['geom'], rows=[(geom,)])
|
||||
assert_result_equal(results, headers=["geom"], rows=[(geom,)])
|
||||
|
||||
|
||||
@dbtest
|
||||
|
@ -63,49 +61,48 @@ def test_table_and_columns_query(executor):
|
|||
run(executor, "create table a(x text, y text)")
|
||||
run(executor, "create table b(z text)")
|
||||
|
||||
assert set(executor.tables()) == set([('a',), ('b',)])
|
||||
assert set(executor.table_columns()) == set(
|
||||
[('a', 'x'), ('a', 'y'), ('b', 'z')])
|
||||
assert set(executor.tables()) == set([("a",), ("b",)])
|
||||
assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_database_list(executor):
|
||||
databases = executor.databases()
|
||||
assert 'mycli_test_db' in databases
|
||||
assert "mycli_test_db" in databases
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_invalid_syntax(executor):
|
||||
with pytest.raises(pymysql.ProgrammingError) as excinfo:
|
||||
run(executor, 'invalid syntax!')
|
||||
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
|
||||
run(executor, "invalid syntax!")
|
||||
assert "You have an error in your SQL syntax;" in str(excinfo.value)
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_invalid_column_name(executor):
|
||||
with pytest.raises(pymysql.err.OperationalError) as excinfo:
|
||||
run(executor, 'select invalid command')
|
||||
run(executor, "select invalid command")
|
||||
assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value)
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_unicode_support_in_output(executor):
|
||||
run(executor, "create table unicodechars(t text)")
|
||||
run(executor, u"insert into unicodechars (t) values ('é')")
|
||||
run(executor, "insert into unicodechars (t) values ('é')")
|
||||
|
||||
# See issue #24, this raises an exception without proper handling
|
||||
results = run(executor, u"select * from unicodechars")
|
||||
assert_result_equal(results, headers=['t'], rows=[(u'é',)])
|
||||
results = run(executor, "select * from unicodechars")
|
||||
assert_result_equal(results, headers=["t"], rows=[("é",)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_multiple_queries_same_line(executor):
|
||||
results = run(executor, "select 'foo'; select 'bar'")
|
||||
|
||||
expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)],
|
||||
'status': '1 row in set'},
|
||||
{'title': None, 'headers': ['bar'], 'rows': [('bar',)],
|
||||
'status': '1 row in set'}]
|
||||
expected = [
|
||||
{"title": None, "headers": ["foo"], "rows": [("foo",)], "status": "1 row in set"},
|
||||
{"title": None, "headers": ["bar"], "rows": [("bar",)], "status": "1 row in set"},
|
||||
]
|
||||
assert expected == results
|
||||
|
||||
|
||||
|
@ -113,7 +110,7 @@ def test_multiple_queries_same_line(executor):
|
|||
def test_multiple_queries_same_line_syntaxerror(executor):
|
||||
with pytest.raises(pymysql.ProgrammingError) as excinfo:
|
||||
run(executor, "select 'foo'; invalid syntax")
|
||||
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
|
||||
assert "You have an error in your SQL syntax;" in str(excinfo.value)
|
||||
|
||||
|
||||
@dbtest
|
||||
|
@ -125,15 +122,13 @@ def test_favorite_query(executor):
|
|||
run(executor, "insert into test values('def')")
|
||||
|
||||
results = run(executor, "\\fs test-a select * from test where a like 'a%'")
|
||||
assert_result_equal(results, status='Saved.')
|
||||
assert_result_equal(results, status="Saved.")
|
||||
|
||||
results = run(executor, "\\f test-a")
|
||||
assert_result_equal(results,
|
||||
title="> select * from test where a like 'a%'",
|
||||
headers=['a'], rows=[('abc',)], auto_status=False)
|
||||
assert_result_equal(results, title="> select * from test where a like 'a%'", headers=["a"], rows=[("abc",)], auto_status=False)
|
||||
|
||||
results = run(executor, "\\fd test-a")
|
||||
assert_result_equal(results, status='test-a: Deleted')
|
||||
assert_result_equal(results, status="test-a: Deleted")
|
||||
|
||||
|
||||
@dbtest
|
||||
|
@ -144,158 +139,147 @@ def test_favorite_query_multiple_statement(executor):
|
|||
run(executor, "insert into test values('abc')")
|
||||
run(executor, "insert into test values('def')")
|
||||
|
||||
results = run(executor,
|
||||
"\\fs test-ad select * from test where a like 'a%'; "
|
||||
"select * from test where a like 'd%'")
|
||||
assert_result_equal(results, status='Saved.')
|
||||
results = run(executor, "\\fs test-ad select * from test where a like 'a%'; " "select * from test where a like 'd%'")
|
||||
assert_result_equal(results, status="Saved.")
|
||||
|
||||
results = run(executor, "\\f test-ad")
|
||||
expected = [{'title': "> select * from test where a like 'a%'",
|
||||
'headers': ['a'], 'rows': [('abc',)], 'status': None},
|
||||
{'title': "> select * from test where a like 'd%'",
|
||||
'headers': ['a'], 'rows': [('def',)], 'status': None}]
|
||||
expected = [
|
||||
{"title": "> select * from test where a like 'a%'", "headers": ["a"], "rows": [("abc",)], "status": None},
|
||||
{"title": "> select * from test where a like 'd%'", "headers": ["a"], "rows": [("def",)], "status": None},
|
||||
]
|
||||
assert expected == results
|
||||
|
||||
results = run(executor, "\\fd test-ad")
|
||||
assert_result_equal(results, status='test-ad: Deleted')
|
||||
assert_result_equal(results, status="test-ad: Deleted")
|
||||
|
||||
|
||||
@dbtest
|
||||
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
|
||||
def test_favorite_query_expanded_output(executor):
|
||||
set_expanded_output(False)
|
||||
run(executor, '''create table test(a text)''')
|
||||
run(executor, '''insert into test values('abc')''')
|
||||
run(executor, """create table test(a text)""")
|
||||
run(executor, """insert into test values('abc')""")
|
||||
|
||||
results = run(executor, "\\fs test-ae select * from test")
|
||||
assert_result_equal(results, status='Saved.')
|
||||
assert_result_equal(results, status="Saved.")
|
||||
|
||||
results = run(executor, "\\f test-ae \\G")
|
||||
assert is_expanded_output() is True
|
||||
assert_result_equal(results, title='> select * from test',
|
||||
headers=['a'], rows=[('abc',)], auto_status=False)
|
||||
assert_result_equal(results, title="> select * from test", headers=["a"], rows=[("abc",)], auto_status=False)
|
||||
|
||||
set_expanded_output(False)
|
||||
|
||||
results = run(executor, "\\fd test-ae")
|
||||
assert_result_equal(results, status='test-ae: Deleted')
|
||||
assert_result_equal(results, status="test-ae: Deleted")
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_special_command(executor):
|
||||
results = run(executor, '\\?')
|
||||
assert_result_equal(results, rows=('quit', '\\q', 'Quit.'),
|
||||
headers='Command', assert_contains=True,
|
||||
auto_status=False)
|
||||
results = run(executor, "\\?")
|
||||
assert_result_equal(results, rows=("quit", "\\q", "Quit."), headers="Command", assert_contains=True, auto_status=False)
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_cd_command_without_a_folder_name(executor):
|
||||
results = run(executor, 'system cd')
|
||||
assert_result_equal(results, status='No folder name was provided.')
|
||||
results = run(executor, "system cd")
|
||||
assert_result_equal(results, status="No folder name was provided.")
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_system_command_not_found(executor):
|
||||
results = run(executor, 'system xyz')
|
||||
if os.name=='nt':
|
||||
assert_result_equal(results, status='OSError: The system cannot find the file specified',
|
||||
assert_contains=True)
|
||||
results = run(executor, "system xyz")
|
||||
if os.name == "nt":
|
||||
assert_result_equal(results, status="OSError: The system cannot find the file specified", assert_contains=True)
|
||||
else:
|
||||
assert_result_equal(results, status='OSError: No such file or directory',
|
||||
assert_contains=True)
|
||||
assert_result_equal(results, status="OSError: No such file or directory", assert_contains=True)
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_system_command_output(executor):
|
||||
eol = os.linesep
|
||||
test_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
test_file_path = os.path.join(test_dir, 'test.txt')
|
||||
results = run(executor, 'system cat {0}'.format(test_file_path))
|
||||
assert_result_equal(results, status=f'mycli rocks!{eol}')
|
||||
test_file_path = os.path.join(test_dir, "test.txt")
|
||||
results = run(executor, "system cat {0}".format(test_file_path))
|
||||
assert_result_equal(results, status=f"mycli rocks!{eol}")
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_cd_command_current_dir(executor):
|
||||
test_path = os.path.abspath(os.path.dirname(__file__))
|
||||
run(executor, 'system cd {0}'.format(test_path))
|
||||
run(executor, "system cd {0}".format(test_path))
|
||||
assert os.getcwd() == test_path
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_unicode_support(executor):
|
||||
results = run(executor, u"SELECT '日本語' AS japanese;")
|
||||
assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)])
|
||||
results = run(executor, "SELECT '日本語' AS japanese;")
|
||||
assert_result_equal(results, headers=["japanese"], rows=[("日本語",)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_timestamp_null(executor):
|
||||
run(executor, '''create table ts_null(a timestamp null)''')
|
||||
run(executor, '''insert into ts_null values(null)''')
|
||||
results = run(executor, '''select * from ts_null''')
|
||||
assert_result_equal(results, headers=['a'],
|
||||
rows=[(None,)])
|
||||
run(executor, """create table ts_null(a timestamp null)""")
|
||||
run(executor, """insert into ts_null values(null)""")
|
||||
results = run(executor, """select * from ts_null""")
|
||||
assert_result_equal(results, headers=["a"], rows=[(None,)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_datetime_null(executor):
|
||||
run(executor, '''create table dt_null(a datetime null)''')
|
||||
run(executor, '''insert into dt_null values(null)''')
|
||||
results = run(executor, '''select * from dt_null''')
|
||||
assert_result_equal(results, headers=['a'],
|
||||
rows=[(None,)])
|
||||
run(executor, """create table dt_null(a datetime null)""")
|
||||
run(executor, """insert into dt_null values(null)""")
|
||||
results = run(executor, """select * from dt_null""")
|
||||
assert_result_equal(results, headers=["a"], rows=[(None,)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_date_null(executor):
|
||||
run(executor, '''create table date_null(a date null)''')
|
||||
run(executor, '''insert into date_null values(null)''')
|
||||
results = run(executor, '''select * from date_null''')
|
||||
assert_result_equal(results, headers=['a'], rows=[(None,)])
|
||||
run(executor, """create table date_null(a date null)""")
|
||||
run(executor, """insert into date_null values(null)""")
|
||||
results = run(executor, """select * from date_null""")
|
||||
assert_result_equal(results, headers=["a"], rows=[(None,)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_time_null(executor):
|
||||
run(executor, '''create table time_null(a time null)''')
|
||||
run(executor, '''insert into time_null values(null)''')
|
||||
results = run(executor, '''select * from time_null''')
|
||||
assert_result_equal(results, headers=['a'], rows=[(None,)])
|
||||
run(executor, """create table time_null(a time null)""")
|
||||
run(executor, """insert into time_null values(null)""")
|
||||
results = run(executor, """select * from time_null""")
|
||||
assert_result_equal(results, headers=["a"], rows=[(None,)])
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_multiple_results(executor):
|
||||
query = '''CREATE PROCEDURE dmtest()
|
||||
query = """CREATE PROCEDURE dmtest()
|
||||
BEGIN
|
||||
SELECT 1;
|
||||
SELECT 2;
|
||||
END'''
|
||||
END"""
|
||||
executor.conn.cursor().execute(query)
|
||||
|
||||
results = run(executor, 'call dmtest;')
|
||||
results = run(executor, "call dmtest;")
|
||||
expected = [
|
||||
{'title': None, 'rows': [(1,)], 'headers': ['1'],
|
||||
'status': '1 row in set'},
|
||||
{'title': None, 'rows': [(2,)], 'headers': ['2'],
|
||||
'status': '1 row in set'}
|
||||
{"title": None, "rows": [(1,)], "headers": ["1"], "status": "1 row in set"},
|
||||
{"title": None, "rows": [(2,)], "headers": ["2"], "status": "1 row in set"},
|
||||
]
|
||||
assert results == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'version_string, species, parsed_version_string, version',
|
||||
"version_string, species, parsed_version_string, version",
|
||||
(
|
||||
('5.7.25-TiDB-v6.1.0','TiDB', '6.1.0', 60100),
|
||||
('8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa', 'TiDB', '7.2.0', 70200),
|
||||
('5.7.32-35', 'Percona', '5.7.32', 50732),
|
||||
('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
|
||||
('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
|
||||
('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
|
||||
('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
|
||||
('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
|
||||
('unexpected version string', None, '', 0),
|
||||
('', None, '', 0),
|
||||
(None, None, '', 0),
|
||||
)
|
||||
("5.7.25-TiDB-v6.1.0", "TiDB", "6.1.0", 60100),
|
||||
("8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa", "TiDB", "7.2.0", 70200),
|
||||
("5.7.32-35", "Percona", "5.7.32", 50732),
|
||||
("5.7.32-0ubuntu0.18.04.1", "MySQL", "5.7.32", 50732),
|
||||
("10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
|
||||
("5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
|
||||
("5.0.16-pro-nt-log", "MySQL", "5.0.16", 50016),
|
||||
("5.1.5a-alpha", "MySQL", "5.1.5", 50105),
|
||||
("unexpected version string", None, "", 0),
|
||||
("", None, "", 0),
|
||||
(None, None, "", 0),
|
||||
),
|
||||
)
|
||||
def test_version_parsing(version_string, species, parsed_version_string, version):
|
||||
server_info = ServerInfo.from_version_string(version_string)
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
from textwrap import dedent
|
||||
|
||||
from mycli.packages.tabular_output import sql_format
|
||||
from cli_helpers.tabular_output import TabularOutputFormatter
|
||||
|
||||
from .utils import USER, PASSWORD, HOST, PORT, dbtest
|
||||
|
||||
|
@ -23,20 +21,17 @@ def mycli():
|
|||
@dbtest
|
||||
def test_sql_output(mycli):
|
||||
"""Test the sql output adapter."""
|
||||
headers = ['letters', 'number', 'optional', 'float', 'binary']
|
||||
headers = ["letters", "number", "optional", "float", "binary"]
|
||||
|
||||
class FakeCursor(object):
|
||||
def __init__(self):
|
||||
self.data = [
|
||||
('abc', 1, None, 10.0, b'\xAA'),
|
||||
('d', 456, '1', 0.5, b'\xAA\xBB')
|
||||
]
|
||||
self.data = [("abc", 1, None, 10.0, b"\xaa"), ("d", 456, "1", 0.5, b"\xaa\xbb")]
|
||||
self.description = [
|
||||
(None, FIELD_TYPE.VARCHAR),
|
||||
(None, FIELD_TYPE.LONG),
|
||||
(None, FIELD_TYPE.LONG),
|
||||
(None, FIELD_TYPE.FLOAT),
|
||||
(None, FIELD_TYPE.BLOB)
|
||||
(None, FIELD_TYPE.BLOB),
|
||||
]
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -52,12 +47,11 @@ def test_sql_output(mycli):
|
|||
return self.description
|
||||
|
||||
# Test sql-update output format
|
||||
assert list(mycli.change_table_format("sql-update")) == \
|
||||
[(None, None, None, 'Changed table format to sql-update')]
|
||||
assert list(mycli.change_table_format("sql-update")) == [(None, None, None, "Changed table format to sql-update")]
|
||||
mycli.formatter.query = ""
|
||||
output = mycli.format_output(None, FakeCursor(), headers)
|
||||
actual = "\n".join(output)
|
||||
assert actual == dedent('''\
|
||||
assert actual == dedent("""\
|
||||
UPDATE `DUAL` SET
|
||||
`number` = 1
|
||||
, `optional` = NULL
|
||||
|
@ -69,13 +63,12 @@ def test_sql_output(mycli):
|
|||
, `optional` = '1'
|
||||
, `float` = 0.5e0
|
||||
, `binary` = X'aabb'
|
||||
WHERE `letters` = 'd';''')
|
||||
WHERE `letters` = 'd';""")
|
||||
# Test sql-update-2 output format
|
||||
assert list(mycli.change_table_format("sql-update-2")) == \
|
||||
[(None, None, None, 'Changed table format to sql-update-2')]
|
||||
assert list(mycli.change_table_format("sql-update-2")) == [(None, None, None, "Changed table format to sql-update-2")]
|
||||
mycli.formatter.query = ""
|
||||
output = mycli.format_output(None, FakeCursor(), headers)
|
||||
assert "\n".join(output) == dedent('''\
|
||||
assert "\n".join(output) == dedent("""\
|
||||
UPDATE `DUAL` SET
|
||||
`optional` = NULL
|
||||
, `float` = 10.0e0
|
||||
|
@ -85,34 +78,31 @@ def test_sql_output(mycli):
|
|||
`optional` = '1'
|
||||
, `float` = 0.5e0
|
||||
, `binary` = X'aabb'
|
||||
WHERE `letters` = 'd' AND `number` = 456;''')
|
||||
WHERE `letters` = 'd' AND `number` = 456;""")
|
||||
# Test sql-insert output format (without table name)
|
||||
assert list(mycli.change_table_format("sql-insert")) == \
|
||||
[(None, None, None, 'Changed table format to sql-insert')]
|
||||
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
|
||||
mycli.formatter.query = ""
|
||||
output = mycli.format_output(None, FakeCursor(), headers)
|
||||
assert "\n".join(output) == dedent('''\
|
||||
assert "\n".join(output) == dedent("""\
|
||||
INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
|
||||
('abc', 1, NULL, 10.0e0, X'aa')
|
||||
, ('d', 456, '1', 0.5e0, X'aabb')
|
||||
;''')
|
||||
;""")
|
||||
# Test sql-insert output format (with table name)
|
||||
assert list(mycli.change_table_format("sql-insert")) == \
|
||||
[(None, None, None, 'Changed table format to sql-insert')]
|
||||
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
|
||||
mycli.formatter.query = "SELECT * FROM `table`"
|
||||
output = mycli.format_output(None, FakeCursor(), headers)
|
||||
assert "\n".join(output) == dedent('''\
|
||||
assert "\n".join(output) == dedent("""\
|
||||
INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
|
||||
('abc', 1, NULL, 10.0e0, X'aa')
|
||||
, ('d', 456, '1', 0.5e0, X'aabb')
|
||||
;''')
|
||||
;""")
|
||||
# Test sql-insert output format (with database + table name)
|
||||
assert list(mycli.change_table_format("sql-insert")) == \
|
||||
[(None, None, None, 'Changed table format to sql-insert')]
|
||||
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
|
||||
mycli.formatter.query = "SELECT * FROM `database`.`table`"
|
||||
output = mycli.format_output(None, FakeCursor(), headers)
|
||||
assert "\n".join(output) == dedent('''\
|
||||
assert "\n".join(output) == dedent("""\
|
||||
INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
|
||||
('abc', 1, NULL, 10.0e0, X'aa')
|
||||
, ('d', 456, '1', 0.5e0, X'aabb')
|
||||
;''')
|
||||
;""")
|
||||
|
|
|
@ -9,20 +9,18 @@ import pytest
|
|||
|
||||
from mycli.main import special
|
||||
|
||||
PASSWORD = os.getenv('PYTEST_PASSWORD')
|
||||
USER = os.getenv('PYTEST_USER', 'root')
|
||||
HOST = os.getenv('PYTEST_HOST', 'localhost')
|
||||
PORT = int(os.getenv('PYTEST_PORT', 3306))
|
||||
CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
|
||||
SSH_USER = os.getenv('PYTEST_SSH_USER', None)
|
||||
SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
|
||||
SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)
|
||||
PASSWORD = os.getenv("PYTEST_PASSWORD")
|
||||
USER = os.getenv("PYTEST_USER", "root")
|
||||
HOST = os.getenv("PYTEST_HOST", "localhost")
|
||||
PORT = int(os.getenv("PYTEST_PORT", 3306))
|
||||
CHARSET = os.getenv("PYTEST_CHARSET", "utf8")
|
||||
SSH_USER = os.getenv("PYTEST_SSH_USER", None)
|
||||
SSH_HOST = os.getenv("PYTEST_SSH_HOST", None)
|
||||
SSH_PORT = os.getenv("PYTEST_SSH_PORT", 22)
|
||||
|
||||
|
||||
def db_connection(dbname=None):
|
||||
conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname,
|
||||
password=PASSWORD, charset=CHARSET,
|
||||
local_infile=False)
|
||||
conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARSET, local_infile=False)
|
||||
conn.autocommit = True
|
||||
return conn
|
||||
|
||||
|
@ -30,20 +28,18 @@ def db_connection(dbname=None):
|
|||
try:
|
||||
db_connection()
|
||||
CAN_CONNECT_TO_DB = True
|
||||
except:
|
||||
except Exception:
|
||||
CAN_CONNECT_TO_DB = False
|
||||
|
||||
dbtest = pytest.mark.skipif(
|
||||
not CAN_CONNECT_TO_DB,
|
||||
reason="Need a mysql instance at localhost accessible by user 'root'")
|
||||
dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason="Need a mysql instance at localhost accessible by user 'root'")
|
||||
|
||||
|
||||
def create_db(dbname):
|
||||
with db_connection().cursor() as cur:
|
||||
try:
|
||||
cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''')
|
||||
cur.execute('''CREATE DATABASE mycli_test_db''')
|
||||
except:
|
||||
cur.execute("""DROP DATABASE IF EXISTS mycli_test_db""")
|
||||
cur.execute("""CREATE DATABASE mycli_test_db""")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -53,8 +49,7 @@ def run(executor, sql, rows_as_list=True):
|
|||
|
||||
for title, rows, headers, status in executor.run(sql):
|
||||
rows = list(rows) if (rows_as_list and rows) else rows
|
||||
result.append({'title': title, 'rows': rows, 'headers': headers,
|
||||
'status': status})
|
||||
result.append({"title": title, "rows": rows, "headers": headers, "status": status})
|
||||
|
||||
return result
|
||||
|
||||
|
@ -87,8 +82,6 @@ def send_ctrl_c(wait_seconds):
|
|||
Returns the `multiprocessing.Process` created.
|
||||
|
||||
"""
|
||||
ctrl_c_process = multiprocessing.Process(
|
||||
target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
|
||||
)
|
||||
ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds))
|
||||
ctrl_c_process.start()
|
||||
return ctrl_c_process
|
||||
|
|
20
tox.ini
20
tox.ini
|
@ -1,15 +1,21 @@
|
|||
[tox]
|
||||
envlist = py36, py37, py38
|
||||
envlist = py
|
||||
|
||||
[testenv]
|
||||
deps = pytest
|
||||
mock
|
||||
pexpect
|
||||
behave
|
||||
coverage
|
||||
commands = python setup.py test
|
||||
skip_install = true
|
||||
deps = uv
|
||||
passenv = PYTEST_HOST
|
||||
PYTEST_USER
|
||||
PYTEST_PASSWORD
|
||||
PYTEST_PORT
|
||||
PYTEST_CHARSET
|
||||
commands = uv pip install -e .[dev,ssh]
|
||||
coverage run -m pytest -v test
|
||||
coverage report -m
|
||||
behave test/features
|
||||
|
||||
[testenv:style]
|
||||
skip_install = true
|
||||
deps = ruff
|
||||
commands = ruff check --fix
|
||||
ruff format
|
||||
|
|
Loading…
Add table
Reference in a new issue