1
0
Fork 0

Adding upstream version 1.23.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 18:53:36 +01:00
parent f253096a15
commit 94e3fc38e7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
93 changed files with 10761 additions and 0 deletions

0
test/__init__.py Normal file
View file

29
test/conftest.py Normal file
View file

@ -0,0 +1,29 @@
import pytest
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
import mycli.sqlexecute
@pytest.fixture(scope="function")
def connection():
create_db('_test_db')
connection = db_connection('_test_db')
yield connection
connection.close()
@pytest.fixture
def cursor(connection):
with connection.cursor() as cur:
return cur
@pytest.fixture
def executor(connection):
return mycli.sqlexecute.SQLExecute(
database='_test_db', user=USER,
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
)

View file

View file

@ -0,0 +1,12 @@
Feature: auto_vertical mode:
on, off
Scenario: auto_vertical on with small query
When we run dbcli with --auto-vertical-output
and we execute a small query
then we see small results in horizontal format
Scenario: auto_vertical on with large query
When we run dbcli with --auto-vertical-output
and we execute a large query
then we see large results in vertical format

View file

@ -0,0 +1,19 @@
Feature: run the cli,
call the help command,
exit the cli
Scenario: run "\?" command
When we send "\?" command
then we see help output
Scenario: run source command
When we send source command
then we see help output
Scenario: check our application_name
When we run query to check application_name
then we see found
Scenario: run the cli and exit
When we send "ctrl + d"
then dbcli exits

View file

@ -0,0 +1,30 @@
Feature: manipulate databases:
create, drop, connect, disconnect
Scenario: create and drop temporary database
When we create database
then we see database created
when we drop database
then we confirm the destructive warning
then we see database dropped
when we connect to dbserver
then we see database connected
Scenario: connect and disconnect from test database
When we connect to test database
then we see database connected
when we connect to dbserver
then we see database connected
Scenario: connect and disconnect from quoted test database
When we connect to quoted test database
then we see database connected
Scenario: create and drop default database
When we create database
then we see database created
when we connect to tmp database
then we see database connected
when we drop database
then we confirm the destructive warning
then we see database dropped and no default database

View file

@ -0,0 +1,49 @@
Feature: manipulate tables:
create, insert, update, select, delete from, drop
Scenario: create, insert, select from, update, drop table
When we connect to test database
then we see database connected
when we create table
then we see table created
when we insert into table
then we see record inserted
when we update table
then we see record updated
when we select from table
then we see data selected
when we delete from table
then we confirm the destructive warning
then we see record deleted
when we drop table
then we confirm the destructive warning
then we see table dropped
when we connect to dbserver
then we see database connected
Scenario: select null values
When we connect to test database
then we see database connected
when we select null
then we see null selected
Scenario: confirm destructive query
When we query "create table foo(x integer);"
and we query "delete from foo;"
and we answer the destructive warning with "y"
then we see text "Your call!"
Scenario: decline destructive query
When we query "delete from foo;"
and we answer the destructive warning with "n"
then we see text "Wise choice!"
Scenario: no destructive warning if disabled in config
When we run dbcli with --no-warn
and we query "create table blabla(x integer);"
and we query "delete from blabla;"
Then we see text "Query OK"
Scenario: confirm destructive query with invalid response
When we query "delete from foo;"
then we answer the destructive warning with invalid "1" and see text "is not a valid boolean"

93
test/features/db_utils.py Normal file
View file

@ -0,0 +1,93 @@
import pymysql
def create_db(hostname='localhost', port=3306, username=None,
password=None, dbname=None):
"""Create test database.
:param hostname: string
:param port: int
:param username: string
:param password: string
:param dbname: string
:return:
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
cr.execute('drop database if exists ' + dbname)
cr.execute('create database ' + dbname)
cn.close()
cn = create_cn(hostname, port, password, username, dbname)
return cn
def create_cn(hostname, port, password, username, dbname):
"""Open connection to database.
:param hostname:
:param port:
:param password:
:param username:
:param dbname: string
:return: psycopg2.connection
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
return cn
def drop_db(hostname='localhost', port=3306, username=None,
password=None, dbname=None):
"""Drop database.
:param hostname: string
:param port: int
:param username: string
:param password: string
:param dbname: string
"""
cn = pymysql.connect(
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
cr.execute('drop database if exists ' + dbname)
close_cn(cn)
def close_cn(cn=None):
"""Close connection.
:param connection: pymysql.connection
"""
if cn:
cn.close()

View file

@ -0,0 +1,140 @@
import os
import sys
from tempfile import mkstemp
import db_utils as dbutils
import fixture_utils as fixutils
import pexpect
from steps.wrappers import run_cli, wait_prompt
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
def before_all(context):
"""Set env parameters."""
os.environ['LINES'] = "100"
os.environ['COLUMNS'] = "100"
os.environ['EDITOR'] = 'ex'
os.environ['LC_ALL'] = 'en_US.UTF-8'
os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
os.environ['MYCLI_HISTFILE'] = os.devnull
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
'.coveragerc')
context.exit_sent = False
vi = '_'.join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get(
'my_test_db', None) or "mycli_behave_tests"
db_name_full = '{0}_{1}'.format(db_name, vi)
# Store get params from config/environment variables
context.conf = {
'host': context.config.userdata.get(
'my_test_host',
os.getenv('PYTEST_HOST', 'localhost')
),
'port': context.config.userdata.get(
'my_test_port',
int(os.getenv('PYTEST_PORT', '3306'))
),
'user': context.config.userdata.get(
'my_test_user',
os.getenv('PYTEST_USER', 'root')
),
'pass': context.config.userdata.get(
'my_test_pass',
os.getenv('PYTEST_PASSWORD', None)
),
'cli_command': context.config.userdata.get(
'my_cli_command', None) or
sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
'dbname': db_name,
'dbname_tmp': db_name_full + '_tmp',
'vi': vi,
'pager_boundary': '---boundary---',
}
_, my_cnf = mkstemp()
with open(my_cnf, 'w') as f:
f.write(
'[client]\n'
'pager={0} {1} {2}\n'.format(
sys.executable, os.path.join(context.package_root,
'test/features/wrappager.py'),
context.conf['pager_boundary'])
)
context.conf['defaults-file'] = my_cnf
context.conf['myclirc'] = os.path.join(context.package_root, 'test',
'myclirc')
context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
context.conf['user'],
context.conf['pass'],
context.conf['dbname'])
context.fixture_data = fixutils.read_fixture_files()
def after_all(context):
"""Unset env parameters."""
dbutils.close_cn(context.cn)
dbutils.drop_db(context.conf['host'], context.conf['port'],
context.conf['user'], context.conf['pass'],
context.conf['dbname'])
# Restore env vars.
#for k, v in context.pgenv.items():
# if k in os.environ and v is None:
# del os.environ[k]
# elif v:
# os.environ[k] = v
def before_step(context, _):
context.atprompt = False
def before_scenario(context, _):
with open(test_log_file, 'w') as f:
f.write('')
run_cli(context)
wait_prompt(context)
def after_scenario(context, _):
"""Cleans up after each test complete."""
with open(test_log_file) as f:
for line in f:
if 'error' in line.lower():
raise RuntimeError(f'Error in log file: {line}')
if hasattr(context, 'cli') and not context.exit_sent:
# Quit nicely.
if not context.atprompt:
user = context.conf['user']
host = context.conf['host']
dbname = context.currentdb
context.cli.expect_exact(
'{0}@{1}:{2}>'.format(
user, host, dbname
),
timeout=5
)
context.cli.sendcontrol('c')
context.cli.sendcontrol('d')
context.cli.expect_exact(pexpect.EOF, timeout=5)
# TODO: uncomment to debug a failure
# def after_step(context, step):
# if step.status == "failed":
# import ipdb; ipdb.set_trace()

View file

@ -0,0 +1,24 @@
+--------------------------+-----------------------------------------------+
| Command | Description |
|--------------------------+-----------------------------------------------|
| \# | Refresh auto-completions. |
| \? | Show Help. |
| \c[onnect] database_name | Change to a new database. |
| \d [pattern] | List or describe tables, views and sequences. |
| \dT[S+] [pattern] | List data types |
| \df[+] [pattern] | List functions. |
| \di[+] [pattern] | List indexes. |
| \dn[+] [pattern] | List schemas. |
| \ds[+] [pattern] | List sequences. |
| \dt[+] [pattern] | List tables. |
| \du[+] [pattern] | List roles. |
| \dv[+] [pattern] | List views. |
| \e [file] | Edit the query with external editor. |
| \l | List databases. |
| \n[+] [name] | List or execute named queries. |
| \nd [name [query]] | Delete a named query. |
| \ns name query | Save a named query. |
| \refresh | Refresh auto-completions. |
| \timing | Toggle timing of commands. |
| \x | Toggle expanded output. |
+--------------------------+-----------------------------------------------+

View file

@ -0,0 +1,31 @@
+-------------+----------------------------+------------------------------------------------------------+
| Command | Shortcut | Description |
+-------------+----------------------------+------------------------------------------------------------+
| \G | \G | Display current query results vertically. |
| \clip | \clip | Copy query to the system clipboard. |
| \dt | \dt[+] [table] | List or describe tables. |
| \e | \e | Edit command with editor (uses $EDITOR). |
| \f | \f [name [args..]] | List or execute favorite queries. |
| \fd | \fd [name] | Delete a favorite query. |
| \fs | \fs name query | Save a favorite query. |
| \l | \l | List databases. |
| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). |
| \pipe_once | \| command | Send next result to a subprocess. |
| \timing | \t | Toggle timing of commands. |
| connect | \r | Reconnect to the database. Optional database argument. |
| exit | \q | Exit. |
| help | \? | Show this help. |
| nopager | \n | Disable pager, print to stdout. |
| notee | notee | Stop writing results to an output file. |
| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
| prompt | \R | Change prompt format. |
| quit | \q | Quit. |
| rehash | \# | Refresh auto-completions. |
| source | \. filename | Execute commands from file. |
| status | \s | Get status information from the server. |
| system | system [command] | Execute a system shell commmand. |
| tableformat | \T | Change the table format used to output results. |
| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
| use | \u | Change to a new database. |
| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
+-------------+----------------------------+------------------------------------------------------------+

View file

@ -0,0 +1,29 @@
import os
import io
def read_fixture_lines(filename):
"""Read lines of text from file.
:param filename: string name
:return: list of strings
"""
lines = []
for line in open(filename):
lines.append(line.strip())
return lines
def read_fixture_files():
"""Read all files inside fixture_data directory."""
fixture_dict = {}
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, 'fixture_data/')
for filename in os.listdir(fixture_dir):
if filename not in ['.', '..']:
fullname = os.path.join(fixture_dir, filename)
fixture_dict[filename] = read_fixture_lines(fullname)
return fixture_dict

View file

@ -0,0 +1,47 @@
Feature: I/O commands
Scenario: edit sql in file with external editor
When we start external editor providing a file name
and we type "select * from abc" in the editor
and we exit the editor
then we see dbcli prompt
and we see "select * from abc" in prompt
Scenario: tee output from query
When we tee output
and we wait for prompt
and we select "select 123456"
and we wait for prompt
and we notee output
and we wait for prompt
then we see 123456 in tee output
Scenario: set delimiter
When we query "delimiter $"
then delimiter is set to "$"
Scenario: set delimiter twice
When we query "delimiter $"
and we query "delimiter ]]"
then delimiter is set to "]]"
Scenario: set delimiter and query on same line
When we query "select 123; delimiter $ select 456 $ delimiter %"
then we see result "123"
and we see result "456"
and delimiter is set to "%"
Scenario: send output to file
When we query "\o /tmp/output1.sql"
and we query "select 123"
and we query "system cat /tmp/output1.sql"
then we see result "123"
Scenario: send output to file two times
When we query "\o /tmp/output1.sql"
and we query "select 123"
and we query "\o /tmp/output2.sql"
and we query "select 456"
and we query "system cat /tmp/output2.sql"
then we see result "456"

View file

@ -0,0 +1,24 @@
Feature: named queries:
save, use and delete named queries
Scenario: save, use and delete named queries
When we connect to test database
then we see database connected
when we save a named query
then we see the named query saved
when we use a named query
then we see the named query executed
when we delete a named query
then we see the named query deleted
Scenario: save, use and delete named queries with parameters
When we connect to test database
then we see database connected
when we save a named query with parameters
then we see the named query saved
when we use named query with parameters
then we see the named query with parameters executed
when we use named query with too few parameters
then we see the named query with parameters fail with missing parameters
when we use named query with too many parameters
then we see the named query with parameters fail with extra parameters

View file

@ -0,0 +1,7 @@
Feature: Special commands
@wip
Scenario: run refresh command
When we refresh completions
and we wait for prompt
then we see completions refresh started

View file

View file

@ -0,0 +1,45 @@
from textwrap import dedent
from behave import then, when
import wrappers
@when('we run dbcli with {arg}')
def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=arg.split('='))
@when('we execute a small query')
def step_execute_small_query(context):
context.cli.sendline('select 1')
@when('we execute a large query')
def step_execute_large_query(context):
context.cli.sendline(
'select {}'.format(','.join([str(n) for n in range(1, 50)])))
@then('we see small results in horizontal format')
def step_see_small_results(context):
wrappers.expect_pager(context, dedent("""\
+---+\r
| 1 |\r
+---+\r
| 1 |\r
+---+\r
\r
"""), timeout=5)
wrappers.expect_exact(context, '1 row in set', timeout=2)
@then('we see large results in vertical format')
def step_see_large_results(context):
rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)]
expected = ('***************************[ 1. row ]'
'***************************\r\n' +
'{}\r\n'.format('\r\n'.join(rows) + '\r\n'))
wrappers.expect_pager(context, expected, timeout=10)
wrappers.expect_exact(context, '1 row in set', timeout=2)

View file

@ -0,0 +1,100 @@
"""Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. This string is used
to call the step in "*.feature" file.
"""
from behave import when
from textwrap import dedent
import tempfile
import wrappers
@when('we run dbcli')
def step_run_cli(context):
wrappers.run_cli(context)
@when('we wait for prompt')
def step_wait_prompt(context):
wrappers.wait_prompt(context)
@when('we send "ctrl + d"')
def step_ctrl_d(context):
"""Send Ctrl + D to hopefully exit."""
context.cli.sendcontrol('d')
context.exit_sent = True
@when('we send "\?" command')
def step_send_help(context):
"""Send \?
to see help.
"""
context.cli.sendline('\\?')
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
@when(u'we send source command')
def step_send_source_command(context):
with tempfile.NamedTemporaryFile() as f:
f.write(b'\?')
f.flush()
context.cli.sendline('\. {0}'.format(f.name))
wrappers.expect_exact(
context, context.conf['pager_boundary'] + '\r\n', timeout=5)
@when(u'we run query to check application_name')
def step_check_application_name(context):
context.cli.sendline(
"SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'"
)
@then(u'we see found')
def step_see_found(context):
wrappers.expect_exact(
context,
context.conf['pager_boundary'] + '\r' + dedent('''
+-------+\r
| found |\r
+-------+\r
| found |\r
+-------+\r
\r
''') + context.conf['pager_boundary'],
timeout=5
)
@then(u'we confirm the destructive warning')
def step_confirm_destructive_command(context):
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
context.cli.sendline('y')
@when(u'we answer the destructive warning with "{confirmation}"')
def step_confirm_destructive_command(context, confirmation):
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
context.cli.sendline(confirmation)
@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
def step_confirm_destructive_command(context, confirmation, text):
"""Confirm destructive command."""
wrappers.expect_exact(
context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
context.cli.sendline(confirmation)
wrappers.expect_exact(context, text, timeout=2)
# we must exit the Click loop, or the feature will hang
context.cli.sendline('n')

View file

@ -0,0 +1,115 @@
"""Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. This string is used
to call the step in "*.feature" file.
"""
import pexpect
import wrappers
from behave import when, then
@when('we create database')
def step_db_create(context):
"""Send create database."""
context.cli.sendline('create database {0};'.format(
context.conf['dbname_tmp']))
context.response = {
'database_name': context.conf['dbname_tmp']
}
@when('we drop database')
def step_db_drop(context):
"""Send drop database."""
context.cli.sendline('drop database {0};'.format(
context.conf['dbname_tmp']))
@when('we connect to test database')
def step_db_connect_test(context):
"""Send connect to database."""
db_name = context.conf['dbname']
context.currentdb = db_name
context.cli.sendline('use {0};'.format(db_name))
@when('we connect to quoted test database')
def step_db_connect_quoted_tmp(context):
"""Send connect to database."""
db_name = context.conf['dbname']
context.currentdb = db_name
context.cli.sendline('use `{0}`;'.format(db_name))
@when('we connect to tmp database')
def step_db_connect_tmp(context):
"""Send connect to database."""
db_name = context.conf['dbname_tmp']
context.currentdb = db_name
context.cli.sendline('use {0}'.format(db_name))
@when('we connect to dbserver')
def step_db_connect_dbserver(context):
"""Send connect to database."""
context.currentdb = 'mysql'
context.cli.sendline('use mysql')
@then('dbcli exits')
def step_wait_exit(context):
"""Make sure the cli exits."""
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
@then('we see dbcli prompt')
def step_see_prompt(context):
"""Wait to see the prompt."""
user = context.conf['user']
host = context.conf['host']
dbname = context.currentdb
wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname))
@then('we see help output')
def step_see_help(context):
for expected_line in context.fixture_data['help_commands.txt']:
wrappers.expect_exact(context, expected_line, timeout=1)
@then('we see database created')
def step_see_db_created(context):
"""Wait to see create database output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
@then('we see database dropped')
def step_see_db_dropped(context):
"""Wait to see drop database output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
@then('we see database dropped and no default database')
def step_see_db_dropped_no_default(context):
"""Wait to see drop database output."""
user = context.conf['user']
host = context.conf['host']
database = '(none)'
context.currentdb = None
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database))
@then('we see database connected')
def step_see_db_connected(context):
"""Wait to see drop database output."""
wrappers.expect_exact(
context, 'You are now connected to database "', timeout=2)
wrappers.expect_exact(context, '"', timeout=2)
wrappers.expect_exact(context, ' as user "{0}"'.format(
context.conf['user']), timeout=2)

View file

@ -0,0 +1,112 @@
"""Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. This string is used
to call the step in "*.feature" file.
"""
import wrappers
from behave import when, then
from textwrap import dedent
@when('we create table')
def step_create_table(context):
"""Send create table."""
context.cli.sendline('create table a(x text);')
@when('we insert into table')
def step_insert_into_table(context):
"""Send insert into table."""
context.cli.sendline('''insert into a(x) values('xxx');''')
@when('we update table')
def step_update_table(context):
"""Send insert into table."""
context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
@when('we select from table')
def step_select_from_table(context):
"""Send select from table."""
context.cli.sendline('select * from a;')
@when('we delete from table')
def step_delete_from_table(context):
"""Send deete from table."""
context.cli.sendline('''delete from a where x = 'yyy';''')
@when('we drop table')
def step_drop_table(context):
"""Send drop table."""
context.cli.sendline('drop table a;')
@then('we see table created')
def step_see_table_created(context):
"""Wait to see create table output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
@then('we see record inserted')
def step_see_record_inserted(context):
"""Wait to see insert output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
@then('we see record updated')
def step_see_record_updated(context):
"""Wait to see update output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
@then('we see data selected')
def step_see_data_selected(context):
"""Wait to see select output."""
wrappers.expect_pager(
context, dedent("""\
+-----+\r
| x |\r
+-----+\r
| yyy |\r
+-----+\r
\r
"""), timeout=2)
wrappers.expect_exact(context, '1 row in set', timeout=2)
@then('we see record deleted')
def step_see_data_deleted(context):
"""Wait to see delete output."""
wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
@then('we see table dropped')
def step_see_table_dropped(context):
"""Wait to see drop output."""
wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
@when('we select null')
def step_select_null(context):
"""Send select null."""
context.cli.sendline('select null;')
@then('we see null selected')
def step_see_null_selected(context):
"""Wait to see null output."""
wrappers.expect_pager(
context, dedent("""\
+--------+\r
| NULL |\r
+--------+\r
| <null> |\r
+--------+\r
\r
"""), timeout=2)
wrappers.expect_exact(context, '1 row in set', timeout=2)

View file

@ -0,0 +1,105 @@
import os
import wrappers
from behave import when, then
from textwrap import dedent
@when('we start external editor providing a file name')
def step_edit_file(context):
"""Edit file with external editor."""
context.editor_file_name = os.path.join(
context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline('\e {0}'.format(
os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
wrappers.expect_exact(context, '\r\n:', timeout=2)
@when('we type "{query}" in the editor')
def step_edit_type_sql(context, query):
context.cli.sendline('i')
context.cli.sendline(query)
context.cli.sendline('.')
wrappers.expect_exact(context, '\r\n:', timeout=2)
@when('we exit the editor')
def step_edit_quit(context):
context.cli.sendline('x')
wrappers.expect_exact(context, "written", timeout=2)
@then('we see "{query}" in prompt')
def step_edit_done_sql(context, query):
for match in query.split(' '):
wrappers.expect_exact(context, match, timeout=5)
# Cleanup the command line.
context.cli.sendcontrol('c')
# Cleanup the edited file.
if context.editor_file_name and os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
@when(u'we tee output')
def step_tee_ouptut(context):
context.tee_file_name = os.path.join(
context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline('tee {0}'.format(
os.path.basename(context.tee_file_name)))
@when(u'we select "select {param}"')
def step_query_select_number(context, param):
context.cli.sendline(u'select {}'.format(param))
wrappers.expect_pager(context, dedent(u"""\
+{dashes}+\r
| {param} |\r
+{dashes}+\r
| {param} |\r
+{dashes}+\r
\r
""".format(param=param, dashes='-' * (len(param) + 2))
), timeout=5)
wrappers.expect_exact(context, '1 row in set', timeout=2)
@then(u'we see result "{result}"')
def step_see_result(context, result):
wrappers.expect_exact(
context,
u"| {} |".format(result),
timeout=2
)
@when(u'we query "{query}"')
def step_query(context, query):
context.cli.sendline(query)
@when(u'we notee output')
def step_notee_output(context):
context.cli.sendline('notee')
@then(u'we see 123456 in tee output')
def step_see_123456_in_ouput(context):
with open(context.tee_file_name) as f:
assert '123456' in f.read()
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
@then(u'delimiter is set to "{delimiter}"')
def delimiter_is_set(context, delimiter):
wrappers.expect_exact(
context,
u'Changed delimiter to {}'.format(delimiter),
timeout=2
)

View file

@ -0,0 +1,90 @@
"""Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. This string is used
to call the step in "*.feature" file.
"""
import wrappers
from behave import when, then
@when('we save a named query')
def step_save_named_query(context):
"""Send \fs command."""
context.cli.sendline('\\fs foo SELECT 12345')
@when('we use a named query')
def step_use_named_query(context):
"""Send \f command."""
context.cli.sendline('\\f foo')
@when('we delete a named query')
def step_delete_named_query(context):
"""Send \fd command."""
context.cli.sendline('\\fd foo')
@then('we see the named query saved')
def step_see_named_query_saved(context):
"""Wait to see query saved."""
wrappers.expect_exact(context, 'Saved.', timeout=2)
@then('we see the named query executed')
def step_see_named_query_executed(context):
"""Wait to see select output."""
wrappers.expect_exact(context, 'SELECT 12345', timeout=2)
@then('we see the named query deleted')
def step_see_named_query_deleted(context):
"""Wait to see query deleted."""
wrappers.expect_exact(context, 'foo: Deleted', timeout=2)
@when('we save a named query with parameters')
def step_save_named_query_with_parameters(context):
"""Send \fs command for query with parameters."""
context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"')
@when('we use named query with parameters')
def step_use_named_query_with_parameters(context):
"""Send \f command with parameters."""
context.cli.sendline('\\f foo_args 101 second "third value"')
@then('we see the named query with parameters executed')
def step_see_named_query_with_parameters_executed(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'SELECT 101, "second", "third value"', timeout=2)
@when('we use named query with too few parameters')
def step_use_named_query_with_too_few_parameters(context):
"""Send \f command with missing parameters."""
context.cli.sendline('\\f foo_args 101')
@then('we see the named query with parameters fail with missing parameters')
def step_see_named_query_with_parameters_fail_with_missing_parameters(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'missing substitution for $2 in query:', timeout=2)
@when('we use named query with too many parameters')
def step_use_named_query_with_too_many_parameters(context):
"""Send \f command with extra parameters."""
context.cli.sendline('\\f foo_args 101 102 103 104')
@then('we see the named query with parameters fail with extra parameters')
def step_see_named_query_with_parameters_fail_with_extra_parameters(context):
"""Wait to see select output."""
wrappers.expect_exact(
context, 'query does not have substitution parameter $4:', timeout=2)

View file

@ -0,0 +1,27 @@
"""Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. This string is used
to call the step in "*.feature" file.
"""
import wrappers
from behave import when, then
@when('we refresh completions')
def step_refresh_completions(context):
"""Send refresh command."""
context.cli.sendline('rehash')
@then('we see text "{text}"')
def step_see_text(context, text):
"""Wait to see given text message."""
wrappers.expect_exact(context, text, timeout=2)
@then('we see completions refresh started')
def step_see_refresh_started(context):
"""Wait to see refresh output."""
wrappers.expect_exact(
context, 'Auto-completion refresh started in the background.', timeout=2)

View file

@ -0,0 +1,94 @@
import re
import pexpect
import sys
import textwrap
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def expect_exact(context, expected, timeout):
timedout = False
try:
context.cli.expect_exact(expected, timeout=timeout)
except pexpect.exceptions.TIMEOUT:
timedout = True
if timedout:
# Strip color codes out of the output.
actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
'', context.cli.before)
raise Exception(
textwrap.dedent('''\
Expected:
---
{0!r}
---
Actual:
---
{1!r}
---
Full log:
---
{2!r}
---
''').format(
expected,
actual,
context.logfile.getvalue()
)
)
def expect_pager(context, expected, timeout):
expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
context.conf['pager_boundary'], expected), timeout=timeout)
def run_cli(context, run_args=None):
"""Run the process using pexpect."""
run_args = run_args or []
if context.conf.get('host', None):
run_args.extend(('-h', context.conf['host']))
if context.conf.get('user', None):
run_args.extend(('-u', context.conf['user']))
if context.conf.get('pass', None):
run_args.extend(('-p', context.conf['pass']))
if context.conf.get('dbname', None):
run_args.extend(('-D', context.conf['dbname']))
if context.conf.get('defaults-file', None):
run_args.extend(('--defaults-file', context.conf['defaults-file']))
if context.conf.get('myclirc', None):
run_args.extend(('--myclirc', context.conf['myclirc']))
try:
cli_cmd = context.conf['cli_command']
except KeyError:
cli_cmd = (
'{0!s} -c "'
'import coverage ; '
'coverage.process_startup(); '
'import mycli.main; '
'mycli.main.cli()'
'"'
).format(sys.executable)
cmd_parts = [cli_cmd] + run_args
cmd = ' '.join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = context.conf['dbname']
def wait_prompt(context, prompt=None):
"""Make sure prompt is displayed."""
if prompt is None:
user = context.conf['user']
host = context.conf['host']
dbname = context.currentdb
prompt = '{0}@{1}:{2}>'.format(
user, host, dbname),
expect_exact(context, prompt, timeout=5)
context.atprompt = True

16
test/features/wrappager.py Executable file
View file

@ -0,0 +1,16 @@
#!/usr/bin/env python
import sys
def wrappager(boundary):
print(boundary)
while 1:
buf = sys.stdin.read(2048)
if not buf:
break
sys.stdout.write(buf)
print(boundary)
if __name__ == "__main__":
wrappager(sys.argv[1])

12
test/myclirc Normal file
View file

@ -0,0 +1,12 @@
# vi: ft=dosini
# This file is loaded after mycli/myclirc and should override only those
# variables needed for testing.
# To see what every variable does see mycli/myclirc
[main]
log_file = ~/.mycli.test.log
log_level = DEBUG
prompt = '\t \u@\h:\d> '
less_chatty = True

BIN
test/mylogin.cnf Normal file

Binary file not shown.

1
test/test.txt Normal file
View file

@ -0,0 +1 @@
mycli rocks!

27
test/test_clistyle.py Normal file
View file

@ -0,0 +1,27 @@
"""Test the mycli.clistyle module."""
import pytest
from pygments.style import Style
from pygments.token import Token
from mycli.clistyle import style_factory
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory():
"""Test that a Pygments Style class is created."""
header = 'bold underline #ansired'
cli_style = {'Token.Output.Header': header}
style = style_factory('default', cli_style)
assert isinstance(style(), Style)
assert Token.Output.Header in style.styles
assert header == style.styles[Token.Output.Header]
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory_unknown_name():
"""Test that an unrecognized name will not throw an error."""
style = style_factory('foobar', {})
assert isinstance(style(), Style)

View file

@ -0,0 +1,537 @@
from mycli.packages.completion_engine import suggest_type
import pytest
def sorted_dicts(dicts):
"""input is a list of dicts."""
return sorted(tuple(x.items()) for x in dicts)
def test_select_suggests_cols_with_visible_table_scope():
suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
def test_select_suggests_cols_with_qualified_table_scope():
suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [('sch', 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE (',
'SELECT * FROM tabl WHERE foo = ',
'SELECT * FROM tabl WHERE bar OR ',
'SELECT * FROM tabl WHERE foo = 1 AND ',
'SELECT * FROM tabl WHERE (bar > 10 AND ',
'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (',
'SELECT * FROM tabl WHERE 10 < ',
'SELECT * FROM tabl WHERE foo BETWEEN ',
'SELECT * FROM tabl WHERE foo BETWEEN foo AND ',
])
def test_where_suggests_columns_functions(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
@pytest.mark.parametrize('expression', [
'SELECT * FROM tabl WHERE foo IN (',
'SELECT * FROM tabl WHERE foo IN (bar, ',
])
def test_where_in_suggests_columns(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
def test_where_equals_any_suggests_columns_or_keywords():
text = 'SELECT * FROM tabl WHERE foo = ANY('
suggestions = suggest_type(text, text)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'}])
def test_lparen_suggests_cols():
suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
def test_operand_inside_function_suggests_cols1():
suggestion = suggest_type(
'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
def test_operand_inside_function_suggests_cols2():
suggestion = suggest_type(
'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ')
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
def test_select_suggests_cols_and_funcs():
suggestions = suggest_type('SELECT ', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': []},
{'type': 'column', 'tables': []},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
@pytest.mark.parametrize('expression', [
'SELECT * FROM ',
'INSERT INTO ',
'COPY ',
'UPDATE ',
'DESCRIBE ',
'DESC ',
'EXPLAIN ',
'SELECT * FROM foo JOIN ',
])
def test_expression_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
@pytest.mark.parametrize('expression', [
'SELECT * FROM sch.',
'INSERT INTO sch.',
'COPY sch.',
'UPDATE sch.',
'DESCRIBE sch.',
'DESC sch.',
'EXPLAIN sch.',
'SELECT * FROM foo JOIN sch.',
])
def test_expression_suggests_qualified_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 'sch'},
{'type': 'view', 'schema': 'sch'}])
def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'schema'}])
def test_truncate_suggests_qualified_tables():
suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 'sch'}])
def test_distinct_suggests_cols():
suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
assert suggestions == [{'type': 'column', 'tables': []}]
def test_col_comma_suggests_cols():
suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tbl']},
{'type': 'column', 'tables': [(None, 'tbl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type('SELECT a, b FROM tbl1, ',
'SELECT a, b FROM tbl1, ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
def test_insert_into_lparen_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
def test_insert_into_lparen_partial_text_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
def test_insert_into_lparen_comma_suggests_cols():
suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
def test_partially_typed_col_name_suggests_col_names():
suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
'SELECT * FROM tabl WHERE col_n')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['tabl']},
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'table', 'schema': 'tabl'},
{'type': 'view', 'schema': 'tabl'},
{'type': 'function', 'schema': 'tabl'}])
def test_dot_suggests_cols_of_an_alias():
suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
'SELECT t1.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 't1'},
{'type': 'view', 'schema': 't1'},
{'type': 'column', 'tables': [(None, 'tabl1', 't1')]},
{'type': 'function', 'schema': 't1'}])
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
'SELECT t1.a, t2.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
{'type': 'table', 'schema': 't2'},
{'type': 'view', 'schema': 't2'},
{'type': 'function', 'schema': 't2'}])
@pytest.mark.parametrize('expression', [
'SELECT * FROM (',
'SELECT * FROM foo WHERE EXISTS (',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (',
'SELECT 1 AS',
])
def test_sub_select_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == [{'type': 'keyword'}]
@pytest.mark.parametrize('expression', [
'SELECT * FROM (S',
'SELECT * FROM foo WHERE EXISTS (S',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (S',
])
def test_sub_select_partial_text_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == [{'type': 'keyword'}]
def test_outer_table_reference_in_exists_subquery_suggests_columns():
q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.'
suggestions = suggest_type(q, q)
assert suggestions == [
{'type': 'column', 'tables': [(None, 'foo', 'f')]},
{'type': 'table', 'schema': 'f'},
{'type': 'view', 'schema': 'f'},
{'type': 'function', 'schema': 'f'}]
@pytest.mark.parametrize('expression', [
'SELECT * FROM (SELECT * FROM ',
'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
])
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
def test_sub_select_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
'SELECT * FROM (SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['abc']},
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
@pytest.mark.xfail
def test_sub_select_multiple_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
'SELECT * FROM (SELECT a, ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function', 'schema': []}])
def test_sub_select_dot_col_name_completion():
suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
'SELECT * FROM (SELECT t.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', 't')]},
{'type': 'table', 'schema': 't'},
{'type': 'view', 'schema': 't'},
{'type': 'function', 'schema': 't'}])
@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
@pytest.mark.parametrize('tbl_alias', ['', 'foo'])
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
suggestion = suggest_type(text, text)
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.',
])
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', 'a')]},
{'type': 'table', 'schema': 'a'},
{'type': 'view', 'schema': 'a'},
{'type': 'function', 'schema': 'a'}])
@pytest.mark.parametrize('sql', [
'SELECT * FROM abc a JOIN def d ON a.id = d.',
'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.',
])
def test_join_alias_dot_suggests_cols2(sql):
suggestions = suggest_type(sql, sql)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'def', 'd')]},
{'type': 'table', 'schema': 'd'},
{'type': 'view', 'schema': 'd'},
{'type': 'function', 'schema': 'd'}])
@pytest.mark.parametrize('sql', [
'select a.x, b.y from abc a join bcd b on ',
'select a.x, b.y from abc a join bcd b on a.id = b.id OR ',
])
def test_on_suggests_aliases(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
@pytest.mark.parametrize('sql', [
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ',
])
def test_on_suggests_tables(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
@pytest.mark.parametrize('sql', [
'select a.x, b.y from abc a join bcd b on a.id = ',
'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ',
])
def test_on_suggests_aliases_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
@pytest.mark.parametrize('sql', [
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ',
])
def test_on_suggests_tables_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
@pytest.mark.parametrize('col_list', ['', 'col1, '])
def test_join_using_suggests_common_columns(col_list):
text = 'select * from abc inner join def using (' + col_list
assert suggest_type(text, text) == [
{'type': 'column',
'tables': [(None, 'abc', None), (None, 'def', None)],
'drop_unique': True}]
def test_2_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ',
'select * from a; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type('select * from a; select from b',
'select * from a; select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['b']},
{'type': 'column', 'tables': [(None, 'b', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
# Should work even if first statement is invalid
suggestions = suggest_type('select * from; select * from ',
'select * from; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
def test_2_statements_1st_current():
suggestions = suggest_type('select * from ; select * from b',
'select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type('select from a; select * from b',
'select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['a']},
{'type': 'column', 'tables': [(None, 'a', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
def test_3_statements_2nd_current():
suggestions = suggest_type('select * from a; select * from ; select * from c',
'select * from a; select * from ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
suggestions = suggest_type('select * from a; select from b; select * from c',
'select * from a; select ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'alias', 'aliases': ['b']},
{'type': 'column', 'tables': [(None, 'b', None)]},
{'type': 'function', 'schema': []},
{'type': 'keyword'},
])
def test_create_db_with_template():
suggestions = suggest_type('create database foo with template ',
'create database foo with template ')
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}])
@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t'])
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
assert sorted_dicts(suggestions) == \
sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])
def test_specials_not_included_after_initial_token():
suggestions = suggest_type('create table foo (dt d',
'create table foo (dt d')
assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}])
def test_drop_schema_qualified_table_suggests_only_tables():
text = 'DROP TABLE schema_name.table_name'
suggestions = suggest_type(text, text)
assert suggestions == [{'type': 'table', 'schema': 'schema_name'}]
@pytest.mark.parametrize('text', [',', ' ,', 'sel ,'])
def test_handle_pre_completion_comma_gracefully(text):
suggestions = suggest_type(text, text)
assert iter(suggestions)
def test_cross_join():
text = 'select * from v1 cross join v2 JOIN v1.id, '
suggestions = suggest_type(text, text)
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
@pytest.mark.parametrize('expression', [
'SELECT 1 AS ',
'SELECT 1 FROM tabl AS ',
])
def test_after_as(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set()
@pytest.mark.parametrize('expression', [
'\\. ',
'select 1; \\. ',
'select 1;\\. ',
'select 1 ; \\. ',
'source ',
'truncate table test; source ',
'truncate table test ; source ',
'truncate table test;source ',
])
def test_source_is_file(expression):
suggestions = suggest_type(expression, expression)
assert suggestions == [{'type': 'file_name'}]
@pytest.mark.parametrize("expression", [
"\\f ",
])
def test_favorite_name_suggestion(expression):
suggestions = suggest_type(expression, expression)
assert suggestions == [{'type': 'favoritequery'}]
def test_order_by():
text = 'select * from foo order by '
suggestions = suggest_type(text, text)
assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}]

View file

@ -0,0 +1,88 @@
import time
import pytest
from mock import Mock, patch
@pytest.fixture
def refresher():
from mycli.completion_refresher import CompletionRefresher
return CompletionRefresher()
def test_ctor(refresher):
"""Refresher object should contain a few handlers.
:param refresher:
:return:
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions',
'special_commands', 'show_commands']
assert expected_handlers == actual_handlers
def test_refresh_called_once(refresher):
"""
:param refresher:
:return:
"""
callbacks = Mock()
sqlexecute = Mock()
with patch.object(refresher, '_bg_refresh') as bg_refresh:
actual = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual) == 1
assert len(actual[0]) == 4
assert actual[0][3] == 'Auto-completion refresh started in the background.'
bg_refresh.assert_called_with(sqlexecute, callbacks, {})
def test_refresh_called_twice(refresher):
"""If refresh is called a second time, it should be restarted.
:param refresher:
:return:
"""
callbacks = Mock()
sqlexecute = Mock()
def dummy_bg_refresh(*args):
time.sleep(3) # seconds
refresher._bg_refresh = dummy_bg_refresh
actual1 = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual1) == 1
assert len(actual1[0]) == 4
assert actual1[0][3] == 'Auto-completion refresh started in the background.'
actual2 = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual2) == 1
assert len(actual2[0]) == 4
assert actual2[0][3] == 'Auto-completion refresh restarted.'
def test_refresh_with_callbacks(refresher):
"""Callbacks must be called.
:param refresher:
"""
callbacks = [Mock()]
sqlexecute_class = Mock()
sqlexecute = Mock()
with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert (callbacks[0].call_count == 1)

196
test/test_config.py Normal file
View file

@ -0,0 +1,196 @@
"""Unit tests for the mycli.config module."""
from io import BytesIO, StringIO, TextIOWrapper
import os
import struct
import sys
import tempfile
import pytest
from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf,
read_and_decrypt_mylogin_cnf, read_config_file,
str_to_bool, strip_matching_quotes)
LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__),
'mylogin.cnf'))
def open_bmylogin_cnf(name):
"""Open contents of *name* in a BytesIO buffer."""
with open(name, 'rb') as f:
buf = BytesIO()
buf.write(f.read())
return buf
def test_read_mylogin_cnf():
"""Tests that a login path file can be read and decrypted."""
mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE)
assert isinstance(mylogin_cnf, TextIOWrapper)
contents = mylogin_cnf.read()
for word in ('[test]', 'user', 'password', 'host', 'port'):
assert word in contents
def test_decrypt_blank_mylogin_cnf():
"""Test that a blank login path file is handled correctly."""
mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO())
assert mylogin_cnf is None
def test_corrupted_login_key():
"""Test that a corrupted login path key is handled correctly."""
buf = open_bmylogin_cnf(LOGIN_PATH_FILE)
# Skip past the unused bytes
buf.seek(4)
# Write null bytes over half the login key
buf.write(b'\0\0\0\0\0\0\0\0\0\0')
buf.seek(0)
mylogin_cnf = read_and_decrypt_mylogin_cnf(buf)
assert mylogin_cnf is None
def test_corrupted_pad():
"""Tests that a login path file with a corrupted pad is partially read."""
buf = open_bmylogin_cnf(LOGIN_PATH_FILE)
# Skip past the login key
buf.seek(24)
# Skip option group
len_buf = buf.read(4)
cipher_len, = struct.unpack("<i", len_buf)
buf.read(cipher_len)
# Corrupt the pad for the user line
len_buf = buf.read(4)
cipher_len, = struct.unpack("<i", len_buf)
buf.read(cipher_len - 1)
buf.write(b'\0')
buf.seek(0)
mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf))
contents = mylogin_cnf.read()
for word in ('[test]', 'password', 'host', 'port'):
assert word in contents
assert 'user' not in contents
def test_get_mylogin_cnf_path():
"""Tests that the path for .mylogin.cnf is detected."""
original_env = None
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
is_windows = sys.platform == 'win32'
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
if login_cnf_path is not None:
assert login_cnf_path.endswith('.mylogin.cnf')
if is_windows is True:
assert 'MySQL' in login_cnf_path
else:
home_dir = os.path.expanduser('~')
assert login_cnf_path.startswith(home_dir)
def test_alternate_get_mylogin_cnf_path():
"""Tests that the alternate path for .mylogin.cnf is detected."""
original_env = None
if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
_, temp_path = tempfile.mkstemp()
os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
assert temp_path == login_cnf_path
def test_str_to_bool():
"""Tests that str_to_bool function converts values correctly."""
assert str_to_bool(False) is False
assert str_to_bool(True) is True
assert str_to_bool('False') is False
assert str_to_bool('True') is True
assert str_to_bool('TRUE') is True
assert str_to_bool('1') is True
assert str_to_bool('0') is False
assert str_to_bool('on') is True
assert str_to_bool('off') is False
assert str_to_bool('off') is False
with pytest.raises(ValueError):
str_to_bool('foo')
with pytest.raises(TypeError):
str_to_bool(None)
def test_read_config_file_list_values_default():
"""Test that reading a config file uses list_values by default."""
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f)
assert config['main']['weather'] == u"cloudy with a chance of meatballs"
def test_read_config_file_list_values_off():
"""Test that you can disable list_values when reading a config file."""
f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f, list_values=False)
assert config['main']['weather'] == u"'cloudy with a chance of meatballs'"
def test_strip_quotes_with_matching_quotes():
"""Test that a string with matching quotes is unquoted."""
s = "May the force be with you."
assert s == strip_matching_quotes('"{}"'.format(s))
assert s == strip_matching_quotes("'{}'".format(s))
def test_strip_quotes_with_unmatching_quotes():
"""Test that a string with unmatching quotes is not unquoted."""
s = "May the force be with you."
assert '"' + s == strip_matching_quotes('"{}'.format(s))
assert s + "'" == strip_matching_quotes("{}'".format(s))
def test_strip_quotes_with_empty_string():
"""Test that an empty string is handled during unquoting."""
assert '' == strip_matching_quotes('')
def test_strip_quotes_with_none():
"""Test that None is handled during unquoting."""
assert None is strip_matching_quotes(None)
def test_strip_quotes_with_quotes():
"""Test that strings with quotes in them are handled during unquoting."""
s1 = 'Darth Vader said, "Luke, I am your father."'
assert s1 == strip_matching_quotes(s1)
s2 = '"Darth Vader said, "Luke, I am your father.""'
assert s2[1:-1] == strip_matching_quotes(s2)

42
test/test_dbspecial.py Normal file
View file

@ -0,0 +1,42 @@
from mycli.packages.completion_engine import suggest_type
from .test_completion_engine import sorted_dicts
from mycli.packages.special.utils import format_uptime
def test_u_suggests_databases():
suggestions = suggest_type('\\u ', '\\u ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'database'}])
def test_describe_table():
suggestions = suggest_type('\\dt', '\\dt ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
def test_list_or_show_create_tables():
suggestions = suggest_type('\\dt+', '\\dt+ ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'view', 'schema': []},
{'type': 'schema'}])
def test_format_uptime():
seconds = 59
assert '59 sec' == format_uptime(seconds)
seconds = 120
assert '2 min 0 sec' == format_uptime(seconds)
seconds = 54890
assert '15 hours 14 min 50 sec' == format_uptime(seconds)
seconds = 598244
assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds)
seconds = 522600
assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds)

528
test/test_main.py Normal file
View file

@ -0,0 +1,528 @@
import os
import click
from click.testing import CliRunner
from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
from textwrap import dedent
from collections import namedtuple
from tempfile import NamedTemporaryFile
from textwrap import dedent
test_dir = os.path.abspath(os.path.dirname(__file__))
project_dir = os.path.dirname(test_dir)
default_config_file = os.path.join(project_dir, 'test', 'myclirc')
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
'--password', PASSWORD, '--myclirc', default_config_file,
'--defaults-file', default_config_file,
'_test_db']
@dbtest
def test_execute_arg(executor):
run(executor, 'create table test (a text)')
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
assert result.exit_code == 0
assert 'abc' in result.output
result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
assert result.exit_code == 0
assert 'abc' in result.output
expected = 'a\nabc\n'
assert expected in result.output
@dbtest
def test_execute_arg_with_table(executor):
run(executor, 'create table test (a text)')
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n'
assert result.exit_code == 0
assert expected in result.output
@dbtest
def test_execute_arg_with_csv(executor):
run(executor, 'create table test (a text)')
run(executor, 'insert into test values("abc")')
sql = 'select * from test;'
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
expected = '"a"\n"abc"\n'
assert result.exit_code == 0
assert expected in "".join(result.output)
@dbtest
def test_batch_mode(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
sql = (
'select count(*) from test;\n'
'select * from test limit 1;'
)
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
assert result.exit_code == 0
assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
@dbtest
def test_batch_mode_table(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
sql = (
'select count(*) from test;\n'
'select * from test limit 1;'
)
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
expected = (dedent("""\
+----------+
| count(*) |
+----------+
| 3 |
+----------+
+-----+
| a |
+-----+
| abc |
+-----+"""))
assert result.exit_code == 0
assert expected in result.output
@dbtest
def test_batch_mode_csv(executor):
run(executor, '''create table test(a text, b text)''')
run(executor,
'''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
sql = 'select * from test;'
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
assert result.exit_code == 0
assert expected in "".join(result.output)
def test_thanks_picker_utf8():
author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
name = thanks_picker((author_file, sponsor_file))
assert name and isinstance(name, str)
def test_help_strings_end_with_periods():
"""Make sure click options have help text that end with a period."""
for param in cli.params:
if isinstance(param, click.core.Option):
assert hasattr(param, 'help')
assert param.help.endswith('.')
def test_command_descriptions_end_with_periods():
"""Make sure that mycli commands' descriptions end with a period."""
MyCli()
for _, command in SPECIAL_COMMANDS.items():
assert command[3].endswith('.')
def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
global clickoutput
clickoutput = ""
m = MyCli(myclirc=default_config_file)
class TestOutput():
def get_size(self):
size = namedtuple('Size', 'rows columns')
size.columns, size.rows = terminal_size
return size
class TestExecute():
host = 'test'
user = 'test'
dbname = 'test'
port = 0
def server_type(self):
return ['test']
class PromptBuffer():
output = TestOutput()
m.prompt_app = PromptBuffer()
m.sqlexecute = TestExecute()
m.explicit_pager = explicit_pager
def echo_via_pager(s):
assert expect_pager
global clickoutput
clickoutput += "".join(s)
def secho(s):
assert not expect_pager
global clickoutput
clickoutput += s + "\n"
monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
monkeypatch.setattr(click, 'secho', secho)
m.output(testdata)
if clickoutput.endswith("\n"):
clickoutput = clickoutput[:-1]
assert clickoutput == "\n".join(testdata)
def test_conditional_pager(monkeypatch):
testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
" ")
# User didn't set pager, output doesn't fit screen -> pager
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=False,
expect_pager=True
)
# User didn't set pager, output fits screen -> no pager
output(
monkeypatch,
terminal_size=(20, 20),
testdata=testdata,
explicit_pager=False,
expect_pager=False
)
# User manually configured pager, output doesn't fit screen -> pager
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=True,
expect_pager=True
)
# User manually configured pager, output fit screen -> pager
output(
monkeypatch,
terminal_size=(20, 20),
testdata=testdata,
explicit_pager=True,
expect_pager=True
)
SPECIAL_COMMANDS['nopager'].handler()
output(
monkeypatch,
terminal_size=(5, 10),
testdata=testdata,
explicit_pager=False,
expect_pager=False
)
SPECIAL_COMMANDS['pager'].handler('')
def test_reserved_space_is_integer():
"""Make sure that reserved space is returned as an integer."""
def stub_terminal_size():
return (5, 5)
old_func = click.get_terminal_size
click.get_terminal_size = stub_terminal_size
mycli = MyCli()
assert isinstance(mycli.get_reserved_space(), int)
click.get_terminal_size = old_func
def test_list_dsn():
runner = CliRunner()
with NamedTemporaryFile(mode="w") as myclirc:
myclirc.write(dedent("""\
[alias_dsn]
test = mysql://test/test
"""))
myclirc.flush()
args = ['--list-dsn', '--myclirc', myclirc.name]
result = runner.invoke(cli, args=args)
assert result.output == "test\n"
result = runner.invoke(cli, args=args + ['--verbose'])
assert result.output == "test : mysql://test/test\n"
def test_list_ssh_config():
runner = CliRunner()
with NamedTemporaryFile(mode="w") as ssh_config:
ssh_config.write(dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
"""))
ssh_config.flush()
args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
result = runner.invoke(cli, args=args)
assert "test\n" in result.output
result = runner.invoke(cli, args=args + ['--verbose'])
assert "test : test.example.com\n" in result.output
def test_dsn(monkeypatch):
# Setup classes to mock mycli.main.MyCli
class Formatter:
format_name = None
class Logger:
def debug(self, *args, **args_dict):
pass
def warning(self, *args, **args_dict):
pass
class MockMyCli:
config = {'alias_dsn': {}}
def __init__(self, **args):
self.logger = Logger()
self.destructive_warning = False
self.formatter = Formatter()
def connect(self, **args):
MockMyCli.connect_args = args
def run_query(self, query, new_line=True):
pass
import mycli.main
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
runner = CliRunner()
# When a user supplies a DSN as database argument to mycli,
# use these values.
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "dsn_user" and \
MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
MockMyCli.connect_args["host"] == "dsn_host" and \
MockMyCli.connect_args["port"] == 1 and \
MockMyCli.connect_args["database"] == "dsn_database"
MockMyCli.connect_args = None
# When a use supplies a DSN as database argument to mycli,
# and used command line arguments, use the command line
# arguments.
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
"--user", "arg_user",
"--password", "arg_password",
"--host", "arg_host",
"--port", "3",
"--database", "arg_database",
])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "arg_user" and \
MockMyCli.connect_args["passwd"] == "arg_password" and \
MockMyCli.connect_args["host"] == "arg_host" and \
MockMyCli.connect_args["port"] == 3 and \
MockMyCli.connect_args["database"] == "arg_database"
MockMyCli.config = {
'alias_dsn': {
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
}
}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn),
# use these values.
result = runner.invoke(cli, args=['--dsn', 'test'])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "alias_dsn_user" and \
MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
MockMyCli.connect_args["host"] == "alias_dsn_host" and \
MockMyCli.connect_args["port"] == 4 and \
MockMyCli.connect_args["database"] == "alias_dsn_database"
MockMyCli.config = {
'alias_dsn': {
'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
}
}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn)
# and used command line arguments, use the command line arguments.
result = runner.invoke(cli, args=[
'--dsn', 'test', '',
"--user", "arg_user",
"--password", "arg_password",
"--host", "arg_host",
"--port", "5",
"--database", "arg_database",
])
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "arg_user" and \
MockMyCli.connect_args["passwd"] == "arg_password" and \
MockMyCli.connect_args["host"] == "arg_host" and \
MockMyCli.connect_args["port"] == 5 and \
MockMyCli.connect_args["database"] == "arg_database"
# Use a DSN without password
result = runner.invoke(mycli.main.cli, args=[
"mysql://dsn_user@dsn_host:6/dsn_database"]
)
assert result.exit_code == 0, result.output + " " + str(result.exception)
assert \
MockMyCli.connect_args["user"] == "dsn_user" and \
MockMyCli.connect_args["passwd"] is None and \
MockMyCli.connect_args["host"] == "dsn_host" and \
MockMyCli.connect_args["port"] == 6 and \
MockMyCli.connect_args["database"] == "dsn_database"
def test_ssh_config(monkeypatch):
# Setup classes to mock mycli.main.MyCli
class Formatter:
format_name = None
class Logger:
def debug(self, *args, **args_dict):
pass
def warning(self, *args, **args_dict):
pass
class MockMyCli:
config = {'alias_dsn': {}}
def __init__(self, **args):
self.logger = Logger()
self.destructive_warning = False
self.formatter = Formatter()
def connect(self, **args):
MockMyCli.connect_args = args
def run_query(self, query, new_line=True):
pass
import mycli.main
monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
runner = CliRunner()
# Setup temporary configuration
with NamedTemporaryFile(mode="w") as ssh_config:
ssh_config.write(dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
"""))
ssh_config.flush()
# When a user supplies a ssh config.
result = runner.invoke(mycli.main.cli, args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test"
])
assert result.exit_code == 0, result.output + \
" " + str(result.exception)
assert \
MockMyCli.connect_args["ssh_user"] == "joe" and \
MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
MockMyCli.connect_args["ssh_port"] == 22222 and \
MockMyCli.connect_args["ssh_key_filename"] == os.getenv(
"HOME") + "/.ssh/gateway"
# When a user supplies a ssh config host as argument to mycli,
# and used command line arguments, use the command line
# arguments.
result = runner.invoke(mycli.main.cli, args=[
"--ssh-config-path",
ssh_config.name,
"--ssh-config-host",
"test",
"--ssh-user", "arg_user",
"--ssh-host", "arg_host",
"--ssh-port", "3",
"--ssh-key-filename", "/path/to/key"
])
assert result.exit_code == 0, result.output + \
" " + str(result.exception)
assert \
MockMyCli.connect_args["ssh_user"] == "arg_user" and \
MockMyCli.connect_args["ssh_host"] == "arg_host" and \
MockMyCli.connect_args["ssh_port"] == 3 and \
MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
@dbtest
def test_init_command_arg(executor):
init_command = "set sql_select_limit=1000"
sql = 'show variables like "sql_select_limit";'
runner = CliRunner()
result = runner.invoke(
cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
)
expected = "sql_select_limit\t1000\n"
assert result.exit_code == 0
assert expected in result.output
@dbtest
def test_init_command_multiple_arg(executor):
init_command = 'set sql_select_limit=2000; set max_join_size=20000'
sql = (
'show variables like "sql_select_limit";\n'
'show variables like "max_join_size"'
)
runner = CliRunner()
result = runner.invoke(
cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
)
expected_sql_select_limit = 'sql_select_limit\t2000\n'
expected_max_join_size = 'max_join_size\t20000\n'
assert result.exit_code == 0
assert expected_sql_select_limit in result.output
assert expected_max_join_size in result.output

View file

@ -0,0 +1,63 @@
import pytest
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
@pytest.fixture
def completer():
import mycli.sqlcompleter as sqlcompleter
return sqlcompleter.SQLCompleter(smart_completion=False)
@pytest.fixture
def complete_event():
from mock import Mock
return Mock()
def test_empty_string_completion(completer, complete_event):
text = ''
position = 0
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list(map(Completion, sorted(completer.all_completions)))
def test_select_keyword_completion(completer, complete_event):
text = 'SEL'
position = len('SEL')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([Completion(text='SELECT', start_position=-3)])
def test_function_name_completion(completer, complete_event):
text = 'SELECT MA'
position = len('SELECT MA')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='MASTER', start_position=-2),
Completion(text='MAX', start_position=-2)])
def test_column_name_completion(completer, complete_event):
text = 'SELECT FROM users'
position = len('SELECT ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list(map(Completion, sorted(completer.all_completions)))
def test_special_name_completion(completer, complete_event):
text = '\\'
position = len('\\')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
# Special commands will NOT be suggested during naive completion mode.
assert result == set()

190
test/test_parseutils.py Normal file
View file

@ -0,0 +1,190 @@
import pytest
from mycli.packages.parseutils import (
extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause,
is_dropping_database)
def test_empty_string():
tables = extract_tables('')
assert tables == []
def test_simple_select_single_table():
tables = extract_tables('select * from abc')
assert tables == [(None, 'abc', None)]
def test_simple_select_single_table_schema_qualified():
tables = extract_tables('select * from abc.def')
assert tables == [('abc', 'def', None)]
def test_simple_select_multiple_tables():
tables = extract_tables('select * from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables('select * from abc.def, ghi.jkl')
assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
def test_simple_select_with_cols_single_table():
tables = extract_tables('select a,b from abc')
assert tables == [(None, 'abc', None)]
def test_simple_select_with_cols_single_table_schema_qualified():
tables = extract_tables('select a,b from abc.def')
assert tables == [('abc', 'def', None)]
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables('select a,b from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
def test_simple_select_with_cols_multiple_tables_with_schema():
tables = extract_tables('select a,b from abc.def, def.ghi')
assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
def test_select_with_hanging_comma_single_table():
tables = extract_tables('select a, from abc')
assert tables == [(None, 'abc', None)]
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables('select a, from abc, def')
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
def test_simple_insert_single_table():
tables = extract_tables('insert into abc (id, name) values (1, "def")')
# sqlparse mistakenly assigns an alias to the table
# assert tables == [(None, 'abc', None)]
assert tables == [(None, 'abc', 'abc')]
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
assert tables == [('abc', 'def', None)]
def test_simple_update_table():
tables = extract_tables('update abc set id = 1')
assert tables == [(None, 'abc', None)]
def test_simple_update_table_with_schema():
tables = extract_tables('update abc.def set id = 1')
assert tables == [('abc', 'def', None)]
def test_join_table():
tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
def test_join_table_schema_qualified():
tables = extract_tables(
'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
def test_join_as_table():
tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
assert tables == [(None, 'my_table', 'm')]
def test_query_starts_with():
query = 'USE test;'
assert query_starts_with(query, ('use', )) is True
query = 'DROP DATABASE test;'
assert query_starts_with(query, ('use', )) is False
def test_query_starts_with_comment():
query = '# comment\nUSE test;'
assert query_starts_with(query, ('use', )) is True
def test_queries_start_with():
sql = (
'# comment\n'
'show databases;'
'use foo;'
)
assert queries_start_with(sql, ('show', 'select')) is True
assert queries_start_with(sql, ('use', 'drop')) is True
assert queries_start_with(sql, ('delete', 'update')) is False
def test_is_destructive():
sql = (
'use test;\n'
'show databases;\n'
'drop database foo;'
)
assert is_destructive(sql) is True
def test_is_destructive_update_with_where_clause():
sql = (
'use test;\n'
'show databases;\n'
'UPDATE test SET x = 1 WHERE id = 1;'
)
assert is_destructive(sql) is False
def test_is_destructive_update_without_where_clause():
sql = (
'use test;\n'
'show databases;\n'
'UPDATE test SET x = 1;'
)
assert is_destructive(sql) is True
@pytest.mark.parametrize(
('sql', 'has_where_clause'),
[
('update test set dummy = 1;', False),
('update test set dummy = 1 where id = 1);', True),
],
)
def test_query_has_where_clause(sql, has_where_clause):
assert query_has_where_clause(sql) is has_where_clause
@pytest.mark.parametrize(
('sql', 'dbname', 'is_dropping'),
[
('select bar from foo', 'foo', False),
('drop database "foo";', '`foo`', True),
('drop schema foo', 'foo', True),
('drop schema foo', 'bar', False),
('drop database bar', 'foo', False),
('drop database foo', None, False),
('drop database foo; create database foo', 'foo', False),
('drop database foo; create database bar', 'foo', True),
('select bar from foo; drop database bazz', 'foo', False),
('select bar from foo; drop database bazz', 'bazz', True),
('-- dropping database \n '
'drop -- really dropping \n '
'schema abc -- now it is dropped',
'abc',
True)
]
)
def test_is_dropping_database(sql, dbname, is_dropping):
assert is_dropping_database(sql, dbname) == is_dropping

38
test/test_plan.wiki Normal file
View file

@ -0,0 +1,38 @@
= Gross Checks =
* [ ] Check connecting to a local database.
* [ ] Check connecting to a remote database.
* [ ] Check connecting to a database with a user/password.
* [ ] Check connecting to a non-existent database.
* [ ] Test changing the database.
== PGExecute ==
* [ ] Test successful execution given a cursor.
* [ ] Test unsuccessful execution with a syntax error.
* [ ] Test a series of executions with the same cursor without failure.
* [ ] Test a series of executions with the same cursor with failure.
* [ ] Test passing in a special command.
== Naive Autocompletion ==
* [ ] Input empty string, ask for completions - Everything.
* [ ] Input partial prefix, ask for completions - Stars with prefix.
* [ ] Input fully autocompleted string, ask for completions - Only full match
* [ ] Input non-existent prefix, ask for completions - nothing
* [ ] Input lowercase prefix - case insensitive completions
== Smart Autocompletion ==
* [ ] Input empty string and check if only keywords are returned.
* [ ] Input SELECT prefix and check if only columns and '*' are returned.
* [ ] Input SELECT blah - only keywords are returned.
* [ ] Input SELECT * FROM - Table names only
== PGSpecial ==
* [ ] Test \d
* [ ] Test \d tablename
* [ ] Test \d tablena*
* [ ] Test \d non-existent-tablename
* [ ] Test \d index
* [ ] Test \d sequence
* [ ] Test \d view
== Exceptionals ==
* [ ] Test the 'use' command to change db.

11
test/test_prompt_utils.py Normal file
View file

@ -0,0 +1,11 @@
import click
from mycli.packages.prompt_utils import confirm_destructive_query
def test_confirm_destructive_query_notty():
stdin = click.get_text_stream('stdin')
assert stdin.isatty() is False
sql = 'drop database foo;'
assert confirm_destructive_query(sql) is None

View file

@ -0,0 +1,385 @@
import pytest
from mock import patch
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
import mycli.packages.special.main as special
metadata = {
'users': ['id', 'email', 'first_name', 'last_name'],
'orders': ['id', 'ordered_date', 'status'],
'select': ['id', 'insert', 'ABC'],
'réveillé': ['id', 'insert', 'ABC']
}
@pytest.fixture
def completer():
import mycli.sqlcompleter as sqlcompleter
comp = sqlcompleter.SQLCompleter(smart_completion=True)
tables, columns = [], []
for table, cols in metadata.items():
tables.append((table,))
columns.extend([(table, col) for col in cols])
comp.set_dbname('test')
comp.extend_schemata('test')
comp.extend_relations(tables, kind='tables')
comp.extend_columns(columns, kind='tables')
comp.extend_special_commands(special.COMMANDS)
return comp
@pytest.fixture
def complete_event():
from mock import Mock
return Mock()
def test_special_name_completion(completer, complete_event):
text = '\\d'
position = len('\\d')
result = completer.get_completions(
Document(text=text, cursor_position=position),
complete_event)
assert result == [Completion(text='\\dt', start_position=-2)]
def test_empty_string_completion(completer, complete_event):
text = ''
position = 0
result = list(
completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert list(map(Completion, sorted(completer.keywords) +
sorted(completer.special_commands))) == result
def test_select_keyword_completion(completer, complete_event):
text = 'SEL'
position = len('SEL')
result = completer.get_completions(
Document(text=text, cursor_position=position),
complete_event)
assert list(result) == list([Completion(text='SELECT', start_position=-3)])
def test_table_completion(completer, complete_event):
text = 'SELECT * FROM '
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert list(result) == list([
Completion(text='`réveillé`', start_position=0),
Completion(text='`select`', start_position=0),
Completion(text='orders', start_position=0),
Completion(text='users', start_position=0),
])
def test_function_name_completion(completer, complete_event):
text = 'SELECT MA'
position = len('SELECT MA')
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert list(result) == list([Completion(text='MAX', start_position=-2),
Completion(text='MASTER', start_position=-2),
])
def test_suggested_column_names(completer, complete_event):
"""Suggest column and function names when selecting from table.
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT from users'
position = len('SELECT ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='*', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='id', start_position=0),
Completion(text='last_name', start_position=0),
] +
list(map(Completion, completer.functions)) +
[Completion(text='users', start_position=0)] +
list(map(Completion, completer.keywords)))
def test_suggested_column_names_in_function(completer, complete_event):
"""Suggest column and function names when selecting multiple columns from
table.
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT MAX( from users'
position = len('SELECT MAX(')
result = completer.get_completions(
Document(text=text, cursor_position=position),
complete_event)
assert list(result) == list([
Completion(text='*', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='id', start_position=0),
Completion(text='last_name', start_position=0)])
def test_suggested_column_names_with_table_dot(completer, complete_event):
"""Suggest column names on table name and dot.
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT users. from users'
position = len('SELECT users.')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='*', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='id', start_position=0),
Completion(text='last_name', start_position=0)])
def test_suggested_column_names_with_alias(completer, complete_event):
"""Suggest column names on table alias and dot.
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT u. from users u'
position = len('SELECT u.')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='*', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='id', start_position=0),
Completion(text='last_name', start_position=0)])
def test_suggested_multiple_column_names(completer, complete_event):
"""Suggest column and function names when selecting multiple columns from
table.
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT id, from users u'
position = len('SELECT id, ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='*', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='id', start_position=0),
Completion(text='last_name', start_position=0)] +
list(map(Completion, completer.functions)) +
[Completion(text='u', start_position=0)] +
list(map(Completion, completer.keywords)))
def test_suggested_multiple_column_names_with_alias(completer, complete_event):
"""Suggest column names on table alias and dot when selecting multiple
columns from table.
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT u.id, u. from users u'
position = len('SELECT u.id, u.')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='*', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='id', start_position=0),
Completion(text='last_name', start_position=0)])
def test_suggested_multiple_column_names_with_dot(completer, complete_event):
"""Suggest column names on table names and dot when selecting multiple
columns from table.
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT users.id, users. from users u'
position = len('SELECT users.id, users.')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='*', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='id', start_position=0),
Completion(text='last_name', start_position=0)])
def test_suggested_aliases_after_on(completer, complete_event):
text = 'SELECT u.name, o.id FROM users u JOIN orders o ON '
position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='o', start_position=0),
Completion(text='u', start_position=0)])
def test_suggested_aliases_after_on_right_side(completer, complete_event):
text = 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = '
position = len(
'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='o', start_position=0),
Completion(text='u', start_position=0)])
def test_suggested_tables_after_on(completer, complete_event):
text = 'SELECT users.name, orders.id FROM users JOIN orders ON '
position = len('SELECT users.name, orders.id FROM users JOIN orders ON ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='orders', start_position=0),
Completion(text='users', start_position=0)])
def test_suggested_tables_after_on_right_side(completer, complete_event):
text = 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = '
position = len(
'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='orders', start_position=0),
Completion(text='users', start_position=0)])
def test_table_names_after_from(completer, complete_event):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='`réveillé`', start_position=0),
Completion(text='`select`', start_position=0),
Completion(text='orders', start_position=0),
Completion(text='users', start_position=0),
])
def test_auto_escaped_col_names(completer, complete_event):
text = 'SELECT from `select`'
position = len('SELECT ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == [
Completion(text='*', start_position=0),
Completion(text='`ABC`', start_position=0),
Completion(text='`insert`', start_position=0),
Completion(text='id', start_position=0),
] + \
list(map(Completion, completer.functions)) + \
[Completion(text='`select`', start_position=0)] + \
list(map(Completion, completer.keywords))
def test_un_escaped_table_names(completer, complete_event):
text = 'SELECT from réveillé'
position = len('SELECT ')
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert result == list([
Completion(text='*', start_position=0),
Completion(text='`ABC`', start_position=0),
Completion(text='`insert`', start_position=0),
Completion(text='id', start_position=0),
] +
list(map(Completion, completer.functions)) +
[Completion(text='réveillé', start_position=0)] +
list(map(Completion, completer.keywords)))
def dummy_list_path(dir_name):
dirs = {
'/': [
'dir1',
'file1.sql',
'file2.sql',
],
'/dir1': [
'subdir1',
'subfile1.sql',
'subfile2.sql',
],
'/dir1/subdir1': [
'lastfile.sql',
],
}
return dirs.get(dir_name, [])
@patch('mycli.packages.filepaths.list_path', new=dummy_list_path)
@pytest.mark.parametrize('text,expected', [
# ('source ', [('~', 0),
# ('/', 0),
# ('.', 0),
# ('..', 0)]),
('source /', [('dir1', 0),
('file1.sql', 0),
('file2.sql', 0)]),
('source /dir1/', [('subdir1', 0),
('subfile1.sql', 0),
('subfile2.sql', 0)]),
('source /dir1/subdir1/', [('lastfile.sql', 0)]),
])
def test_file_name_completion(completer, complete_event, text, expected):
position = len(text)
result = list(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
expected = list((Completion(txt, pos) for txt, pos in expected))
assert result == expected

View file

@ -0,0 +1,287 @@
import os
import stat
import tempfile
from time import time
from mock import patch
import pytest
from pymysql import ProgrammingError
import mycli.packages.special
from .utils import dbtest, db_connection, send_ctrl_c
def test_set_get_pager():
mycli.packages.special.set_pager_enabled(True)
assert mycli.packages.special.is_pager_enabled()
mycli.packages.special.set_pager_enabled(False)
assert not mycli.packages.special.is_pager_enabled()
mycli.packages.special.set_pager('less')
assert os.environ['PAGER'] == "less"
mycli.packages.special.set_pager(False)
assert os.environ['PAGER'] == "less"
del os.environ['PAGER']
mycli.packages.special.set_pager(False)
mycli.packages.special.disable_pager()
assert not mycli.packages.special.is_pager_enabled()
def test_set_get_timing():
mycli.packages.special.set_timing_enabled(True)
assert mycli.packages.special.is_timing_enabled()
mycli.packages.special.set_timing_enabled(False)
assert not mycli.packages.special.is_timing_enabled()
def test_set_get_expanded_output():
mycli.packages.special.set_expanded_output(True)
assert mycli.packages.special.is_expanded_output()
mycli.packages.special.set_expanded_output(False)
assert not mycli.packages.special.is_expanded_output()
def test_editor_command():
assert mycli.packages.special.editor_command(r'hello\e')
assert mycli.packages.special.editor_command(r'\ehello')
assert not mycli.packages.special.editor_command(r'hello')
assert mycli.packages.special.get_filename(r'\e filename') == "filename"
os.environ['EDITOR'] = 'true'
os.environ['VISUAL'] = 'true'
mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1"
def test_tee_command():
mycli.packages.special.write_tee(u"hello world") # write without file set
with tempfile.NamedTemporaryFile() as f:
mycli.packages.special.execute(None, u"tee " + f.name)
mycli.packages.special.write_tee(u"hello world")
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"tee -o " + f.name)
mycli.packages.special.write_tee(u"hello world")
f.seek(0)
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"notee")
mycli.packages.special.write_tee(u"hello world")
f.seek(0)
assert f.read() == b"hello world\n"
def test_tee_command_error():
with pytest.raises(TypeError):
mycli.packages.special.execute(None, 'tee')
with pytest.raises(OSError):
with tempfile.NamedTemporaryFile() as f:
os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
mycli.packages.special.execute(None, 'tee {}'.format(f.name))
@dbtest
def test_favorite_query():
with db_connection().cursor() as cur:
query = u'select ""'
mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query))
assert next(mycli.packages.special.execute(
cur, u'\\f check'))[0] == "> " + query
def test_once_command():
with pytest.raises(TypeError):
mycli.packages.special.execute(None, u"\\once")
with pytest.raises(OSError):
mycli.packages.special.execute(None, u"\\once /proc/access-denied")
mycli.packages.special.write_once(u"hello world") # write without file set
with tempfile.NamedTemporaryFile() as f:
mycli.packages.special.execute(None, u"\\once " + f.name)
mycli.packages.special.write_once(u"hello world")
assert f.read() == b"hello world\n"
mycli.packages.special.execute(None, u"\\once -o " + f.name)
mycli.packages.special.write_once(u"hello world line 1")
mycli.packages.special.write_once(u"hello world line 2")
f.seek(0)
assert f.read() == b"hello world line 1\nhello world line 2\n"
def test_pipe_once_command():
with pytest.raises(IOError):
mycli.packages.special.execute(None, u"\\pipe_once")
with pytest.raises(OSError):
mycli.packages.special.execute(
None, u"\\pipe_once /proc/access-denied")
mycli.packages.special.execute(None, u"\\pipe_once wc")
mycli.packages.special.write_once(u"hello world")
mycli.packages.special.unset_pipe_once_if_written()
# how to assert on wc output?
def test_parseargfile():
"""Test that parseargfile expands the user directory."""
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
'mode': 'a'}
assert expected == mycli.packages.special.iocommands.parseargfile(
'~/filename')
expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
'mode': 'w'}
assert expected == mycli.packages.special.iocommands.parseargfile(
'-o ~/filename')
def test_parseargfile_no_file():
"""Test that parseargfile raises a TypeError if there is no filename."""
with pytest.raises(TypeError):
mycli.packages.special.iocommands.parseargfile('')
with pytest.raises(TypeError):
mycli.packages.special.iocommands.parseargfile('-o ')
@dbtest
def test_watch_query_iteration():
"""Test that a single iteration of the result of `watch_query` executes
the desired query and returns the given results."""
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
expected_title = '> {0!s}'.format(query)
with db_connection().cursor() as cur:
result = next(mycli.packages.special.iocommands.watch_query(
arg=query, cur=cur
))
assert result[0] == expected_title
assert result[2][0] == expected_value
@dbtest
def test_watch_query_full():
"""Test that `watch_query`:
* Returns the expected results.
* Executes the defined times inside the given interval, in this case with
a 0.3 seconds wait, it should execute 4 times inside a 1 seconds
interval.
* Stops at Ctrl-C
"""
watch_seconds = 0.3
wait_interval = 1
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
expected_title = '> {0!s}'.format(query)
expected_results = 4
ctrl_c_process = send_ctrl_c(wait_interval)
with db_connection().cursor() as cur:
results = list(
result for result in mycli.packages.special.iocommands.watch_query(
arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur
)
)
ctrl_c_process.join(1)
assert len(results) == expected_results
for result in results:
assert result[0] == expected_title
assert result[2][0] == expected_value
@dbtest
@patch('click.clear')
def test_watch_query_clear(clear_mock):
"""Test that the screen is cleared with the -c flag of `watch` command
before execute the query."""
with db_connection().cursor() as cur:
watch_gen = mycli.packages.special.iocommands.watch_query(
arg='0.1 -c select 1;', cur=cur
)
assert not clear_mock.called
next(watch_gen)
assert clear_mock.called
clear_mock.reset_mock()
next(watch_gen)
assert clear_mock.called
clear_mock.reset_mock()
@dbtest
def test_watch_query_bad_arguments():
"""Test different incorrect combinations of arguments for `watch`
command."""
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
with pytest.raises(ProgrammingError):
next(watch_query('a select 1;', cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('-a select 1;', cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('1 -a select 1;', cur=cur))
with pytest.raises(ProgrammingError):
next(watch_query('-c -a select 1;', cur=cur))
@dbtest
@patch('click.clear')
def test_watch_query_interval_clear(clear_mock):
"""Test `watch` command with interval and clear flag."""
def test_asserts(gen):
clear_mock.reset_mock()
start = time()
next(gen)
assert clear_mock.called
next(gen)
exec_time = time() - start
assert exec_time > seconds and exec_time < (seconds + seconds)
seconds = 1.0
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
test_asserts(watch_query('{0!s} -c select 1;'.format(seconds),
cur=cur))
test_asserts(watch_query('-c {0!s} select 1;'.format(seconds),
cur=cur))
def test_split_sql_by_delimiter():
for delimiter_str in (';', '$', '😀'):
mycli.packages.special.set_delimiter(delimiter_str)
sql_input = "select 1{} select \ufffc2".format(delimiter_str)
queries = (
"select 1",
"select \ufffc2"
)
for query, parsed_query in zip(
queries, mycli.packages.special.split_queries(sql_input)):
assert(query == parsed_query)
def test_switch_delimiter_within_query():
mycli.packages.special.set_delimiter(';')
sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$"
queries = (
"select 1",
"delimiter $$ select 2 $$ select 3 $$",
"select 2",
"select 3"
)
for query, parsed_query in zip(
queries,
mycli.packages.special.split_queries(sql_input)):
assert(query == parsed_query)
def test_set_delimiter():
for delim in ('foo', 'bar'):
mycli.packages.special.set_delimiter(delim)
assert mycli.packages.special.get_current_delimiter() == delim
def teardown_function():
mycli.packages.special.set_delimiter(';')

272
test/test_sqlexecute.py Normal file
View file

@ -0,0 +1,272 @@
import os
import pytest
import pymysql
from .utils import run, dbtest, set_expanded_output, is_expanded_output
def assert_result_equal(result, title=None, rows=None, headers=None,
status=None, auto_status=True, assert_contains=False):
"""Assert that an sqlexecute.run() result matches the expected values."""
if status is None and auto_status and rows:
status = '{} row{} in set'.format(
len(rows), 's' if len(rows) > 1 else '')
fields = {'title': title, 'rows': rows, 'headers': headers,
'status': status}
if assert_contains:
# Do a loose match on the results using the *in* operator.
for key, field in fields.items():
if field:
assert field in result[0][key]
else:
# Do an exact match on the fields.
assert result == [fields]
@dbtest
def test_conn(executor):
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc')''')
results = run(executor, '''select * from test''')
assert_result_equal(results, headers=['a'], rows=[('abc',)])
@dbtest
def test_bools(executor):
run(executor, '''create table test(a boolean)''')
run(executor, '''insert into test values(True)''')
results = run(executor, '''select * from test''')
assert_result_equal(results, headers=['a'], rows=[(1,)])
@dbtest
def test_binary(executor):
run(executor, '''create table bt(geom linestring NOT NULL)''')
run(executor, "INSERT INTO bt VALUES "
"(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
results = run(executor, '''select * from bt''')
geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n'
b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9'
b'\xac\xdeC@')
assert_result_equal(results, headers=['geom'], rows=[(geom,)])
@dbtest
def test_table_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
assert set(executor.tables()) == set([('a',), ('b',)])
assert set(executor.table_columns()) == set(
[('a', 'x'), ('a', 'y'), ('b', 'z')])
@dbtest
def test_database_list(executor):
databases = executor.databases()
assert '_test_db' in databases
@dbtest
def test_invalid_syntax(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
run(executor, 'invalid syntax!')
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
@dbtest
def test_invalid_column_name(executor):
with pytest.raises(pymysql.err.OperationalError) as excinfo:
run(executor, 'select invalid command')
assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value)
@dbtest
def test_unicode_support_in_output(executor):
run(executor, "create table unicodechars(t text)")
run(executor, u"insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
results = run(executor, u"select * from unicodechars")
assert_result_equal(results, headers=['t'], rows=[(u'é',)])
@dbtest
def test_multiple_queries_same_line(executor):
results = run(executor, "select 'foo'; select 'bar'")
expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)],
'status': '1 row in set'},
{'title': None, 'headers': ['bar'], 'rows': [('bar',)],
'status': '1 row in set'}]
assert expected == results
@dbtest
def test_multiple_queries_same_line_syntaxerror(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
run(executor, "select 'foo'; invalid syntax")
assert 'You have an error in your SQL syntax;' in str(excinfo.value)
@dbtest
def test_favorite_query(executor):
set_expanded_output(False)
run(executor, "create table test(a text)")
run(executor, "insert into test values('abc')")
run(executor, "insert into test values('def')")
results = run(executor, "\\fs test-a select * from test where a like 'a%'")
assert_result_equal(results, status='Saved.')
results = run(executor, "\\f test-a")
assert_result_equal(results,
title="> select * from test where a like 'a%'",
headers=['a'], rows=[('abc',)], auto_status=False)
results = run(executor, "\\fd test-a")
assert_result_equal(results, status='test-a: Deleted')
@dbtest
def test_favorite_query_multiple_statement(executor):
set_expanded_output(False)
run(executor, "create table test(a text)")
run(executor, "insert into test values('abc')")
run(executor, "insert into test values('def')")
results = run(executor,
"\\fs test-ad select * from test where a like 'a%'; "
"select * from test where a like 'd%'")
assert_result_equal(results, status='Saved.')
results = run(executor, "\\f test-ad")
expected = [{'title': "> select * from test where a like 'a%'",
'headers': ['a'], 'rows': [('abc',)], 'status': None},
{'title': "> select * from test where a like 'd%'",
'headers': ['a'], 'rows': [('def',)], 'status': None}]
assert expected == results
results = run(executor, "\\fd test-ad")
assert_result_equal(results, status='test-ad: Deleted')
@dbtest
def test_favorite_query_expanded_output(executor):
set_expanded_output(False)
run(executor, '''create table test(a text)''')
run(executor, '''insert into test values('abc')''')
results = run(executor, "\\fs test-ae select * from test")
assert_result_equal(results, status='Saved.')
results = run(executor, "\\f test-ae \\G")
assert is_expanded_output() is True
assert_result_equal(results, title='> select * from test',
headers=['a'], rows=[('abc',)], auto_status=False)
set_expanded_output(False)
results = run(executor, "\\fd test-ae")
assert_result_equal(results, status='test-ae: Deleted')
@dbtest
def test_special_command(executor):
results = run(executor, '\\?')
assert_result_equal(results, rows=('quit', '\\q', 'Quit.'),
headers='Command', assert_contains=True,
auto_status=False)
@dbtest
def test_cd_command_without_a_folder_name(executor):
results = run(executor, 'system cd')
assert_result_equal(results, status='No folder name was provided.')
@dbtest
def test_system_command_not_found(executor):
results = run(executor, 'system xyz')
assert_result_equal(results, status='OSError: No such file or directory',
assert_contains=True)
@dbtest
def test_system_command_output(executor):
test_dir = os.path.abspath(os.path.dirname(__file__))
test_file_path = os.path.join(test_dir, 'test.txt')
results = run(executor, 'system cat {0}'.format(test_file_path))
assert_result_equal(results, status='mycli rocks!\n')
@dbtest
def test_cd_command_current_dir(executor):
test_path = os.path.abspath(os.path.dirname(__file__))
run(executor, 'system cd {0}'.format(test_path))
assert os.getcwd() == test_path
@dbtest
def test_unicode_support(executor):
results = run(executor, u"SELECT '日本語' AS japanese;")
assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)])
@dbtest
def test_timestamp_null(executor):
run(executor, '''create table ts_null(a timestamp null)''')
run(executor, '''insert into ts_null values(null)''')
results = run(executor, '''select * from ts_null''')
assert_result_equal(results, headers=['a'],
rows=[(None,)])
@dbtest
def test_datetime_null(executor):
run(executor, '''create table dt_null(a datetime null)''')
run(executor, '''insert into dt_null values(null)''')
results = run(executor, '''select * from dt_null''')
assert_result_equal(results, headers=['a'],
rows=[(None,)])
@dbtest
def test_date_null(executor):
run(executor, '''create table date_null(a date null)''')
run(executor, '''insert into date_null values(null)''')
results = run(executor, '''select * from date_null''')
assert_result_equal(results, headers=['a'], rows=[(None,)])
@dbtest
def test_time_null(executor):
run(executor, '''create table time_null(a time null)''')
run(executor, '''insert into time_null values(null)''')
results = run(executor, '''select * from time_null''')
assert_result_equal(results, headers=['a'], rows=[(None,)])
@dbtest
def test_multiple_results(executor):
query = '''CREATE PROCEDURE dmtest()
BEGIN
SELECT 1;
SELECT 2;
END'''
executor.conn.cursor().execute(query)
results = run(executor, 'call dmtest;')
expected = [
{'title': None, 'rows': [(1,)], 'headers': ['1'],
'status': '1 row in set'},
{'title': None, 'rows': [(2,)], 'headers': ['2'],
'status': '1 row in set'}
]
assert results == expected

118
test/test_tabular_output.py Normal file
View file

@ -0,0 +1,118 @@
"""Test the sql output adapter."""
from textwrap import dedent
from mycli.packages.tabular_output import sql_format
from cli_helpers.tabular_output import TabularOutputFormatter
from .utils import USER, PASSWORD, HOST, PORT, dbtest
import pytest
from mycli.main import MyCli
from pymysql.constants import FIELD_TYPE
@pytest.fixture
def mycli():
cli = MyCli()
cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None)
return cli
@dbtest
def test_sql_output(mycli):
"""Test the sql output adapter."""
headers = ['letters', 'number', 'optional', 'float', 'binary']
class FakeCursor(object):
def __init__(self):
self.data = [
('abc', 1, None, 10.0, b'\xAA'),
('d', 456, '1', 0.5, b'\xAA\xBB')
]
self.description = [
(None, FIELD_TYPE.VARCHAR),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.FLOAT),
(None, FIELD_TYPE.BLOB)
]
def __iter__(self):
return self
def __next__(self):
if self.data:
return self.data.pop(0)
else:
raise StopIteration()
def description(self):
return self.description
# Test sql-update output format
assert list(mycli.change_table_format("sql-update")) == \
[(None, None, None, 'Changed table format to sql-update')]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
actual = "\n".join(output)
assert actual == dedent('''\
UPDATE `DUAL` SET
`number` = 1
, `optional` = NULL
, `float` = 10.0e0
, `binary` = X'aa'
WHERE `letters` = 'abc';
UPDATE `DUAL` SET
`number` = 456
, `optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
WHERE `letters` = 'd';''')
# Test sql-update-2 output format
assert list(mycli.change_table_format("sql-update-2")) == \
[(None, None, None, 'Changed table format to sql-update-2')]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
UPDATE `DUAL` SET
`optional` = NULL
, `float` = 10.0e0
, `binary` = X'aa'
WHERE `letters` = 'abc' AND `number` = 1;
UPDATE `DUAL` SET
`optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
WHERE `letters` = 'd' AND `number` = 456;''')
# Test sql-insert output format (without table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
# Test sql-insert output format (with table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
mycli.formatter.query = "SELECT * FROM `table`"
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
INSERT INTO `table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')
# Test sql-insert output format (with database + table name)
assert list(mycli.change_table_format("sql-insert")) == \
[(None, None, None, 'Changed table format to sql-insert')]
mycli.formatter.query = "SELECT * FROM `database`.`table`"
output = mycli.format_output(None, FakeCursor(), headers)
assert "\n".join(output) == dedent('''\
INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
;''')

94
test/utils.py Normal file
View file

@ -0,0 +1,94 @@
import os
import time
import signal
import platform
import multiprocessing
import pymysql
import pytest
from mycli.main import special
PASSWORD = os.getenv('PYTEST_PASSWORD')
USER = os.getenv('PYTEST_USER', 'root')
HOST = os.getenv('PYTEST_HOST', 'localhost')
PORT = int(os.getenv('PYTEST_PORT', 3306))
CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
SSH_USER = os.getenv('PYTEST_SSH_USER', None)
SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)
def db_connection(dbname=None):
conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname,
password=PASSWORD, charset=CHARSET,
local_infile=False)
conn.autocommit = True
return conn
try:
db_connection()
CAN_CONNECT_TO_DB = True
except:
CAN_CONNECT_TO_DB = False
dbtest = pytest.mark.skipif(
not CAN_CONNECT_TO_DB,
reason="Need a mysql instance at localhost accessible by user 'root'")
def create_db(dbname):
with db_connection().cursor() as cur:
try:
cur.execute('''DROP DATABASE IF EXISTS _test_db''')
cur.execute('''CREATE DATABASE _test_db''')
except:
pass
def run(executor, sql, rows_as_list=True):
"""Return string output for the sql to be run."""
result = []
for title, rows, headers, status in executor.run(sql):
rows = list(rows) if (rows_as_list and rows) else rows
result.append({'title': title, 'rows': rows, 'headers': headers,
'status': status})
return result
def set_expanded_output(is_expanded):
"""Pass-through for the tests."""
return special.set_expanded_output(is_expanded)
def is_expanded_output():
"""Pass-through for the tests."""
return special.is_expanded_output()
def send_ctrl_c_to_pid(pid, wait_seconds):
"""Sends a Ctrl-C like signal to the given `pid` after `wait_seconds`
seconds."""
time.sleep(wait_seconds)
system_name = platform.system()
if system_name == "Windows":
os.kill(pid, signal.CTRL_C_EVENT)
else:
os.kill(pid, signal.SIGINT)
def send_ctrl_c(wait_seconds):
"""Create a process that sends a Ctrl-C like signal to the current process
after `wait_seconds` seconds.
Returns the `multiprocessing.Process` created.
"""
ctrl_c_process = multiprocessing.Process(
target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
)
ctrl_c_process.start()
return ctrl_c_process