1
0
Fork 0

Merging upstream version 4.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 20:05:21 +01:00
parent af10454b21
commit 7c65fc707e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
42 changed files with 955 additions and 184 deletions

View file

@ -1,6 +1,9 @@
name: pgcli name: pgcli
on: on:
push:
branches:
- main
pull_request: pull_request:
paths-ignore: paths-ignore:
- '**.rst' - '**.rst'
@ -11,7 +14,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
services: services:
postgres: postgres:
@ -28,10 +31,10 @@ jobs:
--health-retries 5 --health-retries 5
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
@ -64,6 +67,10 @@ jobs:
psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help' psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help'
- name: Install beta version of pendulum
run: pip install pendulum==3.0.0b1
if: matrix.python-version == '3.12'
- name: Install requirements - name: Install requirements
run: | run: |
pip install -U pip setuptools pip install -U pip setuptools
@ -72,7 +79,7 @@ jobs:
pip install keyrings.alt>=3.1 pip install keyrings.alt>=3.1
- name: Run unit tests - name: Run unit tests
run: coverage run --source pgcli -m py.test run: coverage run --source pgcli -m pytest
- name: Run integration tests - name: Run integration tests
env: env:
@ -86,7 +93,7 @@ jobs:
- name: Run Black - name: Run Black
run: black --check . run: black --check .
if: matrix.python-version == '3.7' if: matrix.python-version == '3.8'
- name: Coverage - name: Coverage
run: | run: |

41
.github/workflows/codeql.yml vendored Normal file
View file

@ -0,0 +1,41 @@
name: "CodeQL"
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
schedule:
- cron: "29 13 * * 1"
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ python ]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
queries: +security-and-quality
- name: Autobuild
uses: github/codeql-action/autobuild@v2
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{ matrix.language }}"

View file

@ -1,5 +1,5 @@
repos: repos:
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 22.3.0 rev: 23.3.0
hooks: hooks:
- id: black - id: black

View file

@ -123,6 +123,13 @@ Contributors:
* Daniel Kukula (dkuku) * Daniel Kukula (dkuku)
* Kian-Meng Ang (kianmeng) * Kian-Meng Ang (kianmeng)
* Liu Zhao (astroshot) * Liu Zhao (astroshot)
* Rigo Neri (rigoneri)
* Anna Glasgall (annathyst)
* Andy Schoenberger (andyscho)
* Damien Baty (dbaty)
* blag
* Rob Berry (rob-b)
* Sharon Yogev (sharonyogev)
Creator: Creator:
-------- --------

View file

@ -165,8 +165,9 @@ in the ``tests`` directory. An example::
First, install the requirements for testing: First, install the requirements for testing:
:: ::
$ pip install -U pip setuptools
$ pip install -r requirements-dev.txt $ pip install --no-cache-dir ".[sshtunnel]"
$ pip install -r requirements-dev.txt
Ensure that the database user has permissions to create and drop test databases Ensure that the database user has permissions to create and drop test databases
by checking your ``pg_hba.conf`` file. The default user should be ``postgres`` by checking your ``pg_hba.conf`` file. The default user should be ``postgres``

View file

@ -157,8 +157,9 @@ get this running in a development setup.
https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst
Please feel free to reach out to me if you need help. Please feel free to reach out to us if you need help.
My email: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_ * Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_
* Irina, pgcli maintainer: i.chernyavska@gmail.com, Twitter: `@irinatruong <http://twitter.com/irinatruong>`_
Detailed Installation Instructions: Detailed Installation Instructions:
----------------------------------- -----------------------------------
@ -351,8 +352,7 @@ choice:
In [3]: my_result = _ In [3]: my_result = _
Pgcli only runs on Python3.7+ since 4.0.0, if you use an old version of Python, Pgcli dropped support for Python<3.8 as of 4.0.0. If you need it, install ``pgcli <= 4.0.0``.
you should use install ``pgcli <= 4.0.0``.
Thanks: Thanks:
------- -------
@ -372,8 +372,8 @@ interface to Postgres database.
Thanks to all the beta testers and contributors for your time and patience. :) Thanks to all the beta testers and contributors for your time and patience. :)
.. |Build Status| image:: https://github.com/dbcli/pgcli/workflows/pgcli/badge.svg .. |Build Status| image:: https://github.com/dbcli/pgcli/actions/workflows/ci.yml/badge.svg?branch=main
:target: https://github.com/dbcli/pgcli/actions?query=workflow%3Apgcli :target: https://github.com/dbcli/pgcli/actions/workflows/ci.yml
.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg .. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
:target: https://codecov.io/gh/dbcli/pgcli :target: https://codecov.io/gh/dbcli/pgcli

View file

@ -1,3 +1,52 @@
==================
4.0.1 (2023-11-30)
==================
Internal:
---------
* Allow stable version of pendulum.
==================
4.0.0 (2023-11-27)
==================
Features:
---------
* Ask for confirmation when quitting cli while a transaction is ongoing.
* New `destructive_statements_require_transaction` config option to refuse to execute a
destructive SQL statement if outside a transaction. This option is off by default.
* Changed the `destructive_warning` config to be a list of commands that are considered
destructive. This would allow you to be warned on `create`, `grant`, or `insert` queries.
* Destructive warnings will now include the alias dsn connection string name if provided (-D option).
* pgcli.magic will now work with connection URLs that use TLS client certificates for authentication
* Have config option to retry queries on operational errors like connections being lost.
Also prevents getting stuck in a retry loop.
* Config option to not restart connection when cancelling a `destructive_warning` query. By default,
it will now not restart.
* Config option to always run with a single connection.
* Add comment explaining default LESS environment variable behavior and change example pager setting.
* Added `\echo` & `\qecho` special commands. ([issue 1335](https://github.com/dbcli/pgcli/issues/1335)).
Bug fixes:
----------
* Fix `\ev` not producing a correctly quoted "schema"."view"
* Fix 'invalid connection option "dsn"' ([issue 1373](https://github.com/dbcli/pgcli/issues/1373)).
* Fix explain mode when used with `expand`, `auto_expand`, or `--explain-vertical-output` ([issue 1393](https://github.com/dbcli/pgcli/issues/1393)).
* Fix sql-insert format emits NULL as 'None' ([issue 1408](https://github.com/dbcli/pgcli/issues/1408)).
* Improve check for prompt-toolkit 3.0.6 ([issue 1416](https://github.com/dbcli/pgcli/issues/1416)).
* Allow specifying an `alias_map_file` in the config that will use
predetermined table aliases instead of generating aliases programmatically on
the fly
* Fixed SQL error when there is a comment on the first line: ([issue 1403](https://github.com/dbcli/pgcli/issues/1403))
* Fix wrong usage of prompt instead of confirm when confirm execution of destructive query
Internal:
---------
* Drop support for Python 3.7 and add 3.12.
3.5.0 (2022/09/15): 3.5.0 (2022/09/15):
=================== ===================

View file

@ -1 +1 @@
__version__ = "3.5.0" __version__ = "4.0.1"

View file

@ -26,7 +26,9 @@ def keyring_initialize(keyring_enabled, *, logger):
try: try:
keyring = importlib.import_module("keyring") keyring = importlib.import_module("keyring")
except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 except (
ModuleNotFoundError
) as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
logger.warning("import keyring failed: %r.", e) logger.warning("import keyring failed: %r.", e)

View file

@ -6,7 +6,6 @@ from .pgcompleter import PGCompleter
class CompletionRefresher: class CompletionRefresher:
refreshers = OrderedDict() refreshers = OrderedDict()
def __init__(self): def __init__(self):
@ -39,7 +38,7 @@ class CompletionRefresher:
args=(executor, special, callbacks, history, settings), args=(executor, special, callbacks, history, settings),
name="completion_refresh", name="completion_refresh",
) )
self._completer_thread.setDaemon(True) self._completer_thread.daemon = True
self._completer_thread.start() self._completer_thread.start()
return [ return [
(None, None, None, "Auto-completion refresh started in the background.") (None, None, None, "Auto-completion refresh started in the background.")

View file

@ -10,7 +10,8 @@ class ExplainOutputFormatter:
self.max_width = max_width self.max_width = max_width
def format_output(self, cur, headers, **output_kwargs): def format_output(self, cur, headers, **output_kwargs):
(data,) = cur.fetchone() # explain query results should always contain 1 row each
[(data,)] = list(cur)
explain_list = json.loads(data) explain_list = json.loads(data)
visualizer = Visualizer(self.max_width) visualizer = Visualizer(self.max_width)
for explain in explain_list: for explain in explain_list:

View file

@ -43,7 +43,7 @@ def pgcli_line_magic(line):
u = conn.session.engine.url u = conn.session.engine.url
_logger.debug("New pgcli: %r", str(u)) _logger.debug("New pgcli: %r", str(u))
pgcli.connect(u.database, u.host, u.username, u.port, u.password) pgcli.connect_uri(str(u._replace(drivername="postgres")))
conn._pgcli = pgcli conn._pgcli = pgcli
# For convenience, print the connection alias # For convenience, print the connection alias

View file

@ -63,15 +63,14 @@ from .config import (
) )
from .key_bindings import pgcli_bindings from .key_bindings import pgcli_bindings
from .packages.formatter.sqlformatter import register_new_formatter from .packages.formatter.sqlformatter import register_new_formatter
from .packages.prompt_utils import confirm_destructive_query from .packages.prompt_utils import confirm, confirm_destructive_query
from .packages.parseutils import is_destructive
from .packages.parseutils import parse_destructive_warning
from .__init__ import __version__ from .__init__ import __version__
click.disable_unicode_literals_warning = True click.disable_unicode_literals_warning = True
try: from urllib.parse import urlparse
from urlparse import urlparse, unquote, parse_qs
except ImportError:
from urllib.parse import urlparse, unquote, parse_qs
from getpass import getuser from getpass import getuser
@ -201,6 +200,9 @@ class PGCli:
self.multiline_mode = c["main"].get("multi_line_mode", "psql") self.multiline_mode = c["main"].get("multi_line_mode", "psql")
self.vi_mode = c["main"].as_bool("vi") self.vi_mode = c["main"].as_bool("vi")
self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
self.auto_retry_closed_connection = c["main"].as_bool(
"auto_retry_closed_connection"
)
self.expanded_output = c["main"].as_bool("expand") self.expanded_output = c["main"].as_bool("expand")
self.pgspecial.timing_enabled = c["main"].as_bool("timing") self.pgspecial.timing_enabled = c["main"].as_bool("timing")
if row_limit is not None: if row_limit is not None:
@ -224,11 +226,16 @@ class PGCli:
self.syntax_style = c["main"]["syntax_style"] self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"] self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
self.destructive_warning = warn or c["main"]["destructive_warning"] self.destructive_warning = parse_destructive_warning(
# also handle boolean format of destructive warning warn or c["main"].as_list("destructive_warning")
self.destructive_warning = {"true": "all", "false": "off"}.get(
self.destructive_warning.lower(), self.destructive_warning
) )
self.destructive_warning_restarts_connection = c["main"].as_bool(
"destructive_warning_restarts_connection"
)
self.destructive_statements_require_transaction = c["main"].as_bool(
"destructive_statements_require_transaction"
)
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.null_string = c["main"].get("null_string", "<null>") self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = ( self.prompt_format = (
@ -258,6 +265,9 @@ class PGCli:
# Initialize completer # Initialize completer
smart_completion = c["main"].as_bool("smart_completion") smart_completion = c["main"].as_bool("smart_completion")
keyword_casing = c["main"]["keyword_casing"] keyword_casing = c["main"]["keyword_casing"]
single_connection = single_connection or c["main"].as_bool(
"always_use_single_connection"
)
self.settings = { self.settings = {
"casing_file": get_casing_file(c), "casing_file": get_casing_file(c),
"generate_casing_file": c["main"].as_bool("generate_casing_file"), "generate_casing_file": c["main"].as_bool("generate_casing_file"),
@ -269,6 +279,7 @@ class PGCli:
"single_connection": single_connection, "single_connection": single_connection,
"less_chatty": less_chatty, "less_chatty": less_chatty,
"keyword_casing": keyword_casing, "keyword_casing": keyword_casing,
"alias_map_file": c["main"]["alias_map_file"] or None,
} }
completer = PGCompleter( completer = PGCompleter(
@ -292,7 +303,6 @@ class PGCli:
raise PgCliQuitError raise PgCliQuitError
def register_special_commands(self): def register_special_commands(self):
self.pgspecial.register( self.pgspecial.register(
self.change_db, self.change_db,
"\\c", "\\c",
@ -354,6 +364,23 @@ class PGCli:
"Change the table format used to output results", "Change the table format used to output results",
) )
self.pgspecial.register(
self.echo,
"\\echo",
"\\echo [string]",
"Echo a string to stdout",
)
self.pgspecial.register(
self.echo,
"\\qecho",
"\\qecho [string]",
"Echo a string to the query output channel.",
)
def echo(self, pattern, **_):
return [(None, None, None, pattern)]
def change_table_format(self, pattern, **_): def change_table_format(self, pattern, **_):
try: try:
if pattern not in TabularOutputFormatter().supported_formats: if pattern not in TabularOutputFormatter().supported_formats:
@ -423,12 +450,20 @@ class PGCli:
except OSError as e: except OSError as e:
return [(None, None, None, str(e), "", False, True)] return [(None, None, None, str(e), "", False, True)]
if ( if self.destructive_warning:
self.destructive_warning != "off" if (
and confirm_destructive_query(query, self.destructive_warning) is False self.destructive_statements_require_transaction
): and not self.pgexecute.valid_transaction()
message = "Wise choice. Command execution stopped." and is_destructive(query, self.destructive_warning)
return [(None, None, None, message)] ):
message = "Destructive statements must be run within a transaction. Command execution stopped."
return [(None, None, None, message)]
destroy = confirm_destructive_query(
query, self.destructive_warning, self.dsn_alias
)
if destroy is False:
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
on_error_resume = self.on_error == "RESUME" on_error_resume = self.on_error == "RESUME"
return self.pgexecute.run( return self.pgexecute.run(
@ -456,7 +491,6 @@ class PGCli:
return [(None, None, None, message, "", True, True)] return [(None, None, None, message, "", True, True)]
def initialize_logging(self): def initialize_logging(self):
log_file = self.config["main"]["log_file"] log_file = self.config["main"]["log_file"]
if log_file == "default": if log_file == "default":
log_file = config_location() + "log" log_file = config_location() + "log"
@ -687,34 +721,52 @@ class PGCli:
editor_command = special.editor_command(text) editor_command = special.editor_command(text)
return text return text
def execute_command(self, text): def execute_command(self, text, handle_closed_connection=True):
logger = self.logger logger = self.logger
query = MetaQuery(query=text, successful=False) query = MetaQuery(query=text, successful=False)
try: try:
if self.destructive_warning != "off": if self.destructive_warning:
destroy = confirm = confirm_destructive_query( if (
text, self.destructive_warning self.destructive_statements_require_transaction
and not self.pgexecute.valid_transaction()
and is_destructive(text, self.destructive_warning)
):
click.secho(
"Destructive statements must be run within a transaction."
)
raise KeyboardInterrupt
destroy = confirm_destructive_query(
text, self.destructive_warning, self.dsn_alias
) )
if destroy is False: if destroy is False:
click.secho("Wise choice!") click.secho("Wise choice!")
raise KeyboardInterrupt raise KeyboardInterrupt
elif destroy: elif destroy:
click.secho("Your call!") click.secho("Your call!")
output, query = self._evaluate_command(text) output, query = self._evaluate_command(text)
except KeyboardInterrupt: except KeyboardInterrupt:
# Restart connection to the database if self.destructive_warning_restarts_connection:
self.pgexecute.connect() # Restart connection to the database
logger.debug("cancelled query, sql: %r", text) self.pgexecute.connect()
click.secho("cancelled query", err=True, fg="red") logger.debug("cancelled query and restarted connection, sql: %r", text)
click.secho(
"cancelled query and restarted connection", err=True, fg="red"
)
else:
logger.debug("cancelled query, sql: %r", text)
click.secho("cancelled query", err=True, fg="red")
except NotImplementedError: except NotImplementedError:
click.secho("Not Yet Implemented.", fg="yellow") click.secho("Not Yet Implemented.", fg="yellow")
except OperationalError as e: except OperationalError as e:
logger.error("sql: %r, error: %r", text, e) logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc()) logger.error("traceback: %r", traceback.format_exc())
self._handle_server_closed_connection(text) click.secho(str(e), err=True, fg="red")
except (PgCliQuitError, EOFError) as e: if handle_closed_connection:
self._handle_server_closed_connection(text)
except (PgCliQuitError, EOFError):
raise raise
except Exception as e: except Exception as e:
logger.error("sql: %r, error: %r", text, e) logger.error("sql: %r, error: %r", text, e)
@ -722,7 +774,9 @@ class PGCli:
click.secho(str(e), err=True, fg="red") click.secho(str(e), err=True, fg="red")
else: else:
try: try:
if self.output_file and not text.startswith(("\\o ", "\\? ")): if self.output_file and not text.startswith(
("\\o ", "\\? ", "\\echo ")
):
try: try:
with open(self.output_file, "a", encoding="utf-8") as f: with open(self.output_file, "a", encoding="utf-8") as f:
click.echo(text, file=f) click.echo(text, file=f)
@ -766,6 +820,34 @@ class PGCli:
logger.debug("Search path: %r", self.completer.search_path) logger.debug("Search path: %r", self.completer.search_path)
return query return query
def _check_ongoing_transaction_and_allow_quitting(self):
"""Return whether we can really quit, possibly by asking the
user to confirm so if there is an ongoing transaction.
"""
if not self.pgexecute.valid_transaction():
return True
while 1:
try:
choice = click.prompt(
"A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.",
default="a",
)
except click.Abort:
# Print newline if user aborts with `^C`, otherwise
# pgcli's prompt will be printed on the same line
# (just after the confirmation prompt).
click.echo(None, err=False)
choice = "a"
choice = choice.lower()
if choice == "a":
return False # do not quit
if choice == "c":
query = self.execute_command("commit")
return query.successful # quit only if query is successful
if choice == "r":
query = self.execute_command("rollback")
return query.successful # quit only if query is successful
def run_cli(self): def run_cli(self):
logger = self.logger logger = self.logger
@ -788,6 +870,10 @@ class PGCli:
text = self.prompt_app.prompt() text = self.prompt_app.prompt()
except KeyboardInterrupt: except KeyboardInterrupt:
continue continue
except EOFError:
if not self._check_ongoing_transaction_and_allow_quitting():
continue
raise
try: try:
text = self.handle_editor_command(text) text = self.handle_editor_command(text)
@ -797,7 +883,12 @@ class PGCli:
click.secho(str(e), err=True, fg="red") click.secho(str(e), err=True, fg="red")
continue continue
self.handle_watch_command(text) try:
self.handle_watch_command(text)
except PgCliQuitError:
if not self._check_ongoing_transaction_and_allow_quitting():
continue
raise
self.now = dt.datetime.today() self.now = dt.datetime.today()
@ -1036,10 +1127,17 @@ class PGCli:
click.secho("Reconnecting...", fg="green") click.secho("Reconnecting...", fg="green")
self.pgexecute.connect() self.pgexecute.connect()
click.secho("Reconnected!", fg="green") click.secho("Reconnected!", fg="green")
self.execute_command(text)
except OperationalError as e: except OperationalError as e:
click.secho("Reconnect Failed", fg="red") click.secho("Reconnect Failed", fg="red")
click.secho(str(e), err=True, fg="red") click.secho(str(e), err=True, fg="red")
else:
retry = self.auto_retry_closed_connection or confirm(
"Run the query from before reconnecting?"
)
if retry:
click.secho("Running query...", fg="green")
# Don't get stuck in a retry loop
self.execute_command(text, handle_closed_connection=False)
def refresh_completions(self, history=None, persist_priorities="all"): def refresh_completions(self, history=None, persist_priorities="all"):
"""Refresh outdated completions """Refresh outdated completions
@ -1266,7 +1364,6 @@ class PGCli:
@click.option( @click.option(
"--warn", "--warn",
default=None, default=None,
type=click.Choice(["all", "moderate", "off"]),
help="Warn before running a destructive query.", help="Warn before running a destructive query.",
) )
@click.option( @click.option(
@ -1575,7 +1672,8 @@ def format_output(title, cur, headers, status, settings, explain_mode=False):
first_line = next(formatted) first_line = next(formatted)
formatted = itertools.chain([first_line], formatted) formatted = itertools.chain([first_line], formatted)
if ( if (
not expanded not explain_mode
and not expanded
and max_width and max_width
and len(strip_ansi(first_line)) > max_width and len(strip_ansi(first_line)) > max_width
and headers and headers

View file

@ -14,10 +14,13 @@ preprocessors = ()
def escape_for_sql_statement(value): def escape_for_sql_statement(value):
if value is None:
return "NULL"
if isinstance(value, bytes): if isinstance(value, bytes):
return f"X'{value.hex()}'" return f"X'{value.hex()}'"
else:
return "'{}'".format(value) return "'{}'".format(value)
def adapter(data, headers, table_format=None, **kwargs): def adapter(data, headers, table_format=None, **kwargs):
@ -29,7 +32,7 @@ def adapter(data, headers, table_format=None, **kwargs):
else: else:
table_name = table[1] table_name = table[1]
else: else:
table_name = '"DUAL"' table_name = "DUAL"
if table_format == "sql-insert": if table_format == "sql-insert":
h = '", "'.join(headers) h = '", "'.join(headers)
yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h) yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h)

View file

@ -1,6 +1,17 @@
import sqlparse import sqlparse
BASE_KEYWORDS = [
"drop",
"shutdown",
"delete",
"truncate",
"alter",
"unconditional_update",
]
ALL_KEYWORDS = BASE_KEYWORDS + ["update"]
def query_starts_with(formatted_sql, prefixes): def query_starts_with(formatted_sql, prefixes):
"""Check if the query starts with any item from *prefixes*.""" """Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes] prefixes = [prefix.lower() for prefix in prefixes]
@ -13,22 +24,35 @@ def query_is_unconditional_update(formatted_sql):
return bool(tokens) and tokens[0] == "update" and "where" not in tokens return bool(tokens) and tokens[0] == "update" and "where" not in tokens
def query_is_simple_update(formatted_sql): def is_destructive(queries, keywords):
"""Check if the query starts with UPDATE."""
tokens = formatted_sql.split()
return bool(tokens) and tokens[0] == "update"
def is_destructive(queries, warning_level="all"):
"""Returns if any of the queries in *queries* is destructive.""" """Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
for query in sqlparse.split(queries): for query in sqlparse.split(queries):
if query: if query:
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
if "unconditional_update" in keywords and query_is_unconditional_update(
formatted_sql
):
return True
if query_starts_with(formatted_sql, keywords): if query_starts_with(formatted_sql, keywords):
return True return True
if query_is_unconditional_update(formatted_sql):
return True
if warning_level == "all" and query_is_simple_update(formatted_sql):
return True
return False return False
def parse_destructive_warning(warning_level):
"""Converts a deprecated destructive warning option to a list of command keywords."""
if not warning_level:
return []
if not isinstance(warning_level, list):
if "," in warning_level:
return warning_level.split(",")
warning_level = [warning_level]
return {
"true": ALL_KEYWORDS,
"false": [],
"all": ALL_KEYWORDS,
"moderate": BASE_KEYWORDS,
"off": [],
"": [],
}.get(warning_level[0], warning_level)

View file

@ -3,7 +3,7 @@ import click
from .parseutils import is_destructive from .parseutils import is_destructive
def confirm_destructive_query(queries, warning_level): def confirm_destructive_query(queries, keywords, alias):
"""Check if the query is destructive and prompts the user to confirm. """Check if the query is destructive and prompts the user to confirm.
Returns: Returns:
@ -12,11 +12,13 @@ def confirm_destructive_query(queries, warning_level):
* False if the query is destructive and the user doesn't want to proceed. * False if the query is destructive and the user doesn't want to proceed.
""" """
prompt_text = ( info = "You're about to run a destructive command"
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" if alias:
) info += f" in {click.style(alias, fg='red')}"
if is_destructive(queries, warning_level) and sys.stdin.isatty():
return prompt(prompt_text, type=bool) prompt_text = f"{info}.\nDo you want to proceed?"
if is_destructive(queries, keywords) and sys.stdin.isatty():
return confirm(prompt_text)
def confirm(*args, **kwargs): def confirm(*args, **kwargs):

View file

@ -290,7 +290,6 @@ def suggest_special(text):
def suggest_based_on_last_token(token, stmt): def suggest_based_on_last_token(token, stmt):
if isinstance(token, str): if isinstance(token, str):
token_v = token.lower() token_v = token.lower()
elif isinstance(token, Comparison): elif isinstance(token, Comparison):
@ -399,7 +398,6 @@ def suggest_based_on_last_token(token, stmt):
elif (token_v.endswith("join") and token.is_keyword) or ( elif (token_v.endswith("join") and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate") token_v in ("copy", "from", "update", "into", "describe", "truncate")
): ):
schema = stmt.get_identifier_schema() schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor) tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith("join") and token.is_keyword is_join = token_v.endswith("join") and token.is_keyword
@ -436,7 +434,6 @@ def suggest_based_on_last_token(token, stmt):
try: try:
prev = stmt.get_previous_token(token).value.lower() prev = stmt.get_previous_token(token).value.lower()
if prev in ("drop", "alter", "create", "create or replace"): if prev in ("drop", "alter", "create", "create or replace"):
# Suggest functions from either the currently-selected schema or the # Suggest functions from either the currently-selected schema or the
# public schema if no schema has been specified # public schema if no schema has been specified
suggest = [] suggest = []

View file

@ -9,6 +9,10 @@ smart_completion = True
# visible.) # visible.)
wider_completion_menu = False wider_completion_menu = False
# Do not create new connections for refreshing completions; Equivalent to
# always running with the --single-connection flag.
always_use_single_connection = False
# Multi-line mode allows breaking up the sql statements into multiple lines. If # Multi-line mode allows breaking up the sql statements into multiple lines. If
# this is set to True, then the end of the statements must have a semi-colon. # this is set to True, then the end of the statements must have a semi-colon.
# If this is set to False then sql statements can't be split into multiple # If this is set to False then sql statements can't be split into multiple
@ -22,14 +26,22 @@ multi_line = False
# a command. # a command.
multi_line_mode = psql multi_line_mode = psql
# Destructive warning mode will alert you before executing a sql statement # Destructive warning will alert you before executing a sql statement
# that may cause harm to the database such as "drop table", "drop database", # that may cause harm to the database such as "drop table", "drop database",
# "shutdown", "delete", or "update". # "shutdown", "delete", or "update".
# Possible values: # You can pass a list of destructive commands or leave it empty if you want to skip all warnings.
# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE # "unconditional_update" will warn you of update statements that don't have a where clause
# "moderate" - skip warning on UPDATE statements, except for unconditional updates destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update
# "off" - skip all warnings
destructive_warning = all # Destructive warning can restart the connection if this is enabled and the
# user declines. This means that any current uncommitted transaction can be
# aborted if the user doesn't want to proceed with a destructive_warning
# statement.
destructive_warning_restarts_connection = False
# When this option is on (and if `destructive_warning` is not empty),
# destructive statements are not executed when outside of a transaction.
destructive_statements_require_transaction = False
# Enables expand mode, which is similar to `\x` in psql. # Enables expand mode, which is similar to `\x` in psql.
expand = False expand = False
@ -37,9 +49,21 @@ expand = False
# Enables auto expand mode, which is similar to `\x auto` in psql. # Enables auto expand mode, which is similar to `\x auto` in psql.
auto_expand = False auto_expand = False
# Auto-retry queries on connection failures and other operational errors. If
# False, will prompt to rerun the failed query instead of auto-retrying.
auto_retry_closed_connection = True
# If set to True, table suggestions will include a table alias # If set to True, table suggestions will include a table alias
generate_aliases = False generate_aliases = False
# Path to a json file that specifies specific table aliases to use when generate_aliases is set to True
# the format for this file should be:
# {
# "some_table_name": "desired_alias",
# "some_other_table_name": "another_alias"
# }
alias_map_file =
# log_file location. # log_file location.
# In Unix/Linux: ~/.config/pgcli/log # In Unix/Linux: ~/.config/pgcli/log
# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log # In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log
@ -83,9 +107,10 @@ qualify_columns = if_more_than_one_table
# When no schema is entered, only suggest objects in search_path # When no schema is entered, only suggest objects in search_path
search_path_filter = False search_path_filter = False
# Default pager. # Default pager. See https://www.pgcli.com/pager for more information on settings.
# By default 'PAGER' environment variable is used # By default 'PAGER' environment variable is used. If the pager is less, and the 'LESS'
# pager = less -SRXF # environment variable is not set, then LESS='-SRXF' will be automatically set.
# pager = less
# Timing of sql statements and table rendering. # Timing of sql statements and table rendering.
timing = True timing = True
@ -140,7 +165,7 @@ less_chatty = False
# \i - Postgres PID # \i - Postgres PID
# \# - "@" sign if logged in as superuser, '>' in other case # \# - "@" sign if logged in as superuser, '>' in other case
# \n - Newline # \n - Newline
# \dsn_alias - name of dsn alias if -D option is used (empty otherwise) # \dsn_alias - name of dsn connection string alias if -D option is used (empty otherwise)
# \x1b[...m - insert ANSI escape sequence # \x1b[...m - insert ANSI escape sequence
# eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>' # eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>'
prompt = '\u@\h:\d> ' prompt = '\u@\h:\d> '
@ -198,7 +223,8 @@ output.null = "#808080"
# Named queries are queries you can execute by name. # Named queries are queries you can execute by name.
[named queries] [named queries]
# DSN to call by -D option # Here's where you can provide a list of connection string aliases.
# You can use it by passing the -D option. `pgcli -D example_dsn`
[alias_dsn] [alias_dsn]
# example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname] # example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname]

View file

@ -1,3 +1,4 @@
import json
import logging import logging
import re import re
from itertools import count, repeat, chain from itertools import count, repeat, chain
@ -61,18 +62,38 @@ arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
def generate_alias(tbl): def generate_alias(tbl, alias_map=None):
"""Generate a table alias, consisting of all upper-case letters in """Generate a table alias, consisting of all upper-case letters in
the table name, or, if there are no upper-case letters, the first letter + the table name, or, if there are no upper-case letters, the first letter +
all letters preceded by _ all letters preceded by _
param tbl - unescaped name of the table to alias param tbl - unescaped name of the table to alias
""" """
if alias_map and tbl in alias_map:
return alias_map[tbl]
return "".join( return "".join(
[l for l in tbl if l.isupper()] [l for l in tbl if l.isupper()]
or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]
) )
class InvalidMapFile(ValueError):
pass
def load_alias_map_file(path):
try:
with open(path) as fo:
alias_map = json.load(fo)
except FileNotFoundError as err:
raise InvalidMapFile(
f"Cannot read alias_map_file - {err.filename} does not exist"
)
except json.JSONDecodeError:
raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json")
else:
return alias_map
class PGCompleter(Completer): class PGCompleter(Completer):
# keywords_tree: A dict mapping keywords to well known following keywords. # keywords_tree: A dict mapping keywords to well known following keywords.
# e.g. 'CREATE': ['TABLE', 'USER', ...], # e.g. 'CREATE': ['TABLE', 'USER', ...],
@ -100,6 +121,11 @@ class PGCompleter(Completer):
self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2)
self.search_path_filter = settings.get("search_path_filter") self.search_path_filter = settings.get("search_path_filter")
self.generate_aliases = settings.get("generate_aliases") self.generate_aliases = settings.get("generate_aliases")
alias_map_file = settings.get("alias_map_file")
if alias_map_file is not None:
self.alias_map = load_alias_map_file(alias_map_file)
else:
self.alias_map = None
self.casing_file = settings.get("casing_file") self.casing_file = settings.get("casing_file")
self.insert_col_skip_patterns = [ self.insert_col_skip_patterns = [
re.compile(pattern) re.compile(pattern)
@ -157,7 +183,6 @@ class PGCompleter(Completer):
self.all_completions.update(additional_keywords) self.all_completions.update(additional_keywords)
def extend_schemata(self, schemata): def extend_schemata(self, schemata):
# schemata is a list of schema names # schemata is a list of schema names
schemata = self.escaped_names(schemata) schemata = self.escaped_names(schemata)
metadata = self.dbmetadata["tables"] metadata = self.dbmetadata["tables"]
@ -226,7 +251,6 @@ class PGCompleter(Completer):
self.all_completions.add(colname) self.all_completions.add(colname)
def extend_functions(self, func_data): def extend_functions(self, func_data):
# func_data is a list of function metadata namedtuples # func_data is a list of function metadata namedtuples
# dbmetadata['schema_name']['functions']['function_name'] should return # dbmetadata['schema_name']['functions']['function_name'] should return
@ -260,7 +284,6 @@ class PGCompleter(Completer):
} }
def extend_foreignkeys(self, fk_data): def extend_foreignkeys(self, fk_data):
# fk_data is a list of ForeignKey namedtuples, with fields # fk_data is a list of ForeignKey namedtuples, with fields
# parentschema, childschema, parenttable, childtable, # parentschema, childschema, parenttable, childtable,
# parentcolumns, childcolumns # parentcolumns, childcolumns
@ -283,7 +306,6 @@ class PGCompleter(Completer):
parcolmeta.foreignkeys.append(fk) parcolmeta.foreignkeys.append(fk)
def extend_datatypes(self, type_data): def extend_datatypes(self, type_data):
# dbmetadata['datatypes'][schema_name][type_name] should store type # dbmetadata['datatypes'][schema_name][type_name] should store type
# metadata, such as composite type field names. Currently, we're not # metadata, such as composite type field names. Currently, we're not
# storing any metadata beyond typename, so just store None # storing any metadata beyond typename, so just store None
@ -697,7 +719,6 @@ class PGCompleter(Completer):
return self.find_matches(word_before_cursor, conds, meta="join") return self.find_matches(word_before_cursor, conds, meta="join")
def get_function_matches(self, suggestion, word_before_cursor, alias=False): def get_function_matches(self, suggestion, word_before_cursor, alias=False):
if suggestion.usage == "from": if suggestion.usage == "from":
# Only suggest functions allowed in FROM clause # Only suggest functions allowed in FROM clause

View file

@ -1,7 +1,7 @@
import logging import logging
import traceback import traceback
from collections import namedtuple from collections import namedtuple
import re
import pgspecial as special import pgspecial as special
import psycopg import psycopg
import psycopg.sql import psycopg.sql
@ -17,6 +17,27 @@ ViewDef = namedtuple(
) )
# we added this funcion to strip beginning comments
# because sqlparse didn't handle tem well. It won't be needed if sqlparse
# does parsing of this situation better
def remove_beginning_comments(command):
# Regular expression pattern to match comments
pattern = r"^(/\*.*?\*/|--.*?)(?:\n|$)"
# Find and remove all comments from the beginning
cleaned_command = command
comments = []
match = re.match(pattern, cleaned_command, re.DOTALL)
while match:
comments.append(match.group())
cleaned_command = cleaned_command[len(match.group()) :].lstrip()
match = re.match(pattern, cleaned_command, re.DOTALL)
return [cleaned_command, comments]
def register_typecasters(connection): def register_typecasters(connection):
"""Casts date and timestamp values to string, resolves issues with out-of-range """Casts date and timestamp values to string, resolves issues with out-of-range
dates (e.g. BC) which psycopg can't handle""" dates (e.g. BC) which psycopg can't handle"""
@ -76,7 +97,6 @@ class ProtocolSafeCursor(psycopg.Cursor):
class PGExecute: class PGExecute:
# The boolean argument to the current_schemas function indicates whether # The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog # implicit schemas, e.g. pg_catalog
search_path_query = """ search_path_query = """
@ -180,7 +200,6 @@ class PGExecute:
dsn=None, dsn=None,
**kwargs, **kwargs,
): ):
conn_params = self._conn_params.copy() conn_params = self._conn_params.copy()
new_params = { new_params = {
@ -203,7 +222,11 @@ class PGExecute:
conn_params.update({k: v for k, v in new_params.items() if v}) conn_params.update({k: v for k, v in new_params.items() if v})
conn_info = make_conninfo(**conn_params) if "dsn" in conn_params:
other_params = {k: v for k, v in conn_params.items() if k != "dsn"}
conn_info = make_conninfo(conn_params["dsn"], **other_params)
else:
conn_info = make_conninfo(**conn_params)
conn = psycopg.connect(conn_info) conn = psycopg.connect(conn_info)
conn.cursor_factory = ProtocolSafeCursor conn.cursor_factory = ProtocolSafeCursor
@ -309,21 +332,20 @@ class PGExecute:
# sql parse doesn't split on a comment first + special # sql parse doesn't split on a comment first + special
# so we're going to do it # so we're going to do it
sqltemp = [] removed_comments = []
sqlarr = [] sqlarr = []
cleaned_command = ""
if statement.startswith("--"): # could skip if statement doesn't match ^-- or ^/*
sqltemp = statement.split("\n") cleaned_command, removed_comments = remove_beginning_comments(statement)
sqlarr.append(sqltemp[0])
for i in sqlparse.split(sqltemp[1]): sqlarr = sqlparse.split(cleaned_command)
sqlarr.append(i)
elif statement.startswith("/*"): # now re-add the beginning comments if there are any, so that they show up in
sqltemp = statement.split("*/") # log files etc when running these commands
sqltemp[0] = sqltemp[0] + "*/"
for i in sqlparse.split(sqltemp[1]): if len(removed_comments) > 0:
sqlarr.append(i) sqlarr = removed_comments + sqlarr
else:
sqlarr = sqlparse.split(statement)
# run each sql query # run each sql query
for sql in sqlarr: for sql in sqlarr:
@ -470,7 +492,7 @@ class PGExecute:
return ( return (
psycopg.sql.SQL(template) psycopg.sql.SQL(template)
.format( .format(
name=psycopg.sql.Identifier(f"{result.nspname}.{result.relname}"), name=psycopg.sql.Identifier(result.nspname, result.relname),
stmt=psycopg.sql.SQL(result.viewdef), stmt=psycopg.sql.SQL(result.viewdef),
) )
.as_string(self.conn) .as_string(self.conn)

View file

@ -1,18 +1,14 @@
from pkg_resources import packaging
import prompt_toolkit
from prompt_toolkit.key_binding.vi_state import InputMode from prompt_toolkit.key_binding.vi_state import InputMode
from prompt_toolkit.application import get_app from prompt_toolkit.application import get_app
parse_version = packaging.version.parse
vi_modes = { vi_modes = {
InputMode.INSERT: "I", InputMode.INSERT: "I",
InputMode.NAVIGATION: "N", InputMode.NAVIGATION: "N",
InputMode.REPLACE: "R", InputMode.REPLACE: "R",
InputMode.INSERT_MULTIPLE: "M", InputMode.INSERT_MULTIPLE: "M",
} }
if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"): # REPLACE_SINGLE is available in prompt_toolkit >= 3.0.6
if "REPLACE_SINGLE" in {e.name for e in InputMode}:
vi_modes[InputMode.REPLACE_SINGLE] = "R" vi_modes[InputMode.REPLACE_SINGLE] = "R"

View file

@ -146,7 +146,7 @@ class Visualizer:
elif self.explain.get("Max Rows") < plan["Actual Rows"]: elif self.explain.get("Max Rows") < plan["Actual Rows"]:
self.explain["Max Rows"] = plan["Actual Rows"] self.explain["Max Rows"] = plan["Actual Rows"]
if not self.explain.get("MaxCost"): if not self.explain.get("Max Cost"):
self.explain["Max Cost"] = plan["Actual Cost"] self.explain["Max Cost"] = plan["Actual Cost"]
elif self.explain.get("Max Cost") < plan["Actual Cost"]: elif self.explain.get("Max Cost") < plan["Actual Cost"]:
self.explain["Max Cost"] = plan["Actual Cost"] self.explain["Max Cost"] = plan["Actual Cost"]
@ -171,7 +171,7 @@ class Visualizer:
return self.warning_format("%.2f ms" % value) return self.warning_format("%.2f ms" % value)
elif value < 60000: elif value < 60000:
return self.critical_format( return self.critical_format(
"%.2f s" % (value / 2000.0), "%.2f s" % (value / 1000.0),
) )
else: else:
return self.critical_format( return self.critical_format(

View file

@ -1,6 +1,6 @@
[tool.black] [tool.black]
line-length = 88 line-length = 88
target-version = ['py36'] target-version = ['py38']
include = '\.pyi?$' include = '\.pyi?$'
exclude = ''' exclude = '''
/( /(
@ -19,4 +19,3 @@ exclude = '''
| tests/data | tests/data
)/ )/
''' '''

View file

@ -57,7 +57,7 @@ def version(version_file):
def commit_for_release(version_file, ver): def commit_for_release(version_file, ver):
run_step("git", "reset") run_step("git", "reset")
run_step("git", "add", version_file) run_step("git", "add", "-u")
run_step("git", "commit", "--message", "Releasing version {}".format(ver)) run_step("git", "commit", "--message", "Releasing version {}".format(ver))

View file

@ -1,7 +1,7 @@
pytest>=2.7.0 pytest>=2.7.0
tox>=1.9.2 tox>=1.9.2
behave>=1.2.4 behave>=1.2.4
black>=22.3.0 black>=23.3.0
pexpect==3.3; platform_system != "Windows" pexpect==3.3; platform_system != "Windows"
pre-commit>=1.16.0 pre-commit>=1.16.0
coverage>=5.0.4 coverage>=5.0.4
@ -10,4 +10,4 @@ docutils>=0.13.1
autopep8>=1.3.3 autopep8>=1.3.3
twine>=1.11.0 twine>=1.11.0
wheel>=0.33.6 wheel>=0.33.6
sshtunnel>=0.4.0 sshtunnel>=0.4.0

View file

@ -51,7 +51,7 @@ setup(
"keyring": ["keyring >= 12.2.0"], "keyring": ["keyring >= 12.2.0"],
"sshtunnel": ["sshtunnel >= 0.4.0"], "sshtunnel": ["sshtunnel >= 0.4.0"],
}, },
python_requires=">=3.7", python_requires=">=3.8",
entry_points=""" entry_points="""
[console_scripts] [console_scripts]
pgcli=pgcli.main:cli pgcli=pgcli.main:cli
@ -62,10 +62,11 @@ setup(
"Operating System :: Unix", "Operating System :: Unix",
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: SQL", "Programming Language :: SQL",
"Topic :: Database", "Topic :: Database",
"Topic :: Database :: Front-Ends", "Topic :: Database :: Front-Ends",

View file

@ -23,6 +23,30 @@ Feature: run the cli,
When we send "ctrl + d" When we send "ctrl + d"
then dbcli exits then dbcli exits
Scenario: confirm exit when a transaction is ongoing
When we begin transaction
and we try to send "ctrl + d"
then we see ongoing transaction message
when we send "c"
then dbcli exits
Scenario: cancel exit when a transaction is ongoing
When we begin transaction
and we try to send "ctrl + d"
then we see ongoing transaction message
when we send "a"
then we see dbcli prompt
when we rollback transaction
when we send "ctrl + d"
then dbcli exits
Scenario: interrupt current query via "ctrl + c"
When we send sleep query
and we send "ctrl + c"
then we see cancelled query warning
when we check for any non-idle sleep queries
then we don't see any non-idle sleep queries
Scenario: list databases Scenario: list databases
When we list databases When we list databases
then we see list of databases then we see list of databases

View file

@ -5,7 +5,7 @@ Feature: manipulate databases:
When we create database When we create database
then we see database created then we see database created
when we drop database when we drop database
then we confirm the destructive warning then we respond to the destructive warning: y
then we see database dropped then we see database dropped
when we connect to dbserver when we connect to dbserver
then we see database connected then we see database connected

View file

@ -8,15 +8,38 @@ Feature: manipulate tables:
then we see table created then we see table created
when we insert into table when we insert into table
then we see record inserted then we see record inserted
when we select from table
then we see data selected: initial
when we update table when we update table
then we see record updated then we see record updated
when we select from table when we select from table
then we see data selected then we see data selected: updated
when we delete from table when we delete from table
then we confirm the destructive warning then we respond to the destructive warning: y
then we see record deleted then we see record deleted
when we drop table when we drop table
then we confirm the destructive warning then we respond to the destructive warning: y
then we see table dropped then we see table dropped
when we connect to dbserver when we connect to dbserver
then we see database connected then we see database connected
Scenario: transaction handling, with cancelling on a destructive warning.
When we connect to test database
then we see database connected
when we create table
then we see table created
when we begin transaction
then we see transaction began
when we insert into table
then we see record inserted
when we delete from table
then we respond to the destructive warning: n
when we select from table
then we see data selected: initial
when we rollback transaction
then we see transaction rolled back
when we select from table
then we see select output without data
when we drop table
then we respond to the destructive warning: y
then we see table dropped

View file

@ -164,10 +164,24 @@ def before_step(context, _):
context.atprompt = False context.atprompt = False
def is_known_problem(scenario):
"""TODO: why is this not working in 3.12?"""
if sys.version_info >= (3, 12):
return scenario.name in (
'interrupt current query via "ctrl + c"',
"run the cli with --username",
"run the cli with --user",
"run the cli with --port",
)
return False
def before_scenario(context, scenario): def before_scenario(context, scenario):
if scenario.name == "list databases": if scenario.name == "list databases":
# not using the cli for that # not using the cli for that
return return
if is_known_problem(scenario):
scenario.skip()
currentdb = None currentdb = None
if "pgbouncer" in scenario.feature.tags: if "pgbouncer" in scenario.feature.tags:
if context.pgbouncer_available: if context.pgbouncer_available:

View file

@ -7,7 +7,7 @@ Feature: expanded mode:
and we select from table and we select from table
then we see expanded data selected then we see expanded data selected
when we drop table when we drop table
then we confirm the destructive warning then we respond to the destructive warning: y
then we see table dropped then we see table dropped
Scenario: expanded off Scenario: expanded off
@ -16,7 +16,7 @@ Feature: expanded mode:
and we select from table and we select from table
then we see nonexpanded data selected then we see nonexpanded data selected
when we drop table when we drop table
then we confirm the destructive warning then we respond to the destructive warning: y
then we see table dropped then we see table dropped
Scenario: expanded auto Scenario: expanded auto
@ -25,5 +25,5 @@ Feature: expanded mode:
and we select from table and we select from table
then we see auto data selected then we see auto data selected
when we drop table when we drop table
then we confirm the destructive warning then we respond to the destructive warning: y
then we see table dropped then we see table dropped

View file

@ -64,13 +64,83 @@ def step_ctrl_d(context):
""" """
Send Ctrl + D to hopefully exit. Send Ctrl + D to hopefully exit.
""" """
step_try_to_ctrl_d(context)
context.cli.expect(pexpect.EOF, timeout=5)
context.exit_sent = True
@when('we try to send "ctrl + d"')
def step_try_to_ctrl_d(context):
"""
Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is
ongoing).
"""
# turn off pager before exiting # turn off pager before exiting
context.cli.sendcontrol("c") context.cli.sendcontrol("c")
context.cli.sendline(r"\pset pager off") context.cli.sendline(r"\pset pager off")
wrappers.wait_prompt(context) wrappers.wait_prompt(context)
context.cli.sendcontrol("d") context.cli.sendcontrol("d")
context.cli.expect(pexpect.EOF, timeout=5)
context.exit_sent = True
@when('we send "ctrl + c"')
def step_ctrl_c(context):
"""Send Ctrl + c to hopefully interrupt."""
context.cli.sendcontrol("c")
@then("we see cancelled query warning")
def step_see_cancelled_query_warning(context):
"""
Make sure we receive the warning that the current query was cancelled.
"""
wrappers.expect_exact(context, "cancelled query", timeout=2)
@then("we see ongoing transaction message")
def step_see_ongoing_transaction_error(context):
"""
Make sure we receive the warning that a transaction is ongoing.
"""
context.cli.expect("A transaction is ongoing.", timeout=2)
@when("we send sleep query")
def step_send_sleep_15_seconds(context):
"""
Send query to sleep for 15 seconds.
"""
context.cli.sendline("select pg_sleep(15)")
@when("we check for any non-idle sleep queries")
def step_check_for_active_sleep_queries(context):
"""
Send query to check for any non-idle pg_sleep queries.
"""
context.cli.sendline(
"select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';"
)
@then("we don't see any non-idle sleep queries")
def step_no_active_sleep_queries(context):
"""Confirm that any pg_sleep queries are either idle or not active."""
wrappers.expect_exact(
context,
context.conf["pager_boundary"]
+ "\r"
+ dedent(
"""
+-------+\r
| state |\r
|-------|\r
+-------+\r
SELECT 0\r
"""
)
+ context.conf["pager_boundary"],
timeout=5,
)
@when(r'we send "\?" command') @when(r'we send "\?" command')
@ -131,18 +201,31 @@ def step_see_found(context):
) )
@then("we confirm the destructive warning") @then("we respond to the destructive warning: {response}")
def step_confirm_destructive_command(context): def step_resppond_to_destructive_command(context, response):
"""Confirm destructive command.""" """Respond to destructive command."""
wrappers.expect_exact( wrappers.expect_exact(
context, context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:",
timeout=2, timeout=2,
) )
context.cli.sendline("y") context.cli.sendline(response.strip())
@then("we send password") @then("we send password")
def step_send_password(context): def step_send_password(context):
wrappers.expect_exact(context, "Password for", timeout=5) wrappers.expect_exact(context, "Password for", timeout=5)
context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER") context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER")
@when('we send "{text}"')
def step_send_text(context, text):
context.cli.sendline(text)
# Try to detect whether we are exiting. If so, set `exit_sent`
# so that `after_scenario` correctly cleans up.
try:
context.cli.expect(pexpect.EOF, timeout=0.2)
except pexpect.TIMEOUT:
pass
else:
context.exit_sent = True

View file

@ -9,6 +9,10 @@ from textwrap import dedent
import wrappers import wrappers
INITIAL_DATA = "xxx"
UPDATED_DATA = "yyy"
@when("we create table") @when("we create table")
def step_create_table(context): def step_create_table(context):
""" """
@ -22,7 +26,7 @@ def step_insert_into_table(context):
""" """
Send insert into table. Send insert into table.
""" """
context.cli.sendline("""insert into a(x) values('xxx');""") context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""")
@when("we update table") @when("we update table")
@ -30,7 +34,9 @@ def step_update_table(context):
""" """
Send insert into table. Send insert into table.
""" """
context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""") context.cli.sendline(
f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';"""
)
@when("we select from table") @when("we select from table")
@ -46,7 +52,7 @@ def step_delete_from_table(context):
""" """
Send deete from table. Send deete from table.
""" """
context.cli.sendline("""delete from a where x = 'yyy';""") context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""")
@when("we drop table") @when("we drop table")
@ -57,6 +63,30 @@ def step_drop_table(context):
context.cli.sendline("drop table a;") context.cli.sendline("drop table a;")
@when("we alter the table")
def step_alter_table(context):
"""
Alter the table by adding a column.
"""
context.cli.sendline("""alter table a add column y varchar;""")
@when("we begin transaction")
def step_begin_transaction(context):
"""
Begin transaction
"""
context.cli.sendline("begin;")
@when("we rollback transaction")
def step_rollback_transaction(context):
"""
Rollback transaction
"""
context.cli.sendline("rollback;")
@then("we see table created") @then("we see table created")
def step_see_table_created(context): def step_see_table_created(context):
""" """
@ -81,21 +111,42 @@ def step_see_record_updated(context):
wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2) wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2)
@then("we see data selected") @then("we see data selected: {data}")
def step_see_data_selected(context): def step_see_data_selected(context, data):
""" """
Wait to see select output. Wait to see select output with initial or updated data.
"""
x = UPDATED_DATA if data == "updated" else INITIAL_DATA
wrappers.expect_pager(
context,
dedent(
f"""\
+-----+\r
| x |\r
|-----|\r
| {x} |\r
+-----+\r
SELECT 1\r
"""
),
timeout=1,
)
@then("we see select output without data")
def step_see_no_data_selected(context):
"""
Wait to see select output without data.
""" """
wrappers.expect_pager( wrappers.expect_pager(
context, context,
dedent( dedent(
"""\ """\
+-----+\r +---+\r
| x |\r | x |\r
|-----|\r |---|\r
| yyy |\r +---+\r
+-----+\r SELECT 0\r
SELECT 1\r
""" """
), ),
timeout=1, timeout=1,
@ -116,3 +167,19 @@ def step_see_table_dropped(context):
Wait to see drop output. Wait to see drop output.
""" """
wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2) wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2)
@then("we see transaction began")
def step_see_transaction_began(context):
"""
Wait to see transaction began.
"""
wrappers.expect_pager(context, "BEGIN\r\n", timeout=2)
@then("we see transaction rolled back")
def step_see_transaction_rolled_back(context):
"""
Wait to see transaction rollback.
"""
wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2)

View file

@ -16,7 +16,7 @@ def step_prepare_data(context):
context.cli.sendline("drop table if exists a;") context.cli.sendline("drop table if exists a;")
wrappers.expect_exact( wrappers.expect_exact(
context, context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:",
timeout=2, timeout=2,
) )
context.cli.sendline("y") context.cli.sendline("y")

View file

@ -3,10 +3,7 @@ import pexpect
from pgcli.main import COLOR_CODE_REGEX from pgcli.main import COLOR_CODE_REGEX
import textwrap import textwrap
try: from io import StringIO
from StringIO import StringIO
except ImportError:
from io import StringIO
def expect_exact(context, expected, timeout): def expect_exact(context, expected, timeout):

View file

@ -34,7 +34,7 @@ def test_output_sql_insert():
"Jackson", "Jackson",
"jackson_test@gmail.com", "jackson_test@gmail.com",
"132454789", "132454789",
"", None,
"2022-09-09 19:44:32.712343+08", "2022-09-09 19:44:32.712343+08",
"2022-09-09 19:44:32.712343+08", "2022-09-09 19:44:32.712343+08",
] ]
@ -58,7 +58,7 @@ def test_output_sql_insert():
output_list = [l for l in output] output_list = [l for l in output]
expected = [ expected = [
'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES',
" ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', '', " " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, "
+ "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')",
";", ";",
] ]

View file

@ -1,5 +1,10 @@
import pytest import pytest
from pgcli.packages.parseutils import is_destructive from pgcli.packages.parseutils import (
is_destructive,
parse_destructive_warning,
BASE_KEYWORDS,
ALL_KEYWORDS,
)
from pgcli.packages.parseutils.tables import extract_tables from pgcli.packages.parseutils.tables import extract_tables
from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
@ -263,18 +268,43 @@ def test_is_open_quote__open(sql):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sql", "warning_level", "expected"), ("sql", "keywords", "expected"),
[ [
("update abc set x = 1", "all", True), ("update abc set x = 1", ALL_KEYWORDS, True),
("update abc set x = 1 where y = 2", "all", True), ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True),
("update abc set x = 1", "moderate", True), ("update abc set x = 1", BASE_KEYWORDS, True),
("update abc set x = 1 where y = 2", "moderate", False), ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False),
("select x, y, z from abc", "all", False), ("select x, y, z from abc", ALL_KEYWORDS, False),
("drop abc", "all", True), ("drop abc", ALL_KEYWORDS, True),
("alter abc", "all", True), ("alter abc", ALL_KEYWORDS, True),
("delete abc", "all", True), ("delete abc", ALL_KEYWORDS, True),
("truncate abc", "all", True), ("truncate abc", ALL_KEYWORDS, True),
("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False),
("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False),
("insert into abc values (1, 2, 3)", ["insert"], True),
("insert into abc values (1, 2, 3)", ["insert"], True),
], ],
) )
def test_is_destructive(sql, warning_level, expected): def test_is_destructive(sql, keywords, expected):
assert is_destructive(sql, warning_level=warning_level) == expected assert is_destructive(sql, keywords) == expected
@pytest.mark.parametrize(
("warning_level", "expected"),
[
("true", ALL_KEYWORDS),
("false", []),
("all", ALL_KEYWORDS),
("moderate", BASE_KEYWORDS),
("off", []),
("", []),
(None, []),
(ALL_KEYWORDS, ALL_KEYWORDS),
(BASE_KEYWORDS, BASE_KEYWORDS),
("insert", ["insert"]),
("drop,alter,delete", ["drop", "alter", "delete"]),
(["drop", "alter", "delete"], ["drop", "alter", "delete"]),
],
)
def test_parse_destructive_warning(warning_level, expected):
assert parse_destructive_warning(warning_level) == expected

View file

@ -216,7 +216,6 @@ def pset_pager_mocks():
with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch(
"pgcli.main.click.echo_via_pager" "pgcli.main.click.echo_via_pager"
) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app:
yield cli, mock_echo, mock_echo_via_pager, mock_app yield cli, mock_echo, mock_echo_via_pager, mock_app
@ -297,6 +296,22 @@ def test_i_works(tmpdir, executor):
run(executor, statement, pgspecial=cli.pgspecial) run(executor, statement, pgspecial=cli.pgspecial)
@dbtest
def test_echo_works(executor):
cli = PGCli(pgexecute=executor)
statement = r"\echo asdf"
result = run(executor, statement, pgspecial=cli.pgspecial)
assert result == ["asdf"]
@dbtest
def test_qecho_works(executor):
cli = PGCli(pgexecute=executor)
statement = r"\qecho asdf"
result = run(executor, statement, pgspecial=cli.pgspecial)
assert result == ["asdf"]
@dbtest @dbtest
def test_watch_works(executor): def test_watch_works(executor):
cli = PGCli(pgexecute=executor) cli = PGCli(pgexecute=executor)
@ -371,7 +386,6 @@ def test_quoted_db_uri(tmpdir):
def test_pg_service_file(tmpdir): def test_pg_service_file(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect: with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:

76
tests/test_pgcompleter.py Normal file
View file

@ -0,0 +1,76 @@
import pytest
from pgcli import pgcompleter
def test_load_alias_map_file_missing_file():
with pytest.raises(
pgcompleter.InvalidMapFile,
match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$",
):
pgcompleter.load_alias_map_file("/path/to/non-existent/file.json")
def test_load_alias_map_file_invalid_json(tmp_path):
fpath = tmp_path / "foo.json"
fpath.write_text("this is not valid json")
with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"):
pgcompleter.load_alias_map_file(str(fpath))
@pytest.mark.parametrize(
"table_name, alias",
[
("SomE_Table", "SET"),
("SOmeTabLe", "SOTL"),
("someTable", "T"),
],
)
def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias):
assert pgcompleter.generate_alias(table_name) == alias
@pytest.mark.parametrize(
"table_name, alias",
[
("some_tab_le", "stl"),
("s_ome_table", "sot"),
("sometable", "s"),
],
)
def test_generate_alias_uses_first_char_and_every_preceded_by_underscore(
table_name, alias
):
assert pgcompleter.generate_alias(table_name) == alias
@pytest.mark.parametrize(
"table_name, alias_map, alias",
[
("some_table", {"some_table": "my_alias"}, "my_alias"),
],
)
def test_generate_alias_can_use_alias_map(table_name, alias_map, alias):
assert pgcompleter.generate_alias(table_name, alias_map) == alias
@pytest.mark.parametrize(
"table_name, alias_map, alias",
[
("SomeTable", {"SomeTable": "my_alias"}, "my_alias"),
],
)
def test_generate_alias_prefers_alias_over_upper_case_name(
table_name, alias_map, alias
):
assert pgcompleter.generate_alias(table_name, alias_map) == alias
@pytest.mark.parametrize(
"table_name, alias",
[
("Some_tablE", "SE"),
("SomeTab_le", "ST"),
],
)
def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias):
assert pgcompleter.generate_alias(table_name) == alias

View file

@ -304,9 +304,7 @@ def test_execute_from_commented_file_that_executes_another_file(
@dbtest @dbtest
def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
# https://github.com/dbcli/pgcli/issues/1362 # just some base cases that should work also
# just some base caes that should work also
statement = "--comment\nselect now();" statement = "--comment\nselect now();"
result = run(executor, statement, pgspecial=pgspecial) result = run(executor, statement, pgspecial=pgspecial)
assert result != None assert result != None
@ -317,23 +315,43 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
assert result != None assert result != None
assert result[1].find("now") >= 0 assert result[1].find("now") >= 0
statement = "/*comment\ncomment line2*/\nselect now();" # https://github.com/dbcli/pgcli/issues/1362
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[1].find("now") >= 0
statement = "--comment\n\\h" statement = "--comment\n\\h"
result = run(executor, statement, pgspecial=pgspecial) result = run(executor, statement, pgspecial=pgspecial)
assert result != None assert result != None
assert result[1].find("ALTER") >= 0 assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0 assert result[1].find("ABORT") >= 0
statement = "--comment1\n--comment2\n\\h"
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
statement = "/*comment*/\n\h;" statement = "/*comment*/\n\h;"
result = run(executor, statement, pgspecial=pgspecial) result = run(executor, statement, pgspecial=pgspecial)
assert result != None assert result != None
assert result[1].find("ALTER") >= 0 assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0 assert result[1].find("ABORT") >= 0
statement = """/*comment1
comment2*/
\h"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
statement = """/*comment1
comment2*/
/*comment 3
comment4*/
\\h"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
statement = " /*comment*/\n\h;" statement = " /*comment*/\n\h;"
result = run(executor, statement, pgspecial=pgspecial) result = run(executor, statement, pgspecial=pgspecial)
assert result != None assert result != None
@ -352,6 +370,126 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
assert result[1].find("ALTER") >= 0 assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0 assert result[1].find("ABORT") >= 0
statement = """\\h /*comment4 */"""
result = run(executor, statement, pgspecial=pgspecial)
print(result)
assert result != None
assert result[0].find("No help") >= 0
# TODO: we probably don't want to do this but sqlparse is not parsing things well
# we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/
# style comments after command
statement = """/*comment1*/
\h
/*comment4 */"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[0].find("No help") >= 0
# TODO: same for this one
statement = """/*comment1
comment3
comment2*/
\\h
/*comment4
comment5
comment6*/"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[0].find("No help") >= 0
@dbtest
def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
# https://github.com/dbcli/pgcli/issues/1403
# just some base cases that should work also
statement = "--comment\nselect now();"
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[1].find("now") >= 0
statement = "/*comment*/\nselect now();"
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[1].find("now") >= 0
# this simulates the original error (1403) without having to add/drop tables
# since it was just an error on reading input files and not the actual
# command itself
# test that the statement works
statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# test the statement with a \n in the middle
statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# test the statement with a newline in the middle
statement = """VALUES (1, 'one'),
(2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# now add a single comment line
statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# doing without special char \n
statement = """--comment
VALUES (1,'one'),
(2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# two comment lines
statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# doing without special char \n
statement = """--comment
--comment2
VALUES (1,'one'), (2, 'two'), (3, 'three');
"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# multiline comment + newline in middle of the statement
statement = """/*comment
comment2
comment3*/
VALUES (1,'one'),
(2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
# multiline comment + newline in middle of the statement
# + comments after the statement
statement = """/*comment
comment2
comment3*/
VALUES (1,'one'),
(2, 'two'), (3, 'three');
--comment4
--comment5"""
result = run(executor, statement, pgspecial=pgspecial)
assert result != None
assert result[5].find("three") >= 0
@dbtest @dbtest
def test_multiple_queries_same_line(executor): def test_multiple_queries_same_line(executor):
@ -558,6 +696,7 @@ def test_view_definition(executor):
run(executor, "create view vw1 AS SELECT * FROM tbl1") run(executor, "create view vw1 AS SELECT * FROM tbl1")
run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1") run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1")
result = executor.view_definition("vw1") result = executor.view_definition("vw1")
assert 'VIEW "public"."vw1" AS' in result
assert "FROM tbl1" in result assert "FROM tbl1" in result
# import pytest; pytest.set_trace() # import pytest; pytest.set_trace()
result = executor.view_definition("mvw1") result = executor.view_definition("mvw1")

View file

@ -7,4 +7,11 @@ def test_confirm_destructive_query_notty():
stdin = click.get_text_stream("stdin") stdin = click.get_text_stream("stdin")
if not stdin.isatty(): if not stdin.isatty():
sql = "drop database foo;" sql = "drop database foo;"
assert confirm_destructive_query(sql, "all") is None assert confirm_destructive_query(sql, [], None) is None
def test_confirm_destructive_query_with_alias():
stdin = click.get_text_stream("stdin")
if not stdin.isatty():
sql = "drop database foo;"
assert confirm_destructive_query(sql, ["drop"], "test") is None

View file

@ -1,10 +1,11 @@
[tox] [tox]
envlist = py37, py38, py39, py310 envlist = py38, py39, py310, py311, py312
[testenv] [testenv]
deps = pytest>=2.7.0,<=3.0.7 deps = pytest>=2.7.0,<=3.0.7
mock>=1.0.1 mock>=1.0.1
behave>=1.2.4 behave>=1.2.4
pexpect==3.3 pexpect==3.3
sshtunnel>=0.4.0
commands = py.test commands = py.test
behave tests/features behave tests/features
passenv = PGHOST passenv = PGHOST