1
0
Fork 0

Merging upstream version 4.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 20:05:49 +01:00
parent bd17f43dd7
commit 73dcfce521
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
14 changed files with 456 additions and 36 deletions

View file

@ -67,10 +67,6 @@ jobs:
psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help' psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help'
- name: Install beta version of pendulum
run: pip install pendulum==3.0.0b1
if: matrix.python-version == '3.12'
- name: Install requirements - name: Install requirements
run: | run: |
pip install -U pip setuptools pip install -U pip setuptools
@ -89,7 +85,7 @@ jobs:
run: behave tests/features --no-capture run: behave tests/features --no-capture
- name: Check changelog for ReST compliance - name: Check changelog for ReST compliance
run: rst2html.py --halt=warning changelog.rst >/dev/null run: docutils --halt=warning changelog.rst >/dev/null
- name: Run Black - name: Run Black
run: black --check . run: black --check .

View file

@ -130,6 +130,10 @@ Contributors:
* blag * blag
* Rob Berry (rob-b) * Rob Berry (rob-b)
* Sharon Yogev (sharonyogev) * Sharon Yogev (sharonyogev)
* Hollis Wu (holi0317)
* Antonio Aguilar (crazybolillo)
* Andrew M. MacFie (amacfie)
* saucoide
Creator: Creator:
-------- --------

View file

@ -1,5 +1,26 @@
4.1.0 (2024-03-09)
================== ==================
4.0.1 (2023-11-30)
Features:
---------
* Support `PGAPPNAME` as an environment variable and `--application-name` as a command line argument.
* Add `verbose_errors` config and `\v` special command which enable the
displaying of all Postgres error fields received.
* Show Postgres notifications.
* Support sqlparse 0.5.x
* Add `--log-file [filename]` cli argument and `\log-file [filename]` special commands to
log to an external file in addition to the normal output
Bug fixes:
----------
* Fix display of "short host" in prompt (with `\h`) for IPv4 addresses ([issue 964](https://github.com/dbcli/pgcli/issues/964)).
* Fix backwards display of NOTICEs from a Function ([issue 1443](https://github.com/dbcli/pgcli/issues/1443))
* Fix psycopg errors when installing on Windows. ([issue 1413](https://https://github.com/dbcli/pgcli/issues/1413))
* Use a home-made function to display query duration instead of relying on a third-party library (the general behaviour does not change), which fixes the installation of `pgcli` on 32-bit architectures ([issue 1451](https://github.com/dbcli/pgcli/issues/1451))
==================
4.0.1 (2023-10-30)
================== ==================
Internal: Internal:
@ -7,7 +28,7 @@ Internal:
* Allow stable version of pendulum. * Allow stable version of pendulum.
================== ==================
4.0.0 (2023-11-27) 4.0.0 (2023-10-27)
================== ==================
Features: Features:

View file

@ -1 +1 @@
__version__ = "4.0.1" __version__ = "4.1.0"

View file

@ -11,9 +11,9 @@ import logging
import threading import threading
import shutil import shutil
import functools import functools
import pendulum
import datetime as dt import datetime as dt
import itertools import itertools
import pathlib
import platform import platform
from time import time, sleep from time import time, sleep
from typing import Optional from typing import Optional
@ -74,8 +74,9 @@ from urllib.parse import urlparse
from getpass import getuser from getpass import getuser
from psycopg import OperationalError, InterfaceError from psycopg import OperationalError, InterfaceError, Notify
from psycopg.conninfo import make_conninfo, conninfo_to_dict from psycopg.conninfo import make_conninfo, conninfo_to_dict
from psycopg.errors import Diagnostic
from collections import namedtuple from collections import namedtuple
@ -129,6 +130,15 @@ class PgCliQuitError(Exception):
pass pass
def notify_callback(notify: Notify):
click.secho(
'Notification received on channel "{}" (PID {}):\n{}'.format(
notify.channel, notify.pid, notify.payload
),
fg="green",
)
class PGCli: class PGCli:
default_prompt = "\\u@\\h:\\d> " default_prompt = "\\u@\\h:\\d> "
max_len_prompt = 30 max_len_prompt = 30
@ -165,6 +175,7 @@ class PGCli:
pgexecute=None, pgexecute=None,
pgclirc_file=None, pgclirc_file=None,
row_limit=None, row_limit=None,
application_name="pgcli",
single_connection=False, single_connection=False,
less_chatty=None, less_chatty=None,
prompt=None, prompt=None,
@ -172,6 +183,7 @@ class PGCli:
auto_vertical_output=False, auto_vertical_output=False,
warn=None, warn=None,
ssh_tunnel_url: Optional[str] = None, ssh_tunnel_url: Optional[str] = None,
log_file: Optional[str] = None,
): ):
self.force_passwd_prompt = force_passwd_prompt self.force_passwd_prompt = force_passwd_prompt
self.never_passwd_prompt = never_passwd_prompt self.never_passwd_prompt = never_passwd_prompt
@ -210,6 +222,8 @@ class PGCli:
else: else:
self.row_limit = c["main"].as_int("row_limit") self.row_limit = c["main"].as_int("row_limit")
self.application_name = application_name
# if not specified, set to DEFAULT_MAX_FIELD_WIDTH # if not specified, set to DEFAULT_MAX_FIELD_WIDTH
# if specified but empty, set to None to disable truncation # if specified but empty, set to None to disable truncation
# ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0 # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0
@ -237,6 +251,9 @@ class PGCli:
) )
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.verbose_errors = "verbose_errors" in c["main"] and c["main"].as_bool(
"verbose_errors"
)
self.null_string = c["main"].get("null_string", "<null>") self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = ( self.prompt_format = (
prompt prompt
@ -295,6 +312,11 @@ class PGCli:
self.ssh_tunnel_url = ssh_tunnel_url self.ssh_tunnel_url = ssh_tunnel_url
self.ssh_tunnel = None self.ssh_tunnel = None
if log_file:
with open(log_file, "a+"):
pass # ensure writeable
self.log_file = log_file
# formatter setup # formatter setup
self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
register_new_formatter(self.formatter) register_new_formatter(self.formatter)
@ -354,6 +376,12 @@ class PGCli:
"\\o [filename]", "\\o [filename]",
"Send all query results to file.", "Send all query results to file.",
) )
self.pgspecial.register(
self.write_to_logfile,
"\\log-file",
"\\log-file [filename]",
"Log all query results to a logfile, in addition to the normal output destination.",
)
self.pgspecial.register( self.pgspecial.register(
self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
) )
@ -378,6 +406,26 @@ class PGCli:
"Echo a string to the query output channel.", "Echo a string to the query output channel.",
) )
self.pgspecial.register(
self.toggle_verbose_errors,
"\\v",
"\\v [on|off]",
"Toggle verbose errors.",
)
def toggle_verbose_errors(self, pattern, **_):
flag = pattern.strip()
if flag == "on":
self.verbose_errors = True
elif flag == "off":
self.verbose_errors = False
else:
self.verbose_errors = not self.verbose_errors
message = "Verbose errors " + "on." if self.verbose_errors else "off."
return [(None, None, None, message)]
def echo(self, pattern, **_): def echo(self, pattern, **_):
return [(None, None, None, pattern)] return [(None, None, None, pattern)]
@ -473,6 +521,26 @@ class PGCli:
explain_mode=self.explain_mode, explain_mode=self.explain_mode,
) )
def write_to_logfile(self, pattern, **_):
if not pattern:
self.log_file = None
message = "Logfile capture disabled"
return [(None, None, None, message, "", True, True)]
log_file = pathlib.Path(pattern).expanduser().absolute()
try:
with open(log_file, "a+"):
pass # ensure writeable
except OSError as e:
self.log_file = None
message = str(e) + "\nLogfile capture disabled"
return [(None, None, None, message, "", False, True)]
self.log_file = str(log_file)
message = 'Writing to file "%s"' % self.log_file
return [(None, None, None, message, "", True, True)]
def write_to_file(self, pattern, **_): def write_to_file(self, pattern, **_):
if not pattern: if not pattern:
self.output_file = None self.output_file = None
@ -568,7 +636,7 @@ class PGCli:
if not database: if not database:
database = user database = user
kwargs.setdefault("application_name", "pgcli") kwargs.setdefault("application_name", self.application_name)
# If password prompt is not forced but no password is provided, try # If password prompt is not forced but no password is provided, try
# getting it from environment variable. # getting it from environment variable.
@ -658,7 +726,16 @@ class PGCli:
# prompt for a password (no -w flag), prompt for a passwd and try again. # prompt for a password (no -w flag), prompt for a passwd and try again.
try: try:
try: try:
pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) pgexecute = PGExecute(
database,
user,
passwd,
host,
port,
dsn,
notify_callback,
**kwargs,
)
except (OperationalError, InterfaceError) as e: except (OperationalError, InterfaceError) as e:
if should_ask_for_password(e): if should_ask_for_password(e):
passwd = click.prompt( passwd = click.prompt(
@ -668,7 +745,14 @@ class PGCli:
type=str, type=str,
) )
pgexecute = PGExecute( pgexecute = PGExecute(
database, user, passwd, host, port, dsn, **kwargs database,
user,
passwd,
host,
port,
dsn,
notify_callback,
**kwargs,
) )
else: else:
raise e raise e
@ -775,7 +859,7 @@ class PGCli:
else: else:
try: try:
if self.output_file and not text.startswith( if self.output_file and not text.startswith(
("\\o ", "\\? ", "\\echo ") ("\\o ", "\\log-file", "\\? ", "\\echo ")
): ):
try: try:
with open(self.output_file, "a", encoding="utf-8") as f: with open(self.output_file, "a", encoding="utf-8") as f:
@ -787,6 +871,23 @@ class PGCli:
else: else:
if output: if output:
self.echo_via_pager("\n".join(output)) self.echo_via_pager("\n".join(output))
# Log to file in addition to normal output
if (
self.log_file
and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo "))
and not text.strip() == ""
):
try:
with open(self.log_file, "a", encoding="utf-8") as f:
click.echo(
dt.datetime.now().isoformat(), file=f
) # timestamp log
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
except OSError as e:
click.secho(str(e), err=True, fg="red")
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -797,9 +898,9 @@ class PGCli:
"Time: %0.03fs (%s), executed in: %0.03fs (%s)" "Time: %0.03fs (%s), executed in: %0.03fs (%s)"
% ( % (
query.total_time, query.total_time,
pendulum.Duration(seconds=query.total_time).in_words(), duration_in_words(query.total_time),
query.execution_time, query.execution_time,
pendulum.Duration(seconds=query.execution_time).in_words(), duration_in_words(query.execution_time),
) )
) )
else: else:
@ -1053,7 +1154,7 @@ class PGCli:
res = self.pgexecute.run( res = self.pgexecute.run(
text, text,
self.pgspecial, self.pgspecial,
exception_formatter, lambda x: exception_formatter(x, self.verbose_errors),
on_error_resume, on_error_resume,
explain_mode=self.explain_mode, explain_mode=self.explain_mode,
) )
@ -1337,6 +1438,12 @@ class PGCli:
type=click.INT, type=click.INT,
help="Set threshold for row limit prompt. Use 0 to disable prompt.", help="Set threshold for row limit prompt. Use 0 to disable prompt.",
) )
@click.option(
"--application-name",
default="pgcli",
envvar="PGAPPNAME",
help="Application name for the connection.",
)
@click.option( @click.option(
"--less-chatty", "--less-chatty",
"less_chatty", "less_chatty",
@ -1371,6 +1478,11 @@ class PGCli:
default=None, default=None,
help="Open an SSH tunnel to the given address and connect to the database from it.", help="Open an SSH tunnel to the given address and connect to the database from it.",
) )
@click.option(
"--log-file",
default=None,
help="Write all queries & output into a file, in addition to the normal output destination.",
)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
def cli( def cli(
@ -1387,6 +1499,7 @@ def cli(
pgclirc, pgclirc,
dsn, dsn,
row_limit, row_limit,
application_name,
less_chatty, less_chatty,
prompt, prompt,
prompt_dsn, prompt_dsn,
@ -1395,6 +1508,7 @@ def cli(
list_dsn, list_dsn,
warn, warn,
ssh_tunnel: str, ssh_tunnel: str,
log_file: str,
): ):
if version: if version:
print("Version:", __version__) print("Version:", __version__)
@ -1445,6 +1559,7 @@ def cli(
never_prompt, never_prompt,
pgclirc_file=pgclirc, pgclirc_file=pgclirc,
row_limit=row_limit, row_limit=row_limit,
application_name=application_name,
single_connection=single_connection, single_connection=single_connection,
less_chatty=less_chatty, less_chatty=less_chatty,
prompt=prompt, prompt=prompt,
@ -1452,6 +1567,7 @@ def cli(
auto_vertical_output=auto_vertical_output, auto_vertical_output=auto_vertical_output,
warn=warn, warn=warn,
ssh_tunnel_url=ssh_tunnel, ssh_tunnel_url=ssh_tunnel,
log_file=log_file,
) )
# Choose which ever one has a valid value. # Choose which ever one has a valid value.
@ -1583,8 +1699,71 @@ def is_select(status):
return status.split(None, 1)[0].lower() == "select" return status.split(None, 1)[0].lower() == "select"
def exception_formatter(e): def diagnostic_output(diagnostic: Diagnostic) -> str:
return click.style(str(e), fg="red") fields = []
if diagnostic.severity is not None:
fields.append("Severity: " + diagnostic.severity)
if diagnostic.severity_nonlocalized is not None:
fields.append("Severity (non-localized): " + diagnostic.severity_nonlocalized)
if diagnostic.sqlstate is not None:
fields.append("SQLSTATE code: " + diagnostic.sqlstate)
if diagnostic.message_primary is not None:
fields.append("Message: " + diagnostic.message_primary)
if diagnostic.message_detail is not None:
fields.append("Detail: " + diagnostic.message_detail)
if diagnostic.message_hint is not None:
fields.append("Hint: " + diagnostic.message_hint)
if diagnostic.statement_position is not None:
fields.append("Position: " + diagnostic.statement_position)
if diagnostic.internal_position is not None:
fields.append("Internal position: " + diagnostic.internal_position)
if diagnostic.internal_query is not None:
fields.append("Internal query: " + diagnostic.internal_query)
if diagnostic.context is not None:
fields.append("Where: " + diagnostic.context)
if diagnostic.schema_name is not None:
fields.append("Schema name: " + diagnostic.schema_name)
if diagnostic.table_name is not None:
fields.append("Table name: " + diagnostic.table_name)
if diagnostic.column_name is not None:
fields.append("Column name: " + diagnostic.column_name)
if diagnostic.datatype_name is not None:
fields.append("Data type name: " + diagnostic.datatype_name)
if diagnostic.constraint_name is not None:
fields.append("Constraint name: " + diagnostic.constraint_name)
if diagnostic.source_file is not None:
fields.append("File: " + diagnostic.source_file)
if diagnostic.source_line is not None:
fields.append("Line: " + diagnostic.source_line)
if diagnostic.source_function is not None:
fields.append("Routine: " + diagnostic.source_function)
return "\n".join(fields)
def exception_formatter(e, verbose_errors: bool = False):
s = str(e)
if verbose_errors:
s += "\n" + diagnostic_output(e.diag)
return click.style(s, fg="red")
def format_output(title, cur, headers, status, settings, explain_mode=False): def format_output(title, cur, headers, status, settings, explain_mode=False):
@ -1724,5 +1903,28 @@ def parse_service_info(service):
return service_conf, service_file return service_conf, service_file
def duration_in_words(duration_in_seconds: float) -> str:
if not duration_in_seconds:
return "0 seconds"
components = []
hours, remainder = divmod(duration_in_seconds, 3600)
if hours > 1:
components.append(f"{hours} hours")
elif hours == 1:
components.append("1 hour")
minutes, seconds = divmod(remainder, 60)
if minutes > 1:
components.append(f"{minutes} minutes")
elif minutes == 1:
components.append("1 minute")
if seconds >= 2:
components.append(f"{int(seconds)} seconds")
elif seconds >= 1:
components.append("1 second")
elif seconds:
components.append(f"{round(seconds, 3)} second")
return " ".join(components)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View file

@ -33,10 +33,11 @@ multi_line_mode = psql
# "unconditional_update" will warn you of update statements that don't have a where clause # "unconditional_update" will warn you of update statements that don't have a where clause
destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update
# Destructive warning can restart the connection if this is enabled and the # When `destructive_warning` is on and the user declines to proceed with a
# user declines. This means that any current uncommitted transaction can be # destructive statement, the current transaction (if any) is left untouched,
# aborted if the user doesn't want to proceed with a destructive_warning # by default. When setting `destructive_warning_restarts_connection` to
# statement. # "True", the connection to the server is restarted. In that case, the
# transaction (if any) is rolled back.
destructive_warning_restarts_connection = False destructive_warning_restarts_connection = False
# When this option is on (and if `destructive_warning` is not empty), # When this option is on (and if `destructive_warning` is not empty),
@ -155,6 +156,11 @@ max_field_width = 500
# Skip intro on startup and goodbye on exit # Skip intro on startup and goodbye on exit
less_chatty = False less_chatty = False
# Show all Postgres error fields (as listed in
# https://www.postgresql.org/docs/current/protocol-error-fields.html).
# Can be toggled with \v.
verbose_errors = False
# Postgres prompt # Postgres prompt
# \t - Current date and time # \t - Current date and time
# \u - Username # \u - Username

View file

@ -1,3 +1,4 @@
import ipaddress
import logging import logging
import traceback import traceback
from collections import namedtuple from collections import namedtuple
@ -166,6 +167,7 @@ class PGExecute:
host=None, host=None,
port=None, port=None,
dsn=None, dsn=None,
notify_callback=None,
**kwargs, **kwargs,
): ):
self._conn_params = {} self._conn_params = {}
@ -178,6 +180,7 @@ class PGExecute:
self.port = None self.port = None
self.server_version = None self.server_version = None
self.extra_args = None self.extra_args = None
self.notify_callback = notify_callback
self.connect(database, user, password, host, port, dsn, **kwargs) self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None self.reset_expanded = None
@ -236,6 +239,9 @@ class PGExecute:
self.conn = conn self.conn = conn
self.conn.autocommit = True self.conn.autocommit = True
if self.notify_callback is not None:
self.conn.add_notify_handler(self.notify_callback)
# When we connect using a DSN, we don't really know what db, # When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it. # user, etc. we connected to. Let's read it.
# Note: moved this after setting autocommit because of #664. # Note: moved this after setting autocommit because of #664.
@ -273,6 +279,11 @@ class PGExecute:
@property @property
def short_host(self): def short_host(self):
try:
ipaddress.ip_address(self.host)
return self.host
except ValueError:
pass
if "," in self.host: if "," in self.host:
host, _, _ = self.host.partition(",") host, _, _ = self.host.partition(",")
else: else:
@ -431,7 +442,11 @@ class PGExecute:
def handle_notices(n): def handle_notices(n):
nonlocal title nonlocal title
title = f"{n.message_primary}\n{n.message_detail}\n{title}" title = f"{title}"
if n.message_primary is not None:
title = f"{title}\n{n.message_primary}"
if n.message_detail is not None:
title = f"{title}\n{n.message_detail}"
self.conn.add_notice_handler(handle_notices) self.conn.add_notice_handler(handle_notices)

View file

@ -12,10 +12,10 @@ install_requirements = [
# We still need to use pt-2 unless pt-3 released on Fedora32 # We still need to use pt-2 unless pt-3 released on Fedora32
# see: https://github.com/dbcli/pgcli/pull/1197 # see: https://github.com/dbcli/pgcli/pull/1197
"prompt_toolkit>=2.0.6,<4.0.0", "prompt_toolkit>=2.0.6,<4.0.0",
"psycopg >= 3.0.14", "psycopg >= 3.0.14; sys_platform != 'win32'",
"sqlparse >=0.3.0,<0.5", "psycopg-binary >= 3.0.14; sys_platform == 'win32'",
"sqlparse >=0.3.0,<0.6",
"configobj >= 5.0.6", "configobj >= 5.0.6",
"pendulum>=2.1.0",
"cli_helpers[styles] >= 2.2.1", "cli_helpers[styles] >= 2.2.1",
] ]
@ -27,11 +27,6 @@ install_requirements = [
if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"): if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"):
install_requirements.append("setproctitle >= 1.1.9") install_requirements.append("setproctitle >= 1.1.9")
# Windows will require the binary psycopg to run pgcli
if platform.system() == "Windows":
install_requirements.append("psycopg-binary >= 3.0.14")
setup( setup(
name="pgcli", name="pgcli",
author="Pgcli Core Team", author="Pgcli Core Team",

View file

@ -9,6 +9,7 @@ from utils import (
db_connection, db_connection,
drop_tables, drop_tables,
) )
import pgcli.main
import pgcli.pgexecute import pgcli.pgexecute
@ -37,6 +38,7 @@ def executor(connection):
password=POSTGRES_PASSWORD, password=POSTGRES_PASSWORD,
port=POSTGRES_PORT, port=POSTGRES_PORT,
dsn=None, dsn=None,
notify_callback=pgcli.main.notify_callback,
) )

View file

@ -3,6 +3,7 @@ Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file. This string is used to call the step in "*.feature" file.
""" """
import pexpect import pexpect
from behave import when, then from behave import when, then

View file

@ -0,0 +1,17 @@
from unittest.mock import patch
from click.testing import CliRunner
from pgcli.main import cli
from pgcli.pgexecute import PGExecute
def test_application_name_in_env():
runner = CliRunner()
app_name = "wonderful_app"
with patch.object(PGExecute, "__init__") as mock_pgxecute:
runner.invoke(
cli, ["127.0.0.1:5432/hello", "user"], env={"PGAPPNAME": app_name}
)
kwargs = mock_pgxecute.call_args.kwargs
assert kwargs.get("application_name") == app_name

View file

@ -1,5 +1,8 @@
import os import os
import platform import platform
import re
import tempfile
import datetime
from unittest import mock from unittest import mock
import pytest import pytest
@ -11,7 +14,9 @@ except ImportError:
from pgcli.main import ( from pgcli.main import (
obfuscate_process_password, obfuscate_process_password,
duration_in_words,
format_output, format_output,
notify_callback,
PGCli, PGCli,
OutputSettings, OutputSettings,
COLOR_CODE_REGEX, COLOR_CODE_REGEX,
@ -296,6 +301,24 @@ def test_i_works(tmpdir, executor):
run(executor, statement, pgspecial=cli.pgspecial) run(executor, statement, pgspecial=cli.pgspecial)
@dbtest
def test_toggle_verbose_errors(executor):
cli = PGCli(pgexecute=executor)
cli._evaluate_command("\\v on")
assert cli.verbose_errors
output, _ = cli._evaluate_command("SELECT 1/0")
assert "SQLSTATE" in output[0]
cli._evaluate_command("\\v off")
assert not cli.verbose_errors
output, _ = cli._evaluate_command("SELECT 1/0")
assert "SQLSTATE" not in output[0]
cli._evaluate_command("\\v")
assert cli.verbose_errors
@dbtest @dbtest
def test_echo_works(executor): def test_echo_works(executor):
cli = PGCli(pgexecute=executor) cli = PGCli(pgexecute=executor)
@ -312,6 +335,34 @@ def test_qecho_works(executor):
assert result == ["asdf"] assert result == ["asdf"]
@dbtest
def test_logfile_works(executor):
with tempfile.TemporaryDirectory() as tmpdir:
log_file = f"{tmpdir}/tempfile.log"
cli = PGCli(pgexecute=executor, log_file=log_file)
statement = r"\qecho hello!"
cli.execute_command(statement)
with open(log_file, "r") as f:
log_contents = f.readlines()
assert datetime.datetime.fromisoformat(log_contents[0].strip())
assert log_contents[1].strip() == r"\qecho hello!"
assert log_contents[2].strip() == "hello!"
@dbtest
def test_logfile_unwriteable_file(executor):
cli = PGCli(pgexecute=executor)
statement = r"\log-file forbidden.log"
with mock.patch("builtins.open") as mock_open:
mock_open.side_effect = PermissionError(
"[Errno 13] Permission denied: 'forbidden.log'"
)
result = run(executor, statement, pgspecial=cli.pgspecial)
assert result == [
"[Errno 13] Permission denied: 'forbidden.log'\nLogfile capture disabled"
]
@dbtest @dbtest
def test_watch_works(executor): def test_watch_works(executor):
cli = PGCli(pgexecute=executor) cli = PGCli(pgexecute=executor)
@ -431,6 +482,7 @@ def test_pg_service_file(tmpdir):
"b_host", "b_host",
"5435", "5435",
"", "",
notify_callback,
application_name="pgcli", application_name="pgcli",
) )
del os.environ["PGPASSWORD"] del os.environ["PGPASSWORD"]
@ -486,5 +538,50 @@ def test_application_name_db_uri(tmpdir):
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar@baz.com/?application_name=cow") cli.connect_uri("postgres://bar@baz.com/?application_name=cow")
mock_pgexecute.assert_called_with( mock_pgexecute.assert_called_with(
"bar", "bar", "", "baz.com", "", "", application_name="cow" "bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow"
) )
@pytest.mark.parametrize(
"duration_in_seconds,words",
[
(0, "0 seconds"),
(0.0009, "0.001 second"),
(0.0005, "0.001 second"),
(0.0004, "0.0 second"), # not perfect, but will do
(0.2, "0.2 second"),
(1, "1 second"),
(1.4, "1 second"),
(2, "2 seconds"),
(3.4, "3 seconds"),
(60, "1 minute"),
(61, "1 minute 1 second"),
(123, "2 minutes 3 seconds"),
(3600, "1 hour"),
(7235, "2 hours 35 seconds"),
(9005, "2 hours 30 minutes 5 seconds"),
(86401, "24 hours 1 second"),
],
)
def test_duration_in_words(duration_in_seconds, words):
assert duration_in_words(duration_in_seconds) == words
@dbtest
def test_notifications(executor):
run(executor, "listen chan1")
with mock.patch("pgcli.main.click.secho") as mock_secho:
run(executor, "notify chan1, 'testing1'")
mock_secho.assert_called()
arg = mock_secho.call_args_list[0].args[0]
assert re.match(
r'Notification received on channel "chan1" \(PID \d+\):\ntesting1',
arg,
)
run(executor, "unlisten chan1")
with mock.patch("pgcli.main.click.secho") as mock_secho:
run(executor, "notify chan1, 'testing2'")
mock_secho.assert_not_called()

View file

@ -1,3 +1,4 @@
import re
from textwrap import dedent from textwrap import dedent
import psycopg import psycopg
@ -6,7 +7,7 @@ from unittest.mock import patch, MagicMock
from pgspecial.main import PGSpecial, NO_QUERY from pgspecial.main import PGSpecial, NO_QUERY
from utils import run, dbtest, requires_json, requires_jsonb from utils import run, dbtest, requires_json, requires_jsonb
from pgcli.main import PGCli from pgcli.main import PGCli, exception_formatter as main_exception_formatter
from pgcli.packages.parseutils.meta import FunctionMetadata from pgcli.packages.parseutils.meta import FunctionMetadata
@ -219,8 +220,33 @@ def test_database_list(executor):
@dbtest @dbtest
def test_invalid_syntax(executor, exception_formatter): def test_invalid_syntax(executor, exception_formatter):
result = run(executor, "invalid syntax!", exception_formatter=exception_formatter) result = run(
executor,
"invalid syntax!",
exception_formatter=lambda x: main_exception_formatter(x, verbose_errors=False),
)
assert 'syntax error at or near "invalid"' in result[0] assert 'syntax error at or near "invalid"' in result[0]
assert "SQLSTATE" not in result[0]
@dbtest
def test_invalid_syntax_verbose(executor):
result = run(
executor,
"invalid syntax!",
exception_formatter=lambda x: main_exception_formatter(x, verbose_errors=True),
)
fields = r"""
Severity: ERROR
Severity \(non-localized\): ERROR
SQLSTATE code: 42601
Message: syntax error at or near "invalid"
Position: 1
File: scan\.l
Line: \d+
Routine: scanner_yyerror
""".strip()
assert re.search(fields, result[0])
@dbtest @dbtest
@ -690,6 +716,38 @@ def test_function_definition(executor):
result = executor.function_definition("the_number_three") result = executor.function_definition("the_number_three")
@dbtest
def test_function_notice_order(executor):
run(
executor,
"""
CREATE OR REPLACE FUNCTION demo_order() RETURNS VOID AS
$$
BEGIN
RAISE NOTICE 'first';
RAISE NOTICE 'second';
RAISE NOTICE 'third';
RAISE NOTICE 'fourth';
RAISE NOTICE 'fifth';
RAISE NOTICE 'sixth';
END;
$$
LANGUAGE plpgsql;
""",
)
executor.function_definition("demo_order")
result = run(executor, "select demo_order()")
assert "first\nsecond\nthird\nfourth\nfifth\nsixth" in result[0]
assert "+------------+" in result[1]
assert "| demo_order |" in result[2]
assert "|------------|" in result[3]
assert "| |" in result[4]
assert "+------------+" in result[5]
assert "SELECT 1" in result[6]
@dbtest @dbtest
def test_view_definition(executor): def test_view_definition(executor):
run(executor, "create table tbl1 (a text, b numeric)") run(executor, "create table tbl1 (a text, b numeric)")
@ -721,6 +779,10 @@ def test_short_host(executor):
executor, "host", "localhost1.example.org,localhost2.example.org" executor, "host", "localhost1.example.org,localhost2.example.org"
): ):
assert executor.short_host == "localhost1" assert executor.short_host == "localhost1"
with patch.object(executor, "host", "ec2-11-222-333-444.compute-1.amazonaws.com"):
assert executor.short_host == "ec2-11-222-333-444"
with patch.object(executor, "host", "1.2.3.4"):
assert executor.short_host == "1.2.3.4"
class VirtualCursor: class VirtualCursor:

View file

@ -6,7 +6,7 @@ from configobj import ConfigObj
from click.testing import CliRunner from click.testing import CliRunner
from sshtunnel import SSHTunnelForwarder from sshtunnel import SSHTunnelForwarder
from pgcli.main import cli, PGCli from pgcli.main import cli, notify_callback, PGCli
from pgcli.pgexecute import PGExecute from pgcli.pgexecute import PGExecute
@ -61,6 +61,7 @@ def test_ssh_tunnel(
"127.0.0.1", "127.0.0.1",
pgcli.ssh_tunnel.local_bind_ports[0], pgcli.ssh_tunnel.local_bind_ports[0],
"", "",
notify_callback,
) )
mock_ssh_tunnel_forwarder.reset_mock() mock_ssh_tunnel_forwarder.reset_mock()
mock_pgexecute.reset_mock() mock_pgexecute.reset_mock()
@ -96,6 +97,7 @@ def test_ssh_tunnel(
"127.0.0.1", "127.0.0.1",
pgcli.ssh_tunnel.local_bind_ports[0], pgcli.ssh_tunnel.local_bind_ports[0],
"", "",
notify_callback,
) )
mock_ssh_tunnel_forwarder.reset_mock() mock_ssh_tunnel_forwarder.reset_mock()
mock_pgexecute.reset_mock() mock_pgexecute.reset_mock()