1
0
Fork 0

Adding upstream version 3.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 19:48:22 +01:00
parent f2184ff4ed
commit ec5391b244
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
104 changed files with 15144 additions and 0 deletions

52
tests/conftest.py Normal file
View file

@ -0,0 +1,52 @@
import os
import pytest
from utils import (
POSTGRES_HOST,
POSTGRES_PORT,
POSTGRES_USER,
POSTGRES_PASSWORD,
create_db,
db_connection,
drop_tables,
)
import pgcli.pgexecute
@pytest.yield_fixture(scope="function")
def connection():
create_db("_test_db")
connection = db_connection("_test_db")
yield connection
drop_tables(connection)
connection.close()
@pytest.fixture
def cursor(connection):
with connection.cursor() as cur:
return cur
@pytest.fixture
def executor(connection):
return pgcli.pgexecute.PGExecute(
database="_test_db",
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
dsn=None,
)
@pytest.fixture
def exception_formatter():
return lambda e: str(e)
@pytest.fixture(scope="session", autouse=True)
def temp_config(tmpdir_factory):
# this function runs on start of test session.
# use temporary directory for config home so user config will not be used
os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data"))

View file

View file

@ -0,0 +1,12 @@
Feature: auto_vertical mode:
on, off
Scenario: auto_vertical on with small query
When we run dbcli with --auto-vertical-output
and we execute a small query
then we see small results in horizontal format
Scenario: auto_vertical on with large query
When we run dbcli with --auto-vertical-output
and we execute a large query
then we see large results in vertical format

View file

@ -0,0 +1,58 @@
Feature: run the cli,
call the help command,
exit the cli
Scenario: run "\?" command
When we send "\?" command
then we see help output
Scenario: run source command
When we send source command
then we see help output
Scenario: run partial select command
When we send partial select command
then we see error message
then we see dbcli prompt
Scenario: check our application_name
When we run query to check application_name
then we see found
Scenario: run the cli and exit
When we send "ctrl + d"
then dbcli exits
Scenario: list databases
When we list databases
then we see list of databases
Scenario: run the cli with --username
When we launch dbcli using --username
and we send "\?" command
then we see help output
Scenario: run the cli with --user
When we launch dbcli using --user
and we send "\?" command
then we see help output
Scenario: run the cli with --port
When we launch dbcli using --port
and we send "\?" command
then we see help output
Scenario: run the cli with --password
When we launch dbcli using --password
then we send password
and we see dbcli prompt
when we send "\?" command
then we see help output
@wip
Scenario: run the cli with dsn and password
When we launch dbcli using dsn_password
then we send password
and we see dbcli prompt
when we send "\?" command
then we see help output

View file

@ -0,0 +1,17 @@
Feature: manipulate databases:
create, drop, connect, disconnect
Scenario: create and drop temporary database
When we create database
then we see database created
when we drop database
then we confirm the destructive warning
then we see database dropped
when we connect to dbserver
then we see database connected
Scenario: connect and disconnect from test database
When we connect to test database
then we see database connected
when we connect to dbserver
then we see database connected

View file

@ -0,0 +1,22 @@
Feature: manipulate tables:
create, insert, update, select, delete from, drop
Scenario: create, insert, select from, update, drop table
When we connect to test database
then we see database connected
when we create table
then we see table created
when we insert into table
then we see record inserted
when we update table
then we see record updated
when we select from table
then we see data selected
when we delete from table
then we confirm the destructive warning
then we see record deleted
when we drop table
then we confirm the destructive warning
then we see table dropped
when we connect to dbserver
then we see database connected

View file

@ -0,0 +1,78 @@
from psycopg2 import connect
from psycopg2.extensions import AsIs
def create_db(
hostname="localhost", username=None, password=None, dbname=None, port=None
):
"""Create test database.
:param hostname: string
:param username: string
:param password: string
:param dbname: string
:param port: int
:return:
"""
cn = create_cn(hostname, password, username, "postgres", port)
# ISOLATION_LEVEL_AUTOCOMMIT = 0
# Needed for DB creation.
cn.set_isolation_level(0)
with cn.cursor() as cr:
cr.execute("drop database if exists %s", (AsIs(dbname),))
cr.execute("create database %s", (AsIs(dbname),))
cn.close()
cn = create_cn(hostname, password, username, dbname, port)
return cn
def create_cn(hostname, password, username, dbname, port):
"""
Open connection to database.
:param hostname:
:param password:
:param username:
:param dbname: string
:return: psycopg2.connection
"""
cn = connect(
host=hostname, user=username, database=dbname, password=password, port=port
)
print("Created connection: {0}.".format(cn.dsn))
return cn
def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None):
"""
Drop database.
:param hostname: string
:param username: string
:param password: string
:param dbname: string
"""
cn = create_cn(hostname, password, username, "postgres", port)
# ISOLATION_LEVEL_AUTOCOMMIT = 0
# Needed for DB drop.
cn.set_isolation_level(0)
with cn.cursor() as cr:
cr.execute("drop database if exists %s", (AsIs(dbname),))
close_cn(cn)
def close_cn(cn=None):
"""
Close connection.
:param connection: psycopg2.connection
"""
if cn:
cn.close()
print("Closed connection: {0}.".format(cn.dsn))

View file

@ -0,0 +1,192 @@
import copy
import os
import sys
import db_utils as dbutils
import fixture_utils as fixutils
import pexpect
import tempfile
import shutil
import signal
from steps import wrappers
def before_all(context):
"""Set env parameters."""
env_old = copy.deepcopy(dict(os.environ))
os.environ["LINES"] = "100"
os.environ["COLUMNS"] = "100"
os.environ["PAGER"] = "cat"
os.environ["EDITOR"] = "ex"
os.environ["VISUAL"] = "ex"
os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
)
fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data")
print("package root:", context.package_root)
print("fixture dir:", fixture_dir)
os.environ["COVERAGE_PROCESS_START"] = os.path.join(
context.package_root, ".coveragerc"
)
context.exit_sent = False
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
db_name_full = "{0}_{1}".format(db_name, vi)
# Store get params from config.
context.conf = {
"host": context.config.userdata.get(
"pg_test_host", os.getenv("PGHOST", "localhost")
),
"user": context.config.userdata.get(
"pg_test_user", os.getenv("PGUSER", "postgres")
),
"pass": context.config.userdata.get(
"pg_test_pass", os.getenv("PGPASSWORD", None)
),
"port": context.config.userdata.get(
"pg_test_port", os.getenv("PGPORT", "5432")
),
"cli_command": (
context.config.userdata.get("pg_cli_command", None)
or '{python} -c "{startup}"'.format(
python=sys.executable,
startup="; ".join(
[
"import coverage",
"coverage.process_startup()",
"import pgcli.main",
"pgcli.main.cli()",
]
),
)
),
"dbname": db_name_full,
"dbname_tmp": db_name_full + "_tmp",
"vi": vi,
"pager_boundary": "---boundary---",
}
os.environ["PAGER"] = "{0} {1} {2}".format(
sys.executable,
os.path.join(context.package_root, "tests/features/wrappager.py"),
context.conf["pager_boundary"],
)
# Store old env vars.
context.pgenv = {
"PGDATABASE": os.environ.get("PGDATABASE", None),
"PGUSER": os.environ.get("PGUSER", None),
"PGHOST": os.environ.get("PGHOST", None),
"PGPASSWORD": os.environ.get("PGPASSWORD", None),
"PGPORT": os.environ.get("PGPORT", None),
"XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None),
"PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None),
}
# Set new env vars.
os.environ["PGDATABASE"] = context.conf["dbname"]
os.environ["PGUSER"] = context.conf["user"]
os.environ["PGHOST"] = context.conf["host"]
os.environ["PGPORT"] = context.conf["port"]
os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf")
if context.conf["pass"]:
os.environ["PGPASSWORD"] = context.conf["pass"]
else:
if "PGPASSWORD" in os.environ:
del os.environ["PGPASSWORD"]
context.cn = dbutils.create_db(
context.conf["host"],
context.conf["user"],
context.conf["pass"],
context.conf["dbname"],
context.conf["port"],
)
context.fixture_data = fixutils.read_fixture_files()
# use temporary directory as config home
context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_")
os.environ["XDG_CONFIG_HOME"] = context.env_config_home
show_env_changes(env_old, dict(os.environ))
def show_env_changes(env_old, env_new):
"""Print out all test-specific env values."""
print("--- os.environ changed values: ---")
all_keys = set(list(env_old.keys()) + list(env_new.keys()))
for k in sorted(all_keys):
old_value = env_old.get(k, "")
new_value = env_new.get(k, "")
if new_value and old_value != new_value:
print('{}="{}"'.format(k, new_value))
print("-" * 20)
def after_all(context):
"""
Unset env parameters.
"""
dbutils.close_cn(context.cn)
dbutils.drop_db(
context.conf["host"],
context.conf["user"],
context.conf["pass"],
context.conf["dbname"],
context.conf["port"],
)
# Remove temp config direcotry
shutil.rmtree(context.env_config_home)
# Restore env vars.
for k, v in context.pgenv.items():
if k in os.environ and v is None:
del os.environ[k]
elif v:
os.environ[k] = v
def before_step(context, _):
context.atprompt = False
def before_scenario(context, scenario):
if scenario.name == "list databases":
# not using the cli for that
return
wrappers.run_cli(context)
wrappers.wait_prompt(context)
def after_scenario(context, scenario):
"""Cleans up after each scenario completes."""
if hasattr(context, "cli") and context.cli and not context.exit_sent:
# Quit nicely.
if not context.atprompt:
dbname = context.currentdb
context.cli.expect_exact("{0}> ".format(dbname), timeout=15)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
try:
context.cli.expect_exact(pexpect.EOF, timeout=15)
except pexpect.TIMEOUT:
print("--- after_scenario {}: kill cli".format(scenario.name))
context.cli.kill(signal.SIGKILL)
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
context.tmpfile_sql_help.close()
context.tmpfile_sql_help = None
# # TODO: uncomment to debug a failure
# def after_step(context, step):
# if step.status == "failed":
# import pdb; pdb.set_trace()

View file

@ -0,0 +1,29 @@
Feature: expanded mode:
on, off, auto
Scenario: expanded on
When we prepare the test data
and we set expanded on
and we select from table
then we see expanded data selected
when we drop table
then we confirm the destructive warning
then we see table dropped
Scenario: expanded off
When we prepare the test data
and we set expanded off
and we select from table
then we see nonexpanded data selected
when we drop table
then we confirm the destructive warning
then we see table dropped
Scenario: expanded auto
When we prepare the test data
and we set expanded auto
and we select from table
then we see auto data selected
when we drop table
then we confirm the destructive warning
then we see table dropped

View file

@ -0,0 +1,25 @@
+--------------------------+------------------------------------------------+
| Command | Description |
|--------------------------+------------------------------------------------|
| \# | Refresh auto-completions. |
| \? | Show Help. |
| \T [format] | Change the table format used to output results |
| \c[onnect] database_name | Change to a new database. |
| \d [pattern] | List or describe tables, views and sequences. |
| \dT[S+] [pattern] | List data types |
| \df[+] [pattern] | List functions. |
| \di[+] [pattern] | List indexes. |
| \dn[+] [pattern] | List schemas. |
| \ds[+] [pattern] | List sequences. |
| \dt[+] [pattern] | List tables. |
| \du[+] [pattern] | List roles. |
| \dv[+] [pattern] | List views. |
| \e [file] | Edit the query with external editor. |
| \l | List databases. |
| \n[+] [name] | List or execute named queries. |
| \nd [name [query]] | Delete a named query. |
| \ns name query | Save a named query. |
| \refresh | Refresh auto-completions. |
| \timing | Toggle timing of commands. |
| \x | Toggle expanded output. |
+--------------------------+------------------------------------------------+

View file

@ -0,0 +1,64 @@
Command
Description
\#
Refresh auto-completions.
\?
Show Commands.
\T [format]
Change the table format used to output results
\c[onnect] database_name
Change to a new database.
\copy [tablename] to/from [filename]
Copy data between a file and a table.
\d[+] [pattern]
List or describe tables, views and sequences.
\dT[S+] [pattern]
List data types
\db[+] [pattern]
List tablespaces.
\df[+] [pattern]
List functions.
\di[+] [pattern]
List indexes.
\dm[+] [pattern]
List materialized views.
\dn[+] [pattern]
List schemas.
\ds[+] [pattern]
List sequences.
\dt[+] [pattern]
List tables.
\du[+] [pattern]
List roles.
\dv[+] [pattern]
List views.
\dx[+] [pattern]
List extensions.
\e [file]
Edit the query with external editor.
\h
Show SQL syntax and help.
\i filename
Execute commands from file.
\l
List databases.
\n[+] [name] [param1 param2 ...]
List or execute named queries.
\nd [name]
Delete a named query.
\ns name query
Save a named query.
\o [filename]
Send all query results to file.
\pager [command]
Set PAGER. Print the query results via PAGER.
\pset [key] [value]
A limited version of traditional \pset
\refresh
Refresh auto-completions.
\sf[+] FUNCNAME
Show a function's definition.
\timing
Toggle timing of commands.
\x
Toggle expanded output.

View file

@ -0,0 +1,4 @@
[mock_postgres]
dbname=postgres
host=localhost
user=postgres

View file

@ -0,0 +1,28 @@
import os
import codecs
def read_fixture_lines(filename):
"""
Read lines of text from file.
:param filename: string name
:return: list of strings
"""
lines = []
for line in codecs.open(filename, "rb", encoding="utf-8"):
lines.append(line.strip())
return lines
def read_fixture_files():
"""Read all files inside fixture_data directory."""
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, "fixture_data/")
print("reading fixture data: {}".format(fixture_dir))
fixture_dict = {}
for filename in os.listdir(fixture_dir):
if filename not in [".", ".."]:
fullname = os.path.join(fixture_dir, filename)
fixture_dict[filename] = read_fixture_lines(fullname)
return fixture_dict

View file

@ -0,0 +1,17 @@
Feature: I/O commands
Scenario: edit sql in file with external editor
When we start external editor providing a file name
and we type sql in the editor
and we exit the editor
then we see dbcli prompt
and we see the sql in prompt
Scenario: tee output from query
When we tee output
and we wait for prompt
and we query "select 123456"
and we wait for prompt
and we stop teeing output
and we wait for prompt
then we see 123456 in tee output

View file

@ -0,0 +1,10 @@
Feature: named queries:
save, use and delete named queries
Scenario: save, use and delete named queries
When we connect to test database
then we see database connected
when we save a named query
then we see the named query saved
when we delete a named query
then we see the named query deleted

View file

@ -0,0 +1,6 @@
Feature: Special commands
Scenario: run refresh command
When we refresh completions
and we wait for prompt
then we see completions refresh started

View file

View file

@ -0,0 +1,99 @@
from textwrap import dedent
from behave import then, when
import wrappers
@when("we run dbcli with {arg}")
def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=arg.split("="))
@when("we execute a small query")
def step_execute_small_query(context):
context.cli.sendline("select 1")
@when("we execute a large query")
def step_execute_large_query(context):
context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
@then("we see small results in horizontal format")
def step_see_small_results(context):
wrappers.expect_pager(
context,
dedent(
"""\
+------------+\r
| ?column? |\r
|------------|\r
| 1 |\r
+------------+\r
SELECT 1\r
"""
),
timeout=5,
)
@then("we see large results in vertical format")
def step_see_large_results(context):
wrappers.expect_pager(
context,
dedent(
"""\
-[ RECORD 1 ]-------------------------\r
?column? | 1\r
?column? | 2\r
?column? | 3\r
?column? | 4\r
?column? | 5\r
?column? | 6\r
?column? | 7\r
?column? | 8\r
?column? | 9\r
?column? | 10\r
?column? | 11\r
?column? | 12\r
?column? | 13\r
?column? | 14\r
?column? | 15\r
?column? | 16\r
?column? | 17\r
?column? | 18\r
?column? | 19\r
?column? | 20\r
?column? | 21\r
?column? | 22\r
?column? | 23\r
?column? | 24\r
?column? | 25\r
?column? | 26\r
?column? | 27\r
?column? | 28\r
?column? | 29\r
?column? | 30\r
?column? | 31\r
?column? | 32\r
?column? | 33\r
?column? | 34\r
?column? | 35\r
?column? | 36\r
?column? | 37\r
?column? | 38\r
?column? | 39\r
?column? | 40\r
?column? | 41\r
?column? | 42\r
?column? | 43\r
?column? | 44\r
?column? | 45\r
?column? | 46\r
?column? | 47\r
?column? | 48\r
?column? | 49\r
SELECT 1\r
"""
),
timeout=5,
)

View file

@ -0,0 +1,147 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
import pexpect
import subprocess
import tempfile
from behave import when, then
from textwrap import dedent
import wrappers
@when("we list databases")
def step_list_databases(context):
cmd = ["pgcli", "--list"]
context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root)
@then("we see list of databases")
def step_see_list_databases(context):
assert b"List of databases" in context.cmd_output
assert b"postgres" in context.cmd_output
context.cmd_output = None
@when("we run dbcli")
def step_run_cli(context):
wrappers.run_cli(context)
@when("we launch dbcli using {arg}")
def step_run_cli_using_arg(context, arg):
prompt_check = False
currentdb = None
if arg == "--username":
arg = "--username={}".format(context.conf["user"])
if arg == "--user":
arg = "--user={}".format(context.conf["user"])
if arg == "--port":
arg = "--port={}".format(context.conf["port"])
if arg == "--password":
arg = "--password"
prompt_check = False
# This uses the mock_pg_service.conf file in fixtures folder.
if arg == "dsn_password":
arg = "service=mock_postgres --password"
prompt_check = False
currentdb = "postgres"
wrappers.run_cli(
context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb
)
@when("we wait for prompt")
def step_wait_prompt(context):
wrappers.wait_prompt(context)
@when('we send "ctrl + d"')
def step_ctrl_d(context):
"""
Send Ctrl + D to hopefully exit.
"""
# turn off pager before exiting
context.cli.sendline("\pset pager off")
wrappers.wait_prompt(context)
context.cli.sendcontrol("d")
context.cli.expect(pexpect.EOF, timeout=15)
context.exit_sent = True
@when('we send "\?" command')
def step_send_help(context):
"""
Send \? to see help.
"""
context.cli.sendline("\?")
@when("we send partial select command")
def step_send_partial_select_command(context):
"""
Send `SELECT a` to see completion.
"""
context.cli.sendline("SELECT a")
@then("we see error message")
def step_see_error_message(context):
wrappers.expect_exact(context, 'column "a" does not exist', timeout=2)
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
context.tmpfile_sql_help.write(b"\?")
context.tmpfile_sql_help.flush()
context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@when("we run query to check application_name")
def step_check_application_name(context):
context.cli.sendline(
"SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;"
)
@then("we see found")
def step_see_found(context):
wrappers.expect_exact(
context,
context.conf["pager_boundary"]
+ "\r"
+ dedent(
"""
+------------+\r
| ?column? |\r
|------------|\r
| found |\r
+------------+\r
SELECT 1\r
"""
)
+ context.conf["pager_boundary"],
timeout=5,
)
@then("we confirm the destructive warning")
def step_confirm_destructive_command(context):
"""Confirm destructive command."""
wrappers.expect_exact(
context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):",
timeout=2,
)
context.cli.sendline("y")
@then("we send password")
def step_send_password(context):
wrappers.expect_exact(context, "Password for", timeout=5)
context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER")

View file

@ -0,0 +1,93 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
import pexpect
from behave import when, then
import wrappers
@when("we create database")
def step_db_create(context):
"""
Send create database.
"""
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
context.response = {"database_name": context.conf["dbname_tmp"]}
@when("we drop database")
def step_db_drop(context):
"""
Send drop database.
"""
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
@when("we connect to test database")
def step_db_connect_test(context):
"""
Send connect to database.
"""
db_name = context.conf["dbname"]
context.cli.sendline("\\connect {0}".format(db_name))
@when("we connect to dbserver")
def step_db_connect_dbserver(context):
"""
Send connect to database.
"""
context.cli.sendline("\\connect postgres")
context.currentdb = "postgres"
@then("dbcli exits")
def step_wait_exit(context):
"""
Make sure the cli exits.
"""
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
@then("we see dbcli prompt")
def step_see_prompt(context):
"""
Wait to see the prompt.
"""
db_name = getattr(context, "currentdb", context.conf["dbname"])
wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5)
context.atprompt = True
@then("we see help output")
def step_see_help(context):
for expected_line in context.fixture_data["help_commands.txt"]:
wrappers.expect_exact(context, expected_line, timeout=2)
@then("we see database created")
def step_see_db_created(context):
"""
Wait to see create database output.
"""
wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5)
@then("we see database dropped")
def step_see_db_dropped(context):
"""
Wait to see drop database output.
"""
wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2)
@then("we see database connected")
def step_see_db_connected(context):
"""
Wait to see drop database output.
"""
wrappers.expect_exact(context, "You are now connected to database", timeout=2)

View file

@ -0,0 +1,118 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
from behave import when, then
from textwrap import dedent
import wrappers
@when("we create table")
def step_create_table(context):
"""
Send create table.
"""
context.cli.sendline("create table a(x text);")
@when("we insert into table")
def step_insert_into_table(context):
"""
Send insert into table.
"""
context.cli.sendline("""insert into a(x) values('xxx');""")
@when("we update table")
def step_update_table(context):
"""
Send insert into table.
"""
context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""")
@when("we select from table")
def step_select_from_table(context):
"""
Send select from table.
"""
context.cli.sendline("select * from a;")
@when("we delete from table")
def step_delete_from_table(context):
"""
Send deete from table.
"""
context.cli.sendline("""delete from a where x = 'yyy';""")
@when("we drop table")
def step_drop_table(context):
"""
Send drop table.
"""
context.cli.sendline("drop table a;")
@then("we see table created")
def step_see_table_created(context):
"""
Wait to see create table output.
"""
wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
@then("we see record inserted")
def step_see_record_inserted(context):
"""
Wait to see insert output.
"""
wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
@then("we see record updated")
def step_see_record_updated(context):
"""
Wait to see update output.
"""
wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2)
@then("we see data selected")
def step_see_data_selected(context):
"""
Wait to see select output.
"""
wrappers.expect_pager(
context,
dedent(
"""\
+-----+\r
| x |\r
|-----|\r
| yyy |\r
+-----+\r
SELECT 1\r
"""
),
timeout=1,
)
@then("we see record deleted")
def step_see_data_deleted(context):
"""
Wait to see delete output.
"""
wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2)
@then("we see table dropped")
def step_see_table_dropped(context):
"""
Wait to see drop output.
"""
wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2)

View file

@ -0,0 +1,70 @@
"""Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. This string is used
to call the step in "*.feature" file.
"""
from behave import when, then
from textwrap import dedent
import wrappers
@when("we prepare the test data")
def step_prepare_data(context):
"""Create table, insert a record."""
context.cli.sendline("drop table if exists a;")
wrappers.expect_exact(
context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):",
timeout=2,
)
context.cli.sendline("y")
wrappers.wait_prompt(context)
context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));")
wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""")
wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
@when("we set expanded {mode}")
def step_set_expanded(context, mode):
"""Set expanded to mode."""
context.cli.sendline("\\" + "x {}".format(mode))
wrappers.expect_exact(context, "Expanded display is", timeout=2)
wrappers.wait_prompt(context)
@then("we see {which} data selected")
def step_see_data(context, which):
"""Select data from expanded test table."""
if which == "expanded":
wrappers.expect_pager(
context,
dedent(
"""\
-[ RECORD 1 ]-------------------------\r
x | 1\r
y | 1.0\r
z | 1.0000\r
SELECT 1\r
"""
),
timeout=1,
)
else:
wrappers.expect_pager(
context,
dedent(
"""\
+-----+-----+--------+\r
| x | y | z |\r
|-----+-----+--------|\r
| 1 | 1.0 | 1.0000 |\r
+-----+-----+--------+\r
SELECT 1\r
"""
),
timeout=1,
)

View file

@ -0,0 +1,80 @@
import os
import os.path
from behave import when, then
import wrappers
@when("we start external editor providing a file name")
def step_edit_file(context):
"""Edit file with external editor."""
context.editor_file_name = os.path.join(
context.package_root, "test_file_{0}.sql".format(context.conf["vi"])
)
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
)
wrappers.expect_exact(context, ":", timeout=2)
@when("we type sql in the editor")
def step_edit_type_sql(context):
context.cli.sendline("i")
context.cli.sendline("select * from abc")
context.cli.sendline(".")
wrappers.expect_exact(context, ":", timeout=2)
@when("we exit the editor")
def step_edit_quit(context):
context.cli.sendline("x")
wrappers.expect_exact(context, "written", timeout=2)
@then("we see the sql in prompt")
def step_edit_done_sql(context):
for match in "select * from abc".split(" "):
wrappers.expect_exact(context, match, timeout=1)
# Cleanup the command line.
context.cli.sendcontrol("c")
# Cleanup the edited file.
if context.editor_file_name and os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.atprompt = True
@when("we tee output")
def step_tee_ouptut(context):
context.tee_file_name = os.path.join(
context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])
)
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name)))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Writing to file", timeout=5)
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Time", timeout=5)
@when('we query "select 123456"')
def step_query_select_123456(context):
context.cli.sendline("select 123456")
@when("we stop teeing output")
def step_notee_output(context):
context.cli.sendline("\o")
wrappers.expect_exact(context, "Time", timeout=5)
@then("we see 123456 in tee output")
def step_see_123456_in_ouput(context):
with open(context.tee_file_name) as f:
assert "123456" in f.read()
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.atprompt = True

View file

@ -0,0 +1,57 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
from behave import when, then
import wrappers
@when("we save a named query")
def step_save_named_query(context):
"""
Send \ns command
"""
context.cli.sendline("\\ns foo SELECT 12345")
@when("we use a named query")
def step_use_named_query(context):
"""
Send \n command
"""
context.cli.sendline("\\n foo")
@when("we delete a named query")
def step_delete_named_query(context):
"""
Send \nd command
"""
context.cli.sendline("\\nd foo")
@then("we see the named query saved")
def step_see_named_query_saved(context):
"""
Wait to see query saved.
"""
wrappers.expect_exact(context, "Saved.", timeout=2)
@then("we see the named query executed")
def step_see_named_query_executed(context):
"""
Wait to see select output.
"""
wrappers.expect_exact(context, "12345", timeout=1)
wrappers.expect_exact(context, "SELECT 1", timeout=1)
@then("we see the named query deleted")
def step_see_named_query_deleted(context):
"""
Wait to see query deleted.
"""
wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1)

View file

@ -0,0 +1,26 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
from behave import when, then
import wrappers
@when("we refresh completions")
def step_refresh_completions(context):
"""
Send refresh command.
"""
context.cli.sendline("\\refresh")
@then("we see completions refresh started")
def step_see_refresh_started(context):
"""
Wait to see refresh output.
"""
wrappers.expect_pager(
context, "Auto-completion refresh started in the background.\r\n", timeout=2
)

View file

@ -0,0 +1,67 @@
import re
import pexpect
from pgcli.main import COLOR_CODE_REGEX
import textwrap
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def expect_exact(context, expected, timeout):
timedout = False
try:
context.cli.expect_exact(expected, timeout=timeout)
except pexpect.TIMEOUT:
timedout = True
if timedout:
# Strip color codes out of the output.
actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
raise Exception(
textwrap.dedent(
"""\
Expected:
---
{0!r}
---
Actual:
---
{1!r}
---
Full log:
---
{2!r}
---
"""
).format(expected, actual, context.logfile.getvalue())
)
def expect_pager(context, expected, timeout):
expect_exact(
context,
"{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
timeout=timeout,
)
def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
"""Run the process using pexpect."""
run_args = run_args or []
cli_cmd = context.conf.get("cli_command")
cmd_parts = [cli_cmd] + run_args
cmd = " ".join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = currentdb or context.conf["dbname"]
context.cli.sendline("\pset pager always")
if prompt_check:
wait_prompt(context)
def wait_prompt(context):
"""Make sure prompt is displayed."""
expect_exact(context, "{0}> ".format(context.conf["dbname"]), timeout=5)

16
tests/features/wrappager.py Executable file
View file

@ -0,0 +1,16 @@
#!/usr/bin/env python
import sys
def wrappager(boundary):
print(boundary)
while 1:
buf = sys.stdin.read(2048)
if not buf:
break
sys.stdout.write(buf)
print(boundary)
if __name__ == "__main__":
wrappager(sys.argv[1])

255
tests/metadata.py Normal file
View file

@ -0,0 +1,255 @@
from functools import partial
from itertools import product
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from mock import Mock
import pytest
parametrize = pytest.mark.parametrize
qual = ["if_more_than_one_table", "always"]
no_qual = ["if_more_than_one_table", "never"]
def escape(name):
if not name.islower() or name in ("select", "localtimestamp"):
return '"' + name + '"'
return name
def completion(display_meta, text, pos=0):
return Completion(text, start_position=pos, display_meta=display_meta)
def function(text, pos=0, display=None):
return Completion(
text, display=display or text, start_position=pos, display_meta="function"
)
def get_result(completer, text, position=None):
position = len(text) if position is None else position
return completer.get_completions(
Document(text=text, cursor_position=position), Mock()
)
def result_set(completer, text, position=None):
return set(get_result(completer, text, position))
# The code below is quivalent to
# def schema(text, pos=0):
# return completion('schema', text, pos)
# and so on
schema = partial(completion, "schema")
table = partial(completion, "table")
view = partial(completion, "view")
column = partial(completion, "column")
keyword = partial(completion, "keyword")
datatype = partial(completion, "datatype")
alias = partial(completion, "table alias")
name_join = partial(completion, "name join")
fk_join = partial(completion, "fk join")
join = partial(completion, "join")
def wildcard_expansion(cols, pos=-1):
return Completion(cols, start_position=pos, display_meta="columns", display="*")
class MetaData(object):
def __init__(self, metadata):
self.metadata = metadata
def builtin_functions(self, pos=0):
return [function(f, pos) for f in self.completer.functions]
def builtin_datatypes(self, pos=0):
return [datatype(dt, pos) for dt in self.completer.datatypes]
def keywords(self, pos=0):
return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()]
def specials(self, pos=0):
return [
Completion(text=k, start_position=pos, display_meta=v.description)
for k, v in self.completer.pgspecial.commands.items()
]
def columns(self, tbl, parent="public", typ="tables", pos=0):
if typ == "functions":
fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0]
cols = fun[1]
else:
cols = self.metadata[typ][parent][tbl]
return [column(escape(col), pos) for col in cols]
def datatypes(self, parent="public", pos=0):
return [
datatype(escape(x), pos)
for x in self.metadata.get("datatypes", {}).get(parent, [])
]
def tables(self, parent="public", pos=0):
return [
table(escape(x), pos)
for x in self.metadata.get("tables", {}).get(parent, [])
]
def views(self, parent="public", pos=0):
return [
view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, [])
]
def functions(self, parent="public", pos=0):
return [
function(
escape(x[0])
+ "("
+ ", ".join(
arg_name + " := "
for (arg_name, arg_mode) in zip(x[1], x[3])
if arg_mode in ("b", "i")
)
+ ")",
pos,
escape(x[0])
+ "("
+ ", ".join(
arg_name
for (arg_name, arg_mode) in zip(x[1], x[3])
if arg_mode in ("b", "i")
)
+ ")",
)
for x in self.metadata.get("functions", {}).get(parent, [])
]
def schemas(self, pos=0):
schemas = set(sch for schs in self.metadata.values() for sch in schs)
return [schema(escape(s), pos=pos) for s in schemas]
def functions_and_keywords(self, parent="public", pos=0):
return (
self.functions(parent, pos)
+ self.builtin_functions(pos)
+ self.keywords(pos)
)
# Note that the filtering parameters here only apply to the columns
def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0):
return self.functions_and_keywords(pos=pos) + self.columns(
tbl, parent, typ, pos
)
def from_clause_items(self, parent="public", pos=0):
return (
self.functions(parent, pos)
+ self.views(parent, pos)
+ self.tables(parent, pos)
)
def schemas_and_from_clause_items(self, parent="public", pos=0):
return self.from_clause_items(parent, pos) + self.schemas(pos)
def types(self, parent="public", pos=0):
return self.datatypes(parent, pos) + self.tables(parent, pos)
@property
def completer(self):
return self.get_completer()
def get_completers(self, casing):
"""
Returns a function taking three bools `casing`, `filtr`, `aliasing` and
the list `qualify`, all defaulting to None.
Returns a list of completers.
These parameters specify the allowed values for the corresponding
completer parameters, `None` meaning any, i.e. (None, None, None, None)
results in all 24 possible completers, whereas e.g.
(True, False, True, ['never']) results in the one completer with
casing, without `search_path` filtering of objects, with table
aliasing, and without column qualification.
"""
def _cfg(_casing, filtr, aliasing, qualify):
cfg = {"settings": {}}
if _casing:
cfg["casing"] = casing
cfg["settings"]["search_path_filter"] = filtr
cfg["settings"]["generate_aliases"] = aliasing
cfg["settings"]["qualify_columns"] = qualify
return cfg
def _cfgs(casing, filtr, aliasing, qualify):
casings = [True, False] if casing is None else [casing]
filtrs = [True, False] if filtr is None else [filtr]
aliases = [True, False] if aliasing is None else [aliasing]
qualifys = qualify or ["always", "if_more_than_one_table", "never"]
return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)]
def completers(casing=None, filtr=None, aliasing=None, qualify=None):
get_comp = self.get_completer
return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)]
return completers
def _make_col(self, sch, tbl, col):
defaults = self.metadata.get("defaults", {}).get(sch, {})
return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col)))
def get_completer(self, settings=None, casing=None):
metadata = self.metadata
from pgcli.pgcompleter import PGCompleter
from pgspecial import PGSpecial
comp = PGCompleter(
smart_completion=True, settings=settings, pgspecial=PGSpecial()
)
schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []
for sch, tbls in metadata["tables"].items():
schemata.append(sch)
for tbl, cols in tbls.items():
tables.append((sch, tbl))
# Let all columns be text columns
tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols])
for sch, tbls in metadata.get("views", {}).items():
for tbl, cols in tbls.items():
views.append((sch, tbl))
# Let all columns be text columns
view_cols.extend([self._make_col(sch, tbl, col) for col in cols])
functions = [
FunctionMetadata(sch, *func_meta, arg_defaults=None)
for sch, funcs in metadata["functions"].items()
for func_meta in funcs
]
datatypes = [
(sch, typ)
for sch, datatypes in metadata["datatypes"].items()
for typ in datatypes
]
foreignkeys = [
ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks
]
comp.extend_schemata(schemata)
comp.extend_relations(tables, kind="tables")
comp.extend_relations(views, kind="views")
comp.extend_columns(tbl_cols, kind="tables")
comp.extend_columns(view_cols, kind="views")
comp.extend_functions(functions)
comp.extend_datatypes(datatypes)
comp.extend_foreignkeys(foreignkeys)
comp.set_search_path(["public"])
comp.extend_casing(casing or [])
return comp

View file

@ -0,0 +1,137 @@
import pytest
from sqlparse import parse
from pgcli.packages.parseutils.ctes import (
token_start_pos,
extract_ctes,
extract_column_names as _extract_column_names,
)
def extract_column_names(sql):
p = parse(sql)[0]
return _extract_column_names(p)
def test_token_str_pos():
sql = "SELECT * FROM xxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")
sql = "SELECT * FROM \nxxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")
def test_single_column_name_extraction():
sql = "SELECT abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_aliased_single_column_name_extraction():
sql = "SELECT abc def FROM xxx"
assert extract_column_names(sql) == ("def",)
def test_aliased_expression_name_extraction():
sql = "SELECT 99 abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_multiple_column_name_extraction():
sql = "SELECT abc, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_missing_column_name_handled_gracefully():
sql = "SELECT abc, 99 FROM xxx"
assert extract_column_names(sql) == ("abc",)
sql = "SELECT abc, 99, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_aliased_multiple_column_name_extraction():
sql = "SELECT abc def, ghi jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
def test_table_qualified_column_name_extraction():
sql = "SELECT abc.def, ghi.jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
@pytest.mark.parametrize(
"sql",
[
"INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
"DELETE FROM foo WHERE x > y RETURNING x, y",
"UPDATE foo SET x = 9 RETURNING x, y",
],
)
def test_extract_column_names_from_returning_clause(sql):
assert extract_column_names(sql) == ("x", "y")
def test_simple_cte_extraction():
sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
start_pos = len("WITH a AS ")
stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_cte_extraction_around_comments():
sql = """--blah blah blah
WITH a AS (SELECT abc def FROM x)
SELECT * FROM a"""
start_pos = len(
"""--blah blah blah
WITH a AS """
)
stop_pos = len(
"""--blah blah blah
WITH a AS (SELECT abc def FROM x)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_multiple_cte_extraction():
sql = """WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)
SELECT * FROM a, b"""
start1 = len(
"""WITH
x AS """
)
stop1 = len(
"""WITH
x AS (SELECT abc, def FROM x)"""
)
start2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS """
)
stop2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (
("x", ("abc", "def"), start1, stop1),
("y", ("ghi", "jkl"), start2, stop2),
)

View file

@ -0,0 +1,19 @@
from pgcli.packages.parseutils.meta import FunctionMetadata
def test_function_metadata_eq():
f1 = FunctionMetadata(
"s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
)
f2 = FunctionMetadata(
"s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
)
f3 = FunctionMetadata(
"s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None
)
assert f1 == f2
assert f1 != f3
assert not (f1 != f2)
assert not (f1 == f3)
assert hash(f1) == hash(f2)
assert hash(f1) != hash(f3)

View file

@ -0,0 +1,269 @@
import pytest
from pgcli.packages.parseutils.tables import extract_tables
from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
def test_empty_string():
tables = extract_tables("")
assert tables == ()
def test_simple_select_single_table():
tables = extract_tables("select * from abc")
assert tables == ((None, "abc", None, False),)
@pytest.mark.parametrize(
"sql", ['select * from "abc"."def"', 'select * from abc."def"']
)
def test_simple_select_single_table_schema_qualified_quoted_table(sql):
tables = extract_tables(sql)
assert tables == (("abc", "def", '"def"', False),)
@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def'])
def test_simple_select_single_table_schema_qualified(sql):
tables = extract_tables(sql)
assert tables == (("abc", "def", None, False),)
def test_simple_select_single_table_double_quoted():
tables = extract_tables('select * from "Abc"')
assert tables == ((None, "Abc", None, False),)
def test_simple_select_multiple_tables():
tables = extract_tables("select * from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_simple_select_multiple_tables_double_quoted():
tables = extract_tables('select * from "Abc", "Def"')
assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)])
def test_simple_select_single_table_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a')
assert tables == ((None, "Abc", "a", False),)
def test_simple_select_multiple_tables_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a, "Def" d')
assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)])
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables("select * from abc.def, ghi.jkl")
assert set(tables) == set(
[("abc", "def", None, False), ("ghi", "jkl", None, False)]
)
def test_simple_select_with_cols_single_table():
tables = extract_tables("select a,b from abc")
assert tables == ((None, "abc", None, False),)
def test_simple_select_with_cols_single_table_schema_qualified():
tables = extract_tables("select a,b from abc.def")
assert tables == (("abc", "def", None, False),)
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables("select a,b from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_simple_select_with_cols_multiple_qualified_tables():
tables = extract_tables("select a,b from abc.def, def.ghi")
assert set(tables) == set(
[("abc", "def", None, False), ("def", "ghi", None, False)]
)
def test_select_with_hanging_comma_single_table():
tables = extract_tables("select a, from abc")
assert tables == ((None, "abc", None, False),)
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables("select a, from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
assert set(tables) == set(
[(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)]
)
def test_simple_insert_single_table():
tables = extract_tables('insert into abc (id, name) values (1, "def")')
# sqlparse mistakenly assigns an alias to the table
# AND mistakenly identifies the field list as
# assert tables == ((None, 'abc', 'abc', False),)
assert tables == ((None, "abc", "abc", False),)
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
assert tables == (("abc", "def", None, False),)
def test_simple_update_table_no_schema():
tables = extract_tables("update abc set id = 1")
assert tables == ((None, "abc", None, False),)
def test_simple_update_table_with_schema():
tables = extract_tables("update abc.def set id = 1")
assert tables == (("abc", "def", None, False),)
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
def test_join_table(join_type):
sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type)
tables = extract_tables(sql)
assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)])
def test_join_table_schema_qualified():
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)])
def test_incomplete_join_clause():
sql = """select a.x, b.y
from abc a join bcd b
on a.id = """
tables = extract_tables(sql)
assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False))
def test_join_as_table():
tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
assert tables == ((None, "my_table", "m", False),)
def test_multiple_joins():
sql = """select * from t1
inner join t2 ON
t1.id = t2.t1_id
inner join t3 ON
t2.id = t3."""
tables = extract_tables(sql)
assert tables == (
(None, "t1", None, False),
(None, "t2", None, False),
(None, "t3", None, False),
)
def test_subselect_tables():
sql = "SELECT * FROM (SELECT FROM abc"
tables = extract_tables(sql)
assert tables == ((None, "abc", None, False),)
@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"])
def test_extract_no_tables(text):
tables = extract_tables(text)
assert tables == tuple()
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0})".format(arg_list))
assert tables == ((None, "foo", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_schema_qualified_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list))
assert tables == (("foo", "bar", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_aliased_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list))
assert tables == ((None, "foo", "bar", True),)
def test_simple_table_and_function():
tables = extract_tables("SELECT * FROM foo JOIN bar()")
assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)])
def test_complex_table_and_function():
tables = extract_tables(
"""SELECT * FROM foo.bar baz
JOIN bar.qux(x, y, z) quux"""
)
assert set(tables) == set(
[("foo", "bar", "baz", False), ("bar", "qux", "quux", True)]
)
def test_find_prev_keyword_using():
q = "select * from tbl1 inner join tbl2 using (col1, "
kw, q2 = find_prev_keyword(q)
assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using ("
@pytest.mark.parametrize(
"sql",
[
"select * from foo where bar",
"select * from foo where bar = 1 and baz or ",
"select * from foo where bar = 1 and baz between qux and ",
],
)
def test_find_prev_keyword_where(sql):
kw, stripped = find_prev_keyword(sql)
assert kw.value == "where" and stripped == "select * from foo where"
@pytest.mark.parametrize(
"sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "]
)
def test_find_prev_keyword_open_parens(sql):
kw, _ = find_prev_keyword(sql)
assert kw.value == "("
@pytest.mark.parametrize(
"sql",
[
"",
"$$ foo $$",
"$$ 'foo' $$",
'$$ "foo" $$',
"$$ $a$ $$",
"$a$ $$ $a$",
"foo bar $$ baz $$",
],
)
def test_is_open_quote__closed(sql):
assert not is_open_quote(sql)
@pytest.mark.parametrize(
"sql",
[
"$$",
";;;$$",
"foo $$ bar $$; foo $$",
"$$ foo $a$",
"foo 'bar baz",
"$a$ foo ",
'$$ "foo" ',
"$$ $a$ ",
"foo bar $$ baz",
],
)
def test_is_open_quote__open(sql):
assert is_open_quote(sql)

2
tests/pytest.ini Normal file
View file

@ -0,0 +1,2 @@
[pytest]
addopts=--capture=sys --showlocals

View file

@ -0,0 +1,97 @@
import time
import pytest
from mock import Mock, patch
@pytest.fixture
def refresher():
from pgcli.completion_refresher import CompletionRefresher
return CompletionRefresher()
def test_ctor(refresher):
"""
Refresher object should contain a few handlers
:param refresher:
:return:
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = [
"schemata",
"tables",
"views",
"types",
"databases",
"casing",
"functions",
]
assert expected_handlers == actual_handlers
def test_refresh_called_once(refresher):
"""
:param refresher:
:return:
"""
callbacks = Mock()
pgexecute = Mock()
special = Mock()
with patch.object(refresher, "_bg_refresh") as bg_refresh:
actual = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual) == 1
assert len(actual[0]) == 4
assert actual[0][3] == "Auto-completion refresh started in the background."
bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None)
def test_refresh_called_twice(refresher):
"""
If refresh is called a second time, it should be restarted
:param refresher:
:return:
"""
callbacks = Mock()
pgexecute = Mock()
special = Mock()
def dummy_bg_refresh(*args):
time.sleep(3) # seconds
refresher._bg_refresh = dummy_bg_refresh
actual1 = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual1) == 1
assert len(actual1[0]) == 4
assert actual1[0][3] == "Auto-completion refresh started in the background."
actual2 = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual2) == 1
assert len(actual2[0]) == 4
assert actual2[0][3] == "Auto-completion refresh restarted."
def test_refresh_with_callbacks(refresher):
"""
Callbacks must be called
:param refresher:
"""
callbacks = [Mock()]
pgexecute_class = Mock()
pgexecute = Mock()
pgexecute.extra_args = {}
special = Mock()
with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert callbacks[0].call_count == 1

30
tests/test_config.py Normal file
View file

@ -0,0 +1,30 @@
import os
import stat
import pytest
from pgcli.config import ensure_dir_exists
def test_ensure_file_parent(tmpdir):
subdir = tmpdir.join("subdir")
rcfile = subdir.join("rcfile")
ensure_dir_exists(str(rcfile))
def test_ensure_existing_dir(tmpdir):
rcfile = str(tmpdir.mkdir("subdir").join("rcfile"))
# should just not raise
ensure_dir_exists(rcfile)
def test_ensure_other_create_error(tmpdir):
subdir = tmpdir.join("subdir")
rcfile = subdir.join("rcfile")
# trigger an oserror that isn't "directory already exists"
os.chmod(str(tmpdir), stat.S_IREAD)
with pytest.raises(OSError):
ensure_dir_exists(str(rcfile))

View file

View file

@ -0,0 +1,87 @@
import pytest
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
return pgcompleter.PGCompleter()
def test_ranking_ignores_identifier_quotes(completer):
"""When calculating result rank, identifier quotes should be ignored.
The result ranking algorithm ignores identifier quotes. Without this
correction, the match "user", which Postgres requires to be quoted
since it is also a reserved word, would incorrectly fall below the
match user_action because the literal quotation marks in "user"
alter the position of the match.
This test checks that the fuzzy ranking algorithm correctly ignores
quotation marks when computing match ranks.
"""
text = "user"
collection = ["user_action", '"user"']
matches = completer.find_matches(text, collection)
assert len(matches) == 2
def test_ranking_based_on_shortest_match(completer):
"""Fuzzy result rank should be based on shortest match.
Result ranking in fuzzy searching is partially based on the length
of matches: shorter matches are considered more relevant than
longer ones. When searching for the text 'user', the length
component of the match 'user_group' could be either 4 ('user') or
7 ('user_gr').
This test checks that the fuzzy ranking algorithm uses the shorter
match when calculating result rank.
"""
text = "user"
collection = ["api_user", "user_group"]
matches = completer.find_matches(text, collection)
assert matches[1].priority > matches[0].priority
@pytest.mark.parametrize(
"collection",
[["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]],
)
def test_should_break_ties_using_lexical_order(completer, collection):
"""Fuzzy result rank should use lexical order to break ties.
When fuzzy matching, if multiple matches have the same match length and
start position, present them in lexical (rather than arbitrary) order. For
example, if we have tables 'user', 'user_action', and 'user_group', a
search for the text 'user' should present these tables in this order.
The input collections to this test are out of order; each run checks that
the search text 'user' results in the input tables being reordered
lexically.
"""
text = "user"
matches = completer.find_matches(text, collection)
assert matches[1].priority > matches[0].priority
def test_matching_should_be_case_insensitive(completer):
"""Fuzzy matching should keep matches even if letter casing doesn't match.
This test checks that variations of the text which have different casing
are still matched.
"""
text = "foo"
collection = ["Foo", "FOO", "fOO"]
matches = completer.find_matches(text, collection)
assert len(matches) == 3

383
tests/test_main.py Normal file
View file

@ -0,0 +1,383 @@
import os
import platform
import mock
import pytest
try:
import setproctitle
except ImportError:
setproctitle = None
from pgcli.main import (
obfuscate_process_password,
format_output,
PGCli,
OutputSettings,
COLOR_CODE_REGEX,
)
from pgcli.pgexecute import PGExecute
from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS
from utils import dbtest, run
from collections import namedtuple
@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows")
@pytest.mark.skipif(not setproctitle, reason="setproctitle not available")
def test_obfuscate_process_password():
original_title = setproctitle.getproctitle()
setproctitle.setproctitle("pgcli user=root password=secret host=localhost")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli user=root password=xxxx host=localhost"
assert title == expected
setproctitle.setproctitle("pgcli user=root password=top secret host=localhost")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli user=root password=xxxx host=localhost"
assert title == expected
setproctitle.setproctitle("pgcli user=root password=top secret")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli user=root password=xxxx"
assert title == expected
setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli postgres://root:xxxx@localhost/db"
assert title == expected
setproctitle.setproctitle(original_title)
def test_format_output():
settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g")
results = format_output(
"Title", [("abc", "def")], ["head1", "head2"], "test status", settings
)
expected = [
"Title",
"+---------+---------+",
"| head1 | head2 |",
"|---------+---------|",
"| abc | def |",
"+---------+---------+",
"test status",
]
assert list(results) == expected
@dbtest
def test_format_array_output(executor):
statement = """
SELECT
array[1, 2, 3]::bigint[] as bigint_array,
'{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
'{å,魚,текст}'::text[] as 配列
UNION ALL
SELECT '{}', NULL, array[NULL]
"""
results = run(executor, statement)
expected = [
"+----------------+------------------------+--------------+",
"| bigint_array | nested_numeric_array | 配列 |",
"|----------------+------------------------+--------------|",
"| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |",
"| {} | <null> | {<null>} |",
"+----------------+------------------------+--------------+",
"SELECT 2",
]
assert list(results) == expected
@dbtest
def test_format_array_output_expanded(executor):
statement = """
SELECT
array[1, 2, 3]::bigint[] as bigint_array,
'{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
'{å,魚,текст}'::text[] as 配列
UNION ALL
SELECT '{}', NULL, array[NULL]
"""
results = run(executor, statement, expanded=True)
expected = [
"-[ RECORD 1 ]-------------------------",
"bigint_array | {1,2,3}",
"nested_numeric_array | {{1,2},{3,4}}",
"配列 | {å,魚,текст}",
"-[ RECORD 2 ]-------------------------",
"bigint_array | {}",
"nested_numeric_array | <null>",
"配列 | {<null>}",
"SELECT 2",
]
assert "\n".join(results) == "\n".join(expected)
def test_format_output_auto_expand():
settings = OutputSettings(
table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100
)
table_results = format_output(
"Title", [("abc", "def")], ["head1", "head2"], "test status", settings
)
table = [
"Title",
"+---------+---------+",
"| head1 | head2 |",
"|---------+---------|",
"| abc | def |",
"+---------+---------+",
"test status",
]
assert list(table_results) == table
expanded_results = format_output(
"Title",
[("abc", "def")],
["head1", "head2"],
"test status",
settings._replace(max_width=1),
)
expanded = [
"Title",
"-[ RECORD 1 ]-------------------------",
"head1 | abc",
"head2 | def",
"test status",
]
assert "\n".join(expanded_results) == "\n".join(expanded)
termsize = namedtuple("termsize", ["rows", "columns"])
test_line = "-" * 10
test_data = [
(10, 10, "\n".join([test_line] * 7)),
(10, 10, "\n".join([test_line] * 6)),
(10, 10, "\n".join([test_line] * 5)),
(10, 10, "-" * 11),
(10, 10, "-" * 10),
(10, 10, "-" * 9),
]
# 4 lines are reserved at the bottom of the terminal for pgcli's prompt
use_pager_when_on = [True, True, False, True, False, False]
# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL
test_ids = [
"Output longer than terminal height",
"Output equal to terminal height",
"Output shorter than terminal height",
"Output longer than terminal width",
"Output equal to terminal width",
"Output shorter than terminal width",
]
@pytest.fixture
def pset_pager_mocks():
cli = PGCli()
cli.watch_command = None
with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch(
"pgcli.main.click.echo_via_pager"
) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app:
yield cli, mock_echo, mock_echo_via_pager, mock_app
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF):
cli.echo_via_pager(text)
mock_echo.assert_called()
mock_echo_via_pager.assert_not_called()
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS):
cli.echo_via_pager(text)
mock_echo.assert_not_called()
mock_echo_via_pager.assert_called()
pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)]
@pytest.mark.parametrize(
"term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids
)
def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT):
cli.echo_via_pager(text)
if use_pager:
mock_echo.assert_not_called()
mock_echo_via_pager.assert_called()
else:
mock_echo_via_pager.assert_not_called()
mock_echo.assert_called()
@pytest.mark.parametrize(
"text,expected_length",
[
(
"22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s",
78,
),
("=\u001b[m=", 2),
("-\u001b]23\u0007-", 2),
],
)
def test_color_pattern(text, expected_length, pset_pager_mocks):
cli = pset_pager_mocks[0]
assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length
@dbtest
def test_i_works(tmpdir, executor):
sqlfile = tmpdir.join("test.sql")
sqlfile.write("SELECT NOW()")
rcfile = str(tmpdir.join("rcfile"))
cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
statement = r"\i {0}".format(sqlfile)
run(executor, statement, pgspecial=cli.pgspecial)
def test_missing_rc_dir(tmpdir):
rcfile = str(tmpdir.join("subdir").join("rcfile"))
PGCli(pgclirc_file=rcfile)
assert os.path.exists(rcfile)
def test_quoted_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B")
mock_connect.assert_called_with(
database="testdb[", host="baz.com", user="bar^", passwd="]foo"
)
def test_pg_service_file(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
service_conf.write(
"""[myservice]
host=a_host
user=a_user
port=5433
password=much_secure
dbname=a_dbname
[my_other_service]
host=b_host
user=b_user
port=5435
dbname=b_dbname
"""
)
os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath
cli.connect_service("myservice", "another_user")
mock_connect.assert_called_with(
database="a_dbname",
host="a_host",
user="another_user",
port="5433",
passwd="much_secure",
)
with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
mock_pgexecute.return_value = None
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
os.environ["PGPASSWORD"] = "very_secure"
cli.connect_service("my_other_service", None)
mock_pgexecute.assert_called_with(
"b_dbname",
"b_user",
"very_secure",
"b_host",
"5435",
"",
application_name="pgcli",
)
del os.environ["PGPASSWORD"]
del os.environ["PGSERVICEFILE"]
def test_ssl_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
"postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?"
"sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem"
)
mock_connect.assert_called_with(
database="testdb[",
host="baz.com",
user="bar^",
passwd="]foo",
sslmode="verify-full",
sslcert="my.pem",
sslkey="my-key.pem",
sslrootcert="ca.pem",
)
def test_port_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb")
mock_connect.assert_called_with(
database="testdb", host="baz.com", user="bar", passwd="foo", port="2543"
)
def test_multihost_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
"postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb"
)
mock_connect.assert_called_with(
database="testdb",
host="baz1.com,baz2.com,baz3.com",
user="bar",
passwd="foo",
port="2543,2543,2543",
)
def test_application_name_db_uri(tmpdir):
with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
mock_pgexecute.return_value = None
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar@baz.com/?application_name=cow")
mock_pgexecute.assert_called_with(
"bar", "bar", "", "baz.com", "", "", application_name="cow"
)

View file

@ -0,0 +1,133 @@
import pytest
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from utils import completions_to_set
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
return pgcompleter.PGCompleter(smart_completion=False)
@pytest.fixture
def complete_event():
from mock import Mock
return Mock()
def test_empty_string_completion(completer, complete_event):
text = ""
position = 0
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(map(Completion, completer.all_completions))
def test_select_keyword_completion(completer, complete_event):
text = "SEL"
position = len("SEL")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set([Completion(text="SELECT", start_position=-3)])
def test_function_name_completion(completer, complete_event):
text = "SELECT MA"
position = len("SELECT MA")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(
[
Completion(text="MATERIALIZED VIEW", start_position=-2),
Completion(text="MAX", start_position=-2),
Completion(text="MAXEXTENTS", start_position=-2),
Completion(text="MAKE_DATE", start_position=-2),
Completion(text="MAKE_TIME", start_position=-2),
Completion(text="MAKE_TIMESTAMPTZ", start_position=-2),
Completion(text="MAKE_INTERVAL", start_position=-2),
Completion(text="MASKLEN", start_position=-2),
Completion(text="MAKE_TIMESTAMP", start_position=-2),
]
)
def test_column_name_completion(completer, complete_event):
text = "SELECT FROM users"
position = len("SELECT ")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(map(Completion, completer.all_completions))
def test_alter_well_known_keywords_completion(completer, complete_event):
text = "ALTER "
position = len(text)
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position),
complete_event,
smart_completion=True,
)
)
assert result > completions_to_set(
[
Completion(text="DATABASE", display_meta="keyword"),
Completion(text="TABLE", display_meta="keyword"),
Completion(text="SYSTEM", display_meta="keyword"),
]
)
assert (
completions_to_set([Completion(text="CREATE", display_meta="keyword")])
not in result
)
def test_special_name_completion(completer, complete_event):
text = "\\"
position = len("\\")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
# Special commands will NOT be suggested during naive completion mode.
assert result == completions_to_set([])
def test_datatype_name_completion(completer, complete_event):
text = "SELECT price::IN"
position = len("SELECT price::IN")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position),
complete_event,
smart_completion=True,
)
)
assert result == completions_to_set(
[
Completion(text="INET", display_meta="datatype"),
Completion(text="INT", display_meta="datatype"),
Completion(text="INT2", display_meta="datatype"),
Completion(text="INT4", display_meta="datatype"),
Completion(text="INT8", display_meta="datatype"),
Completion(text="INTEGER", display_meta="datatype"),
Completion(text="INTERNAL", display_meta="datatype"),
Completion(text="INTERVAL", display_meta="datatype"),
]
)

542
tests/test_pgexecute.py Normal file
View file

@ -0,0 +1,542 @@
from textwrap import dedent
import psycopg2
import pytest
from mock import patch, MagicMock
from pgspecial.main import PGSpecial, NO_QUERY
from utils import run, dbtest, requires_json, requires_jsonb
from pgcli.main import PGCli
from pgcli.packages.parseutils.meta import FunctionMetadata
def function_meta_data(
func_name,
schema_name="public",
arg_names=None,
arg_types=None,
arg_modes=None,
return_type=None,
is_aggregate=False,
is_window=False,
is_set_returning=False,
is_extension=False,
arg_defaults=None,
):
return FunctionMetadata(
schema_name,
func_name,
arg_names,
arg_types,
arg_modes,
return_type,
is_aggregate,
is_window,
is_set_returning,
is_extension,
arg_defaults,
)
@dbtest
def test_conn(executor):
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc')""")
assert run(executor, """select * from test""", join=True) == dedent(
"""\
+-----+
| a |
|-----|
| abc |
+-----+
SELECT 1"""
)
@dbtest
def test_copy(executor):
executor_copy = executor.copy()
run(executor_copy, """create table test(a text)""")
run(executor_copy, """insert into test values('abc')""")
assert run(executor_copy, """select * from test""", join=True) == dedent(
"""\
+-----+
| a |
|-----|
| abc |
+-----+
SELECT 1"""
)
@dbtest
def test_bools_are_treated_as_strings(executor):
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
assert run(executor, """select * from test""", join=True) == dedent(
"""\
+------+
| a |
|------|
| True |
+------+
SELECT 1"""
)
@dbtest
def test_expanded_slash_G(executor, pgspecial):
# Tests whether we reset the expanded output after a \G.
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
results = run(executor, """select * from test \G""", pgspecial=pgspecial)
assert pgspecial.expanded_output == False
@dbtest
def test_schemata_table_views_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
run(executor, "create view d as select 1 as e")
run(executor, "create schema schema1")
run(executor, "create table schema1.c (w text DEFAULT 'meow')")
run(executor, "create schema schema2")
# schemata
# don't enforce all members of the schemas since they may include postgres
# temporary schemas
assert set(executor.schemata()) >= set(
["public", "pg_catalog", "information_schema", "schema1", "schema2"]
)
assert executor.search_path() == ["pg_catalog", "public"]
# tables
assert set(executor.tables()) >= set(
[("public", "a"), ("public", "b"), ("schema1", "c")]
)
assert set(executor.table_columns()) >= set(
[
("public", "a", "x", "text", False, None),
("public", "a", "y", "text", False, None),
("public", "b", "z", "text", False, None),
("schema1", "c", "w", "text", True, "'meow'::text"),
]
)
# views
assert set(executor.views()) >= set([("public", "d")])
assert set(executor.view_columns()) >= set(
[("public", "d", "e", "integer", False, None)]
)
@dbtest
def test_foreign_key_query(executor):
run(executor, "create schema schema1")
run(executor, "create schema schema2")
run(executor, "create table schema1.parent(parentid int PRIMARY KEY)")
run(
executor,
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
)
assert set(executor.foreignkeys()) >= set(
[("schema1", "parent", "parentid", "schema2", "child", "motherid")]
)
@dbtest
def test_functions_query(executor):
run(
executor,
"""create function func1() returns int
language sql as $$select 1$$""",
)
run(executor, "create schema schema1")
run(
executor,
"""create function schema1.func2() returns int
language sql as $$select 2$$""",
)
run(
executor,
"""create function func3()
returns table(x int, y int) language sql
as $$select 1, 2 from generate_series(1,5)$$;""",
)
run(
executor,
"""create function func4(x int) returns setof int language sql
as $$select generate_series(1,5)$$;""",
)
funcs = set(executor.functions())
assert funcs >= set(
[
function_meta_data(func_name="func1", return_type="integer"),
function_meta_data(
func_name="func3",
arg_names=["x", "y"],
arg_types=["integer", "integer"],
arg_modes=["t", "t"],
return_type="record",
is_set_returning=True,
),
function_meta_data(
schema_name="public",
func_name="func4",
arg_names=("x",),
arg_types=("integer",),
return_type="integer",
is_set_returning=True,
),
function_meta_data(
schema_name="schema1", func_name="func2", return_type="integer"
),
]
)
@dbtest
def test_datatypes_query(executor):
run(executor, "create type foo AS (a int, b text)")
types = list(executor.datatypes())
assert types == [("public", "foo")]
@dbtest
def test_database_list(executor):
databases = executor.databases()
assert "_test_db" in databases
@dbtest
def test_invalid_syntax(executor, exception_formatter):
result = run(executor, "invalid syntax!", exception_formatter=exception_formatter)
assert 'syntax error at or near "invalid"' in result[0]
@dbtest
def test_invalid_column_name(executor, exception_formatter):
result = run(
executor, "select invalid command", exception_formatter=exception_formatter
)
assert 'column "invalid" does not exist' in result[0]
@pytest.fixture(params=[True, False])
def expanded(request):
return request.param
@dbtest
def test_unicode_support_in_output(executor, expanded):
run(executor, "create table unicodechars(t text)")
run(executor, "insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
assert "é" in run(
executor, "select * from unicodechars", join=True, expanded=expanded
)
@dbtest
def test_not_is_special(executor, pgspecial):
"""is_special is set to false for database queries."""
query = "select 1"
result = list(executor.run(query, pgspecial=pgspecial))
success, is_special = result[0][5:]
assert success == True
assert is_special == False
@dbtest
def test_execute_from_file_no_arg(executor, pgspecial):
"""\i without a filename returns an error."""
result = list(executor.run("\i", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert "missing required argument" in status
assert success == False
assert is_special == True
@dbtest
@patch("pgcli.main.os")
def test_execute_from_file_io_error(os, executor, pgspecial):
"""\i with an io_error returns an error."""
# Inject an IOError.
os.path.expanduser.side_effect = IOError("test")
# Check the result.
result = list(executor.run("\i test", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert status == "test"
assert success == False
assert is_special == True
@dbtest
def test_multiple_queries_same_line(executor):
result = run(executor, "select 'foo'; select 'bar'")
assert len(result) == 12 # 2 * (output+status) * 3 lines
assert "foo" in result[3]
assert "bar" in result[9]
@dbtest
def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
result = run(executor, "select 'foo'; \d", pgspecial=pgspecial)
assert len(result) == 11 # 2 * (output+status) * 3 lines
assert "foo" in result[3]
# This is a lame check. :(
assert "Schema" in result[7]
@dbtest
def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter):
result = run(
executor,
"select 'fooé'; invalid syntax é",
exception_formatter=exception_formatter,
)
assert "fooé" in result[3]
assert 'syntax error at or near "invalid"' in result[-1]
@pytest.fixture
def pgspecial():
return PGCli().pgspecial
@dbtest
def test_special_command_help(executor, pgspecial):
result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|")
assert "Command" in result[1]
assert "Description" in result[2]
@dbtest
def test_bytea_field_support_in_output(executor):
run(executor, "create table binarydata(c bytea)")
run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))")
assert "\\xdeadbeef" in run(executor, "select * from binarydata", join=True)
@dbtest
def test_unicode_support_in_unknown_type(executor):
assert "日本語" in run(executor, "SELECT '日本語' AS japanese;", join=True)
@dbtest
def test_unicode_support_in_enum_type(executor):
run(executor, "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')")
run(executor, "CREATE TABLE person (name TEXT, current_mood mood)")
run(executor, "INSERT INTO person VALUES ('Moe', '日本語')")
assert "日本語" in run(executor, "SELECT * FROM person", join=True)
@requires_json
def test_json_renders_without_u_prefix(executor, expanded):
run(executor, "create table jsontest(d json)")
run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""")
result = run(
executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded
)
assert '{"name": "Éowyn"}' in result
@requires_jsonb
def test_jsonb_renders_without_u_prefix(executor, expanded):
run(executor, "create table jsonbtest(d jsonb)")
run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""")
result = run(
executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded
)
assert '{"name": "Éowyn"}' in result
@dbtest
def test_date_time_types(executor):
run(executor, "SET TIME ZONE UTC")
assert (
run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3]
== "| 00:00:00 |"
)
assert (
run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split(
"\n"
)[3]
== "| 00:00:00+14:59 |"
)
assert (
run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[
3
]
== "| 4713-01-01 BC |"
)
assert (
run(
executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True
).split("\n")[3]
== "| 4713-01-01 00:00:00 BC |"
)
assert (
run(
executor,
"SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))",
join=True,
).split("\n")[3]
== "| 4713-01-01 00:00:00+00 BC |"
)
assert (
run(
executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True
).split("\n")[3]
== "| -123456789 days, 12:23:56 |"
)
@dbtest
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
def test_large_numbers_render_directly(executor, value):
run(executor, "create table numbertest(a numeric)")
run(executor, "insert into numbertest (a) values ({0})".format(value))
assert value in run(executor, "select * from numbertest", join=True)
@dbtest
@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"])
@pytest.mark.parametrize("verbose", ["", "+"])
@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"])
def test_describe_special(executor, command, verbose, pattern, pgspecial):
# We don't have any tests for the output of any of the special commands,
# but we can at least make sure they run without error
sql = r"\{command}{verbose} {pattern}".format(**locals())
list(executor.run(sql, pgspecial=pgspecial))
@dbtest
@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"])
def test_raises_with_no_formatter(executor, sql):
with pytest.raises(psycopg2.ProgrammingError):
list(executor.run(sql))
@dbtest
def test_on_error_resume(executor, exception_formatter):
sql = "select 1; error; select 1;"
result = list(
executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter)
)
assert len(result) == 3
@dbtest
def test_on_error_stop(executor, exception_formatter):
sql = "select 1; error; select 1;"
result = list(
executor.run(
sql, on_error_resume=False, exception_formatter=exception_formatter
)
)
assert len(result) == 2
# @dbtest
# def test_unicode_notices(executor):
# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;"
# result = list(executor.run(sql))
# assert result[0][0] == u'NOTICE: 有人更改\n'
@dbtest
def test_nonexistent_function_definition(executor):
with pytest.raises(RuntimeError):
result = executor.view_definition("there_is_no_such_function")
@dbtest
def test_function_definition(executor):
run(
executor,
"""
CREATE OR REPLACE FUNCTION public.the_number_three()
RETURNS int
LANGUAGE sql
AS $function$
select 3;
$function$
""",
)
result = executor.function_definition("the_number_three")
@dbtest
def test_view_definition(executor):
run(executor, "create table tbl1 (a text, b numeric)")
run(executor, "create view vw1 AS SELECT * FROM tbl1")
run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1")
result = executor.view_definition("vw1")
assert "FROM tbl1" in result
# import pytest; pytest.set_trace()
result = executor.view_definition("mvw1")
assert "MATERIALIZED VIEW" in result
@dbtest
def test_nonexistent_view_definition(executor):
with pytest.raises(RuntimeError):
result = executor.view_definition("there_is_no_such_view")
with pytest.raises(RuntimeError):
result = executor.view_definition("mvw1")
@dbtest
def test_short_host(executor):
with patch.object(executor, "host", "localhost"):
assert executor.short_host == "localhost"
with patch.object(executor, "host", "localhost.example.org"):
assert executor.short_host == "localhost"
with patch.object(
executor, "host", "localhost1.example.org,localhost2.example.org"
):
assert executor.short_host == "localhost1"
class BrokenConnection(object):
"""Mock a connection that failed."""
def cursor(self):
raise psycopg2.InterfaceError("I'm broken!")
@dbtest
def test_exit_without_active_connection(executor):
quit_handler = MagicMock()
pgspecial = PGSpecial()
pgspecial.register(
quit_handler,
"\\q",
"\\q",
"Quit pgcli.",
arg_type=NO_QUERY,
case_sensitive=True,
aliases=(":q",),
)
with patch.object(executor, "conn", BrokenConnection()):
# we should be able to quit the app, even without active connection
run(executor, "\\q", pgspecial=pgspecial)
quit_handler.assert_called_once()
# an exception should be raised when running a query without active connection
with pytest.raises(psycopg2.InterfaceError):
run(executor, "select 1", pgspecial=pgspecial)

78
tests/test_pgspecial.py Normal file
View file

@ -0,0 +1,78 @@
import pytest
from pgcli.packages.sqlcompletion import (
suggest_type,
Special,
Database,
Schema,
Table,
View,
Function,
Datatype,
)
def test_slash_suggests_special():
suggestions = suggest_type("\\", "\\")
assert set(suggestions) == set([Special()])
def test_slash_d_suggests_special():
suggestions = suggest_type("\\d", "\\d")
assert set(suggestions) == set([Special()])
def test_dn_suggests_schemata():
suggestions = suggest_type("\\dn ", "\\dn ")
assert suggestions == (Schema(),)
suggestions = suggest_type("\\dn xxx", "\\dn xxx")
assert suggestions == (Schema(),)
def test_d_suggests_tables_views_and_schemas():
suggestions = suggest_type("\d ", "\d ")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
suggestions = suggest_type("\d xxx", "\d xxx")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
def test_d_dot_suggests_schema_qualified_tables_or_views():
suggestions = suggest_type("\d myschema.", "\d myschema.")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
def test_df_suggests_schema_or_function():
suggestions = suggest_type("\\df xxx", "\\df xxx")
assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()])
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
assert suggestions == (Function(schema="myschema", usage="special"),)
def test_leading_whitespace_ok():
cmd = "\\dn "
whitespace = " "
suggestions = suggest_type(whitespace + cmd, whitespace + cmd)
assert suggestions == suggest_type(cmd, cmd)
def test_dT_suggests_schema_or_datatypes():
text = "\\dT "
suggestions = suggest_type(text, text)
assert set(suggestions) == set([Schema(), Datatype(schema=None)])
def test_schema_qualified_dT_suggests_datatypes():
text = "\\dT foo."
suggestions = suggest_type(text, text)
assert suggestions == (Datatype(schema="foo"),)
@pytest.mark.parametrize("command", ["\\c ", "\\connect "])
def test_c_suggests_databases(command):
suggestions = suggest_type(command, command)
assert suggestions == (Database(),)

38
tests/test_plan.wiki Normal file
View file

@ -0,0 +1,38 @@
= Gross Checks =
* [ ] Check connecting to a local database.
* [ ] Check connecting to a remote database.
* [ ] Check connecting to a database with a user/password.
* [ ] Check connecting to a non-existent database.
* [ ] Test changing the database.
== PGExecute ==
* [ ] Test successful execution given a cursor.
* [ ] Test unsuccessful execution with a syntax error.
* [ ] Test a series of executions with the same cursor without failure.
* [ ] Test a series of executions with the same cursor with failure.
* [ ] Test passing in a special command.
== Naive Autocompletion ==
* [ ] Input empty string, ask for completions - Everything.
* [ ] Input partial prefix, ask for completions - Stars with prefix.
* [ ] Input fully autocompleted string, ask for completions - Only full match
* [ ] Input non-existent prefix, ask for completions - nothing
* [ ] Input lowercase prefix - case insensitive completions
== Smart Autocompletion ==
* [ ] Input empty string and check if only keywords are returned.
* [ ] Input SELECT prefix and check if only columns are returned.
* [ ] Input SELECT blah - only keywords are returned.
* [ ] Input SELECT * FROM - Table names only
== PGSpecial ==
* [ ] Test \d
* [ ] Test \d tablename
* [ ] Test \d tablena*
* [ ] Test \d non-existent-tablename
* [ ] Test \d index
* [ ] Test \d sequence
* [ ] Test \d view
== Exceptionals ==
* [ ] Test the 'use' command to change db.

View file

@ -0,0 +1,20 @@
from pgcli.packages.prioritization import PrevalenceCounter
def test_prevalence_counter():
counter = PrevalenceCounter()
sql = """SELECT * FROM foo WHERE bar GROUP BY baz;
select * from foo;
SELECT * FROM foo WHERE bar GROUP
BY baz"""
counter.update(sql)
keywords = ["SELECT", "FROM", "GROUP BY"]
expected = [3, 3, 2]
kw_counts = [counter.keyword_count(x) for x in keywords]
assert kw_counts == expected
assert counter.keyword_count("NOSUCHKEYWORD") == 0
names = ["foo", "bar", "baz"]
name_counts = [counter.name_count(x) for x in names]
assert name_counts == [3, 2, 2]

View file

@ -0,0 +1,10 @@
import click
from pgcli.packages.prompt_utils import confirm_destructive_query
def test_confirm_destructive_query_notty():
stdin = click.get_text_stream("stdin")
if not stdin.isatty():
sql = "drop database foo;"
assert confirm_destructive_query(sql) is None

79
tests/test_rowlimit.py Normal file
View file

@ -0,0 +1,79 @@
import pytest
from mock import Mock
from pgcli.main import PGCli
# We need this fixtures beacause we need PGCli object to be created
# after test collection so it has config loaded from temp directory
@pytest.fixture(scope="module")
def default_pgcli_obj():
return PGCli()
@pytest.fixture(scope="module")
def DEFAULT(default_pgcli_obj):
return default_pgcli_obj.row_limit
@pytest.fixture(scope="module")
def LIMIT(DEFAULT):
return DEFAULT + 1000
@pytest.fixture(scope="module")
def over_default(DEFAULT):
over_default_cursor = Mock()
over_default_cursor.configure_mock(rowcount=DEFAULT + 10)
return over_default_cursor
@pytest.fixture(scope="module")
def over_limit(LIMIT):
over_limit_cursor = Mock()
over_limit_cursor.configure_mock(rowcount=LIMIT + 10)
return over_limit_cursor
@pytest.fixture(scope="module")
def low_count():
low_count_cursor = Mock()
low_count_cursor.configure_mock(rowcount=1)
return low_count_cursor
def test_row_limit_with_LIMIT_clause(LIMIT, over_limit):
cli = PGCli(row_limit=LIMIT)
stmt = "SELECT * FROM students LIMIT 1000"
result = cli._should_limit_output(stmt, over_limit)
assert result is False
cli = PGCli(row_limit=0)
result = cli._should_limit_output(stmt, over_limit)
assert result is False
def test_row_limit_without_LIMIT_clause(LIMIT, over_limit):
cli = PGCli(row_limit=LIMIT)
stmt = "SELECT * FROM students"
result = cli._should_limit_output(stmt, over_limit)
assert result is True
cli = PGCli(row_limit=0)
result = cli._should_limit_output(stmt, over_limit)
assert result is False
def test_row_limit_on_non_select(over_limit):
cli = PGCli()
stmt = "UPDATE students SET name='Boby'"
result = cli._should_limit_output(stmt, over_limit)
assert result is False
cli = PGCli(row_limit=0)
result = cli._should_limit_output(stmt, over_limit)
assert result is False

View file

@ -0,0 +1,727 @@
import itertools
from metadata import (
MetaData,
alias,
name_join,
fk_join,
join,
schema,
table,
function,
wildcard_expansion,
column,
get_result,
result_set,
qual,
no_qual,
parametrize,
)
from utils import completions_to_set
metadata = {
"tables": {
"public": {
"users": ["id", "email", "first_name", "last_name"],
"orders": ["id", "ordered_date", "status", "datestamp"],
"select": ["id", "localtime", "ABC"],
},
"custom": {
"users": ["id", "phone_number"],
"Users": ["userid", "username"],
"products": ["id", "product_name", "price"],
"shipments": ["id", "address", "user_id"],
},
"Custom": {"projects": ["projectid", "name"]},
"blog": {
"entries": ["entryid", "entrytitle", "entrytext"],
"tags": ["tagid", "name"],
"entrytags": ["entryid", "tagid"],
"entacclog": ["entryid", "username", "datestamp"],
},
},
"functions": {
"public": [
["func1", [], [], [], "", False, False, False, False],
["func2", [], [], [], "", False, False, False, False],
],
"custom": [
["func3", [], [], [], "", False, False, False, False],
[
"set_returning_func",
["x"],
["integer"],
["o"],
"integer",
False,
False,
True,
False,
],
],
"Custom": [["func4", [], [], [], "", False, False, False, False]],
"blog": [
[
"extract_entry_symbols",
["_entryid", "symbol"],
["integer", "text"],
["i", "o"],
"",
False,
False,
True,
False,
],
[
"enter_entry",
["_title", "_text", "entryid"],
["text", "text", "integer"],
["i", "i", "o"],
"",
False,
False,
False,
False,
],
],
},
"datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]},
"foreignkeys": {
"custom": [("public", "users", "id", "custom", "shipments", "user_id")],
"blog": [
("blog", "entries", "entryid", "blog", "entacclog", "entryid"),
("blog", "entries", "entryid", "blog", "entrytags", "entryid"),
("blog", "tags", "tagid", "blog", "entrytags", "tagid"),
],
},
"defaults": {
"public": {
("orders", "id"): "nextval('orders_id_seq'::regclass)",
("orders", "datestamp"): "now()",
("orders", "status"): "'PENDING'::text",
}
},
}
testdata = MetaData(metadata)
cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')]
casing = (
"SELECT",
"Orders",
"User_Emails",
"CUSTOM",
"Func1",
"Entries",
"Tags",
"EntryTags",
"EntAccLog",
"EntryID",
"EntryTitle",
"EntryText",
)
completers = testdata.get_completers(casing)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize("table", ["users", '"users"'])
def test_suggested_column_names_from_shadowed_visible_table(completer, table):
result = get_result(completer, "SELECT FROM " + table, len("SELECT "))
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("users")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize(
"text",
[
"SELECT from custom.users",
"WITH users as (SELECT 1 AS foo) SELECT from custom.users",
],
)
def test_suggested_column_names_from_qualified_shadowed_table(completer, text):
result = get_result(completer, text, position=text.find(" ") + 1)
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("users", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"])
def test_suggested_column_names_from_cte(completer, text):
result = completions_to_set(get_result(completer, text, text.find(" ") + 1))
assert result == completions_to_set(
[column("foo")] + testdata.functions_and_keywords()
)
@parametrize("completer", completers(casing=False))
@parametrize(
"text",
[
"SELECT * FROM users JOIN custom.shipments ON ",
"""SELECT *
FROM public.users
JOIN custom.shipments ON """,
],
)
def test_suggested_join_conditions(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[
alias("users"),
alias("shipments"),
name_join("shipments.id = users.id"),
fk_join("shipments.user_id = users.id"),
]
)
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
@parametrize(
("query", "tbl"),
itertools.product(
(
"SELECT * FROM public.{0} RIGHT OUTER JOIN ",
"""SELECT *
FROM {0}
JOIN """,
),
("users", '"users"', "Users"),
),
)
def test_suggested_joins(completer, query, tbl):
result = get_result(completer, query.format(tbl))
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
+ [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))]
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_column_names_from_schema_qualifed_table(completer):
result = get_result(completer, "SELECT from custom.products", len("SELECT "))
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("products", "custom")
)
@parametrize(
"text",
[
"INSERT INTO orders(",
"INSERT INTO orders (",
"INSERT INTO public.orders(",
"INSERT INTO public.orders (",
],
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_columns_with_insert(completer, text):
assert completions_to_set(get_result(completer, text)) == completions_to_set(
testdata.columns("orders")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_column_names_in_function(completer):
result = get_result(
completer, "SELECT MAX( from custom.products", len("SELECT MAX(")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("products", "custom")
)
@parametrize("completer", completers(casing=False, aliasing=False))
@parametrize(
"text",
["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'],
)
@parametrize("use_leading_double_quote", [False, True])
def test_suggested_table_names_with_schema_dot(
completer, text, use_leading_double_quote
):
if use_leading_double_quote:
text += '"'
start_position = -1
else:
start_position = 0
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.from_clause_items("custom", start_position)
)
@parametrize("completer", completers(casing=False, aliasing=False))
@parametrize("text", ['SELECT * FROM "Custom".'])
@parametrize("use_leading_double_quote", [False, True])
def test_suggested_table_names_with_schema_dot2(
completer, text, use_leading_double_quote
):
if use_leading_double_quote:
text += '"'
start_position = -1
else:
start_position = 0
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.from_clause_items("Custom", start_position)
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_column_names_with_qualified_alias(completer):
result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p."))
assert completions_to_set(result) == completions_to_set(
testdata.columns("products", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_multiple_column_names(completer):
result = get_result(
completer, "SELECT id, from custom.products", len("SELECT id, ")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("products", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_multiple_column_names_with_alias(completer):
result = get_result(
completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns("products", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ",
"SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id",
],
)
def test_suggestions_after_on(completer, text):
position = len(
"SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON "
)
result = get_result(completer, text, position)
assert completions_to_set(result) == completions_to_set(
[
alias("x"),
alias("y"),
name_join("y.price = x.price"),
name_join("y.product_name = x.product_name"),
name_join("y.id = x.id"),
]
)
@parametrize("completer", completers())
def test_suggested_aliases_after_on_right_side(completer):
text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = "
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")])
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
def test_table_names_after_from(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_schema_qualified_function_name(completer):
text = "SELECT custom.func"
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[
function("func3()", -len("func")),
function("set_returning_func()", -len("func")),
]
)
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"SELECT 1::custom.",
"CREATE TABLE foo (bar custom.",
"CREATE FUNCTION foo (bar INT, baz custom.",
"ALTER TABLE foo ALTER COLUMN bar TYPE custom.",
],
)
def test_schema_qualified_type_name(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(testdata.types("custom"))
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggest_columns_from_aliased_set_returning_function(completer):
result = get_result(
completer, "select f. from custom.set_returning_func() f", len("select f.")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns("set_returning_func", "custom", "functions")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize(
"text",
[
"SELECT * FROM custom.set_returning_func()",
"SELECT * FROM Custom.set_returning_func()",
"SELECT * FROM Custom.Set_Returning_Func()",
],
)
def test_wildcard_column_expansion_with_function(completer, text):
position = len("SELECT *")
completions = get_result(completer, text, position)
col_list = "x"
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
def test_wildcard_column_expansion_with_alias_qualifier(completer):
text = "SELECT p.* FROM custom.products p"
position = len("SELECT p.*")
completions = get_result(completer, text, position)
col_list = "id, p.product_name, p.price"
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"""
SELECT count(1) FROM users;
CREATE FUNCTION foo(custom.products _products) returns custom.shipments
LANGUAGE SQL
AS $foo$
SELECT 1 FROM custom.shipments;
INSERT INTO public.orders(*) values(-1, now(), 'preliminary');
SELECT 2 FROM custom.users;
$foo$;
SELECT count(1) FROM custom.shipments;
""",
"INSERT INTO public.orders(*",
"INSERT INTO public.Orders(*",
"INSERT INTO public.orders (*",
"INSERT INTO public.Orders (*",
"INSERT INTO orders(*",
"INSERT INTO Orders(*",
"INSERT INTO orders (*",
"INSERT INTO Orders (*",
"INSERT INTO public.orders(*)",
"INSERT INTO public.Orders(*)",
"INSERT INTO public.orders (*)",
"INSERT INTO public.Orders (*)",
"INSERT INTO orders(*)",
"INSERT INTO Orders(*)",
"INSERT INTO orders (*)",
"INSERT INTO Orders (*)",
],
)
def test_wildcard_column_expansion_with_insert(completer, text):
position = text.index("*") + 1
completions = get_result(completer, text, position)
expected = [wildcard_expansion("ordered_date, status")]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
def test_wildcard_column_expansion_with_table_qualifier(completer):
text = 'SELECT "select".* FROM public."select"'
position = len('SELECT "select".*')
completions = get_result(completer, text, position)
col_list = 'id, "select"."localtime", "select"."ABC"'
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False, qualify=qual))
def test_wildcard_column_expansion_with_two_tables(completer):
text = 'SELECT * FROM public."select" JOIN custom.users ON true'
position = len("SELECT *")
completions = get_result(completer, text, position)
cols = (
'"select".id, "select"."localtime", "select"."ABC", '
"users.id, users.phone_number"
)
expected = [wildcard_expansion(cols)]
assert completions == expected
@parametrize("completer", completers(filtr=True, casing=False))
def test_wildcard_column_expansion_with_two_tables_and_parent(completer):
text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true'
position = len('SELECT "select".*')
completions = get_result(completer, text, position)
col_list = 'id, "select"."localtime", "select"."ABC"'
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"SELECT U. FROM custom.Users U",
"SELECT U. FROM custom.USERS U",
"SELECT U. FROM custom.users U",
'SELECT U. FROM "custom".Users U',
'SELECT U. FROM "custom".USERS U',
'SELECT U. FROM "custom".users U',
],
)
def test_suggest_columns_from_unquoted_table(completer, text):
position = len("SELECT U.")
result = get_result(completer, text, position)
assert completions_to_set(result) == completions_to_set(
testdata.columns("users", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U']
)
def test_suggest_columns_from_quoted_table(completer, text):
position = len("SELECT U.")
result = get_result(completer, text, position)
assert completions_to_set(result) == completions_to_set(
testdata.columns("Users", "custom")
)
texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "]
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
@parametrize("text", texts)
def test_schema_or_visible_table_completion(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
)
@parametrize("completer", completers(aliasing=True, casing=False, filtr=True))
@parametrize("text", texts)
def test_table_aliases(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.schemas()
+ [
table("users u"),
table("orders o" if text == "SELECT * FROM " else "orders o2"),
table('"select" s'),
function("func1() f"),
function("func2() f"),
]
)
@parametrize("completer", completers(aliasing=True, casing=True, filtr=True))
@parametrize("text", texts)
def test_aliases_with_casing(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
cased_schemas
+ [
table("users u"),
table("Orders O" if text == "SELECT * FROM " else "Orders O2"),
table('"select" s'),
function("Func1() F"),
function("func2() f"),
]
)
@parametrize("completer", completers(aliasing=False, casing=True, filtr=True))
@parametrize("text", texts)
def test_table_casing(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
cased_schemas
+ [
table("users"),
table("Orders"),
table('"select"'),
function("Func1()"),
function("func2()"),
]
)
@parametrize("completer", completers(aliasing=False, casing=True))
def test_alias_search_without_aliases2(completer):
text = "SELECT * FROM blog.et"
result = get_result(completer, text)
assert result[0] == table("EntryTags", -2)
@parametrize("completer", completers(aliasing=False, casing=True))
def test_alias_search_without_aliases1(completer):
text = "SELECT * FROM blog.e"
result = get_result(completer, text)
assert result[0] == table("Entries", -1)
@parametrize("completer", completers(aliasing=True, casing=True))
def test_alias_search_with_aliases2(completer):
text = "SELECT * FROM blog.et"
result = get_result(completer, text)
assert result[0] == table("EntryTags ET", -2)
@parametrize("completer", completers(aliasing=True, casing=True))
def test_alias_search_with_aliases1(completer):
text = "SELECT * FROM blog.e"
result = get_result(completer, text)
assert result[0] == table("Entries E", -1)
@parametrize("completer", completers(aliasing=True, casing=True))
def test_join_alias_search_with_aliases1(completer):
text = "SELECT * FROM blog.Entries E JOIN blog.e"
result = get_result(completer, text)
assert result[:2] == [
table("Entries E2", -1),
join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1),
]
@parametrize("completer", completers(aliasing=False, casing=True))
def test_join_alias_search_without_aliases1(completer):
text = "SELECT * FROM blog.Entries JOIN blog.e"
result = get_result(completer, text)
assert result[:2] == [
table("Entries", -1),
join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1),
]
@parametrize("completer", completers(aliasing=True, casing=True))
def test_join_alias_search_with_aliases2(completer):
text = "SELECT * FROM blog.Entries E JOIN blog.et"
result = get_result(completer, text)
assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2)
@parametrize("completer", completers(aliasing=False, casing=True))
def test_join_alias_search_without_aliases2(completer):
text = "SELECT * FROM blog.Entries JOIN blog.et"
result = get_result(completer, text)
assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2)
@parametrize("completer", completers())
def test_function_alias_search_without_aliases(completer):
text = "SELECT blog.ees"
result = get_result(completer, text)
first = result[0]
assert first.start_position == -3
assert first.text == "extract_entry_symbols()"
assert first.display_text == "extract_entry_symbols(_entryid)"
@parametrize("completer", completers())
def test_function_alias_search_with_aliases(completer):
text = "SELECT blog.ee"
result = get_result(completer, text)
first = result[0]
assert first.start_position == -2
assert first.text == "enter_entry(_title := , _text := )"
assert first.display_text == "enter_entry(_title, _text)"
@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual))
def test_column_alias_search(completer):
result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et"))
cols = ("EntryText", "EntryTitle", "EntryID")
assert result[:3] == [column(c, -2) for c in cols]
@parametrize("completer", completers(casing=True))
def test_column_alias_search_qualified(completer):
result = get_result(
completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei")
)
cols = ("EntryID", "EntryTitle")
assert result[:3] == [column(c, -2) for c in cols]
@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
def test_schema_object_order(completer):
result = get_result(completer, "SELECT * FROM u")
assert result[:3] == [
table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users")
]
@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
def test_all_schema_objects(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
[table(x) for x in ("orders", '"select"', "custom.shipments")]
+ [function(x + "()") for x in ("func2",)]
)
@parametrize("completer", completers(filtr=False, aliasing=False, casing=True))
def test_all_schema_objects_with_casing(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
[table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")]
+ [function(x + "()") for x in ("func2",)]
)
@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
def test_all_schema_objects_with_aliases(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
[table(x) for x in ("orders o", '"select" s', "custom.shipments s")]
+ [function(x) for x in ("func2() f",)]
)
@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
def test_set_schema(completer):
text = "SET SCHEMA "
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")]
)

File diff suppressed because it is too large Load diff

993
tests/test_sqlcompletion.py Normal file
View file

@ -0,0 +1,993 @@
from pgcli.packages.sqlcompletion import (
suggest_type,
Special,
Database,
Schema,
Table,
Column,
View,
Keyword,
FromClauseItem,
Function,
Datatype,
Alias,
JoinCondition,
Join,
)
from pgcli.packages.parseutils.tables import TableReference
import pytest
def cols_etc(
table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None
):
"""Returns the expected select-clause suggestions for a single-table
select."""
return set(
[
Column(
table_refs=(TableReference(schema, table, alias, is_function),),
qualifiable=True,
),
Function(schema=parent),
Keyword(last_keyword),
]
)
def test_select_suggests_cols_with_visible_table_scope():
suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
assert set(suggestions) == cols_etc("tabl", last_keyword="SELECT")
def test_select_suggests_cols_with_qualified_table_scope():
suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
assert set(suggestions) == cols_etc("tabl", "sch", last_keyword="SELECT")
def test_cte_does_not_crash():
sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;"
for i in range(len(sql)):
suggestions = suggest_type(sql[: i + 1], sql[: i + 1])
@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE '])
def test_where_suggests_columns_functions_quoted_table(expression):
expected = cols_etc("tabl", alias='"tabl"', last_keyword="WHERE")
suggestions = suggest_type(expression, expression)
assert expected == set(suggestions)
@pytest.mark.parametrize(
"expression",
[
"INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ",
"INSERT INTO OtherTabl SELECT * FROM tabl WHERE ",
"SELECT * FROM tabl WHERE ",
"SELECT * FROM tabl WHERE (",
"SELECT * FROM tabl WHERE foo = ",
"SELECT * FROM tabl WHERE bar OR ",
"SELECT * FROM tabl WHERE foo = 1 AND ",
"SELECT * FROM tabl WHERE (bar > 10 AND ",
"SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
"SELECT * FROM tabl WHERE 10 < ",
"SELECT * FROM tabl WHERE foo BETWEEN ",
"SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
],
)
def test_where_suggests_columns_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
@pytest.mark.parametrize(
"expression",
["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "],
)
def test_where_in_suggests_columns(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "])
def test_after_as(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set()
def test_where_equals_any_suggests_columns_or_keywords():
text = "SELECT * FROM tabl WHERE foo = ANY("
suggestions = suggest_type(text, text)
assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
def test_lparen_suggests_cols_and_funcs():
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
assert set(suggestion) == set(
[
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("("),
]
)
def test_select_suggests_cols_and_funcs():
suggestions = suggest_type("SELECT ", "SELECT ")
assert set(suggestions) == set(
[
Column(table_refs=(), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
@pytest.mark.parametrize(
"expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "]
)
def test_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema=None), View(schema=None), Schema()])
@pytest.mark.parametrize("expression", ["SELECT * FROM "])
def test_suggest_tables_views_schemas_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ",
"SELECT * FROM foo JOIN bar USING (barid) JOIN ",
],
)
def test_suggest_after_join_with_two_tables(expression):
suggestions = suggest_type(expression, expression)
tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
assert set(suggestions) == set(
[FromClauseItem(schema=None, table_refs=tables), Join(tables, None), Schema()]
)
@pytest.mark.parametrize(
"expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"]
)
def test_suggest_after_join_with_one_table(expression):
suggestions = suggest_type(expression, expression)
tables = ((None, "foo", None, False),)
assert set(suggestions) == set(
[
FromClauseItem(schema=None, table_refs=tables),
Join(((None, "foo", None, False),), None),
Schema(),
]
)
@pytest.mark.parametrize(
"expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."]
)
def test_suggest_qualified_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
@pytest.mark.parametrize("expression", ["UPDATE sch."])
def test_suggest_qualified_aliasable_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM sch.",
'SELECT * FROM sch."',
'SELECT * FROM sch."foo',
'SELECT * FROM "sch".',
'SELECT * FROM "sch"."',
],
)
def test_suggest_qualified_tables_views_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([FromClauseItem(schema="sch")])
@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
def test_suggest_qualified_tables_views_functions_and_joins(expression):
suggestions = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
assert set(suggestions) == set(
[FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch")]
)
def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
assert set(suggestions) == set([Table(schema=None), Schema()])
def test_truncate_suggests_qualified_tables():
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
assert set(suggestions) == set([Table(schema="sch")])
@pytest.mark.parametrize(
"text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "]
)
def test_distinct_suggests_cols(text):
suggestions = suggest_type(text, text)
assert set(suggestions) == set(
[
Column(table_refs=(), local_tables=(), qualifiable=True),
Function(schema=None),
Keyword("DISTINCT"),
]
)
@pytest.mark.parametrize(
"text, text_before, last_keyword",
[
("SELECT DISTINCT FROM tbl x JOIN tbl1 y", "SELECT DISTINCT", "SELECT"),
(
"SELECT * FROM tbl x JOIN tbl1 y ORDER BY ",
"SELECT * FROM tbl x JOIN tbl1 y ORDER BY ",
"ORDER BY",
),
],
)
def test_distinct_and_order_by_suggestions_with_aliases(
text, text_before, last_keyword
):
suggestions = suggest_type(text, text_before)
assert set(suggestions) == set(
[
Column(
table_refs=(
TableReference(None, "tbl", "x", False),
TableReference(None, "tbl1", "y", False),
),
local_tables=(),
qualifiable=True,
),
Function(schema=None),
Keyword(last_keyword),
]
)
@pytest.mark.parametrize(
"text, text_before",
[
("SELECT DISTINCT x. FROM tbl x JOIN tbl1 y", "SELECT DISTINCT x."),
(
"SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.",
"SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.",
),
],
)
def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
suggestions = suggest_type(text, text_before)
assert set(suggestions) == set(
[
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
]
)
def test_function_arguments_with_alias_given():
suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
assert set(suggestions) == set(
[
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
]
)
def test_col_comma_suggests_cols():
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
assert set(suggestions) == set(
[
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
assert set(suggestion) == set([Table(schema=None), View(schema=None), Schema()])
@pytest.mark.parametrize(
"text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"]
)
def test_insert_into_lparen_suggests_cols(text):
suggestions = suggest_type(text, "INSERT INTO abc (")
assert suggestions == (
Column(table_refs=((None, "abc", None, False),), context="insert"),
)
def test_insert_into_lparen_partial_text_suggests_cols():
suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
assert suggestions == (
Column(table_refs=((None, "abc", None, False),), context="insert"),
)
def test_insert_into_lparen_comma_suggests_cols():
suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
assert suggestions == (
Column(table_refs=((None, "abc", None, False),), context="insert"),
)
def test_partially_typed_col_name_suggests_col_names():
suggestions = suggest_type(
"SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n"
)
assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl", None, False),)),
Table(schema="tabl"),
View(schema="tabl"),
Function(schema="tabl"),
]
)
@pytest.mark.parametrize(
"sql",
[
"SELECT t1. FROM tabl1 t1",
"SELECT t1. FROM tabl1 t1, tabl2 t2",
'SELECT t1. FROM "tabl1" t1',
'SELECT t1. FROM "tabl1" t1, "tabl2" t2',
],
)
def test_dot_suggests_cols_of_an_alias(sql):
suggestions = suggest_type(sql, "SELECT t1.")
assert set(suggestions) == set(
[
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
]
)
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM tabl1 t1 WHERE t1.",
"SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.",
'SELECT * FROM "tabl1" t1 WHERE t1.',
'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.',
],
)
def test_dot_suggests_cols_of_an_alias_where(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set(
[
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
]
)
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type(
"SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl2", "t2", False),)),
Table(schema="t2"),
View(schema="t2"),
Function(schema="t2"),
]
)
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (",
"SELECT * FROM foo WHERE EXISTS (",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (",
],
)
def test_sub_select_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == (Keyword(),)
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM (S",
"SELECT * FROM foo WHERE EXISTS (S",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
],
)
def test_sub_select_partial_text_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
assert suggestion == (Keyword(),)
def test_outer_table_reference_in_exists_subquery_suggests_columns():
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
suggestions = suggest_type(q, q)
assert set(suggestions) == set(
[
Column(table_refs=((None, "foo", "f", False),)),
Table(schema="f"),
View(schema="f"),
Function(schema="f"),
]
)
@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
assert set(suggestion) == set([FromClauseItem(schema=None), Schema()])
@pytest.mark.parametrize(
"expression",
[
"SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
"SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
],
)
def test_sub_select_table_name_completion_with_outer_table(expression):
suggestion = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
assert set(suggestion) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
)
def test_sub_select_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "abc", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
@pytest.mark.xfail
def test_sub_select_multiple_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, "
)
assert set(suggestions) == cols_etc("abc")
def test_sub_select_dot_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl", "t", False),)),
Table(schema="t"),
View(schema="t"),
Function(schema="t"),
]
)
@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
@pytest.mark.parametrize("tbl_alias", ("", "foo"))
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
suggestion = suggest_type(text, text)
tbls = tuple([(None, "abc", tbl_alias or None, False)])
assert set(suggestion) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema(), Join(tbls, None)]
)
def test_left_join_with_comma():
text = "select * from foo f left join bar b,"
suggestions = suggest_type(text, text)
# tbls should also include (None, 'bar', 'b', False)
# but there's a bug with commas
tbls = tuple([(None, "foo", "f", False)])
assert set(suggestions) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
)
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
],
)
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "def", "d", False))
assert set(suggestions) == set(
[
Column(table_refs=((None, "abc", "a", False),)),
Table(schema="a"),
View(schema="a"),
Function(schema="a"),
JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
]
)
@pytest.mark.parametrize(
"sql",
[
"SELECT * FROM abc a JOIN def d ON a.id = d.",
"SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
],
)
def test_join_alias_dot_suggests_cols2(sql):
suggestion = suggest_type(sql, sql)
assert set(suggestion) == set(
[
Column(table_refs=((None, "def", "d", False),)),
Table(schema="d"),
View(schema="d"),
Function(schema="d"),
]
)
@pytest.mark.parametrize(
"sql",
[
"select a.x, b.y from abc a join bcd b on ",
"""select a.x, b.y
from abc a
join bcd b on
""",
"""select a.x, b.y
from abc a
join bcd b
on """,
"select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
],
)
def test_on_suggests_aliases_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("a", "b")))
)
@pytest.mark.parametrize(
"sql",
[
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
"select abc.x, bcd.y from abc join bcd on ",
],
)
def test_on_suggests_tables_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
)
@pytest.mark.parametrize(
"sql",
[
"select a.x, b.y from abc a join bcd b on a.id = ",
"select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
],
)
def test_on_suggests_aliases_right_side(sql):
suggestions = suggest_type(sql, sql)
assert suggestions == (Alias(aliases=("a", "b")),)
@pytest.mark.parametrize(
"sql",
[
"select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
"select abc.x, bcd.y from abc join bcd on ",
],
)
def test_on_suggests_tables_and_join_conditions_right_side(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
)
@pytest.mark.parametrize(
"text",
(
"select * from abc inner join def using (",
"select * from abc inner join def using (col1, ",
"insert into hij select * from abc inner join def using (",
"""insert into hij(x, y, z)
select * from abc inner join def using (col1, """,
"""insert into hij (a,b,c)
select * from abc inner join def using (col1, """,
),
)
def test_join_using_suggests_common_columns(text):
tables = ((None, "abc", None, False), (None, "def", None, False))
assert set(suggest_type(text, text)) == set(
[Column(table_refs=tables, require_last_table=True)]
)
def test_suggest_columns_after_multiple_joins():
sql = """select * from t1
inner join t2 ON
t1.id = t2.t1_id
inner join t3 ON
t2.id = t3."""
suggestions = suggest_type(sql, sql)
assert Column(table_refs=((None, "t3", None, False),)) in set(suggestions)
def test_2_statements_2nd_current():
suggestions = suggest_type(
"select * from a; select * from ", "select * from a; select * from "
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
suggestions = suggest_type(
"select * from a; select from b", "select * from a; select "
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "b", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
# Should work even if first statement is invalid
suggestions = suggest_type(
"select * from; select * from ", "select * from; select * from "
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
def test_2_statements_1st_current():
suggestions = suggest_type("select * from ; select * from b", "select * from ")
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
suggestions = suggest_type("select from a; select * from b", "select ")
assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
def test_3_statements_2nd_current():
suggestions = suggest_type(
"select * from a; select * from ; select * from c",
"select * from a; select * from ",
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
suggestions = suggest_type(
"select * from a; select from b; select * from c", "select * from a; select "
)
assert set(suggestions) == cols_etc("b", last_keyword="SELECT")
@pytest.mark.parametrize(
"text",
[
"""
CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
SELECT FROM foo;
SELECT 2 FROM bar;
$$ language sql;
""",
"""create function func2(int, varchar)
RETURNS text
language sql AS
$func$
SELECT 2 FROM bar;
SELECT FROM foo;
$func$
""",
"""
CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
SELECT 3 FROM foo;
SELECT 2 FROM bar;
$$ language sql;
create function func2(int, varchar)
RETURNS text
language sql AS
$func$
SELECT 2 FROM bar;
SELECT FROM foo;
$func$
""",
"""
SELECT * FROM baz;
CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
SELECT FROM foo;
SELECT 2 FROM bar;
$$ language sql;
create function func2(int, varchar)
RETURNS text
language sql AS
$func$
SELECT 3 FROM bar;
SELECT FROM foo;
$func$
SELECT * FROM qux;
""",
],
)
def test_statements_in_function_body(text):
suggestions = suggest_type(text, text[: text.find(" ") + 1])
assert set(suggestions) == set(
[
Column(table_refs=((None, "foo", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
functions = [
"""
CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
SELECT 1 FROM foo;
SELECT 2 FROM bar;
$$ language sql;
""",
"""
create function func2(int, varchar)
RETURNS text
language sql AS
'
SELECT 2 FROM bar;
SELECT 1 FROM foo;
';
""",
]
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_after_function_body(text):
suggestions = suggest_type(text, text[: text.find("; ") + 1])
assert set(suggestions) == set([Keyword(), Special()])
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_before_function_body(text):
suggestions = suggest_type(text, "")
assert set(suggestions) == set([Keyword(), Special()])
def test_create_db_with_template():
suggestions = suggest_type(
"create database foo with template ", "create database foo with template "
)
assert set(suggestions) == set((Database(),))
@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
assert set(suggestions) == set([Keyword(), Special()])
def test_drop_schema_qualified_table_suggests_only_tables():
text = "DROP TABLE schema_name.table_name"
suggestions = suggest_type(text, text)
assert suggestions == (Table(schema="schema_name"),)
@pytest.mark.parametrize("text", (",", " ,", "sel ,"))
def test_handle_pre_completion_comma_gracefully(text):
suggestions = suggest_type(text, text)
assert iter(suggestions)
def test_drop_schema_suggests_schemas():
sql = "DROP SCHEMA "
assert suggest_type(sql, sql) == (Schema(),)
@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
def test_cast_operator_suggests_types(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
@pytest.mark.parametrize(
"text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
)
def test_cast_operator_suggests_schema_qualified_types(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema="bar"), Table(schema="bar")]
)
def test_alter_column_type_suggests_types():
q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
assert set(suggest_type(q, q)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
@pytest.mark.parametrize(
"text",
[
"CREATE TABLE foo (bar ",
"CREATE TABLE foo (bar DOU",
"CREATE TABLE foo (bar INT, baz ",
"CREATE TABLE foo (bar INT, baz TEXT, qux ",
"CREATE FUNCTION foo (bar ",
"CREATE FUNCTION foo (bar INT, baz ",
"SELECT * FROM foo() AS bar (baz ",
"SELECT * FROM foo() AS bar (baz INT, qux ",
# make sure this doesnt trigger special completion
"CREATE TABLE foo (dt d",
],
)
def test_identifier_suggests_types_in_parentheses(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
@pytest.mark.parametrize(
"text",
[
"SELECT foo ",
"SELECT foo FROM bar ",
"SELECT foo AS bar ",
"SELECT foo bar ",
"SELECT * FROM foo AS bar ",
"SELECT * FROM foo bar ",
"SELECT foo FROM (SELECT bar ",
],
)
def test_alias_suggests_keywords(text):
suggestions = suggest_type(text, text)
assert suggestions == (Keyword(),)
def test_invalid_sql():
# issue 317
text = "selt *"
suggestions = suggest_type(text, text)
assert suggestions == (Keyword(),)
@pytest.mark.parametrize(
"text",
["SELECT * FROM foo where created > now() - ", "select * from foo where bar "],
)
def test_suggest_where_keyword(text):
# https://github.com/dbcli/mycli/issues/135
suggestions = suggest_type(text, text)
assert set(suggestions) == cols_etc("foo", last_keyword="WHERE")
@pytest.mark.parametrize(
"text, before, expected",
[
(
"\\ns abc SELECT ",
"SELECT ",
[
Column(table_refs=(), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
],
),
("\\ns abc SELECT foo ", "SELECT foo ", (Keyword(),)),
(
"\\ns abc SELECT t1. FROM tabl1 t1",
"SELECT t1.",
[
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
],
),
],
)
def test_named_query_completion(text, before, expected):
suggestions = suggest_type(text, before)
assert set(expected) == set(suggestions)
def test_select_suggests_fields_from_function():
suggestions = suggest_type("SELECT FROM func()", "SELECT ")
assert set(suggestions) == cols_etc("func", is_function=True, last_keyword="SELECT")
@pytest.mark.parametrize("sql", ["("])
def test_leading_parenthesis(sql):
# No assertion for now; just make sure it doesn't crash
suggest_type(sql, sql)
@pytest.mark.parametrize("sql", ['select * from "', 'select * from "foo'])
def test_ignore_leading_double_quotes(sql):
suggestions = suggest_type(sql, sql)
assert FromClauseItem(schema=None) in set(suggestions)
@pytest.mark.parametrize(
"sql",
[
"ALTER TABLE foo ALTER COLUMN ",
"ALTER TABLE foo ALTER COLUMN bar",
"ALTER TABLE foo DROP COLUMN ",
"ALTER TABLE foo DROP COLUMN bar",
],
)
def test_column_keyword_suggests_columns(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set([Column(table_refs=((None, "foo", None, False),))])
def test_handle_unrecognized_kw_generously():
sql = "SELECT * FROM sessions WHERE session = 1 AND "
suggestions = suggest_type(sql, sql)
expected = Column(table_refs=((None, "sessions", None, False),), qualifiable=True)
assert expected in set(suggestions)
@pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "])
def test_keyword_after_alter(sql):
assert Keyword("ALTER") in set(suggest_type(sql, sql))

95
tests/utils.py Normal file
View file

@ -0,0 +1,95 @@
import pytest
import psycopg2
import psycopg2.extras
from pgcli.main import format_output, OutputSettings
from pgcli.pgexecute import register_json_typecasters
from os import getenv
POSTGRES_USER = getenv("PGUSER", "postgres")
POSTGRES_HOST = getenv("PGHOST", "localhost")
POSTGRES_PORT = getenv("PGPORT", 5432)
POSTGRES_PASSWORD = getenv("PGPASSWORD", "")
def db_connection(dbname=None):
conn = psycopg2.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
database=dbname,
)
conn.autocommit = True
return conn
try:
conn = db_connection()
CAN_CONNECT_TO_DB = True
SERVER_VERSION = conn.server_version
json_types = register_json_typecasters(conn, lambda x: x)
JSON_AVAILABLE = "json" in json_types
JSONB_AVAILABLE = "jsonb" in json_types
except:
CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
SERVER_VERSION = 0
dbtest = pytest.mark.skipif(
not CAN_CONNECT_TO_DB,
reason="Need a postgres instance at localhost accessible by user 'postgres'",
)
requires_json = pytest.mark.skipif(
not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
)
requires_jsonb = pytest.mark.skipif(
not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
)
def create_db(dbname):
with db_connection().cursor() as cur:
try:
cur.execute("""CREATE DATABASE _test_db""")
except:
pass
def drop_tables(conn):
with conn.cursor() as cur:
cur.execute(
"""
DROP SCHEMA public CASCADE;
CREATE SCHEMA public;
DROP SCHEMA IF EXISTS schema1 CASCADE;
DROP SCHEMA IF EXISTS schema2 CASCADE"""
)
def run(
executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
):
" Return string output for the sql to be run "
results = executor.run(sql, pgspecial, exception_formatter)
formatted = []
settings = OutputSettings(
table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
)
for title, rows, headers, status, sql, success, is_special in results:
formatted.extend(format_output(title, rows, headers, status, settings))
if join:
formatted = "\n".join(formatted)
return formatted
def completions_to_set(completions):
return set(
(completion.display_text, completion.display_meta_text)
for completion in completions
)