1
0
Fork 0

Adding upstream version 1.29.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 19:15:57 +01:00
parent 5bd6a68e8c
commit f9065f1bef
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
68 changed files with 3723 additions and 3336 deletions

View file

@ -1,13 +1,12 @@
import pytest
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
from .utils import HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT
import mycli.sqlexecute
@pytest.fixture(scope="function")
def connection():
create_db('mycli_test_db')
connection = db_connection('mycli_test_db')
create_db("mycli_test_db")
connection = db_connection("mycli_test_db")
yield connection
connection.close()
@ -22,8 +21,18 @@ def cursor(connection):
@pytest.fixture
def executor(connection):
return mycli.sqlexecute.SQLExecute(
database='mycli_test_db', user=USER,
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
database="mycli_test_db",
user=USER,
host=HOST,
password=PASSWORD,
port=PORT,
socket=None,
charset=CHARSET,
local_infile=False,
ssl=None,
ssh_user=SSH_USER,
ssh_host=SSH_HOST,
ssh_port=SSH_PORT,
ssh_password=None,
ssh_key_filename=None,
)

View file

@ -1,8 +1,7 @@
import pymysql
def create_db(hostname='localhost', port=3306, username=None,
password=None, dbname=None):
def create_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
"""Create test database.
:param hostname: string
@ -14,17 +13,12 @@ def create_db(hostname='localhost', port=3306, username=None,
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
cr.execute('drop database if exists ' + dbname)
cr.execute('create database ' + dbname)
cr.execute("drop database if exists " + dbname)
cr.execute("create database " + dbname)
cn.close()
@ -44,20 +38,13 @@ def create_cn(hostname, port, password, username, dbname):
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
return cn
def drop_db(hostname='localhost', port=3306, username=None,
password=None, dbname=None):
def drop_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
"""Drop database.
:param hostname: string
@ -68,17 +55,11 @@ def drop_db(hostname='localhost', port=3306, username=None,
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
cr.execute('drop database if exists ' + dbname)
cr.execute("drop database if exists " + dbname)
close_cn(cn)

View file

@ -9,96 +9,72 @@ import pexpect
from steps.wrappers import run_cli, wait_prompt
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log")
SELF_CONNECTING_FEATURES = (
'test/features/connection.feature',
)
SELF_CONNECTING_FEATURES = ("test/features/connection.feature",)
MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
MY_CNF_PATH = os.path.expanduser("~/.my.cnf")
MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup"
MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf")
MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup"
def get_db_name_from_context(context):
return context.config.userdata.get(
'my_test_db', None
) or "mycli_behave_tests"
return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests"
def before_all(context):
"""Set env parameters."""
os.environ['LINES'] = "100"
os.environ['COLUMNS'] = "100"
os.environ['EDITOR'] = 'ex'
os.environ['LC_ALL'] = 'en_US.UTF-8'
os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
os.environ['MYCLI_HISTFILE'] = os.devnull
os.environ["LINES"] = "100"
os.environ["COLUMNS"] = "100"
os.environ["EDITOR"] = "ex"
os.environ["LC_ALL"] = "en_US.UTF-8"
os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
os.environ["MYCLI_HISTFILE"] = os.devnull
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
# test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
# login_path_file = os.path.join(test_dir, "mylogin.cnf")
# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
'.coveragerc')
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc")
context.exit_sent = False
vi = '_'.join([str(x) for x in sys.version_info[:3]])
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = get_db_name_from_context(context)
db_name_full = '{0}_{1}'.format(db_name, vi)
db_name_full = "{0}_{1}".format(db_name, vi)
# Store get params from config/environment variables
context.conf = {
'host': context.config.userdata.get(
'my_test_host',
os.getenv('PYTEST_HOST', 'localhost')
),
'port': context.config.userdata.get(
'my_test_port',
int(os.getenv('PYTEST_PORT', '3306'))
),
'user': context.config.userdata.get(
'my_test_user',
os.getenv('PYTEST_USER', 'root')
),
'pass': context.config.userdata.get(
'my_test_pass',
os.getenv('PYTEST_PASSWORD', None)
),
'cli_command': context.config.userdata.get(
'my_cli_command', None) or
sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
'dbname': db_name,
'dbname_tmp': db_name_full + '_tmp',
'vi': vi,
'pager_boundary': '---boundary---',
"host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", "localhost")),
"port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", "3306"))),
"user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", "root")),
"pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)),
"cli_command": context.config.userdata.get("my_cli_command", None)
or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
"dbname": db_name,
"dbname_tmp": db_name_full + "_tmp",
"vi": vi,
"pager_boundary": "---boundary---",
}
_, my_cnf = mkstemp()
with open(my_cnf, 'w') as f:
with open(my_cnf, "w") as f:
f.write(
'[client]\n'
'pager={0} {1} {2}\n'.format(
sys.executable, os.path.join(context.package_root,
'test/features/wrappager.py'),
context.conf['pager_boundary'])
"[client]\n" "pager={0} {1} {2}\n".format(
sys.executable, os.path.join(context.package_root, "test/features/wrappager.py"), context.conf["pager_boundary"]
)
)
context.conf['defaults-file'] = my_cnf
context.conf['myclirc'] = os.path.join(context.package_root, 'test',
'myclirc')
context.conf["defaults-file"] = my_cnf
context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc")
context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
context.conf['user'],
context.conf['pass'],
context.conf['dbname'])
context.cn = dbutils.create_db(
context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]
)
context.fixture_data = fixutils.read_fixture_files()
@ -106,12 +82,10 @@ def before_all(context):
def after_all(context):
"""Unset env parameters."""
dbutils.close_cn(context.cn)
dbutils.drop_db(context.conf['host'], context.conf['port'],
context.conf['user'], context.conf['pass'],
context.conf['dbname'])
dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"])
# Restore env vars.
#for k, v in context.pgenv.items():
# for k, v in context.pgenv.items():
# if k in os.environ and v is None:
# del os.environ[k]
# elif v:
@ -123,8 +97,8 @@ def before_step(context, _):
def before_scenario(context, arg):
with open(test_log_file, 'w') as f:
f.write('')
with open(test_log_file, "w") as f:
f.write("")
if arg.location.filename not in SELF_CONNECTING_FEATURES:
run_cli(context)
wait_prompt(context)
@ -140,23 +114,18 @@ def after_scenario(context, _):
"""Cleans up after each test complete."""
with open(test_log_file) as f:
for line in f:
if 'error' in line.lower():
raise RuntimeError(f'Error in log file: {line}')
if "error" in line.lower():
raise RuntimeError(f"Error in log file: {line}")
if hasattr(context, 'cli') and not context.exit_sent:
if hasattr(context, "cli") and not context.exit_sent:
# Quit nicely.
if not context.atprompt:
user = context.conf['user']
host = context.conf['host']
user = context.conf["user"]
host = context.conf["host"]
dbname = context.currentdb
context.cli.expect_exact(
'{0}@{1}:{2}>'.format(
user, host, dbname
),
timeout=5
)
context.cli.sendcontrol('c')
context.cli.sendcontrol('d')
context.cli.expect_exact("{0}@{1}:{2}>".format(user, host, dbname), timeout=5)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
context.cli.expect_exact(pexpect.EOF, timeout=5)
if os.path.exists(MY_CNF_BACKUP_PATH):

View file

@ -1,5 +1,4 @@
import os
import io
def read_fixture_lines(filename):
@ -20,9 +19,9 @@ def read_fixture_files():
fixture_dict = {}
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, 'fixture_data/')
fixture_dir = os.path.join(current_dir, "fixture_data/")
for filename in os.listdir(fixture_dir):
if filename not in ['.', '..']:
if filename not in [".", ".."]:
fullname = os.path.join(fixture_dir, filename)
fixture_dict[filename] = read_fixture_lines(fullname)

View file

@ -6,41 +6,42 @@ import wrappers
from utils import parse_cli_args_to_dict
@when('we run dbcli with {arg}')
@when("we run dbcli with {arg}")
def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
@when('we execute a small query')
@when("we execute a small query")
def step_execute_small_query(context):
context.cli.sendline('select 1')
context.cli.sendline("select 1")
@when('we execute a large query')
@when("we execute a large query")
def step_execute_large_query(context):
context.cli.sendline(
'select {}'.format(','.join([str(n) for n in range(1, 50)])))
context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
@then('we see small results in horizontal format')
@then("we see small results in horizontal format")
def step_see_small_results(context):
wrappers.expect_pager(context, dedent("""\
wrappers.expect_pager(
context,
dedent("""\
+---+\r
| 1 |\r
+---+\r
| 1 |\r
+---+\r
\r
"""), timeout=5)
wrappers.expect_exact(context, '1 row in set', timeout=2)
"""),
timeout=5,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)
@then('we see large results in vertical format')
@then("we see large results in vertical format")
def step_see_large_results(context):
rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)]
expected = ('***************************[ 1. row ]'
'***************************\r\n' +
'{}\r\n'.format('\r\n'.join(rows) + '\r\n'))
rows = ["{n:3}| {n}".format(n=str(n)) for n in range(1, 50)]
expected = "***************************[ 1. row ]" "***************************\r\n" + "{}\r\n".format("\r\n".join(rows) + "\r\n")
wrappers.expect_pager(context, expected, timeout=10)
wrappers.expect_exact(context, '1 row in set', timeout=2)
wrappers.expect_exact(context, "1 row in set", timeout=2)

View file

@ -5,18 +5,18 @@ to call the step in "*.feature" file.
"""
from behave import when
from behave import when, then
from textwrap import dedent
import tempfile
import wrappers
@when('we run dbcli')
@when("we run dbcli")
def step_run_cli(context):
wrappers.run_cli(context)
@when('we wait for prompt')
@when("we wait for prompt")
def step_wait_prompt(context):
wrappers.wait_prompt(context)
@ -24,77 +24,75 @@ def step_wait_prompt(context):
@when('we send "ctrl + d"')
def step_ctrl_d(context):
"""Send Ctrl + D to hopefully exit."""
context.cli.sendcontrol('d')
context.cli.sendcontrol("d")
context.exit_sent = True
@when('we send "\?" command')
@when(r'we send "\?" command')
def step_send_help(context):
"""Send \?
r"""Send \?
to see help.
"""
context.cli.sendline('\\?')
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
context.cli.sendline("\\?")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@when(u'we send source command')
@when("we send source command")
def step_send_source_command(context):
with tempfile.NamedTemporaryFile() as f:
f.write(b'\?')
f.write(b"\\?")
f.flush()
context.cli.sendline('\. {0}'.format(f.name))
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
context.cli.sendline("\\. {0}".format(f.name))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@when(u'we run query to check application_name')
@when("we run query to check application_name")
def step_check_application_name(context):
context.cli.sendline(
"SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'"
)
@then(u'we see found')
@then("we see found")
def step_see_found(context):
wrappers.expect_exact(
context,
context.conf['pager_boundary'] + '\r' + dedent('''
context.conf["pager_boundary"]
+ "\r"
+ dedent("""
+-------+\r
| found |\r
+-------+\r
| found |\r
+-------+\r
\r
''') + context.conf['pager_boundary'],
timeout=5
""")
+ context.conf["pager_boundary"],
timeout=5,
)
@then(u'we confirm the destructive warning')
def step_confirm_destructive_command(context):
@then("we confirm the destructive warning")
def step_confirm_destructive_command(context): # noqa
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
context.cli.sendline('y')
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline("y")
@when(u'we answer the destructive warning with "{confirmation}"')
def step_confirm_destructive_command(context, confirmation):
@when('we answer the destructive warning with "{confirmation}"')
def step_confirm_destructive_command(context, confirmation): # noqa
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline(confirmation)
@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
def step_confirm_destructive_command(context, confirmation, text):
@then('we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
def step_confirm_destructive_command(context, confirmation, text): # noqa
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline(confirmation)
wrappers.expect_exact(context, text, timeout=2)
# we must exit the Click loop, or the feature will hang
context.cli.sendline('n')
context.cli.sendline("n")

View file

@ -1,9 +1,7 @@
import io
import os
import shlex
from behave import when, then
import pexpect
import wrappers
from test.features.steps.utils import parse_cli_args_to_dict
@ -12,60 +10,44 @@ from test.utils import HOST, PORT, USER, PASSWORD
from mycli.config import encrypt_mylogin_cnf
TEST_LOGIN_PATH = 'test_login_path'
TEST_LOGIN_PATH = "test_login_path"
@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
@when('we run mycli without arguments "{excluded_args}"')
def step_run_cli_without_args(context, excluded_args, exact_args=''):
wrappers.run_cli(
context,
run_args=parse_cli_args_to_dict(exact_args),
exclude_args=parse_cli_args_to_dict(excluded_args).keys()
)
def step_run_cli_without_args(context, excluded_args, exact_args=""):
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(exact_args), exclude_args=parse_cli_args_to_dict(excluded_args).keys())
@then('status contains "{expression}"')
def status_contains(context, expression):
wrappers.expect_exact(context, f'{expression}', timeout=5)
wrappers.expect_exact(context, f"{expression}", timeout=5)
# Normally, the shutdown after scenario waits for the prompt.
# But we may have changed the prompt, depending on parameters,
# so let's wait for its last character
context.cli.expect_exact('>')
context.cli.expect_exact(">")
context.atprompt = True
@when('we create my.cnf file')
@when("we create my.cnf file")
def step_create_my_cnf_file(context):
my_cnf = (
'[client]\n'
f'host = {HOST}\n'
f'port = {PORT}\n'
f'user = {USER}\n'
f'password = {PASSWORD}\n'
)
with open(MY_CNF_PATH, 'w') as f:
my_cnf = "[client]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
with open(MY_CNF_PATH, "w") as f:
f.write(my_cnf)
@when('we create mylogin.cnf file')
@when("we create mylogin.cnf file")
def step_create_mylogin_cnf_file(context):
os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
mylogin_cnf = (
f'[{TEST_LOGIN_PATH}]\n'
f'host = {HOST}\n'
f'port = {PORT}\n'
f'user = {USER}\n'
f'password = {PASSWORD}\n'
)
with open(MYLOGIN_CNF_PATH, 'wb') as f:
os.environ.pop("MYSQL_TEST_LOGIN_FILE", None)
mylogin_cnf = f"[{TEST_LOGIN_PATH}]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
with open(MYLOGIN_CNF_PATH, "wb") as f:
input_file = io.StringIO(mylogin_cnf)
f.write(encrypt_mylogin_cnf(input_file).read())
@then('we are logged in')
@then("we are logged in")
def we_are_logged_in(context):
db_name = get_db_name_from_context(context)
context.cli.expect_exact(f'{db_name}>', timeout=5)
context.cli.expect_exact(f"{db_name}>", timeout=5)
context.atprompt = True

View file

@ -11,105 +11,99 @@ import wrappers
from behave import when, then
@when('we create database')
@when("we create database")
def step_db_create(context):
"""Send create database."""
context.cli.sendline('create database {0};'.format(
context.conf['dbname_tmp']))
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
context.response = {
'database_name': context.conf['dbname_tmp']
}
context.response = {"database_name": context.conf["dbname_tmp"]}
@when('we drop database')
@when("we drop database")
def step_db_drop(context):
"""Send drop database."""
context.cli.sendline('drop database {0};'.format(
context.conf['dbname_tmp']))
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
@when('we connect to test database')
@when("we connect to test database")
def step_db_connect_test(context):
"""Send connect to database."""
db_name = context.conf['dbname']
db_name = context.conf["dbname"]
context.currentdb = db_name
context.cli.sendline('use {0};'.format(db_name))
context.cli.sendline("use {0};".format(db_name))
@when('we connect to quoted test database')
@when("we connect to quoted test database")
def step_db_connect_quoted_tmp(context):
"""Send connect to database."""
db_name = context.conf['dbname']
db_name = context.conf["dbname"]
context.currentdb = db_name
context.cli.sendline('use `{0}`;'.format(db_name))
context.cli.sendline("use `{0}`;".format(db_name))
@when('we connect to tmp database')
@when("we connect to tmp database")
def step_db_connect_tmp(context):
"""Send connect to database."""
db_name = context.conf['dbname_tmp']
db_name = context.conf["dbname_tmp"]
context.currentdb = db_name
context.cli.sendline('use {0}'.format(db_name))
context.cli.sendline("use {0}".format(db_name))
@when('we connect to dbserver')
@when("we connect to dbserver")
def step_db_connect_dbserver(context):
"""Send connect to database."""
context.currentdb = 'mysql'
context.cli.sendline('use mysql')
context.currentdb = "mysql"
context.cli.sendline("use mysql")
@then('dbcli exits')
@then("dbcli exits")
def step_wait_exit(context):
"""Make sure the cli exits."""
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
@then('we see dbcli prompt')
@then("we see dbcli prompt")
def step_see_prompt(context):
"""Wait to see the prompt."""
user = context.conf['user']
host = context.conf['host']
user = context.conf["user"]
host = context.conf["host"]
dbname = context.currentdb
wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname))
wrappers.wait_prompt(context, "{0}@{1}:{2}> ".format(user, host, dbname))
@then('we see help output')
@then("we see help output")
def step_see_help(context):
for expected_line in context.fixture_data['help_commands.txt']:
for expected_line in context.fixture_data["help_commands.txt"]:
wrappers.expect_exact(context, expected_line, timeout=1)
@then('we see database created')
@then("we see database created")
def step_see_db_created(context):
"""Wait to see create database output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see database dropped')
@then("we see database dropped")
def step_see_db_dropped(context):
"""Wait to see drop database output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
@then('we see database dropped and no default database')
@then("we see database dropped and no default database")
def step_see_db_dropped_no_default(context):
"""Wait to see drop database output."""
user = context.conf['user']
host = context.conf['host']
database = '(none)'
user = context.conf["user"]
host = context.conf["host"]
database = "(none)"
context.currentdb = None
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database))
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
wrappers.wait_prompt(context, "{0}@{1}:{2}>".format(user, host, database))
@then('we see database connected')
@then("we see database connected")
def step_see_db_connected(context):
"""Wait to see drop database output."""
wrappers.expect_exact(
context, 'You are now connected to database "', timeout=2)
wrappers.expect_exact(context, 'You are now connected to database "', timeout=2)
wrappers.expect_exact(context, '"', timeout=2)
wrappers.expect_exact(context, ' as user "{0}"'.format(
context.conf['user']), timeout=2)
wrappers.expect_exact(context, ' as user "{0}"'.format(context.conf["user"]), timeout=2)

View file

@ -10,103 +10,109 @@ from behave import when, then
from textwrap import dedent
@when('we create table')
@when("we create table")
def step_create_table(context):
"""Send create table."""
context.cli.sendline('create table a(x text);')
context.cli.sendline("create table a(x text);")
@when('we insert into table')
@when("we insert into table")
def step_insert_into_table(context):
"""Send insert into table."""
context.cli.sendline('''insert into a(x) values('xxx');''')
context.cli.sendline("""insert into a(x) values('xxx');""")
@when('we update table')
@when("we update table")
def step_update_table(context):
"""Send insert into table."""
context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""")
@when('we select from table')
@when("we select from table")
def step_select_from_table(context):
"""Send select from table."""
context.cli.sendline('select * from a;')
context.cli.sendline("select * from a;")
@when('we delete from table')
@when("we delete from table")
def step_delete_from_table(context):
"""Send deete from table."""
context.cli.sendline('''delete from a where x = 'yyy';''')
context.cli.sendline("""delete from a where x = 'yyy';""")
@when('we drop table')
@when("we drop table")
def step_drop_table(context):
"""Send drop table."""
context.cli.sendline('drop table a;')
context.cli.sendline("drop table a;")
@then('we see table created')
@then("we see table created")
def step_see_table_created(context):
"""Wait to see create table output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
@then('we see record inserted')
@then("we see record inserted")
def step_see_record_inserted(context):
"""Wait to see insert output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see record updated')
@then("we see record updated")
def step_see_record_updated(context):
"""Wait to see update output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see data selected')
@then("we see data selected")
def step_see_data_selected(context):
"""Wait to see select output."""
wrappers.expect_pager(
context, dedent("""\
context,
dedent("""\
+-----+\r
| x |\r
+-----+\r
| yyy |\r
+-----+\r
\r
"""), timeout=2)
wrappers.expect_exact(context, '1 row in set', timeout=2)
"""),
timeout=2,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)
@then('we see record deleted')
@then("we see record deleted")
def step_see_data_deleted(context):
"""Wait to see delete output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
@then('we see table dropped')
@then("we see table dropped")
def step_see_table_dropped(context):
"""Wait to see drop output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
@when('we select null')
@when("we select null")
def step_select_null(context):
"""Send select null."""
context.cli.sendline('select null;')
context.cli.sendline("select null;")
@then('we see null selected')
@then("we see null selected")
def step_see_null_selected(context):
"""Wait to see null output."""
wrappers.expect_pager(
context, dedent("""\
context,
dedent("""\
+--------+\r
| NULL |\r
+--------+\r
| <null> |\r
+--------+\r
\r
"""), timeout=2)
wrappers.expect_exact(context, '1 row in set', timeout=2)
"""),
timeout=2,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)

View file

@ -5,101 +5,93 @@ from behave import when, then
from textwrap import dedent
@when('we start external editor providing a file name')
@when("we start external editor providing a file name")
def step_edit_file(context):
"""Edit file with external editor."""
context.editor_file_name = os.path.join(
context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
context.editor_file_name = os.path.join(context.package_root, "test_file_{0}.sql".format(context.conf["vi"]))
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline('\e {0}'.format(
os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
wrappers.expect_exact(context, '\r\n:', timeout=2)
context.cli.sendline("\\e {0}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
wrappers.expect_exact(context, "\r\n:", timeout=2)
@when('we type "{query}" in the editor')
def step_edit_type_sql(context, query):
context.cli.sendline('i')
context.cli.sendline("i")
context.cli.sendline(query)
context.cli.sendline('.')
wrappers.expect_exact(context, '\r\n:', timeout=2)
context.cli.sendline(".")
wrappers.expect_exact(context, "\r\n:", timeout=2)
@when('we exit the editor')
@when("we exit the editor")
def step_edit_quit(context):
context.cli.sendline('x')
context.cli.sendline("x")
wrappers.expect_exact(context, "written", timeout=2)
@then('we see "{query}" in prompt')
def step_edit_done_sql(context, query):
for match in query.split(' '):
for match in query.split(" "):
wrappers.expect_exact(context, match, timeout=5)
# Cleanup the command line.
context.cli.sendcontrol('c')
context.cli.sendcontrol("c")
# Cleanup the edited file.
if context.editor_file_name and os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
@when(u'we tee output')
@when("we tee output")
def step_tee_ouptut(context):
context.tee_file_name = os.path.join(
context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
context.tee_file_name = os.path.join(context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]))
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline('tee {0}'.format(
os.path.basename(context.tee_file_name)))
context.cli.sendline("tee {0}".format(os.path.basename(context.tee_file_name)))
@when(u'we select "select {param}"')
@when('we select "select {param}"')
def step_query_select_number(context, param):
context.cli.sendline(u'select {}'.format(param))
wrappers.expect_pager(context, dedent(u"""\
context.cli.sendline("select {}".format(param))
wrappers.expect_pager(
context,
dedent(
"""\
+{dashes}+\r
| {param} |\r
+{dashes}+\r
| {param} |\r
+{dashes}+\r
\r
""".format(param=param, dashes='-' * (len(param) + 2))
), timeout=5)
wrappers.expect_exact(context, '1 row in set', timeout=2)
@then(u'we see result "{result}"')
def step_see_result(context, result):
wrappers.expect_exact(
context,
u"| {} |".format(result),
timeout=2
""".format(param=param, dashes="-" * (len(param) + 2))
),
timeout=5,
)
wrappers.expect_exact(context, "1 row in set", timeout=2)
@when(u'we query "{query}"')
@then('we see result "{result}"')
def step_see_result(context, result):
wrappers.expect_exact(context, "| {} |".format(result), timeout=2)
@when('we query "{query}"')
def step_query(context, query):
context.cli.sendline(query)
@when(u'we notee output')
@when("we notee output")
def step_notee_output(context):
context.cli.sendline('notee')
context.cli.sendline("notee")
@then(u'we see 123456 in tee output')
@then("we see 123456 in tee output")
def step_see_123456_in_ouput(context):
with open(context.tee_file_name) as f:
assert '123456' in f.read()
assert "123456" in f.read()
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
@then(u'delimiter is set to "{delimiter}"')
@then('delimiter is set to "{delimiter}"')
def delimiter_is_set(context, delimiter):
wrappers.expect_exact(
context,
u'Changed delimiter to {}'.format(delimiter),
timeout=2
)
wrappers.expect_exact(context, "Changed delimiter to {}".format(delimiter), timeout=2)

View file

@ -9,82 +9,79 @@ import wrappers
from behave import when, then
@when('we save a named query')
@when("we save a named query")
def step_save_named_query(context):
"""Send \fs command."""
context.cli.sendline('\\fs foo SELECT 12345')
context.cli.sendline("\\fs foo SELECT 12345")
@when('we use a named query')
@when("we use a named query")
def step_use_named_query(context):
"""Send \f command."""
context.cli.sendline('\\f foo')
context.cli.sendline("\\f foo")
@when('we delete a named query')
@when("we delete a named query")
def step_delete_named_query(context):
"""Send \fd command."""
context.cli.sendline('\\fd foo')
context.cli.sendline("\\fd foo")
@then('we see the named query saved')
@then("we see the named query saved")
def step_see_named_query_saved(context):
"""Wait to see query saved."""
wrappers.expect_exact(context, 'Saved.', timeout=2)
wrappers.expect_exact(context, "Saved.", timeout=2)
@then('we see the named query executed')
@then("we see the named query executed")
def step_see_named_query_executed(context):
"""Wait to see select output."""
wrappers.expect_exact(context, 'SELECT 12345', timeout=2)
wrappers.expect_exact(context, "SELECT 12345", timeout=2)
@then('we see the named query deleted')
@then("we see the named query deleted")
def step_see_named_query_deleted(context):
"""Wait to see query deleted."""
wrappers.expect_exact(context, 'foo: Deleted', timeout=2)
wrappers.expect_exact(context, "foo: Deleted", timeout=2)
@when('we save a named query with parameters')
@when("we save a named query with parameters")
def step_save_named_query_with_parameters(context):
"""Send \fs command for query with parameters."""
context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"')
@when('we use named query with parameters')
@when("we use named query with parameters")
def step_use_named_query_with_parameters(context):
"""Send \f command with parameters."""
context.cli.sendline('\\f foo_args 101 second "third value"')
@then('we see the named query with parameters executed')
@then("we see the named query with parameters executed")
def step_see_named_query_with_parameters_executed(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'SELECT 101, "second", "third value"', timeout=2)
wrappers.expect_exact(context, 'SELECT 101, "second", "third value"', timeout=2)
@when('we use named query with too few parameters')
@when("we use named query with too few parameters")
def step_use_named_query_with_too_few_parameters(context):
"""Send \f command with missing parameters."""
context.cli.sendline('\\f foo_args 101')
context.cli.sendline("\\f foo_args 101")
@then('we see the named query with parameters fail with missing parameters')
@then("we see the named query with parameters fail with missing parameters")
def step_see_named_query_with_parameters_fail_with_missing_parameters(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'missing substitution for $2 in query:', timeout=2)
wrappers.expect_exact(context, "missing substitution for $2 in query:", timeout=2)
@when('we use named query with too many parameters')
@when("we use named query with too many parameters")
def step_use_named_query_with_too_many_parameters(context):
"""Send \f command with extra parameters."""
context.cli.sendline('\\f foo_args 101 102 103 104')
context.cli.sendline("\\f foo_args 101 102 103 104")
@then('we see the named query with parameters fail with extra parameters')
@then("we see the named query with parameters fail with extra parameters")
def step_see_named_query_with_parameters_fail_with_extra_parameters(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'query does not have substitution parameter $4:', timeout=2)
wrappers.expect_exact(context, "query does not have substitution parameter $4:", timeout=2)

View file

@ -9,10 +9,10 @@ import wrappers
from behave import when, then
@when('we refresh completions')
@when("we refresh completions")
def step_refresh_completions(context):
"""Send refresh command."""
context.cli.sendline('rehash')
context.cli.sendline("rehash")
@then('we see text "{text}"')
@ -20,8 +20,8 @@ def step_see_text(context, text):
"""Wait to see given text message."""
wrappers.expect_exact(context, text, timeout=2)
@then('we see completions refresh started')
@then("we see completions refresh started")
def step_see_refresh_started(context):
"""Wait to see refresh output."""
wrappers.expect_exact(
context, 'Auto-completion refresh started in the background.', timeout=2)
wrappers.expect_exact(context, "Auto-completion refresh started in the background.", timeout=2)

View file

@ -4,8 +4,8 @@ import shlex
def parse_cli_args_to_dict(cli_args: str):
args_dict = {}
for arg in shlex.split(cli_args):
if '=' in arg:
key, value = arg.split('=')
if "=" in arg:
key, value = arg.split("=")
args_dict[key] = value
else:
args_dict[arg] = None

View file

@ -18,10 +18,9 @@ def expect_exact(context, expected, timeout):
timedout = True
if timedout:
# Strip color codes out of the output.
actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
'', context.cli.before)
actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
raise Exception(
textwrap.dedent('''\
textwrap.dedent("""\
Expected:
---
{0!r}
@ -34,17 +33,12 @@ def expect_exact(context, expected, timeout):
---
{2!r}
---
''').format(
expected,
actual,
context.logfile.getvalue()
)
""").format(expected, actual, context.logfile.getvalue())
)
def expect_pager(context, expected, timeout):
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
context.conf['pager_boundary'], expected), timeout=timeout)
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), timeout=timeout)
def run_cli(context, run_args=None, exclude_args=None):
@ -63,55 +57,49 @@ def run_cli(context, run_args=None, exclude_args=None):
else:
rendered_args.append(key)
if conf.get('host', None):
add_arg('host', '-h', conf['host'])
if conf.get('user', None):
add_arg('user', '-u', conf['user'])
if conf.get('pass', None):
add_arg('pass', '-p', conf['pass'])
if conf.get('port', None):
add_arg('port', '-P', str(conf['port']))
if conf.get('dbname', None):
add_arg('dbname', '-D', conf['dbname'])
if conf.get('defaults-file', None):
add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
if conf.get('myclirc', None):
add_arg('myclirc', '--myclirc', conf['myclirc'])
if conf.get('login_path'):
add_arg('login_path', '--login-path', conf['login_path'])
if conf.get("host", None):
add_arg("host", "-h", conf["host"])
if conf.get("user", None):
add_arg("user", "-u", conf["user"])
if conf.get("pass", None):
add_arg("pass", "-p", conf["pass"])
if conf.get("port", None):
add_arg("port", "-P", str(conf["port"]))
if conf.get("dbname", None):
add_arg("dbname", "-D", conf["dbname"])
if conf.get("defaults-file", None):
add_arg("defaults_file", "--defaults-file", conf["defaults-file"])
if conf.get("myclirc", None):
add_arg("myclirc", "--myclirc", conf["myclirc"])
if conf.get("login_path"):
add_arg("login_path", "--login-path", conf["login_path"])
for arg_name, arg_value in conf.items():
if arg_name.startswith('-'):
if arg_name.startswith("-"):
add_arg(arg_name, arg_name, arg_value)
try:
cli_cmd = context.conf['cli_command']
cli_cmd = context.conf["cli_command"]
except KeyError:
cli_cmd = (
'{0!s} -c "'
'import coverage ; '
'coverage.process_startup(); '
'import mycli.main; '
'mycli.main.cli()'
'"'
).format(sys.executable)
cli_cmd = ('{0!s} -c "' "import coverage ; " "coverage.process_startup(); " "import mycli.main; " "mycli.main.cli()" '"').format(
sys.executable
)
cmd_parts = [cli_cmd] + rendered_args
cmd = ' '.join(cmd_parts)
cmd = " ".join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = context.conf['dbname']
context.currentdb = context.conf["dbname"]
def wait_prompt(context, prompt=None):
"""Make sure prompt is displayed."""
if prompt is None:
user = context.conf['user']
host = context.conf['host']
user = context.conf["user"]
host = context.conf["host"]
dbname = context.currentdb
prompt = '{0}@{1}:{2}>'.format(
user, host, dbname),
prompt = ("{0}@{1}:{2}>".format(user, host, dbname),)
expect_exact(context, prompt, timeout=5)
context.atprompt = True

View file

@ -153,6 +153,7 @@ output.null = "#808080"
# Favorite queries.
[favorite_queries]
check = 'select "✔"'
foo_args = 'SELECT $1, "$2", "$3"'
# Use the -d option to reference a DSN.
# Special characters in passwords and other strings can be escaped with URL encoding.

View file

@ -1,4 +1,5 @@
"""Test the mycli.clistyle module."""
import pytest
from pygments.style import Style
@ -10,9 +11,9 @@ from mycli.clistyle import style_factory
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory():
"""Test that a Pygments Style class is created."""
header = 'bold underline #ansired'
cli_style = {'Token.Output.Header': header}
style = style_factory('default', cli_style)
header = "bold underline #ansired"
cli_style = {"Token.Output.Header": header}
style = style_factory("default", cli_style)
assert isinstance(style(), Style)
assert Token.Output.Header in style.styles
@ -22,6 +23,6 @@ def test_style_factory():
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory_unknown_name():
"""Test that an unrecognized name will not throw an error."""
style = style_factory('foobar', {})
style = style_factory("foobar", {})
assert isinstance(style(), Style)

View file

@ -8,494 +8,528 @@ def sorted_dicts(dicts):
def test_select_suggests_cols_with_visible_table_scope():
suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_select_suggests_cols_with_qualified_table_scope():
suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [('sch', 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [("sch", "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE (',
'SELECT * FROM tabl WHERE foo = ',
'SELECT * FROM tabl WHERE bar OR ',
'SELECT * FROM tabl WHERE foo = 1 AND ',
'SELECT * FROM tabl WHERE (bar > 10 AND ',
'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (',
'SELECT * FROM tabl WHERE 10 < ',
'SELECT * FROM tabl WHERE foo BETWEEN ',
'SELECT * FROM tabl WHERE foo BETWEEN foo AND ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM tabl WHERE ",
"SELECT * FROM tabl WHERE (",
"SELECT * FROM tabl WHERE foo = ",
"SELECT * FROM tabl WHERE bar OR ",
"SELECT * FROM tabl WHERE foo = 1 AND ",
"SELECT * FROM tabl WHERE (bar > 10 AND ",
"SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
"SELECT * FROM tabl WHERE 10 < ",
"SELECT * FROM tabl WHERE foo BETWEEN ",
"SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
],
)
def test_where_suggests_columns_functions(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE foo IN (',
'SELECT * FROM tabl WHERE foo IN (bar, ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM tabl WHERE foo IN (",
"SELECT * FROM tabl WHERE foo IN (bar, ",
],
)
def test_where_in_suggests_columns(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_where_equals_any_suggests_columns_or_keywords():
text = 'SELECT * FROM tabl WHERE foo = ANY('
text = "SELECT * FROM tabl WHERE foo = ANY("
suggestions = suggest_type(text, text)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_lparen_suggests_cols():
suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_operand_inside_function_suggests_cols1():
suggestion = suggest_type(
'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ")
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_operand_inside_function_suggests_cols2():
suggestion = suggest_type(
'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
suggestion = suggest_type("SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + ")
assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_select_suggests_cols_and_funcs():
suggestions = suggest_type('SELECT ', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': []},
{'type': 'column', 'tables': []},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT ", "SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": []},
{"type": "column", "tables": []},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM ',
'INSERT INTO ',
'COPY ',
'UPDATE ',
'DESCRIBE ',
'DESC ',
'EXPLAIN ',
'SELECT * FROM foo JOIN ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM ",
"INSERT INTO ",
"COPY ",
"UPDATE ",
"DESCRIBE ",
"DESC ",
"EXPLAIN ",
"SELECT * FROM foo JOIN ",
],
)
def test_expression_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
@pytest.mark.parametrize('expression', [
'SELECT * FROM sch.',
'INSERT INTO sch.',
'COPY sch.',
'UPDATE sch.',
'DESCRIBE sch.',
'DESC sch.',
'EXPLAIN sch.',
'SELECT * FROM foo JOIN sch.',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM sch.",
"INSERT INTO sch.",
"COPY sch.",
"UPDATE sch.",
"DESCRIBE sch.",
"DESC sch.",
"EXPLAIN sch.",
"SELECT * FROM foo JOIN sch.",
],
)
def test_expression_suggests_qualified_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 'sch'},
{'type': 'view', 'schema': 'sch'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}])
def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "schema"}])
def test_truncate_suggests_qualified_tables():
suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 'sch'}])
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}])
def test_distinct_suggests_cols():
suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
assert suggestions == [{'type': 'column', 'tables': []}]
suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ")
assert suggestions == [{"type": "column", "tables": []}]
def test_col_comma_suggests_cols():
suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tbl']},
{'type': 'column', 'tables': [(None, 'tbl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tbl"]},
{"type": "column", "tables": [(None, "tbl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type('SELECT a, b FROM tbl1, ',
'SELECT a, b FROM tbl1, ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_insert_into_lparen_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (")
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_insert_into_lparen_partial_text_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_insert_into_lparen_comma_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_partially_typed_col_name_suggests_col_names():
suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
'SELECT * FROM tabl WHERE col_n')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'table', 'schema': 'tabl'},
{'type': 'view', 'schema': 'tabl'},
{'type': 'function', 'schema': 'tabl'}])
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "table", "schema": "tabl"},
{"type": "view", "schema": "tabl"},
{"type": "function", "schema": "tabl"},
]
)
def test_dot_suggests_cols_of_an_alias():
suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
'SELECT t1.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 't1'},
{'type': 'view', 'schema': 't1'},
{'type': 'column', 'tables': [(None, 'tabl1', 't1')]},
{'type': 'function', 'schema': 't1'}])
suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "table", "schema": "t1"},
{"type": "view", "schema": "t1"},
{"type": "column", "tables": [(None, "tabl1", "t1")]},
{"type": "function", "schema": "t1"},
]
)
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
'SELECT t1.a, t2.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
{'type': 'table', 'schema': 't2'},
{'type': 'view', 'schema': 't2'},
{'type': 'function', 'schema': 't2'}])
suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "tabl2", "t2")]},
{"type": "table", "schema": "t2"},
{"type": "view", "schema": "t2"},
{"type": "function", "schema": "t2"},
]
)
@pytest.mark.parametrize('expression', [
'SELECT * FROM (',
'SELECT * FROM foo WHERE EXISTS (',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (',
'SELECT 1 AS',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (",
"SELECT * FROM foo WHERE EXISTS (",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (",
"SELECT 1 AS",
],
)
def test_sub_select_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == [{'type': 'keyword'}]
assert suggestion == [{"type": "keyword"}]
@pytest.mark.parametrize('expression', [
'SELECT * FROM (S',
'SELECT * FROM foo WHERE EXISTS (S',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (S',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (S",
"SELECT * FROM foo WHERE EXISTS (S",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
],
)
def test_sub_select_partial_text_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == [{'type': 'keyword'}]
assert suggestion == [{"type": "keyword"}]
def test_outer_table_reference_in_exists_subquery_suggests_columns():
q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.'
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
suggestions = suggest_type(q, q)
assert suggestions == [
{'type': 'column', 'tables': [(None, 'foo', 'f')]},
{'type': 'table', 'schema': 'f'},
{'type': 'view', 'schema': 'f'},
{'type': 'function', 'schema': 'f'}]
{"type": "column", "tables": [(None, "foo", "f")]},
{"type": "table", "schema": "f"},
{"type": "view", "schema": "f"},
{"type": "function", "schema": "f"},
]
@pytest.mark.parametrize('expression', [
'SELECT * FROM (SELECT * FROM ',
'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (SELECT * FROM ",
"SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
],
)
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_sub_select_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
'SELECT * FROM (SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['abc']},
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["abc"]},
{"type": "column", "tables": [(None, "abc", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
@pytest.mark.xfail
def test_sub_select_multiple_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
'SELECT * FROM (SELECT a, ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function', 'schema': []}])
suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ")
assert sorted_dicts(suggestions) == sorted_dicts(
[{"type": "column", "tables": [(None, "abc", None)]}, {"type": "function", "schema": []}]
)
def test_sub_select_dot_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
'SELECT * FROM (SELECT t.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', 't')]},
{'type': 'table', 'schema': 't'},
{'type': 'view', 'schema': 't'},
{'type': 'function', 'schema': 't'}])
suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "tabl", "t")]},
{"type": "table", "schema": "t"},
{"type": "view", "schema": "t"},
{"type": "function", "schema": "t"},
]
)
@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
@pytest.mark.parametrize('tbl_alias', ['', 'foo'])
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
@pytest.mark.parametrize("tbl_alias", ["", "foo"])
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
suggestion = suggest_type(text, text)
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.',
])
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
],
)
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', 'a')]},
{'type': 'table', 'schema': 'a'},
{'type': 'view', 'schema': 'a'},
{'type': 'function', 'schema': 'a'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "abc", "a")]},
{"type": "table", "schema": "a"},
{"type": "view", "schema": "a"},
{"type": "function", "schema": "a"},
]
)
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.id = d.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.',
])
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.id = d.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
],
)
def test_join_alias_dot_suggests_cols2(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'def', 'd')]},
{'type': 'table', 'schema': 'd'},
{'type': 'view', 'schema': 'd'},
{'type': 'function', 'schema': 'd'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "def", "d")]},
{"type": "table", "schema": "d"},
{"type": "view", "schema": "d"},
{"type": "function", "schema": "d"},
]
)
@pytest.mark.parametrize('sql', [
'select a.x, b.y from abc a join bcd b on ',
'select a.x, b.y from abc a join bcd b on a.id = b.id OR ',
])
@pytest.mark.parametrize(
"sql",
[
"select a.x, b.y from abc a join bcd b on ",
"select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
],
)
def test_on_suggests_aliases(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
@pytest.mark.parametrize('sql', [
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ',
])
@pytest.mark.parametrize(
"sql",
[
"select abc.x, bcd.y from abc join bcd on ",
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
],
)
def test_on_suggests_tables(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
@pytest.mark.parametrize('sql', [
'select a.x, b.y from abc a join bcd b on a.id = ',
'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ',
])
@pytest.mark.parametrize(
"sql",
[
"select a.x, b.y from abc a join bcd b on a.id = ",
"select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
],
)
def test_on_suggests_aliases_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
@pytest.mark.parametrize('sql', [
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ',
])
@pytest.mark.parametrize(
"sql",
[
"select abc.x, bcd.y from abc join bcd on ",
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
],
)
def test_on_suggests_tables_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
@pytest.mark.parametrize('col_list', ['', 'col1, '])
@pytest.mark.parametrize("col_list", ["", "col1, "])
def test_join_using_suggests_common_columns(col_list):
text = 'select * from abc inner join def using (' + col_list
assert suggest_type(text, text) == [
{'type': 'column',
'tables': [(None, 'abc', None), (None, 'def', None)],
'drop_unique': True}]
text = "select * from abc inner join def using (" + col_list
assert suggest_type(text, text) == [{"type": "column", "tables": [(None, "abc", None), (None, "def", None)], "drop_unique": True}]
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.',
])
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.",
],
)
def test_two_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'ghi', 'g')]},
{'type': 'table', 'schema': 'g'},
{'type': 'view', 'schema': 'g'},
{'type': 'function', 'schema': 'g'}])
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "column", "tables": [(None, "ghi", "g")]},
{"type": "table", "schema": "g"},
{"type": "view", "schema": "g"},
{"type": "function", "schema": "g"},
]
)
def test_2_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ',
'select * from a; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
suggestions = suggest_type('select * from a; select from b',
'select * from a; select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['b']},
{'type': 'column', 'tables': [(None, 'b', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("select * from a; select from b", "select * from a; select ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["b"]},
{"type": "column", "tables": [(None, "b", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
# Should work even if first statement is invalid
suggestions = suggest_type('select * from; select * from ',
'select * from; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from; select * from ", "select * from; select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_2_statements_1st_current():
suggestions = suggest_type('select * from ; select * from b',
'select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from ; select * from b", "select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
suggestions = suggest_type('select from a; select * from b',
'select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['a']},
{'type': 'column', 'tables': [(None, 'a', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("select from a; select * from b", "select ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["a"]},
{"type": "column", "tables": [(None, "a", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_3_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ; select * from c',
'select * from a; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("select * from a; select * from ; select * from c", "select * from a; select * from ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
suggestions = suggest_type('select * from a; select from b; select * from c',
'select * from a; select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['b']},
{'type': 'column', 'tables': [(None, 'b', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ")
assert sorted_dicts(suggestions) == sorted_dicts(
[
{"type": "alias", "aliases": ["b"]},
{"type": "column", "tables": [(None, "b", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
]
)
def test_create_db_with_template():
suggestions = suggest_type('create database foo with template ',
'create database foo with template ')
suggestions = suggest_type("create database foo with template ", "create database foo with template ")
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t'])
@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"])
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
assert sorted_dicts(suggestions) == \
sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}])
def test_specials_not_included_after_initial_token():
suggestions = suggest_type('create table foo (dt d',
'create table foo (dt d')
suggestions = suggest_type("create table foo (dt d", "create table foo (dt d")
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}])
def test_drop_schema_qualified_table_suggests_only_tables():
text = 'DROP TABLE schema_name.table_name'
text = "DROP TABLE schema_name.table_name"
suggestions = suggest_type(text, text)
assert suggestions == [{'type': 'table', 'schema': 'schema_name'}]
assert suggestions == [{"type": "table", "schema": "schema_name"}]
@pytest.mark.parametrize('text', [',', ' ,', 'sel ,'])
@pytest.mark.parametrize("text", [",", " ,", "sel ,"])
def test_handle_pre_completion_comma_gracefully(text):
suggestions = suggest_type(text, text)
@ -503,53 +537,59 @@ def test_handle_pre_completion_comma_gracefully(text):
def test_cross_join():
text = 'select * from v1 cross join v2 JOIN v1.id, '
text = "select * from v1 cross join v2 JOIN v1.id, "
suggestions = suggest_type(text, text)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
@pytest.mark.parametrize('expression', [
'SELECT 1 AS ',
'SELECT 1 FROM tabl AS ',
])
@pytest.mark.parametrize(
"expression",
[
"SELECT 1 AS ",
"SELECT 1 FROM tabl AS ",
],
)
def test_after_as(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set()
@pytest.mark.parametrize('expression', [
'\\. ',
'select 1; \\. ',
'select 1;\\. ',
'select 1 ; \\. ',
'source ',
'truncate table test; source ',
'truncate table test ; source ',
'truncate table test;source ',
])
@pytest.mark.parametrize(
"expression",
[
"\\. ",
"select 1; \\. ",
"select 1;\\. ",
"select 1 ; \\. ",
"source ",
"truncate table test; source ",
"truncate table test ; source ",
"truncate table test;source ",
],
)
def test_source_is_file(expression):
suggestions = suggest_type(expression, expression)
assert suggestions == [{'type': 'file_name'}]
assert suggestions == [{"type": "file_name"}]
@pytest.mark.parametrize("expression", [
"\\f ",
])
@pytest.mark.parametrize(
"expression",
[
"\\f ",
],
)
def test_favorite_name_suggestion(expression):
suggestions = suggest_type(expression, expression)
assert suggestions == [{'type': 'favoritequery'}]
assert suggestions == [{"type": "favoritequery"}]
def test_order_by():
text = 'select * from foo order by '
text = "select * from foo order by "
suggestions = suggest_type(text, text)
assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}]
assert suggestions == [{"tables": [(None, "foo", None)], "type": "column"}]
def test_quoted_where():
text = "'where i=';"
suggestions = suggest_type(text, text)
assert suggestions == [{'type': 'keyword'}]
assert suggestions == [{"type": "keyword"}]

View file

@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
@pytest.fixture
def refresher():
from mycli.completion_refresher import CompletionRefresher
return CompletionRefresher()
@ -18,8 +19,7 @@ def test_ctor(refresher):
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions',
'special_commands', 'show_commands', 'keywords']
expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"]
assert expected_handlers == actual_handlers
@ -32,12 +32,12 @@ def test_refresh_called_once(refresher):
callbacks = Mock()
sqlexecute = Mock()
with patch.object(refresher, '_bg_refresh') as bg_refresh:
with patch.object(refresher, "_bg_refresh") as bg_refresh:
actual = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual) == 1
assert len(actual[0]) == 4
assert actual[0][3] == 'Auto-completion refresh started in the background.'
assert actual[0][3] == "Auto-completion refresh started in the background."
bg_refresh.assert_called_with(sqlexecute, callbacks, {})
@ -61,13 +61,13 @@ def test_refresh_called_twice(refresher):
time.sleep(1) # Wait for the thread to work.
assert len(actual1) == 1
assert len(actual1[0]) == 4
assert actual1[0][3] == 'Auto-completion refresh started in the background.'
assert actual1[0][3] == "Auto-completion refresh started in the background."
actual2 = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual2) == 1
assert len(actual2[0]) == 4
assert actual2[0][3] == 'Auto-completion refresh restarted.'
assert actual2[0][3] == "Auto-completion refresh restarted."
def test_refresh_with_callbacks(refresher):
@ -80,9 +80,9 @@ def test_refresh_with_callbacks(refresher):
sqlexecute_class = Mock()
sqlexecute = Mock()
with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class):
with patch("mycli.completion_refresher.SQLExecute", sqlexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert (callbacks[0].call_count == 1)
assert callbacks[0].call_count == 1

View file

@ -1,4 +1,5 @@
"""Unit tests for the mycli.config module."""
from io import BytesIO, StringIO, TextIOWrapper
import os
import struct
@ -6,21 +7,26 @@ import sys
import tempfile
import pytest
from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf,
read_and_decrypt_mylogin_cnf, read_config_file,
str_to_bool, strip_matching_quotes)
from mycli.config import (
get_mylogin_cnf_path,
open_mylogin_cnf,
read_and_decrypt_mylogin_cnf,
read_config_file,
str_to_bool,
strip_matching_quotes,
)
LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__),
'mylogin.cnf'))
LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), "mylogin.cnf"))
def open_bmylogin_cnf(name):
"""Open contents of *name* in a BytesIO buffer."""
with open(name, 'rb') as f:
with open(name, "rb") as f:
buf = BytesIO()
buf.write(f.read())
return buf
def test_read_mylogin_cnf():
"""Tests that a login path file can be read and decrypted."""
mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE)
@ -28,7 +34,7 @@ def test_read_mylogin_cnf():
assert isinstance(mylogin_cnf, TextIOWrapper)
contents = mylogin_cnf.read()
for word in ('[test]', 'user', 'password', 'host', 'port'):
for word in ("[test]", "user", "password", "host", "port"):
assert word in contents
@ -46,7 +52,7 @@ def test_corrupted_login_key():
buf.seek(4)
# Write null bytes over half the login key
buf.write(b'\0\0\0\0\0\0\0\0\0\0')
buf.write(b"\0\0\0\0\0\0\0\0\0\0")
buf.seek(0)
mylogin_cnf = read_and_decrypt_mylogin_cnf(buf)
@ -63,58 +69,58 @@ def test_corrupted_pad():
# Skip option group
len_buf = buf.read(4)
cipher_len, = struct.unpack("<i", len_buf)
(cipher_len,) = struct.unpack("<i", len_buf)
buf.read(cipher_len)
# Corrupt the pad for the user line
len_buf = buf.read(4)
cipher_len, = struct.unpack("<i", len_buf)
(cipher_len,) = struct.unpack("<i", len_buf)
buf.read(cipher_len - 1)
buf.write(b'\0')
buf.write(b"\0")
buf.seek(0)
mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf))
contents = mylogin_cnf.read()
for word in ('[test]', 'password', 'host', 'port'):
for word in ("[test]", "password", "host", "port"):
assert word in contents
assert 'user' not in contents
assert "user" not in contents
def test_get_mylogin_cnf_path():
"""Tests that the path for .mylogin.cnf is detected."""
original_env = None
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
is_windows = sys.platform == 'win32'
if "MYSQL_TEST_LOGIN_FILE" in os.environ:
original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
is_windows = sys.platform == "win32"
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
if login_cnf_path is not None:
assert login_cnf_path.endswith('.mylogin.cnf')
assert login_cnf_path.endswith(".mylogin.cnf")
if is_windows is True:
assert 'MySQL' in login_cnf_path
assert "MySQL" in login_cnf_path
else:
home_dir = os.path.expanduser('~')
home_dir = os.path.expanduser("~")
assert login_cnf_path.startswith(home_dir)
def test_alternate_get_mylogin_cnf_path():
"""Tests that the alternate path for .mylogin.cnf is detected."""
original_env = None
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
if "MYSQL_TEST_LOGIN_FILE" in os.environ:
original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
_, temp_path = tempfile.mkstemp()
os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path
os.environ["MYSQL_TEST_LOGIN_FILE"] = temp_path
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
assert temp_path == login_cnf_path
@ -124,17 +130,17 @@ def test_str_to_bool():
assert str_to_bool(False) is False
assert str_to_bool(True) is True
assert str_to_bool('False') is False
assert str_to_bool('True') is True
assert str_to_bool('TRUE') is True
assert str_to_bool('1') is True
assert str_to_bool('0') is False
assert str_to_bool('on') is True
assert str_to_bool('off') is False
assert str_to_bool('off') is False
assert str_to_bool("False") is False
assert str_to_bool("True") is True
assert str_to_bool("TRUE") is True
assert str_to_bool("1") is True
assert str_to_bool("0") is False
assert str_to_bool("on") is True
assert str_to_bool("off") is False
assert str_to_bool("off") is False
with pytest.raises(ValueError):
str_to_bool('foo')
str_to_bool("foo")
with pytest.raises(TypeError):
str_to_bool(None)
@ -143,19 +149,19 @@ def test_str_to_bool():
def test_read_config_file_list_values_default():
"""Test that reading a config file uses list_values by default."""
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f)
assert config['main']['weather'] == u"cloudy with a chance of meatballs"
assert config["main"]["weather"] == "cloudy with a chance of meatballs"
def test_read_config_file_list_values_off():
"""Test that you can disable list_values when reading a config file."""
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f, list_values=False)
assert config['main']['weather'] == u"'cloudy with a chance of meatballs'"
assert config["main"]["weather"] == "'cloudy with a chance of meatballs'"
def test_strip_quotes_with_matching_quotes():
@ -177,7 +183,7 @@ def test_strip_quotes_with_unmatching_quotes():
def test_strip_quotes_with_empty_string():
"""Test that an empty string is handled during unquoting."""
assert '' == strip_matching_quotes('')
assert "" == strip_matching_quotes("")
def test_strip_quotes_with_none():

View file

@ -4,39 +4,32 @@ from mycli.packages.special.utils import format_uptime
def test_u_suggests_databases():
suggestions = suggest_type('\\u ', '\\u ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'database'}])
suggestions = suggest_type("\\u ", "\\u ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
def test_describe_table():
suggestions = suggest_type('\\dt', '\\dt ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("\\dt", "\\dt ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_list_or_show_create_tables():
suggestions = suggest_type('\\dt+', '\\dt+ ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type("\\dt+", "\\dt+ ")
assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_format_uptime():
seconds = 59
assert '59 sec' == format_uptime(seconds)
assert "59 sec" == format_uptime(seconds)
seconds = 120
assert '2 min 0 sec' == format_uptime(seconds)
assert "2 min 0 sec" == format_uptime(seconds)
seconds = 54890
assert '15 hours 14 min 50 sec' == format_uptime(seconds)
assert "15 hours 14 min 50 sec" == format_uptime(seconds)
seconds = 598244
assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds)
assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds)
seconds = 522600
assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds)
assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds)

View file

@ -13,52 +13,62 @@ from textwrap import dedent
from collections import namedtuple
from tempfile import NamedTemporaryFile
from textwrap import dedent
test_dir = os.path.abspath(os.path.dirname(__file__))
project_dir = os.path.dirname(test_dir)
default_config_file = os.path.join(project_dir, 'test', 'myclirc')
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
default_config_file = os.path.join(project_dir, "test", "myclirc")
login_path_file = os.path.join(test_dir, "mylogin.cnf")
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
'--password', PASSWORD, '--myclirc', default_config_file,
'--defaults-file', default_config_file,
'mycli_test_db']
os.environ["MYSQL_TEST_LOGIN_FILE"] = login_path_file
CLI_ARGS = [
"--user",
USER,
"--host",
HOST,
"--port",
PORT,
"--password",
PASSWORD,
"--myclirc",
default_config_file,
"--defaults-file",
default_config_file,
"mycli_test_db",
]
@dbtest
def test_execute_arg(executor):
run(executor, 'create table test (a text)')
run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql])
assert result.exit_code == 0
assert 'abc' in result.output
assert "abc" in result.output
result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql])
assert result.exit_code == 0
assert 'abc' in result.output
assert "abc" in result.output
expected = 'a\nabc\n'
expected = "a\nabc\n"
assert expected in result.output
@dbtest
def test_execute_arg_with_table(executor):
run(executor, 'create table test (a text)')
run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n'
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"])
expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n"
assert result.exit_code == 0
assert expected in result.output
@ -66,12 +76,12 @@ def test_execute_arg_with_table(executor):
@dbtest
def test_execute_arg_with_csv(executor):
run(executor, 'create table test (a text)')
run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"])
expected = '"a"\n"abc"\n'
assert result.exit_code == 0
@ -80,35 +90,29 @@ def test_execute_arg_with_csv(executor):
@dbtest
def test_batch_mode(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
sql = (
'select count(*) from test;\n'
'select * from test limit 1;'
)
sql = "select count(*) from test;\n" "select * from test limit 1;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
assert result.exit_code == 0
assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
assert "count(*)\n3\na\nabc\n" in "".join(result.output)
@dbtest
def test_batch_mode_table(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
sql = (
'select count(*) from test;\n'
'select * from test limit 1;'
)
sql = "select count(*) from test;\n" "select * from test limit 1;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql)
expected = (dedent("""\
expected = dedent("""\
+----------+
| count(*) |
+----------+
@ -118,7 +122,7 @@ def test_batch_mode_table(executor):
| a |
+-----+
| abc |
+-----+"""))
+-----+""")
assert result.exit_code == 0
assert expected in result.output
@ -126,14 +130,13 @@ def test_batch_mode_table(executor):
@dbtest
def test_batch_mode_csv(executor):
run(executor, '''create table test(a text, b text)''')
run(executor,
'''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
run(executor, """create table test(a text, b text)""")
run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""")
sql = 'select * from test;'
sql = "select * from test;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
@ -150,15 +153,15 @@ def test_help_strings_end_with_periods():
"""Make sure click options have help text that end with a period."""
for param in cli.params:
if isinstance(param, click.core.Option):
assert hasattr(param, 'help')
assert param.help.endswith('.')
assert hasattr(param, "help")
assert param.help.endswith(".")
def test_command_descriptions_end_with_periods():
"""Make sure that mycli commands' descriptions end with a period."""
MyCli()
for _, command in SPECIAL_COMMANDS.items():
assert command[3].endswith('.')
assert command[3].endswith(".")
def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
@ -166,23 +169,23 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
clickoutput = ""
m = MyCli(myclirc=default_config_file)
class TestOutput():
class TestOutput:
def get_size(self):
size = namedtuple('Size', 'rows columns')
size = namedtuple("Size", "rows columns")
size.columns, size.rows = terminal_size
return size
class TestExecute():
host = 'test'
user = 'test'
dbname = 'test'
server_info = ServerInfo.from_version_string('unknown')
class TestExecute:
host = "test"
user = "test"
dbname = "test"
server_info = ServerInfo.from_version_string("unknown")
port = 0
def server_type(self):
return ['test']
return ["test"]
class PromptBuffer():
class PromptBuffer:
output = TestOutput()
m.prompt_app = PromptBuffer()
@ -199,8 +202,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
global clickoutput
clickoutput += s + "\n"
monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
monkeypatch.setattr(click, 'secho', secho)
monkeypatch.setattr(click, "echo_via_pager", echo_via_pager)
monkeypatch.setattr(click, "secho", secho)
m.output(testdata)
if clickoutput.endswith("\n"):
clickoutput = clickoutput[:-1]
@ -208,59 +211,29 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
def test_conditional_pager(monkeypatch):
testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
" ")
testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(" ")
# User didn't set pager, output doesn't fit screen -> pager
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=False,
expect_pager=True
)
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=True)
# User didn't set pager, output fits screen -> no pager
output(
monkeypatch,
terminal_size=(20, 20),
testdata=testdata,
explicit_pager=False,
expect_pager=False
)
output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=False, expect_pager=False)
# User manually configured pager, output doesn't fit screen -> pager
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=True,
expect_pager=True
)
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=True, expect_pager=True)
# User manually configured pager, output fit screen -> pager
output(
monkeypatch,
terminal_size=(20, 20),
testdata=testdata,
explicit_pager=True,
expect_pager=True
)
output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=True, expect_pager=True)
SPECIAL_COMMANDS['nopager'].handler()
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=False,
expect_pager=False
)
SPECIAL_COMMANDS['pager'].handler('')
SPECIAL_COMMANDS["nopager"].handler()
output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=False)
SPECIAL_COMMANDS["pager"].handler("")
def test_reserved_space_is_integer(monkeypatch):
"""Make sure that reserved space is returned as an integer."""
def stub_terminal_size():
return (5, 5)
with monkeypatch.context() as m:
m.setattr(shutil, 'get_terminal_size', stub_terminal_size)
m.setattr(shutil, "get_terminal_size", stub_terminal_size)
mycli = MyCli()
assert isinstance(mycli.get_reserved_space(), int)
@ -268,18 +241,20 @@ def test_reserved_space_is_integer(monkeypatch):
def test_list_dsn():
runner = CliRunner()
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w",delete=False) as myclirc:
myclirc.write(dedent("""\
with NamedTemporaryFile(mode="w", delete=False) as myclirc:
myclirc.write(
dedent("""\
[alias_dsn]
test = mysql://test/test
"""))
""")
)
myclirc.flush()
args = ['--list-dsn', '--myclirc', myclirc.name]
args = ["--list-dsn", "--myclirc", myclirc.name]
result = runner.invoke(cli, args=args)
assert result.output == "test\n"
result = runner.invoke(cli, args=args + ['--verbose'])
result = runner.invoke(cli, args=args + ["--verbose"])
assert result.output == "test : mysql://test/test\n"
# delete=False means we should try to clean up
try:
if os.path.exists(myclirc.name):
@ -287,41 +262,41 @@ def test_list_dsn():
except Exception as e:
print(f"An error occurred while attempting to delete the file: {e}")
def test_prettify_statement():
statement = 'SELECT 1'
statement = "SELECT 1"
m = MyCli()
pretty_statement = m.handle_prettify_binding(statement)
assert pretty_statement == 'SELECT\n 1;'
assert pretty_statement == "SELECT\n 1;"
def test_unprettify_statement():
statement = 'SELECT\n 1'
statement = "SELECT\n 1"
m = MyCli()
unpretty_statement = m.handle_unprettify_binding(statement)
assert unpretty_statement == 'SELECT 1;'
assert unpretty_statement == "SELECT 1;"
def test_list_ssh_config():
runner = CliRunner()
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
ssh_config.write(dedent("""\
with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
ssh_config.write(
dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
"""))
""")
)
ssh_config.flush()
args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
args = ["--list-ssh-config", "--ssh-config-path", ssh_config.name]
result = runner.invoke(cli, args=args)
assert "test\n" in result.output
result = runner.invoke(cli, args=args + ['--verbose'])
result = runner.invoke(cli, args=args + ["--verbose"])
assert "test : test.example.com\n" in result.output
# delete=False means we should try to clean up
try:
if os.path.exists(ssh_config.name):
@ -343,7 +318,7 @@ def test_dsn(monkeypatch):
pass
class MockMyCli:
config = {'alias_dsn': {}}
config = {"alias_dsn": {}}
def __init__(self, **args):
self.logger = Logger()
@ -357,97 +332,109 @@ def test_dsn(monkeypatch):
pass
import mycli.main
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
runner = CliRunner()
# When a user supplies a DSN as database argument to mycli,
# use these values.
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
)
result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "dsn_user" and \
MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
MockMyCli.connect_args["host"] == "dsn_host" and \
MockMyCli.connect_args["port"] == 1 and \
MockMyCli.connect_args["database"] == "dsn_database"
assert (
MockMyCli.connect_args["user"] == "dsn_user"
and MockMyCli.connect_args["passwd"] == "dsn_passwd"
and MockMyCli.connect_args["host"] == "dsn_host"
and MockMyCli.connect_args["port"] == 1
and MockMyCli.connect_args["database"] == "dsn_database"
)
MockMyCli.connect_args = None
# When a use supplies a DSN as database argument to mycli,
# and used command line arguments, use the command line
# arguments.
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
"--user", "arg_user",
"--password", "arg_password",
"--host", "arg_host",
"--port", "3",
"--database", "arg_database",
])
result = runner.invoke(
mycli.main.cli,
args=[
"mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
"--user",
"arg_user",
"--password",
"arg_password",
"--host",
"arg_host",
"--port",
"3",
"--database",
"arg_database",
],
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "arg_user" and \
MockMyCli.connect_args["passwd"] == "arg_password" and \
MockMyCli.connect_args["host"] == "arg_host" and \
MockMyCli.connect_args["port"] == 3 and \
MockMyCli.connect_args["database"] == "arg_database"
assert (
MockMyCli.connect_args["user"] == "arg_user"
and MockMyCli.connect_args["passwd"] == "arg_password"
and MockMyCli.connect_args["host"] == "arg_host"
and MockMyCli.connect_args["port"] == 3
and MockMyCli.connect_args["database"] == "arg_database"
)
MockMyCli.config = {
'alias_dsn': {
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
}
}
MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn),
# use these values.
result = runner.invoke(cli, args=['--dsn', 'test'])
result = runner.invoke(cli, args=["--dsn", "test"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "alias_dsn_user" and \
MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
MockMyCli.connect_args["host"] == "alias_dsn_host" and \
MockMyCli.connect_args["port"] == 4 and \
MockMyCli.connect_args["database"] == "alias_dsn_database"
assert (
MockMyCli.connect_args["user"] == "alias_dsn_user"
and MockMyCli.connect_args["passwd"] == "alias_dsn_passwd"
and MockMyCli.connect_args["host"] == "alias_dsn_host"
and MockMyCli.connect_args["port"] == 4
and MockMyCli.connect_args["database"] == "alias_dsn_database"
)
MockMyCli.config = {
'alias_dsn': {
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
}
}
MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn)
# and used command line arguments, use the command line arguments.
result = runner.invoke(cli, args=[
'--dsn', 'test', '',
"--user", "arg_user",
"--password", "arg_password",
"--host", "arg_host",
"--port", "5",
"--database", "arg_database",
])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "arg_user" and \
MockMyCli.connect_args["passwd"] == "arg_password" and \
MockMyCli.connect_args["host"] == "arg_host" and \
MockMyCli.connect_args["port"] == 5 and \
MockMyCli.connect_args["database"] == "arg_database"
# Use a DSN without password
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user@dsn_host:6/dsn_database"]
result = runner.invoke(
cli,
args=[
"--dsn",
"test",
"",
"--user",
"arg_user",
"--password",
"arg_password",
"--host",
"arg_host",
"--port",
"5",
"--database",
"arg_database",
],
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "dsn_user" and \
MockMyCli.connect_args["passwd"] is None and \
MockMyCli.connect_args["host"] == "dsn_host" and \
MockMyCli.connect_args["port"] == 6 and \
MockMyCli.connect_args["database"] == "dsn_database"
assert (
MockMyCli.connect_args["user"] == "arg_user"
and MockMyCli.connect_args["passwd"] == "arg_password"
and MockMyCli.connect_args["host"] == "arg_host"
and MockMyCli.connect_args["port"] == 5
and MockMyCli.connect_args["database"] == "arg_database"
)
# Use a DSN without password
result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user@dsn_host:6/dsn_database"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["user"] == "dsn_user"
and MockMyCli.connect_args["passwd"] is None
and MockMyCli.connect_args["host"] == "dsn_host"
and MockMyCli.connect_args["port"] == 6
and MockMyCli.connect_args["database"] == "dsn_database"
)
def test_ssh_config(monkeypatch):
@ -463,7 +450,7 @@ def test_ssh_config(monkeypatch):
pass
class MockMyCli:
config = {'alias_dsn': {}}
config = {"alias_dsn": {}}
def __init__(self, **args):
self.logger = Logger()
@ -477,58 +464,62 @@ def test_ssh_config(monkeypatch):
pass
import mycli.main
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
runner = CliRunner()
# Setup temporary configuration
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
ssh_config.write(dedent("""\
with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
ssh_config.write(
dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
"""))
""")
)
ssh_config.flush()
# When a user supplies a ssh config.
result = runner.invoke(mycli.main.cli, args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test"
])
assert result.exit_code == 0, result.output + \
" " + str(result.exception)
assert \
MockMyCli.connect_args["ssh_user"] == "joe" and \
MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
MockMyCli.connect_args["ssh_port"] == 22222 and \
MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser(
"~") + "/.ssh/gateway"
result = runner.invoke(mycli.main.cli, args=["--ssh-config-path", ssh_config.name, "--ssh-config-host", "test"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["ssh_user"] == "joe"
and MockMyCli.connect_args["ssh_host"] == "test.example.com"
and MockMyCli.connect_args["ssh_port"] == 22222
and MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser("~") + "/.ssh/gateway"
)
# When a user supplies a ssh config host as argument to mycli,
# and used command line arguments, use the command line
# arguments.
result = runner.invoke(mycli.main.cli, args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test",
"--ssh-user", "arg_user",
"--ssh-host", "arg_host",
"--ssh-port", "3",
"--ssh-key-filename", "/path/to/key"
])
assert result.exit_code == 0, result.output + \
" " + str(result.exception)
assert \
MockMyCli.connect_args["ssh_user"] == "arg_user" and \
MockMyCli.connect_args["ssh_host"] == "arg_host" and \
MockMyCli.connect_args["ssh_port"] == 3 and \
MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
result = runner.invoke(
mycli.main.cli,
args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test",
"--ssh-user",
"arg_user",
"--ssh-host",
"arg_host",
"--ssh-port",
"3",
"--ssh-key-filename",
"/path/to/key",
],
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert (
MockMyCli.connect_args["ssh_user"] == "arg_user"
and MockMyCli.connect_args["ssh_host"] == "arg_host"
and MockMyCli.connect_args["ssh_port"] == 3
and MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
)
# delete=False means we should try to clean up
try:
if os.path.exists(ssh_config.name):
@ -542,9 +533,7 @@ def test_init_command_arg(executor):
init_command = "set sql_select_limit=1000"
sql = 'show variables like "sql_select_limit";'
runner = CliRunner()
result = runner.invoke(
cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
)
result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
expected = "sql_select_limit\t1000\n"
assert result.exit_code == 0
@ -553,18 +542,13 @@ def test_init_command_arg(executor):
@dbtest
def test_init_command_multiple_arg(executor):
init_command = 'set sql_select_limit=2000; set max_join_size=20000'
sql = (
'show variables like "sql_select_limit";\n'
'show variables like "max_join_size"'
)
init_command = "set sql_select_limit=2000; set max_join_size=20000"
sql = 'show variables like "sql_select_limit";\n' 'show variables like "max_join_size"'
runner = CliRunner()
result = runner.invoke(
cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
)
result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
expected_sql_select_limit = 'sql_select_limit\t2000\n'
expected_max_join_size = 'max_join_size\t20000\n'
expected_sql_select_limit = "sql_select_limit\t2000\n"
expected_max_join_size = "max_join_size\t20000\n"
assert result.exit_code == 0
assert expected_sql_select_limit in result.output

View file

@ -6,56 +6,48 @@ from prompt_toolkit.document import Document
@pytest.fixture
def completer():
import mycli.sqlcompleter as sqlcompleter
return sqlcompleter.SQLCompleter(smart_completion=False)
@pytest.fixture
def complete_event():
from unittest.mock import Mock
return Mock()
def test_empty_string_completion(completer, complete_event):
text = ''
text = ""
position = 0
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(map(Completion, completer.all_completions))
def test_select_keyword_completion(completer, complete_event):
text = 'SEL'
position = len('SEL')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([Completion(text='SELECT', start_position=-3)])
text = "SEL"
position = len("SEL")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list([Completion(text="SELECT", start_position=-3)])
def test_function_name_completion(completer, complete_event):
text = 'SELECT MA'
position = len('SELECT MA')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
text = "SELECT MA"
position = len("SELECT MA")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert sorted(x.text for x in result) == ["MASTER", "MAX"]
def test_column_name_completion(completer, complete_event):
text = 'SELECT FROM users'
position = len('SELECT ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
text = "SELECT FROM users"
position = len("SELECT ")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(map(Completion, completer.all_completions))
def test_special_name_completion(completer, complete_event):
text = '\\'
position = len('\\')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
text = "\\"
position = len("\\")
result = set(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
# Special commands will NOT be suggested during naive completion mode.
assert result == set()

View file

@ -1,67 +1,72 @@
import pytest
from mycli.packages.parseutils import (
extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause,
is_dropping_database)
extract_tables,
query_starts_with,
queries_start_with,
is_destructive,
query_has_where_clause,
is_dropping_database,
)
def test_empty_string():
tables = extract_tables('')
tables = extract_tables("")
assert tables == []
def test_simple_select_single_table():
tables = extract_tables('select * from abc')
assert tables == [(None, 'abc', None)]
tables = extract_tables("select * from abc")
assert tables == [(None, "abc", None)]
def test_simple_select_single_table_schema_qualified():
tables = extract_tables('select * from abc.def')
assert tables == [('abc', 'def', None)]
tables = extract_tables("select * from abc.def")
assert tables == [("abc", "def", None)]
def test_simple_select_multiple_tables():
tables = extract_tables('select * from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
tables = extract_tables("select * from abc, def")
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables('select * from abc.def, ghi.jkl')
assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
tables = extract_tables("select * from abc.def, ghi.jkl")
assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)]
def test_simple_select_with_cols_single_table():
tables = extract_tables('select a,b from abc')
assert tables == [(None, 'abc', None)]
tables = extract_tables("select a,b from abc")
assert tables == [(None, "abc", None)]
def test_simple_select_with_cols_single_table_schema_qualified():
tables = extract_tables('select a,b from abc.def')
assert tables == [('abc', 'def', None)]
tables = extract_tables("select a,b from abc.def")
assert tables == [("abc", "def", None)]
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables('select a,b from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
tables = extract_tables("select a,b from abc, def")
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_simple_select_with_cols_multiple_tables_with_schema():
tables = extract_tables('select a,b from abc.def, def.ghi')
assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
tables = extract_tables("select a,b from abc.def, def.ghi")
assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)]
def test_select_with_hanging_comma_single_table():
tables = extract_tables('select a, from abc')
assert tables == [(None, 'abc', None)]
tables = extract_tables("select a, from abc")
assert tables == [(None, "abc", None)]
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables('select a, from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
tables = extract_tables("select a, from abc, def")
assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")]
def test_simple_insert_single_table():
@ -69,97 +74,80 @@ def test_simple_insert_single_table():
# sqlparse mistakenly assigns an alias to the table
# assert tables == [(None, 'abc', None)]
assert tables == [(None, 'abc', 'abc')]
assert tables == [(None, "abc", "abc")]
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
assert tables == [('abc', 'def', None)]
assert tables == [("abc", "def", None)]
def test_simple_update_table():
tables = extract_tables('update abc set id = 1')
assert tables == [(None, 'abc', None)]
tables = extract_tables("update abc set id = 1")
assert tables == [(None, "abc", None)]
def test_simple_update_table_with_schema():
tables = extract_tables('update abc.def set id = 1')
assert tables == [('abc', 'def', None)]
tables = extract_tables("update abc.def set id = 1")
assert tables == [("abc", "def", None)]
def test_join_table():
tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num")
assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")]
def test_join_table_schema_qualified():
tables = extract_tables(
'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")]
def test_join_as_table():
tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
assert tables == [(None, 'my_table', 'm')]
tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
assert tables == [(None, "my_table", "m")]
def test_query_starts_with():
query = 'USE test;'
assert query_starts_with(query, ('use', )) is True
query = "USE test;"
assert query_starts_with(query, ("use",)) is True
query = 'DROP DATABASE test;'
assert query_starts_with(query, ('use', )) is False
query = "DROP DATABASE test;"
assert query_starts_with(query, ("use",)) is False
def test_query_starts_with_comment():
query = '# comment\nUSE test;'
assert query_starts_with(query, ('use', )) is True
query = "# comment\nUSE test;"
assert query_starts_with(query, ("use",)) is True
def test_queries_start_with():
sql = (
'# comment\n'
'show databases;'
'use foo;'
)
assert queries_start_with(sql, ('show', 'select')) is True
assert queries_start_with(sql, ('use', 'drop')) is True
assert queries_start_with(sql, ('delete', 'update')) is False
sql = "# comment\n" "show databases;" "use foo;"
assert queries_start_with(sql, ("show", "select")) is True
assert queries_start_with(sql, ("use", "drop")) is True
assert queries_start_with(sql, ("delete", "update")) is False
def test_is_destructive():
sql = (
'use test;\n'
'show databases;\n'
'drop database foo;'
)
sql = "use test;\n" "show databases;\n" "drop database foo;"
assert is_destructive(sql) is True
def test_is_destructive_update_with_where_clause():
sql = (
'use test;\n'
'show databases;\n'
'UPDATE test SET x = 1 WHERE id = 1;'
)
sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1 WHERE id = 1;"
assert is_destructive(sql) is False
def test_is_destructive_update_without_where_clause():
sql = (
'use test;\n'
'show databases;\n'
'UPDATE test SET x = 1;'
)
sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1;"
assert is_destructive(sql) is True
@pytest.mark.parametrize(
('sql', 'has_where_clause'),
("sql", "has_where_clause"),
[
('update test set dummy = 1;', False),
('update test set dummy = 1 where id = 1);', True),
("update test set dummy = 1;", False),
("update test set dummy = 1 where id = 1);", True),
],
)
def test_query_has_where_clause(sql, has_where_clause):
@ -167,24 +155,20 @@ def test_query_has_where_clause(sql, has_where_clause):
@pytest.mark.parametrize(
('sql', 'dbname', 'is_dropping'),
("sql", "dbname", "is_dropping"),
[
('select bar from foo', 'foo', False),
('drop database "foo";', '`foo`', True),
('drop schema foo', 'foo', True),
('drop schema foo', 'bar', False),
('drop database bar', 'foo', False),
('drop database foo', None, False),
('drop database foo; create database foo', 'foo', False),
('drop database foo; create database bar', 'foo', True),
('select bar from foo; drop database bazz', 'foo', False),
('select bar from foo; drop database bazz', 'bazz', True),
('-- dropping database \n '
'drop -- really dropping \n '
'schema abc -- now it is dropped',
'abc',
True)
]
("select bar from foo", "foo", False),
('drop database "foo";', "`foo`", True),
("drop schema foo", "foo", True),
("drop schema foo", "bar", False),
("drop database bar", "foo", False),
("drop database foo", None, False),
("drop database foo; create database foo", "foo", False),
("drop database foo; create database bar", "foo", True),
("select bar from foo; drop database bazz", "foo", False),
("select bar from foo; drop database bazz", "bazz", True),
("-- dropping database \n " "drop -- really dropping \n " "schema abc -- now it is dropped", "abc", True),
],
)
def test_is_dropping_database(sql, dbname, is_dropping):
assert is_dropping_database(sql, dbname) == is_dropping

View file

@ -4,8 +4,8 @@ from mycli.packages.prompt_utils import confirm_destructive_query
def test_confirm_destructive_query_notty():
stdin = click.get_text_stream('stdin')
stdin = click.get_text_stream("stdin")
assert stdin.isatty() is False
sql = 'drop database foo;'
sql = "drop database foo;"
assert confirm_destructive_query(sql) is None

View file

@ -43,49 +43,35 @@ def complete_event():
def test_special_name_completion(completer, complete_event):
text = "\\d"
position = len("\\d")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert result == [Completion(text="\\dt", start_position=-2)]
def test_empty_string_completion(completer, complete_event):
text = ""
position = 0
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert (
list(map(Completion, completer.keywords + completer.special_commands)) == result
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert list(map(Completion, completer.keywords + completer.special_commands)) == result
def test_select_keyword_completion(completer, complete_event):
text = "SEL"
position = len("SEL")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list([Completion(text="SELECT", start_position=-3)])
def test_select_star(completer, complete_event):
text = "SELECT * "
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(map(Completion, completer.keywords))
def test_table_completion(completer, complete_event):
text = "SELECT * FROM "
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="users", start_position=0),
@ -99,9 +85,7 @@ def test_table_completion(completer, complete_event):
def test_function_name_completion(completer, complete_event):
text = "SELECT MA"
position = len("SELECT MA")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="MAX", start_position=-2),
@ -127,11 +111,7 @@ def test_suggested_column_names(completer, complete_event):
"""
text = "SELECT from users"
position = len("SELECT ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -157,9 +137,7 @@ def test_suggested_column_names_in_function(completer, complete_event):
"""
text = "SELECT MAX( from users"
position = len("SELECT MAX(")
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="*", start_position=0),
@ -181,11 +159,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event):
"""
text = "SELECT users. from users"
position = len("SELECT users.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -207,11 +181,7 @@ def test_suggested_column_names_with_alias(completer, complete_event):
"""
text = "SELECT u. from users u"
position = len("SELECT u.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -234,11 +204,7 @@ def test_suggested_multiple_column_names(completer, complete_event):
"""
text = "SELECT id, from users u"
position = len("SELECT id, ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -264,11 +230,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event):
"""
text = "SELECT u.id, u. from users u"
position = len("SELECT u.id, u.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -291,11 +253,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
"""
text = "SELECT users.id, users. from users u"
position = len("SELECT users.id, users.")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -310,11 +268,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
def test_suggested_aliases_after_on(completer, complete_event):
text = "SELECT u.name, o.id FROM users u JOIN orders o ON "
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="u", start_position=0),
@ -326,11 +280,7 @@ def test_suggested_aliases_after_on(completer, complete_event):
def test_suggested_aliases_after_on_right_side(completer, complete_event):
text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = "
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="u", start_position=0),
@ -342,11 +292,7 @@ def test_suggested_aliases_after_on_right_side(completer, complete_event):
def test_suggested_tables_after_on(completer, complete_event):
text = "SELECT users.name, orders.id FROM users JOIN orders ON "
position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@ -357,14 +303,8 @@ def test_suggested_tables_after_on(completer, complete_event):
def test_suggested_tables_after_on_right_side(completer, complete_event):
text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
position = len(
"SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
)
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ")
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@ -376,11 +316,7 @@ def test_suggested_tables_after_on_right_side(completer, complete_event):
def test_table_names_after_from(completer, complete_event):
text = "SELECT * FROM "
position = len("SELECT * FROM ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@ -394,29 +330,21 @@ def test_table_names_after_from(completer, complete_event):
def test_auto_escaped_col_names(completer, complete_event):
text = "SELECT from `select`"
position = len("SELECT ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == [
Completion(text="*", start_position=0),
Completion(text="id", start_position=0),
Completion(text="`insert`", start_position=0),
Completion(text="`ABC`", start_position=0),
] + list(map(Completion, completer.functions)) + [
Completion(text="select", start_position=0)
] + list(map(Completion, completer.keywords))
] + list(map(Completion, completer.functions)) + [Completion(text="select", start_position=0)] + list(
map(Completion, completer.keywords)
)
def test_un_escaped_table_names(completer, complete_event):
text = "SELECT from réveillé"
position = len("SELECT ")
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@ -464,10 +392,6 @@ def dummy_list_path(dir_name):
)
def test_file_name_completion(completer, complete_event, text, expected):
position = len(text)
result = list(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
expected = list((Completion(txt, pos) for txt, pos in expected))
assert result == expected

View file

@ -17,11 +17,11 @@ def test_set_get_pager():
assert mycli.packages.special.is_pager_enabled()
mycli.packages.special.set_pager_enabled(False)
assert not mycli.packages.special.is_pager_enabled()
mycli.packages.special.set_pager('less')
assert os.environ['PAGER'] == "less"
mycli.packages.special.set_pager("less")
assert os.environ["PAGER"] == "less"
mycli.packages.special.set_pager(False)
assert os.environ['PAGER'] == "less"
del os.environ['PAGER']
assert os.environ["PAGER"] == "less"
del os.environ["PAGER"]
mycli.packages.special.set_pager(False)
mycli.packages.special.disable_pager()
assert not mycli.packages.special.is_pager_enabled()
@ -42,45 +42,44 @@ def test_set_get_expanded_output():
def test_editor_command():
assert mycli.packages.special.editor_command(r'hello\e')
assert mycli.packages.special.editor_command(r'\ehello')
assert not mycli.packages.special.editor_command(r'hello')
assert mycli.packages.special.editor_command(r"hello\e")
assert mycli.packages.special.editor_command(r"\ehello")
assert not mycli.packages.special.editor_command(r"hello")
assert mycli.packages.special.get_filename(r'\e filename') == "filename"
assert mycli.packages.special.get_filename(r"\e filename") == "filename"
os.environ['EDITOR'] = 'true'
os.environ['VISUAL'] = 'true'
os.environ["EDITOR"] = "true"
os.environ["VISUAL"] = "true"
# Set the editor to Notepad on Windows
if os.name != 'nt':
mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1"
if os.name != "nt":
mycli.packages.special.open_external_editor(sql=r"select 1") == "select 1"
else:
pytest.skip('Skipping on Windows platform.')
pytest.skip("Skipping on Windows platform.")
def test_tee_command():
mycli.packages.special.write_tee(u"hello world") # write without file set
mycli.packages.special.write_tee("hello world") # write without file set
# keep Windows from locking the file with delete=False
with tempfile.NamedTemporaryFile(delete=False) as f:
mycli.packages.special.execute(None, u"tee " + f.name)
mycli.packages.special.write_tee(u"hello world")
if os.name=='nt':
mycli.packages.special.execute(None, "tee " + f.name)
mycli.packages.special.write_tee("hello world")
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"tee -o " + f.name)
mycli.packages.special.write_tee(u"hello world")
mycli.packages.special.execute(None, "tee -o " + f.name)
mycli.packages.special.write_tee("hello world")
f.seek(0)
if os.name=='nt':
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"notee")
mycli.packages.special.write_tee(u"hello world")
mycli.packages.special.execute(None, "notee")
mycli.packages.special.write_tee("hello world")
f.seek(0)
if os.name=='nt':
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
@ -92,52 +91,49 @@ def test_tee_command():
os.remove(f.name)
except Exception as e:
print(f"An error occurred while attempting to delete the file: {e}")
def test_tee_command_error():
with pytest.raises(TypeError):
mycli.packages.special.execute(None, 'tee')
mycli.packages.special.execute(None, "tee")
with pytest.raises(OSError):
with tempfile.NamedTemporaryFile() as f:
os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
mycli.packages.special.execute(None, 'tee {}'.format(f.name))
mycli.packages.special.execute(None, "tee {}".format(f.name))
@dbtest
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
def test_favorite_query():
with db_connection().cursor() as cur:
query = u'select ""'
mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query))
assert next(mycli.packages.special.execute(
cur, u'\\f check'))[0] == "> " + query
query = 'select ""'
mycli.packages.special.execute(cur, "\\fs check {0}".format(query))
assert next(mycli.packages.special.execute(cur, "\\f check"))[0] == "> " + query
def test_once_command():
with pytest.raises(TypeError):
mycli.packages.special.execute(None, u"\\once")
mycli.packages.special.execute(None, "\\once")
with pytest.raises(OSError):
mycli.packages.special.execute(None, u"\\once /proc/access-denied")
mycli.packages.special.execute(None, "\\once /proc/access-denied")
mycli.packages.special.write_once(u"hello world") # write without file set
mycli.packages.special.write_once("hello world") # write without file set
# keep Windows from locking the file with delete=False
with tempfile.NamedTemporaryFile(delete=False) as f:
mycli.packages.special.execute(None, u"\\once " + f.name)
mycli.packages.special.write_once(u"hello world")
if os.name=='nt':
mycli.packages.special.execute(None, "\\once " + f.name)
mycli.packages.special.write_once("hello world")
if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"\\once -o " + f.name)
mycli.packages.special.write_once(u"hello world line 1")
mycli.packages.special.write_once(u"hello world line 2")
mycli.packages.special.execute(None, "\\once -o " + f.name)
mycli.packages.special.write_once("hello world line 1")
mycli.packages.special.write_once("hello world line 2")
f.seek(0)
if os.name=='nt':
if os.name == "nt":
assert f.read() == b"hello world line 1\r\nhello world line 2\r\n"
else:
assert f.read() == b"hello world line 1\nhello world line 2\n"
@ -151,52 +147,47 @@ def test_once_command():
def test_pipe_once_command():
with pytest.raises(IOError):
mycli.packages.special.execute(None, u"\\pipe_once")
mycli.packages.special.execute(None, "\\pipe_once")
with pytest.raises(OSError):
mycli.packages.special.execute(
None, u"\\pipe_once /proc/access-denied")
mycli.packages.special.execute(None, "\\pipe_once /proc/access-denied")
if os.name == 'nt':
mycli.packages.special.execute(None, u'\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
mycli.packages.special.write_once(u"hello world")
if os.name == "nt":
mycli.packages.special.execute(None, '\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
mycli.packages.special.write_once("hello world")
mycli.packages.special.unset_pipe_once_if_written()
else:
mycli.packages.special.execute(None, u"\\pipe_once wc")
mycli.packages.special.write_once(u"hello world")
mycli.packages.special.unset_pipe_once_if_written()
# how to assert on wc output?
with tempfile.NamedTemporaryFile() as f:
mycli.packages.special.execute(None, "\\pipe_once tee " + f.name)
mycli.packages.special.write_pipe_once("hello world")
mycli.packages.special.unset_pipe_once_if_written()
f.seek(0)
assert f.read() == b"hello world\n"
def test_parseargfile():
"""Test that parseargfile expands the user directory."""
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
'mode': 'a'}
if os.name=='nt':
assert expected == mycli.packages.special.iocommands.parseargfile(
'~\\filename')
else:
assert expected == mycli.packages.special.iocommands.parseargfile(
'~/filename')
expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "a"}
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
'mode': 'w'}
if os.name=='nt':
assert expected == mycli.packages.special.iocommands.parseargfile(
'-o ~\\filename')
if os.name == "nt":
assert expected == mycli.packages.special.iocommands.parseargfile("~\\filename")
else:
assert expected == mycli.packages.special.iocommands.parseargfile(
'-o ~/filename')
assert expected == mycli.packages.special.iocommands.parseargfile("~/filename")
expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "w"}
if os.name == "nt":
assert expected == mycli.packages.special.iocommands.parseargfile("-o ~\\filename")
else:
assert expected == mycli.packages.special.iocommands.parseargfile("-o ~/filename")
def test_parseargfile_no_file():
"""Test that parseargfile raises a TypeError if there is no filename."""
with pytest.raises(TypeError):
mycli.packages.special.iocommands.parseargfile('')
mycli.packages.special.iocommands.parseargfile("")
with pytest.raises(TypeError):
mycli.packages.special.iocommands.parseargfile('-o ')
mycli.packages.special.iocommands.parseargfile("-o ")
@dbtest
@ -205,11 +196,9 @@ def test_watch_query_iteration():
the desired query and returns the given results."""
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
expected_title = '> {0!s}'.format(query)
expected_title = "> {0!s}".format(query)
with db_connection().cursor() as cur:
result = next(mycli.packages.special.iocommands.watch_query(
arg=query, cur=cur
))
result = next(mycli.packages.special.iocommands.watch_query(arg=query, cur=cur))
assert result[0] == expected_title
assert result[2][0] == expected_value
@ -230,14 +219,12 @@ def test_watch_query_full():
wait_interval = 1
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
expected_title = '> {0!s}'.format(query)
expected_title = "> {0!s}".format(query)
expected_results = 4
ctrl_c_process = send_ctrl_c(wait_interval)
with db_connection().cursor() as cur:
results = list(
result for result in mycli.packages.special.iocommands.watch_query(
arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur
)
result for result in mycli.packages.special.iocommands.watch_query(arg="{0!s} {1!s}".format(watch_seconds, query), cur=cur)
)
ctrl_c_process.join(1)
assert len(results) == expected_results
@ -247,14 +234,12 @@ def test_watch_query_full():
@dbtest
@patch('click.clear')
@patch("click.clear")
def test_watch_query_clear(clear_mock):
"""Test that the screen is cleared with the -c flag of `watch` command
before execute the query."""
with db_connection().cursor() as cur:
watch_gen = mycli.packages.special.iocommands.watch_query(
arg='0.1 -c select 1;', cur=cur
)
watch_gen = mycli.packages.special.iocommands.watch_query(arg="0.1 -c select 1;", cur=cur)
assert not clear_mock.called
next(watch_gen)
assert clear_mock.called
@ -271,19 +256,20 @@ def test_watch_query_bad_arguments():
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
with pytest.raises(ProgrammingError):
next(watch_query('a select 1;', cur=cur))
next(watch_query("a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('-a select 1;', cur=cur))
next(watch_query("-a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('1 -a select 1;', cur=cur))
next(watch_query("1 -a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('-c -a select 1;', cur=cur))
next(watch_query("-c -a select 1;", cur=cur))
@dbtest
@patch('click.clear')
@patch("click.clear")
def test_watch_query_interval_clear(clear_mock):
"""Test `watch` command with interval and clear flag."""
def test_asserts(gen):
clear_mock.reset_mock()
start = time()
@ -296,46 +282,32 @@ def test_watch_query_interval_clear(clear_mock):
seconds = 1.0
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
test_asserts(watch_query('{0!s} -c select 1;'.format(seconds),
cur=cur))
test_asserts(watch_query('-c {0!s} select 1;'.format(seconds),
cur=cur))
test_asserts(watch_query("{0!s} -c select 1;".format(seconds), cur=cur))
test_asserts(watch_query("-c {0!s} select 1;".format(seconds), cur=cur))
def test_split_sql_by_delimiter():
for delimiter_str in (';', '$', '😀'):
for delimiter_str in (";", "$", "😀"):
mycli.packages.special.set_delimiter(delimiter_str)
sql_input = "select 1{} select \ufffc2".format(delimiter_str)
queries = (
"select 1",
"select \ufffc2"
)
for query, parsed_query in zip(
queries, mycli.packages.special.split_queries(sql_input)):
assert(query == parsed_query)
queries = ("select 1", "select \ufffc2")
for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
assert query == parsed_query
def test_switch_delimiter_within_query():
mycli.packages.special.set_delimiter(';')
mycli.packages.special.set_delimiter(";")
sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$"
queries = (
"select 1",
"delimiter $$ select 2 $$ select 3 $$",
"select 2",
"select 3"
)
for query, parsed_query in zip(
queries,
mycli.packages.special.split_queries(sql_input)):
assert(query == parsed_query)
queries = ("select 1", "delimiter $$ select 2 $$ select 3 $$", "select 2", "select 3")
for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
assert query == parsed_query
def test_set_delimiter():
for delim in ('foo', 'bar'):
for delim in ("foo", "bar"):
mycli.packages.special.set_delimiter(delim)
assert mycli.packages.special.get_current_delimiter() == delim
def teardown_function():
mycli.packages.special.set_delimiter(';')
mycli.packages.special.set_delimiter(";")

View file

@ -7,14 +7,11 @@ from mycli.sqlexecute import ServerInfo, ServerSpecies
from .utils import run, dbtest, set_expanded_output, is_expanded_output
def assert_result_equal(result, title=None, rows=None, headers=None,
status=None, auto_status=True, assert_contains=False):
def assert_result_equal(result, title=None, rows=None, headers=None, status=None, auto_status=True, assert_contains=False):
"""Assert that an sqlexecute.run() result matches the expected values."""
if status is None and auto_status and rows:
status = '{} row{} in set'.format(
len(rows), 's' if len(rows) > 1 else '')
fields = {'title': title, 'rows': rows, 'headers': headers,
'status': status}
status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "")
fields = {"title": title, "rows": rows, "headers": headers, "status": status}
if assert_contains:
# Do a loose match on the results using the *in* operator.
@ -28,34 +25,35 @@ def assert_result_equal(result, title=None, rows=None, headers=None,
@dbtest
def test_conn(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc')''')
results = run(executor, '''select * from test''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc')""")
results = run(executor, """select * from test""")
assert_result_equal(results, headers=['a'], rows=[('abc',)])
assert_result_equal(results, headers=["a"], rows=[("abc",)])
@dbtest
def test_bools(executor):
run(executor, '''create table test(a boolean)''')
run(executor, '''insert into test values(True)''')
results = run(executor, '''select * from test''')
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
results = run(executor, """select * from test""")
assert_result_equal(results, headers=['a'], rows=[(1,)])
assert_result_equal(results, headers=["a"], rows=[(1,)])
@dbtest
def test_binary(executor):
run(executor, '''create table bt(geom linestring NOT NULL)''')
run(executor, "INSERT INTO bt VALUES "
"(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
results = run(executor, '''select * from bt''')
run(executor, """create table bt(geom linestring NOT NULL)""")
run(executor, "INSERT INTO bt VALUES " "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
results = run(executor, """select * from bt""")
geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n'
b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9'
b'\xac\xdeC@')
geom = (
b"\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n"
b"\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9"
b"\xac\xdeC@"
)
assert_result_equal(results, headers=['geom'], rows=[(geom,)])
assert_result_equal(results, headers=["geom"], rows=[(geom,)])
@dbtest
@ -63,49 +61,48 @@ def test_table_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
assert set(executor.tables()) == set([('a',), ('b',)])
assert set(executor.table_columns()) == set(
[('a', 'x'), ('a', 'y'), ('b', 'z')])
assert set(executor.tables()) == set([("a",), ("b",)])
assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")])
@dbtest
def test_database_list(executor):
databases = executor.databases()
assert 'mycli_test_db' in databases
assert "mycli_test_db" in databases
@dbtest
def test_invalid_syntax(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
run(executor, 'invalid syntax!')
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
run(executor, "invalid syntax!")
assert "You have an error in your SQL syntax;" in str(excinfo.value)
@dbtest
def test_invalid_column_name(executor):
with pytest.raises(pymysql.err.OperationalError) as excinfo:
run(executor, 'select invalid command')
run(executor, "select invalid command")
assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value)
@dbtest
def test_unicode_support_in_output(executor):
run(executor, "create table unicodechars(t text)")
run(executor, u"insert into unicodechars (t) values ('é')")
run(executor, "insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
results = run(executor, u"select * from unicodechars")
assert_result_equal(results, headers=['t'], rows=[(u'é',)])
results = run(executor, "select * from unicodechars")
assert_result_equal(results, headers=["t"], rows=[("é",)])
@dbtest
def test_multiple_queries_same_line(executor):
results = run(executor, "select 'foo'; select 'bar'")
expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)],
'status': '1 row in set'},
{'title': None, 'headers': ['bar'], 'rows': [('bar',)],
'status': '1 row in set'}]
expected = [
{"title": None, "headers": ["foo"], "rows": [("foo",)], "status": "1 row in set"},
{"title": None, "headers": ["bar"], "rows": [("bar",)], "status": "1 row in set"},
]
assert expected == results
@ -113,7 +110,7 @@ def test_multiple_queries_same_line(executor):
def test_multiple_queries_same_line_syntaxerror(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
run(executor, "select 'foo'; invalid syntax")
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
assert "You have an error in your SQL syntax;" in str(excinfo.value)
@dbtest
@ -125,15 +122,13 @@ def test_favorite_query(executor):
run(executor, "insert into test values('def')")
results = run(executor, "\\fs test-a select * from test where a like 'a%'")
assert_result_equal(results, status='Saved.')
assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-a")
assert_result_equal(results,
title="> select * from test where a like 'a%'",
headers=['a'], rows=[('abc',)], auto_status=False)
assert_result_equal(results, title="> select * from test where a like 'a%'", headers=["a"], rows=[("abc",)], auto_status=False)
results = run(executor, "\\fd test-a")
assert_result_equal(results, status='test-a: Deleted')
assert_result_equal(results, status="test-a: Deleted")
@dbtest
@ -144,158 +139,147 @@ def test_favorite_query_multiple_statement(executor):
run(executor, "insert into test values('abc')")
run(executor, "insert into test values('def')")
results = run(executor,
"\\fs test-ad select * from test where a like 'a%'; "
"select * from test where a like 'd%'")
assert_result_equal(results, status='Saved.')
results = run(executor, "\\fs test-ad select * from test where a like 'a%'; " "select * from test where a like 'd%'")
assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-ad")
expected = [{'title': "> select * from test where a like 'a%'",
'headers': ['a'], 'rows': [('abc',)], 'status': None},
{'title': "> select * from test where a like 'd%'",
'headers': ['a'], 'rows': [('def',)], 'status': None}]
expected = [
{"title": "> select * from test where a like 'a%'", "headers": ["a"], "rows": [("abc",)], "status": None},
{"title": "> select * from test where a like 'd%'", "headers": ["a"], "rows": [("def",)], "status": None},
]
assert expected == results
results = run(executor, "\\fd test-ad")
assert_result_equal(results, status='test-ad: Deleted')
assert_result_equal(results, status="test-ad: Deleted")
@dbtest
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
def test_favorite_query_expanded_output(executor):
set_expanded_output(False)
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc')''')
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc')""")
results = run(executor, "\\fs test-ae select * from test")
assert_result_equal(results, status='Saved.')
assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-ae \\G")
assert is_expanded_output() is True
assert_result_equal(results, title='> select * from test',
headers=['a'], rows=[('abc',)], auto_status=False)
assert_result_equal(results, title="> select * from test", headers=["a"], rows=[("abc",)], auto_status=False)
set_expanded_output(False)
results = run(executor, "\\fd test-ae")
assert_result_equal(results, status='test-ae: Deleted')
assert_result_equal(results, status="test-ae: Deleted")
@dbtest
def test_special_command(executor):
results = run(executor, '\\?')
assert_result_equal(results, rows=('quit', '\\q', 'Quit.'),
headers='Command', assert_contains=True,
auto_status=False)
results = run(executor, "\\?")
assert_result_equal(results, rows=("quit", "\\q", "Quit."), headers="Command", assert_contains=True, auto_status=False)
@dbtest
def test_cd_command_without_a_folder_name(executor):
results = run(executor, 'system cd')
assert_result_equal(results, status='No folder name was provided.')
results = run(executor, "system cd")
assert_result_equal(results, status="No folder name was provided.")
@dbtest
def test_system_command_not_found(executor):
results = run(executor, 'system xyz')
if os.name=='nt':
assert_result_equal(results, status='OSError: The system cannot find the file specified',
assert_contains=True)
results = run(executor, "system xyz")
if os.name == "nt":
assert_result_equal(results, status="OSError: The system cannot find the file specified", assert_contains=True)
else:
assert_result_equal(results, status='OSError: No such file or directory',
assert_contains=True)
assert_result_equal(results, status="OSError: No such file or directory", assert_contains=True)
@dbtest
def test_system_command_output(executor):
eol = os.linesep
test_dir = os.path.abspath(os.path.dirname(__file__))
test_file_path = os.path.join(test_dir, 'test.txt')
results = run(executor, 'system cat {0}'.format(test_file_path))
assert_result_equal(results, status=f'mycli rocks!{eol}')
test_file_path = os.path.join(test_dir, "test.txt")
results = run(executor, "system cat {0}".format(test_file_path))
assert_result_equal(results, status=f"mycli rocks!{eol}")
@dbtest
def test_cd_command_current_dir(executor):
test_path = os.path.abspath(os.path.dirname(__file__))
run(executor, 'system cd {0}'.format(test_path))
run(executor, "system cd {0}".format(test_path))
assert os.getcwd() == test_path
@dbtest
def test_unicode_support(executor):
results = run(executor, u"SELECT '日本語' AS japanese;")
assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)])
results = run(executor, "SELECT '日本語' AS japanese;")
assert_result_equal(results, headers=["japanese"], rows=[("日本語",)])
@dbtest
def test_timestamp_null(executor):
run(executor, '''create table ts_null(a timestamp null)''')
run(executor, '''insert into ts_null values(null)''')
results = run(executor, '''select * from ts_null''')
assert_result_equal(results, headers=['a'],
rows=[(None,)])
run(executor, """create table ts_null(a timestamp null)""")
run(executor, """insert into ts_null values(null)""")
results = run(executor, """select * from ts_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_datetime_null(executor):
run(executor, '''create table dt_null(a datetime null)''')
run(executor, '''insert into dt_null values(null)''')
results = run(executor, '''select * from dt_null''')
assert_result_equal(results, headers=['a'],
rows=[(None,)])
run(executor, """create table dt_null(a datetime null)""")
run(executor, """insert into dt_null values(null)""")
results = run(executor, """select * from dt_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_date_null(executor):
run(executor, '''create table date_null(a date null)''')
run(executor, '''insert into date_null values(null)''')
results = run(executor, '''select * from date_null''')
assert_result_equal(results, headers=['a'], rows=[(None,)])
run(executor, """create table date_null(a date null)""")
run(executor, """insert into date_null values(null)""")
results = run(executor, """select * from date_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_time_null(executor):
run(executor, '''create table time_null(a time null)''')
run(executor, '''insert into time_null values(null)''')
results = run(executor, '''select * from time_null''')
assert_result_equal(results, headers=['a'], rows=[(None,)])
run(executor, """create table time_null(a time null)""")
run(executor, """insert into time_null values(null)""")
results = run(executor, """select * from time_null""")
assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_multiple_results(executor):
query = '''CREATE PROCEDURE dmtest()
query = """CREATE PROCEDURE dmtest()
BEGIN
SELECT 1;
SELECT 2;
END'''
END"""
executor.conn.cursor().execute(query)
results = run(executor, 'call dmtest;')
results = run(executor, "call dmtest;")
expected = [
{'title': None, 'rows': [(1,)], 'headers': ['1'],
'status': '1 row in set'},
{'title': None, 'rows': [(2,)], 'headers': ['2'],
'status': '1 row in set'}
{"title": None, "rows": [(1,)], "headers": ["1"], "status": "1 row in set"},
{"title": None, "rows": [(2,)], "headers": ["2"], "status": "1 row in set"},
]
assert results == expected
@pytest.mark.parametrize(
'version_string, species, parsed_version_string, version',
"version_string, species, parsed_version_string, version",
(
('5.7.25-TiDB-v6.1.0','TiDB', '6.1.0', 60100),
('8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa', 'TiDB', '7.2.0', 70200),
('5.7.32-35', 'Percona', '5.7.32', 50732),
('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
('unexpected version string', None, '', 0),
('', None, '', 0),
(None, None, '', 0),
)
("5.7.25-TiDB-v6.1.0", "TiDB", "6.1.0", 60100),
("8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa", "TiDB", "7.2.0", 70200),
("5.7.32-35", "Percona", "5.7.32", 50732),
("5.7.32-0ubuntu0.18.04.1", "MySQL", "5.7.32", 50732),
("10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
("5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
("5.0.16-pro-nt-log", "MySQL", "5.0.16", 50016),
("5.1.5a-alpha", "MySQL", "5.1.5", 50105),
("unexpected version string", None, "", 0),
("", None, "", 0),
(None, None, "", 0),
),
)
def test_version_parsing(version_string, species, parsed_version_string, version):
server_info = ServerInfo.from_version_string(version_string)

View file

@ -2,8 +2,6 @@
from textwrap import dedent
from mycli.packages.tabular_output import sql_format
from cli_helpers.tabular_output import TabularOutputFormatter
from .utils import USER, PASSWORD, HOST, PORT, dbtest
@ -23,20 +21,17 @@ def mycli():
@dbtest
def test_sql_output(mycli):
"""Test the sql output adapter."""
headers = ['letters', 'number', 'optional', 'float', 'binary']
headers = ["letters", "number", "optional", "float", "binary"]
class FakeCursor(object):
def __init__(self):
self.data = [
('abc', 1, None, 10.0, b'\xAA'),
('d', 456, '1', 0.5, b'\xAA\xBB')
]
self.data = [("abc", 1, None, 10.0, b"\xaa"), ("d", 456, "1", 0.5, b"\xaa\xbb")]
self.description = [
(None, FIELD_TYPE.VARCHAR),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.FLOAT),
(None, FIELD_TYPE.BLOB)
(None, FIELD_TYPE.BLOB),
]
def __iter__(self):
@ -52,12 +47,11 @@ def test_sql_output(mycli):
return self.description
# Test sql-update output format
assert list(mycli.change_table_format("sql-update")) == \
[(None, None, None, 'Changed table format to sql-update')]
assert list(mycli.change_table_format("sql-update")) == [(None, None, None, "Changed table format to sql-update")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
actual = "\n".join(output)
assert actual == dedent('''\
assert actual == dedent("""\
UPDATE `DUAL` SET
`number` = 1
, `optional` = NULL
@ -69,13 +63,12 @@ def test_sql_output(mycli):
, `optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
WHERE `letters` = 'd';''')
WHERE `letters` = 'd';""")
# Test sql-update-2 output format
assert list(mycli.change_table_format("sql-update-2")) == \
[(None, None, None, 'Changed table format to sql-update-2')]
assert list(mycli.change_table_format("sql-update-2")) == [(None, None, None, "Changed table format to sql-update-2")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
UPDATE `DUAL` SET
`optional` = NULL
, `float` = 10.0e0
@ -85,34 +78,31 @@ def test_sql_output(mycli):
`optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
WHERE `letters` = 'd' AND `number` = 456;''')
WHERE `letters` = 'd' AND `number` = 456;""")
# Test sql-insert output format (without table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
;""")
# Test sql-insert output format (with table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = "SELECT * FROM `table`"
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
;""")
# Test sql-insert output format (with database + table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = "SELECT * FROM `database`.`table`"
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
assert "\n".join(output) == dedent("""\
INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
;""")

View file

@ -9,20 +9,18 @@ import pytest
from mycli.main import special
PASSWORD = os.getenv('PYTEST_PASSWORD')
USER = os.getenv('PYTEST_USER', 'root')
HOST = os.getenv('PYTEST_HOST', 'localhost')
PORT = int(os.getenv('PYTEST_PORT', 3306))
CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
SSH_USER = os.getenv('PYTEST_SSH_USER', None)
SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)
PASSWORD = os.getenv("PYTEST_PASSWORD")
USER = os.getenv("PYTEST_USER", "root")
HOST = os.getenv("PYTEST_HOST", "localhost")
PORT = int(os.getenv("PYTEST_PORT", 3306))
CHARSET = os.getenv("PYTEST_CHARSET", "utf8")
SSH_USER = os.getenv("PYTEST_SSH_USER", None)
SSH_HOST = os.getenv("PYTEST_SSH_HOST", None)
SSH_PORT = os.getenv("PYTEST_SSH_PORT", 22)
def db_connection(dbname=None):
conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname,
password=PASSWORD, charset=CHARSET,
local_infile=False)
conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARSET, local_infile=False)
conn.autocommit = True
return conn
@ -30,20 +28,18 @@ def db_connection(dbname=None):
try:
db_connection()
CAN_CONNECT_TO_DB = True
except:
except Exception:
CAN_CONNECT_TO_DB = False
dbtest = pytest.mark.skipif(
not CAN_CONNECT_TO_DB,
reason="Need a mysql instance at localhost accessible by user 'root'")
dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason="Need a mysql instance at localhost accessible by user 'root'")
def create_db(dbname):
with db_connection().cursor() as cur:
try:
cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''')
cur.execute('''CREATE DATABASE mycli_test_db''')
except:
cur.execute("""DROP DATABASE IF EXISTS mycli_test_db""")
cur.execute("""CREATE DATABASE mycli_test_db""")
except Exception:
pass
@ -53,8 +49,7 @@ def run(executor, sql, rows_as_list=True):
for title, rows, headers, status in executor.run(sql):
rows = list(rows) if (rows_as_list and rows) else rows
result.append({'title': title, 'rows': rows, 'headers': headers,
'status': status})
result.append({"title": title, "rows": rows, "headers": headers, "status": status})
return result
@ -87,8 +82,6 @@ def send_ctrl_c(wait_seconds):
Returns the `multiprocessing.Process` created.
"""
ctrl_c_process = multiprocessing.Process(
target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
)
ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds))
ctrl_c_process.start()
return ctrl_c_process