1
0
Fork 0

Merging upstream version 3.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 19:58:08 +01:00
parent a868bb3d29
commit 39b7cc8559
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
50 changed files with 952 additions and 634 deletions

66
.github/workflows/ci.yml vendored Normal file
View file

@ -0,0 +1,66 @@
name: pgcli
on:
pull_request:
paths-ignore:
- '**.rst'
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
services:
postgres:
image: postgres:9.6
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install requirements
run: |
pip install -U pip setuptools
pip install --no-cache-dir .
pip install -r requirements-dev.txt
pip install keyrings.alt>=3.1
- name: Run unit tests
run: coverage run --source pgcli -m py.test
- name: Run integration tests
env:
PGUSER: postgres
PGPASSWORD: postgres
run: behave tests/features --no-capture
- name: Check changelog for ReST compliance
run: rst2html.py --halt=warning changelog.rst >/dev/null
- name: Run Black
run: pip install black && black --check .
if: matrix.python-version == '3.6'
- name: Coverage
run: |
coverage combine
coverage report
codecov

View file

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: stable
rev: 21.5b0
hooks:
- id: black
language_version: python3.7

View file

@ -1,51 +0,0 @@
dist: xenial
sudo: required
language: python
python:
- "3.6"
- "3.7"
- "3.8"
- "3.9-dev"
before_install:
- which python
- which pip
- pip install -U setuptools
install:
- pip install --no-cache-dir .
- pip install -r requirements-dev.txt
- pip install keyrings.alt>=3.1
script:
- set -e
- coverage run --source pgcli -m py.test
- cd tests
- behave --no-capture
- cd ..
# check for changelog ReST compliance
- rst2html.py --halt=warning changelog.rst >/dev/null
# check for black code compliance, 3.6 only
- if [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then pip install black && black --check . ; else echo "Skipping black for $TRAVIS_PYTHON_VERSION"; fi
- set +e
after_success:
- coverage combine
- codecov
notifications:
webhooks:
urls:
- YOUR_WEBHOOK_URL
on_success: change # options: [always|never|change] default: always
on_failure: always # options: [always|never|change] default: always
on_start: false # default: false
services:
- postgresql
addons:
postgresql: "9.6"

View file

@ -114,6 +114,10 @@ Contributors:
* Tom Caruso (tomplex)
* Jan Brun Rasmussen (janbrunrasmussen)
* Kevin Marsh (kevinmarsh)
* Eero Ruohola (ruohola)
* Miroslav Šedivý (eumiro)
* Eric R Young (ERYoung11)
* Paweł Sacawa (psacawa)
Creator:
--------

View file

@ -170,7 +170,7 @@ Troubleshooting the integration tests
- Make sure postgres instance on localhost is running
- Check your ``pg_hba.conf`` file to verify local connections are enabled
- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information.
- Contact us on `gitter <https://gitter.im/dbcli/pgcli/>`_ or `file an issue <https://github.com/dbcli/pgcli/issues/new>`_.
- `File an issue <https://github.com/dbcli/pgcli/issues/new>`_.
Coding Style
------------

View file

@ -1,7 +1,7 @@
A REPL for Postgres
-------------------
|Build Status| |CodeCov| |PyPI| |Landscape| |Gitter|
|Build Status| |CodeCov| |PyPI| |Landscape|
This is a postgres client that does auto-completion and syntax highlighting.
@ -62,32 +62,32 @@ For more details:
Usage: pgcli [OPTIONS] [DBNAME] [USERNAME]
Options:
-h, --host TEXT Host address of the postgres database.
-p, --port INTEGER Port number at which the postgres instance is
listening.
-U, --username TEXT Username to connect to the postgres database.
-u, --user TEXT Username to connect to the postgres database.
-W, --password Force password prompt.
-w, --no-password Never prompt for password.
--single-connection Do not use a separate connection for completions.
-v, --version Version of pgcli.
-d, --dbname TEXT database name to connect to.
--pgclirc PATH Location of pgclirc file.
-D, --dsn TEXT Use DSN configured into the [alias_dsn] section of
pgclirc file.
--list-dsn list of DSN configured into the [alias_dsn] section
of pgclirc file.
--row-limit INTEGER Set threshold for row limit prompt. Use 0 to disable
prompt.
--less-chatty Skip intro on startup and goodbye on exit.
--prompt TEXT Prompt format (Default: "\u@\h:\d> ").
--prompt-dsn TEXT Prompt format for connections using DSN aliases
(Default: "\u@\h:\d> ").
-l, --list list available databases, then exit.
--auto-vertical-output Automatically switch to vertical output mode if the
result is wider than the terminal width.
--warn / --no-warn Warn before running a destructive query.
--help Show this message and exit.
-h, --host TEXT Host address of the postgres database.
-p, --port INTEGER Port number at which the postgres instance is
listening.
-U, --username TEXT Username to connect to the postgres database.
-u, --user TEXT Username to connect to the postgres database.
-W, --password Force password prompt.
-w, --no-password Never prompt for password.
--single-connection Do not use a separate connection for completions.
-v, --version Version of pgcli.
-d, --dbname TEXT database name to connect to.
--pgclirc FILE Location of pgclirc file.
-D, --dsn TEXT Use DSN configured into the [alias_dsn] section
of pgclirc file.
--list-dsn list of DSN configured into the [alias_dsn]
section of pgclirc file.
--row-limit INTEGER Set threshold for row limit prompt. Use 0 to
disable prompt.
--less-chatty Skip intro on startup and goodbye on exit.
--prompt TEXT Prompt format (Default: "\u@\h:\d> ").
--prompt-dsn TEXT Prompt format for connections using DSN aliases
(Default: "\u@\h:\d> ").
-l, --list list available databases, then exit.
--auto-vertical-output Automatically switch to vertical output mode if
the result is wider than the terminal width.
--warn [all|moderate|off] Warn before running a destructive query.
--help Show this message and exit.
``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``).
@ -352,8 +352,8 @@ interface to Postgres database.
Thanks to all the beta testers and contributors for your time and patience. :)
.. |Build Status| image:: https://api.travis-ci.org/dbcli/pgcli.svg?branch=master
:target: https://travis-ci.org/dbcli/pgcli
.. |Build Status| image:: https://github.com/dbcli/pgcli/workflows/pgcli/badge.svg
:target: https://github.com/dbcli/pgcli/actions?query=workflow%3Apgcli
.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
:target: https://codecov.io/gh/dbcli/pgcli
@ -366,7 +366,3 @@ Thanks to all the beta testers and contributors for your time and patience. :)
.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg
:target: https://pypi.python.org/pypi/pgcli/
:alt: Latest Version
.. |Gitter| image:: https://badges.gitter.im/Join%20Chat.svg
:target: https://gitter.im/dbcli/pgcli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge
:alt: Gitter Chat

81
Vagrantfile vendored
View file

@ -1,5 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
#
#
Vagrant.configure(2) do |config|
@ -9,20 +11,23 @@ Vagrant.configure(2) do |config|
pgcli_description = "Postgres CLI with autocompletion and syntax highlighting"
config.vm.define "debian" do |debian|
debian.vm.box = "chef/debian-7.8"
debian.vm.box = "bento/debian-10.8"
debian.vm.provision "shell", inline: <<-SHELL
echo "-> Building DEB on `lsb_release -s`"
echo "-> Building DEB on `lsb_release -d`"
sudo apt-get update
sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems
sudo easy_install pip
sudo pip install virtualenv virtualenv-tools
sudo apt install -y python3-pip
sudo pip3 install --no-cache-dir virtualenv virtualenv-tools3
sudo apt-get install -y ruby-dev
sudo apt-get install -y git
sudo apt-get install -y rpm librpmbuild8
sudo gem install fpm
echo "-> Cleaning up old workspace"
rm -rf build
sudo rm -rf build
mkdir -p build/usr/share
virtualenv build/usr/share/pgcli
build/usr/share/pgcli/bin/pip install -U pip distribute
build/usr/share/pgcli/bin/pip uninstall -y distribute
build/usr/share/pgcli/bin/pip install /pgcli
echo "-> Cleaning Virtualenv"
@ -45,24 +50,59 @@ Vagrant.configure(2) do |config|
--url https://github.com/dbcli/pgcli \
--description "#{pgcli_description}" \
--license 'BSD'
SHELL
end
# This is considerably more messy than the debian section. I had to go off-standard to update
# some packages to get this to work.
config.vm.define "centos" do |centos|
centos.vm.box = "chef/centos-7.0"
centos.vm.box = "bento/centos-7.9"
centos.vm.box_version = "202012.21.0"
centos.vm.provision "shell", inline: <<-SHELL
#!/bin/bash
echo "-> Building RPM on `lsb_release -s`"
sudo yum install -y rpm-build gcc ruby-devel postgresql-devel python-devel rubygems
sudo easy_install pip
sudo pip install virtualenv virtualenv-tools
sudo gem install fpm
echo "-> Building RPM on `hostnamectl | grep "Operating System"`"
export PATH=/usr/local/rvm/gems/ruby-2.6.3/bin:/usr/local/rvm/gems/ruby-2.6.3@global/bin:/usr/local/rvm/rubies/ruby-2.6.3/bin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/rvm/bin:/root/bin
echo "PATH -> " $PATH
#####
### get base updates
sudo yum install -y rpm-build gcc postgresql-devel python-devel python3-pip git python3-devel
######
### install FPM, which we need to install to get an up-to-date version of ruby, which we need for git
echo "-> Get FPM installed"
# import the necessary GPG keys
gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
sudo gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
# install RVM
sudo curl -sSL https://get.rvm.io | sudo bash -s stable
sudo usermod -aG rvm vagrant
sudo usermod -aG rvm root
sudo /usr/local/rvm/bin/rvm alias create default 2.6.3
source /etc/profile.d/rvm.sh
# install a newer version of ruby. centos7 only comes with ruby2.0.0, which isn't good enough for git.
sudo yum install -y ruby-devel
sudo /usr/local/rvm/bin/rvm install 2.6.3
#
# yes,this gives an error about generating doc but we don't need the doc.
/usr/local/rvm/gems/ruby-2.6.3/wrappers/gem install fpm
######
sudo pip3 install virtualenv virtualenv-tools3
echo "-> Cleaning up old workspace"
rm -rf build
mkdir -p build/usr/share
virtualenv build/usr/share/pgcli
build/usr/share/pgcli/bin/pip install -U pip distribute
build/usr/share/pgcli/bin/pip uninstall -y distribute
build/usr/share/pgcli/bin/pip install /pgcli
echo "-> Cleaning Virtualenv"
@ -74,9 +114,9 @@ Vagrant.configure(2) do |config|
find build -iname '*.pyc' -delete
find build -iname '*.pyo' -delete
cd /home/vagrant
echo "-> Creating PgCLI RPM"
echo $PATH
sudo /usr/local/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
/usr/local/rvm/gems/ruby-2.6.3/gems/fpm-1.12.0/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
-a all \
-d postgresql-devel \
-d python-devel \
@ -86,8 +126,13 @@ Vagrant.configure(2) do |config|
--url https://github.com/dbcli/pgcli \
--description "#{pgcli_description}" \
--license 'BSD'
SHELL
SHELL
end
end

View file

@ -1,3 +1,37 @@
TBD
=====
Features:
---------
Bug fixes:
----------
3.2.0
=====
Release date: 2021/08/23
Features:
---------
* Consider `update` queries destructive and issue a warning. Change
`destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239)
* Skip initial comment in .pg_session even if it doesn't start with '#'
* Include functions from schemas in search_path. (`Amjith Ramanujam`_)
Bug fixes:
----------
* Fix issue where `syntax_style` config value would not have any effect. (#1212)
* Fix crash because of not found `InputMode.REPLACE_SINGLE` with prompt-toolkit < 3.0.6
* Fix comments being lost in config when saving a named query. (#1240)
* Fix IPython magic for ipython-sql >= 0.4.0
* Fix pager not being used when output format is set to csv. (#1238)
* Add function literals random, generate_series, generate_subscripts
* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly
* Fix pgcli crashing with virtual `pgbouncer` database. (#1093)
3.1.0
=====

View file

@ -1 +1 @@
__version__ = "3.1.0"
__version__ = "3.2.0"

View file

@ -3,10 +3,9 @@ import os
from collections import OrderedDict
from .pgcompleter import PGCompleter
from .pgexecute import PGExecute
class CompletionRefresher(object):
class CompletionRefresher:
refreshers = OrderedDict()
@ -27,6 +26,10 @@ class CompletionRefresher(object):
has completed the refresh. The newly created completion
object will be passed in as an argument to each callback.
"""
if executor.is_virtual_database():
# do nothing
return [(None, None, None, "Auto-completion refresh can't be started.")]
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]
@ -141,7 +144,7 @@ def refresh_casing(completer, executor):
with open(casing_file, "w") as f:
f.write(casing_prefs)
if os.path.isfile(casing_file):
with open(casing_file, "r") as f:
with open(casing_file) as f:
completer.extend_casing([line.strip() for line in f])

View file

@ -3,6 +3,8 @@ import shutil
import os
import platform
from os.path import expanduser, exists, dirname
import re
from typing import TextIO
from configobj import ConfigObj
@ -16,11 +18,15 @@ def config_location():
def load_config(usr_cfg, def_cfg=None):
cfg = ConfigObj()
cfg.merge(ConfigObj(def_cfg, interpolation=False))
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
# avoid config merges when possible. For writing, we need an umerged config instance.
# see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171
if def_cfg:
cfg = ConfigObj()
cfg.merge(ConfigObj(def_cfg, interpolation=False))
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
else:
cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")
cfg.filename = expanduser(usr_cfg)
return cfg
@ -44,12 +50,16 @@ def upgrade_config(config, def_config):
cfg.write()
def get_config_filename(pgclirc_file=None):
return pgclirc_file or "%sconfig" % config_location()
def get_config(pgclirc_file=None):
from pgcli import __file__ as package_root
package_root = os.path.dirname(package_root)
pgclirc_file = pgclirc_file or "%sconfig" % config_location()
pgclirc_file = get_config_filename(pgclirc_file)
default_config = os.path.join(package_root, "pgclirc")
write_default_config(default_config, pgclirc_file)
@ -62,3 +72,28 @@ def get_casing_file(config):
if casing_file == "default":
casing_file = config_location() + "casing"
return casing_file
def skip_initial_comment(f_stream: TextIO) -> int:
"""
Initial comment in ~/.pg_service.conf is not always marked with '#'
which crashes the parser. This function takes a file object and
"rewinds" it to the beginning of the first section,
from where on it can be parsed safely
:return: number of skipped lines
"""
section_regex = r"\s*\["
pos = f_stream.tell()
lines_skipped = 0
while True:
line = f_stream.readline()
if line == "":
break
if re.match(section_regex, line) is not None:
f_stream.seek(pos)
break
else:
pos += len(line)
lines_skipped += 1
return lines_skipped

View file

@ -25,7 +25,11 @@ def pgcli_line_magic(line):
if hasattr(sql.connection.Connection, "get"):
conn = sql.connection.Connection.get(parsed["connection"])
else:
conn = sql.connection.Connection.set(parsed["connection"])
try:
conn = sql.connection.Connection.set(parsed["connection"])
# a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql
except TypeError:
conn = sql.connection.Connection.set(parsed["connection"], False)
try:
# A corresponding pgcli object already exists
@ -43,7 +47,7 @@ def pgcli_line_magic(line):
conn._pgcli = pgcli
# For convenience, print the connection alias
print("Connected: {}".format(conn.name))
print(f"Connected: {conn.name}")
try:
pgcli.run_cli()

View file

@ -2,8 +2,9 @@ import platform
import warnings
from os.path import expanduser
from configobj import ConfigObj
from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
from .config import skip_initial_comment
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
@ -20,12 +21,12 @@ import datetime as dt
import itertools
import platform
from time import time, sleep
from codecs import open
keyring = None # keyring will be loaded later
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
from cli_helpers.utils import strip_ansi
import click
try:
@ -62,6 +63,7 @@ from .config import (
config_location,
ensure_dir_exists,
get_config,
get_config_filename,
)
from .key_bindings import pgcli_bindings
from .packages.prompt_utils import confirm_destructive_query
@ -122,7 +124,7 @@ class PgCliQuitError(Exception):
pass
class PGCli(object):
class PGCli:
default_prompt = "\\u@\\h:\\d> "
max_len_prompt = 30
@ -175,7 +177,11 @@ class PGCli(object):
# Load config.
c = self.config = get_config(pgclirc_file)
NamedQueries.instance = NamedQueries.from_config(self.config)
# at this point, config should be written to pgclirc_file if it did not exist. Read it.
self.config_writer = load_config(get_config_filename(pgclirc_file))
# make sure to use self.config_writer, not self.config
NamedQueries.instance = NamedQueries.from_config(self.config_writer)
self.logger = logging.getLogger(__name__)
self.initialize_logging()
@ -201,8 +207,11 @@ class PGCli(object):
self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
c_dest_warning = c["main"].as_bool("destructive_warning")
self.destructive_warning = c_dest_warning if warn is None else warn
self.destructive_warning = warn or c["main"]["destructive_warning"]
# also handle boolean format of destructive warning
self.destructive_warning = {"true": "all", "false": "off"}.get(
self.destructive_warning.lower(), self.destructive_warning
)
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = (
@ -325,11 +334,11 @@ class PGCli(object):
if pattern not in TabularOutputFormatter().supported_formats:
raise ValueError()
self.table_format = pattern
yield (None, None, None, "Changed table format to {}".format(pattern))
yield (None, None, None, f"Changed table format to {pattern}")
except ValueError:
msg = "Table format {} not recognized. Allowed formats:".format(pattern)
msg = f"Table format {pattern} not recognized. Allowed formats:"
for table_type in TabularOutputFormatter().supported_formats:
msg += "\n\t{}".format(table_type)
msg += f"\n\t{table_type}"
msg += "\nCurrently set to: %s" % self.table_format
yield (None, None, None, msg)
@ -386,10 +395,13 @@ class PGCli(object):
try:
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
query = f.read()
except IOError as e:
except OSError as e:
return [(None, None, None, str(e), "", False, True)]
if self.destructive_warning and confirm_destructive_query(query) is False:
if (
self.destructive_warning != "off"
and confirm_destructive_query(query, self.destructive_warning) is False
):
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
@ -407,7 +419,7 @@ class PGCli(object):
if not os.path.isfile(filename):
try:
open(filename, "w").close()
except IOError as e:
except OSError as e:
self.output_file = None
message = str(e) + "\nFile output disabled"
return [(None, None, None, message, "", False, True)]
@ -479,7 +491,7 @@ class PGCli(object):
service_config, file = parse_service_info(service)
if service_config is None:
click.secho(
"service '%s' was not found in %s" % (service, file), err=True, fg="red"
f"service '{service}' was not found in {file}", err=True, fg="red"
)
exit(1)
self.connect(
@ -515,7 +527,7 @@ class PGCli(object):
passwd = os.environ.get("PGPASSWORD", "")
# Find password from store
key = "%s@%s" % (user, host)
key = f"{user}@{host}"
keyring_error_message = dedent(
"""\
{}
@ -644,8 +656,10 @@ class PGCli(object):
query = MetaQuery(query=text, successful=False)
try:
if self.destructive_warning:
destroy = confirm = confirm_destructive_query(text)
if self.destructive_warning != "off":
destroy = confirm = confirm_destructive_query(
text, self.destructive_warning
)
if destroy is False:
click.secho("Wise choice!")
raise KeyboardInterrupt
@ -677,7 +691,7 @@ class PGCli(object):
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
except IOError as e:
except OSError as e:
click.secho(str(e), err=True, fg="red")
else:
if output:
@ -729,7 +743,6 @@ class PGCli(object):
if not self.less_chatty:
print("Server: PostgreSQL", self.pgexecute.server_version)
print("Version:", __version__)
print("Chat: https://gitter.im/dbcli/pgcli")
print("Home: http://pgcli.com")
try:
@ -753,11 +766,7 @@ class PGCli(object):
while self.watch_command:
try:
query = self.execute_command(self.watch_command)
click.echo(
"Waiting for {0} seconds before repeating".format(
timing
)
)
click.echo(f"Waiting for {timing} seconds before repeating")
sleep(timing)
except KeyboardInterrupt:
self.watch_command = None
@ -979,16 +988,13 @@ class PGCli(object):
callback = functools.partial(
self._on_completions_refreshed, persist_priorities=persist_priorities
)
self.completion_refresher.refresh(
return self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
callback,
history=history,
settings=self.settings,
)
return [
(None, None, None, "Auto-completion refresh started in the background.")
]
def _on_completions_refreshed(self, new_completer, persist_priorities):
self._swap_completer_objects(new_completer, persist_priorities)
@ -1049,7 +1055,7 @@ class PGCli(object):
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
)
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
string = string.replace("\\n", "\n")
return string
@ -1075,9 +1081,10 @@ class PGCli(object):
def echo_via_pager(self, text, color=None):
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
click.echo(text, color=color)
elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv":
click.echo_via_pager(text, color)
elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
elif (
self.pgspecial.pager_config == PAGER_LONG_OUTPUT
and self.table_format != "csv"
):
lines = text.split("\n")
# The last 4 lines are reserved for the pgcli menu and padding
@ -1192,7 +1199,10 @@ class PGCli(object):
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
)
@click.option(
"--warn/--no-warn", default=None, help="Warn before running a destructive query."
"--warn",
default=None,
type=click.Choice(["all", "moderate", "off"]),
help="Warn before running a destructive query.",
)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
@ -1384,7 +1394,7 @@ def is_mutating(status):
if not status:
return False
mutating = set(["insert", "update", "delete"])
mutating = {"insert", "update", "delete"}
return status.split(None, 1)[0].lower() in mutating
@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings):
formatted = iter(formatted.splitlines())
first_line = next(formatted)
formatted = itertools.chain([first_line], formatted)
if not expanded and max_width and len(first_line) > max_width and headers:
if (
not expanded
and max_width
and len(strip_ansi(first_line)) > max_width
and headers
):
formatted = formatter.format_output(
cur, headers, format_name="vertical", column_types=None, **output_kwargs
)
@ -1502,10 +1517,16 @@ def parse_service_info(service):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
service_file = expanduser("~/.pg_service.conf")
if not service:
if not service or not os.path.exists(service_file):
# nothing to do
return None, service_file
service_file_config = ConfigObj(service_file)
with open(service_file, newline="") as f:
skipped_lines = skip_initial_comment(f)
try:
service_file_config = ConfigObj(f)
except ParseError as err:
err.line_number += skipped_lines
raise err
if service not in service_file_config:
return None, service_file
service_conf = service_file_config.get(service)

View file

@ -1,22 +1,34 @@
import sqlparse
def query_starts_with(query, prefixes):
def query_starts_with(formatted_sql, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
def queries_start_with(queries, prefixes):
"""Check if any queries start with any item from *prefixes*."""
for query in sqlparse.split(queries):
if query and query_starts_with(query, prefixes) is True:
return True
return False
def query_is_unconditional_update(formatted_sql):
"""Check if the query starts with UPDATE and contains no WHERE."""
tokens = formatted_sql.split()
return bool(tokens) and tokens[0] == "update" and "where" not in tokens
def is_destructive(queries):
def query_is_simple_update(formatted_sql):
"""Check if the query starts with UPDATE."""
tokens = formatted_sql.split()
return bool(tokens) and tokens[0] == "update"
def is_destructive(queries, warning_level="all"):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
return queries_start_with(queries, keywords)
for query in sqlparse.split(queries):
if query:
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
if query_starts_with(formatted_sql, keywords):
return True
if query_is_unconditional_update(formatted_sql):
return True
if warning_level == "all" and query_is_simple_update(formatted_sql):
return True
return False

View file

@ -50,7 +50,7 @@ def parse_defaults(defaults_string):
yield current
class FunctionMetadata(object):
class FunctionMetadata:
def __init__(
self,
schema_name,

View file

@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
yield from extract_from_part(item, stop_at_punctuation)
elif stop_at_punctuation and item.ttype is Punctuation:
return
# An incomplete nested select won't be recognized correctly as a

View file

@ -392,6 +392,7 @@
"QUOTE_NULLABLE",
"RADIANS",
"RADIUS",
"RANDOM",
"RANK",
"REGEXP_MATCH",
"REGEXP_MATCHES",

View file

@ -16,10 +16,10 @@ def _compile_regex(keyword):
keywords = get_literals("keywords")
keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
keyword_regexs = {kw: _compile_regex(kw) for kw in keywords}
class PrevalenceCounter(object):
class PrevalenceCounter:
def __init__(self):
self.keyword_counts = defaultdict(int)
self.name_counts = defaultdict(int)

View file

@ -3,7 +3,7 @@ import click
from .parseutils import is_destructive
def confirm_destructive_query(queries):
def confirm_destructive_query(queries, warning_level):
"""Check if the query is destructive and prompts the user to confirm.
Returns:
@ -15,7 +15,7 @@ def confirm_destructive_query(queries):
prompt_text = (
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
)
if is_destructive(queries) and sys.stdin.isatty():
if is_destructive(queries, warning_level) and sys.stdin.isatty():
return prompt(prompt_text, type=bool)

View file

@ -47,7 +47,7 @@ Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple("Path", [])
class SqlStatement(object):
class SqlStatement:
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(

View file

@ -23,9 +23,13 @@ multi_line = False
multi_line_mode = psql
# Destructive warning mode will alert you before executing a sql statement
# that may cause harm to the database such as "drop table", "drop database"
# or "shutdown".
destructive_warning = True
# that may cause harm to the database such as "drop table", "drop database",
# "shutdown", "delete", or "update".
# Possible values:
# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE
# "moderate" - skip warning on UPDATE statements, except for unconditional updates
# "off" - skip all warnings
destructive_warning = all
# Enables expand mode, which is similar to `\x` in psql.
expand = False
@ -170,9 +174,12 @@ arg-toolbar = 'noinherit bold'
arg-toolbar.text = 'nobold'
bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
literal.string = '#ba2121'
literal.number = '#666666'
keyword = 'bold #008000'
# These three values can be used to further refine the syntax highlighting.
# They are commented out by default, since they have priority over the theme set
# with the `syntax_style` setting and overriding its behavior can be confusing.
# literal.string = '#ba2121'
# literal.number = '#666666'
# keyword = 'bold #008000'
# style classes for colored table output
output.header = "#00ff5f bold"

View file

@ -83,7 +83,7 @@ class PGCompleter(Completer):
reserved_words = set(get_literals("reserved"))
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
super(PGCompleter, self).__init__()
super().__init__()
self.smart_completion = smart_completion
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
@ -140,7 +140,7 @@ class PGCompleter(Completer):
return "'{}'".format(self.unescape_name(name))
def unescape_name(self, name):
""" Unquote a string."""
"""Unquote a string."""
if name and name[0] == '"' and name[-1] == '"':
name = name[1:-1]
@ -177,7 +177,7 @@ class PGCompleter(Completer):
:return:
"""
# casing should be a dict {lowercasename:PreferredCasingName}
self.casing = dict((word.lower(), word) for word in words)
self.casing = {word.lower(): word for word in words}
def extend_relations(self, data, kind):
"""extend metadata for tables or views.
@ -279,8 +279,8 @@ class PGCompleter(Completer):
fk = ForeignKey(
parentschema, parenttable, parcol, childschema, childtable, childcol
)
childcolmeta.foreignkeys.append((fk))
parcolmeta.foreignkeys.append((fk))
childcolmeta.foreignkeys.append(fk)
parcolmeta.foreignkeys.append(fk)
def extend_datatypes(self, type_data):
@ -424,7 +424,7 @@ class PGCompleter(Completer):
# the same priority as unquoted names.
lexical_priority = (
tuple(
0 if c in (" _") else -ord(c)
0 if c in " _" else -ord(c)
for c in self.unescape_name(item.lower())
)
+ (1,)
@ -517,9 +517,9 @@ class PGCompleter(Completer):
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref
other_tbl_cols = set(
other_tbl_cols = {
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
)
}
scoped_cols = {
t: [col for col in cols if col.name in other_tbl_cols]
for t, cols in scoped_cols.items()
@ -574,7 +574,7 @@ class PGCompleter(Completer):
tbls - TableReference iterable of tables already in query
"""
tbl = self.case(tbl)
tbls = set(normalize_ref(t.ref) for t in tbls)
tbls = {normalize_ref(t.ref) for t in tbls}
if self.generate_aliases:
tbl = generate_alias(self.unescape_name(tbl))
if normalize_ref(tbl) not in tbls:
@ -589,10 +589,10 @@ class PGCompleter(Completer):
tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
refs = set(normalize_ref(t.ref) for t in tbls)
other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
qualified = {normalize_ref(t.ref): t.schema for t in tbls}
ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
refs = {normalize_ref(t.ref) for t in tbls}
other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
joins = []
# Iterate over FKs in existing tables to find potential joins
fks = (
@ -667,7 +667,7 @@ class PGCompleter(Completer):
return d
# Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
# Map (schema, table, col) to tables
coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
@ -703,7 +703,11 @@ class PGCompleter(Completer):
not f.is_aggregate
and not f.is_window
and not f.is_extension
and (f.is_public or f.schema_name == suggestion.schema)
and (
f.is_public
or f.schema_name in self.search_path
or f.schema_name == suggestion.schema
)
)
else:
@ -721,9 +725,7 @@ class PGCompleter(Completer):
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt)
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
)
funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
matches = self.find_matches(word_before_cursor, funcs, meta="function")
@ -953,7 +955,7 @@ class PGCompleter(Completer):
:return: {TableReference:{colname:ColumnMetaData}}
"""
ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
columns = OrderedDict()
meta = self.dbmetadata

View file

@ -1,13 +1,15 @@
import traceback
import logging
import select
import traceback
import pgspecial as special
import psycopg2
import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
import psycopg2.extras
import sqlparse
import pgspecial as special
import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
_wait_callback_is_set = False
def _wait_select(conn):
@ -34,31 +37,41 @@ def _wait_select(conn):
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except select.error as e:
errno = e.args[0]
if errno != 4:
raise
try:
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except OSError as e:
errno = e.args[0]
if errno != 4:
raise
except psycopg2.OperationalError:
pass
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
def _set_wait_callback(is_virtual_database):
global _wait_callback_is_set
if _wait_callback_is_set:
return
_wait_callback_is_set = True
if is_virtual_database:
return
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
@ -72,6 +85,8 @@ def register_date_typecasters(connection):
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
if cursor.description is None:
return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
except psycopg2.ProgrammingError:
except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
@ -127,7 +142,39 @@ def register_hstore_typecaster(conn):
pass
class PGExecute(object):
class ProtocolSafeCursor(psycopg2.extensions.cursor):
def __init__(self, *args, **kwargs):
self.protocol_error = False
self.protocol_message = ""
super().__init__(*args, **kwargs)
def __iter__(self):
if self.protocol_error:
raise StopIteration
return super().__iter__()
def fetchall(self):
if self.protocol_error:
return [(self.protocol_message,)]
return super().fetchall()
def fetchone(self):
if self.protocol_error:
return (self.protocol_message,)
return super().fetchone()
def execute(self, sql, args=None):
try:
psycopg2.extensions.cursor.execute(self, sql, args)
self.protocol_error = False
self.protocol_message = ""
except psycopg2.errors.ProtocolViolation as ex:
self.protocol_error = True
self.protocol_message = ex.pgerror
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
class PGExecute:
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
@ -190,8 +237,6 @@ class PGExecute(object):
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
version_query = "SELECT version();"
def __init__(
self,
database=None,
@ -203,6 +248,7 @@ class PGExecute(object):
**kwargs,
):
self._conn_params = {}
self._is_virtual_database = None
self.conn = None
self.dbname = None
self.user = None
@ -214,6 +260,11 @@ class PGExecute(object):
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
def is_virtual_database(self):
if self._is_virtual_database is None:
self._is_virtual_database = self.is_protocol_error()
return self._is_virtual_database
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
@ -250,9 +301,9 @@ class PGExecute(object):
)
conn_params.update({k: v for k, v in new_params.items() if v})
conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
@ -293,16 +344,22 @@ class PGExecute(object):
self.extra_args = kwargs
if not self.host:
self.host = self.get_socket_directory()
self.host = (
"pgbouncer"
if self.is_virtual_database()
else self.get_socket_directory()
)
pid = self._select_one(cursor, "select pg_backend_pid()")[0]
self.pid = pid
self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
self.server_version = conn.get_parameter_status("server_version")
self.server_version = conn.get_parameter_status("server_version") or ""
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
_set_wait_callback(self.is_virtual_database())
if not self.is_virtual_database():
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
@property
def short_host(self):
@ -395,7 +452,13 @@ class PGExecute(object):
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
for result in pgspecial.execute(cur, sql):
response = pgspecial.execute(cur, sql)
if cur and cur.protocol_error:
yield None, None, None, cur.protocol_message, statement, False, False
# this would close connection. We should reconnect.
self.connect()
continue
for result in response:
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
@ -453,6 +516,9 @@ class PGExecute(object):
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
elif cur.protocol_error:
_logger.debug("Protocol error, unsupported command.")
return title, None, None, cur.protocol_message
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
@ -485,7 +551,7 @@ class PGExecute(object):
try:
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
raise RuntimeError("View {} does not exist.".format(spec))
raise RuntimeError(f"View {spec} does not exist.")
result = cur.fetchone()
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
@ -501,7 +567,7 @@ class PGExecute(object):
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
raise RuntimeError("Function {} does not exist.".format(spec))
raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
"""Returns a list of schema names in the database"""
@ -527,21 +593,18 @@ class PGExecute(object):
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
yield from cur
def tables(self):
"""Yields (schema_name, table_name) tuples"""
for row in self._relations(kinds=["r", "p", "f"]):
yield row
yield from self._relations(kinds=["r", "p", "f"])
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
for row in self._relations(kinds=["v", "m"]):
yield row
yield from self._relations(kinds=["v", "m"])
def _columns(self, kinds=("r", "p", "f", "v", "m")):
"""Get column metadata for tables and views
@ -599,16 +662,13 @@ class PGExecute(object):
sql = cur.mogrify(columns_query, [kinds])
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
yield from cur
def table_columns(self):
for row in self._columns(kinds=["r", "p", "f"]):
yield row
yield from self._columns(kinds=["r", "p", "f"])
def view_columns(self):
for row in self._columns(kinds=["v", "m"]):
yield row
yield from self._columns(kinds=["v", "m"])
def databases(self):
with self.conn.cursor() as cur:
@ -623,6 +683,13 @@ class PGExecute(object):
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
def is_protocol_error(self):
query = "SELECT 1"
with self.conn.cursor() as cur:
_logger.debug("Simple Query. sql: %r", query)
cur.execute(query)
return bool(cur.protocol_error)
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(
@ -804,8 +871,7 @@ class PGExecute(object):
"""
_logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row
yield from cur
def casing(self):
"""Yields the most common casing for names used in db functions"""

View file

@ -1,15 +1,23 @@
from pkg_resources import packaging
import prompt_toolkit
from prompt_toolkit.key_binding.vi_state import InputMode
from prompt_toolkit.application import get_app
parse_version = packaging.version.parse
vi_modes = {
InputMode.INSERT: "I",
InputMode.NAVIGATION: "N",
InputMode.REPLACE: "R",
InputMode.INSERT_MULTIPLE: "M",
}
if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"):
vi_modes[InputMode.REPLACE_SINGLE] = "R"
def _get_vi_mode():
return {
InputMode.INSERT: "I",
InputMode.NAVIGATION: "N",
InputMode.REPLACE: "R",
InputMode.REPLACE_SINGLE: "R",
InputMode.INSERT_MULTIPLE: "M",
}[get_app().vi_state.input_mode]
return vi_modes[get_app().vi_state.input_mode]
def create_toolbar_tokens_func(pgcli):

View file

@ -1,5 +1,4 @@
pytest>=2.7.0
mock>=1.0.1
tox>=1.9.2
behave>=1.2.4
pexpect==3.3

View file

@ -12,7 +12,7 @@ from utils import (
import pgcli.pgexecute
@pytest.yield_fixture(scope="function")
@pytest.fixture(scope="function")
def connection():
create_db("_test_db")
connection = db_connection("_test_db")

View file

@ -44,7 +44,7 @@ def create_cn(hostname, password, username, dbname, port):
host=hostname, user=username, database=dbname, password=password, port=port
)
print("Created connection: {0}.".format(cn.dsn))
print(f"Created connection: {cn.dsn}.")
return cn
@ -75,4 +75,4 @@ def close_cn(cn=None):
"""
if cn:
cn.close()
print("Closed connection: {0}.".format(cn.dsn))
print(f"Closed connection: {cn.dsn}.")

View file

@ -38,7 +38,7 @@ def before_all(context):
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
db_name_full = "{0}_{1}".format(db_name, vi)
db_name_full = f"{db_name}_{vi}"
# Store get params from config.
context.conf = {
@ -63,7 +63,7 @@ def before_all(context):
"import coverage",
"coverage.process_startup()",
"import pgcli.main",
"pgcli.main.cli()",
"pgcli.main.cli(auto_envvar_prefix='BEHAVE')",
]
),
)
@ -102,6 +102,7 @@ def before_all(context):
else:
if "PGPASSWORD" in os.environ:
del os.environ["PGPASSWORD"]
os.environ["BEHAVE_WARN"] = "moderate"
context.cn = dbutils.create_db(
context.conf["host"],
@ -122,12 +123,12 @@ def before_all(context):
def show_env_changes(env_old, env_new):
"""Print out all test-specific env values."""
print("--- os.environ changed values: ---")
all_keys = set(list(env_old.keys()) + list(env_new.keys()))
all_keys = env_old.keys() | env_new.keys()
for k in sorted(all_keys):
old_value = env_old.get(k, "")
new_value = env_new.get(k, "")
if new_value and old_value != new_value:
print('{}="{}"'.format(k, new_value))
print(f'{k}="{new_value}"')
print("-" * 20)
@ -173,13 +174,13 @@ def after_scenario(context, scenario):
# Quit nicely.
if not context.atprompt:
dbname = context.currentdb
context.cli.expect_exact("{0}> ".format(dbname), timeout=15)
context.cli.expect_exact(f"{dbname}> ", timeout=15)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
try:
context.cli.expect_exact(pexpect.EOF, timeout=15)
except pexpect.TIMEOUT:
print("--- after_scenario {}: kill cli".format(scenario.name))
print(f"--- after_scenario {scenario.name}: kill cli")
context.cli.kill(signal.SIGKILL)
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
context.tmpfile_sql_help.close()

View file

@ -18,7 +18,7 @@ def read_fixture_files():
"""Read all files inside fixture_data directory."""
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, "fixture_data/")
print("reading fixture data: {}".format(fixture_dir))
print(f"reading fixture data: {fixture_dir}")
fixture_dict = {}
for filename in os.listdir(fixture_dir):
if filename not in [".", ".."]:

View file

@ -65,19 +65,20 @@ def step_ctrl_d(context):
Send Ctrl + D to hopefully exit.
"""
# turn off pager before exiting
context.cli.sendline("\pset pager off")
context.cli.sendcontrol("c")
context.cli.sendline(r"\pset pager off")
wrappers.wait_prompt(context)
context.cli.sendcontrol("d")
context.cli.expect(pexpect.EOF, timeout=15)
context.exit_sent = True
@when('we send "\?" command')
@when(r'we send "\?" command')
def step_send_help(context):
"""
r"""
Send \? to see help.
"""
context.cli.sendline("\?")
context.cli.sendline(r"\?")
@when("we send partial select command")
@ -96,9 +97,9 @@ def step_see_error_message(context):
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
context.tmpfile_sql_help.write(b"\?")
context.tmpfile_sql_help.write(br"\?")
context.tmpfile_sql_help.flush()
context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name))
context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)

View file

@ -14,7 +14,7 @@ def step_db_create(context):
"""
Send create database.
"""
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
context.cli.sendline("create database {};".format(context.conf["dbname_tmp"]))
context.response = {"database_name": context.conf["dbname_tmp"]}
@ -24,7 +24,7 @@ def step_db_drop(context):
"""
Send drop database.
"""
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"]))
@when("we connect to test database")
@ -33,7 +33,7 @@ def step_db_connect_test(context):
Send connect to database.
"""
db_name = context.conf["dbname"]
context.cli.sendline("\\connect {0}".format(db_name))
context.cli.sendline(f"\\connect {db_name}")
@when("we connect to dbserver")
@ -59,7 +59,7 @@ def step_see_prompt(context):
Wait to see the prompt.
"""
db_name = getattr(context, "currentdb", context.conf["dbname"])
wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5)
wrappers.expect_exact(context, f"{db_name}> ", timeout=5)
context.atprompt = True

View file

@ -31,7 +31,7 @@ def step_prepare_data(context):
@when("we set expanded {mode}")
def step_set_expanded(context, mode):
"""Set expanded to mode."""
context.cli.sendline("\\" + "x {}".format(mode))
context.cli.sendline("\\" + f"x {mode}")
wrappers.expect_exact(context, "Expanded display is", timeout=2)
wrappers.wait_prompt(context)

View file

@ -13,7 +13,7 @@ def step_edit_file(context):
)
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name)))
context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
)
@ -53,7 +53,7 @@ def step_tee_ouptut(context):
)
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name)))
context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name)))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Writing to file", timeout=5)
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@ -67,7 +67,7 @@ def step_query_select_123456(context):
@when("we stop teeing output")
def step_notee_output(context):
context.cli.sendline("\o")
context.cli.sendline(r"\o")
wrappers.expect_exact(context, "Time", timeout=5)

View file

@ -22,5 +22,10 @@ def step_see_refresh_started(context):
Wait to see refresh output.
"""
wrappers.expect_pager(
context, "Auto-completion refresh started in the background.\r\n", timeout=2
context,
[
"Auto-completion refresh started in the background.\r\n",
"Auto-completion refresh restarted.\r\n",
],
timeout=2,
)

View file

@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout):
def expect_pager(context, expected, timeout):
formatted = expected if isinstance(expected, list) else [expected]
formatted = [
f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
for t in formatted
]
expect_exact(
context,
"{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
formatted,
timeout=timeout,
)
@ -57,7 +63,7 @@ def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = currentdb or context.conf["dbname"]
context.cli.sendline("\pset pager always")
context.cli.sendline(r"\pset pager always")
if prompt_check:
wait_prompt(context)

0
tests/features/wrappager.py Executable file → Normal file
View file

View file

@ -3,7 +3,7 @@ from itertools import product
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from mock import Mock
from unittest.mock import Mock
import pytest
parametrize = pytest.mark.parametrize
@ -59,7 +59,7 @@ def wildcard_expansion(cols, pos=-1):
return Completion(cols, start_position=pos, display_meta="columns", display="*")
class MetaData(object):
class MetaData:
def __init__(self, metadata):
self.metadata = metadata
@ -128,7 +128,7 @@ class MetaData(object):
]
def schemas(self, pos=0):
schemas = set(sch for schs in self.metadata.values() for sch in schs)
schemas = {sch for schs in self.metadata.values() for sch in schs}
return [schema(escape(s), pos=pos) for s in schemas]
def functions_and_keywords(self, parent="public", pos=0):

View file

@ -1,4 +1,5 @@
import pytest
from pgcli.packages.parseutils import is_destructive
from pgcli.packages.parseutils.tables import extract_tables
from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
@ -34,12 +35,12 @@ def test_simple_select_single_table_double_quoted():
def test_simple_select_multiple_tables():
tables = extract_tables("select * from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_simple_select_multiple_tables_double_quoted():
tables = extract_tables('select * from "Abc", "Def"')
assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)])
assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)}
def test_simple_select_single_table_deouble_quoted_aliased():
@ -49,14 +50,12 @@ def test_simple_select_single_table_deouble_quoted_aliased():
def test_simple_select_multiple_tables_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a, "Def" d')
assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)])
assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)}
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables("select * from abc.def, ghi.jkl")
assert set(tables) == set(
[("abc", "def", None, False), ("ghi", "jkl", None, False)]
)
assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)}
def test_simple_select_with_cols_single_table():
@ -71,14 +70,12 @@ def test_simple_select_with_cols_single_table_schema_qualified():
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables("select a,b from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_simple_select_with_cols_multiple_qualified_tables():
tables = extract_tables("select a,b from abc.def, def.ghi")
assert set(tables) == set(
[("abc", "def", None, False), ("def", "ghi", None, False)]
)
assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)}
def test_select_with_hanging_comma_single_table():
@ -88,14 +85,12 @@ def test_select_with_hanging_comma_single_table():
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables("select a, from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
assert set(tables) == set(
[(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)]
)
assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)}
def test_simple_insert_single_table():
@ -126,14 +121,14 @@ def test_simple_update_table_with_schema():
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
def test_join_table(join_type):
sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type)
sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num"
tables = extract_tables(sql)
assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)])
assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)}
def test_join_table_schema_qualified():
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)])
assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)}
def test_incomplete_join_clause():
@ -177,25 +172,25 @@ def test_extract_no_tables(text):
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0})".format(arg_list))
tables = extract_tables(f"SELECT * FROM foo({arg_list})")
assert tables == ((None, "foo", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_schema_qualified_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list))
tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})")
assert tables == (("foo", "bar", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_aliased_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list))
tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar")
assert tables == ((None, "foo", "bar", True),)
def test_simple_table_and_function():
tables = extract_tables("SELECT * FROM foo JOIN bar()")
assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)])
assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)}
def test_complex_table_and_function():
@ -203,9 +198,7 @@ def test_complex_table_and_function():
"""SELECT * FROM foo.bar baz
JOIN bar.qux(x, y, z) quux"""
)
assert set(tables) == set(
[("foo", "bar", "baz", False), ("bar", "qux", "quux", True)]
)
assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)}
def test_find_prev_keyword_using():
@ -267,3 +260,21 @@ def test_is_open_quote__closed(sql):
)
def test_is_open_quote__open(sql):
assert is_open_quote(sql)
@pytest.mark.parametrize(
("sql", "warning_level", "expected"),
[
("update abc set x = 1", "all", True),
("update abc set x = 1 where y = 2", "all", True),
("update abc set x = 1", "moderate", True),
("update abc set x = 1 where y = 2", "moderate", False),
("select x, y, z from abc", "all", False),
("drop abc", "all", True),
("alter abc", "all", True),
("delete abc", "all", True),
("truncate abc", "all", True),
],
)
def test_is_destructive(sql, warning_level, expected):
assert is_destructive(sql, warning_level=warning_level) == expected

View file

@ -1,6 +1,6 @@
import time
import pytest
from mock import Mock, patch
from unittest.mock import Mock, patch
@pytest.fixture
@ -37,7 +37,7 @@ def test_refresh_called_once(refresher):
:return:
"""
callbacks = Mock()
pgexecute = Mock()
pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
with patch.object(refresher, "_bg_refresh") as bg_refresh:
@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher):
"""
callbacks = Mock()
pgexecute = Mock()
pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
def dummy_bg_refresh(*args):
@ -84,14 +84,12 @@ def test_refresh_with_callbacks(refresher):
:param refresher:
"""
callbacks = [Mock()]
pgexecute_class = Mock()
pgexecute = Mock()
pgexecute = Mock(**{"is_virtual_database.return_value": False})
pgexecute.extra_args = {}
special = Mock()
with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert callbacks[0].call_count == 1
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert callbacks[0].call_count == 1

View file

@ -1,9 +1,10 @@
import io
import os
import stat
import pytest
from pgcli.config import ensure_dir_exists
from pgcli.config import ensure_dir_exists, skip_initial_comment
def test_ensure_file_parent(tmpdir):
@ -20,11 +21,23 @@ def test_ensure_existing_dir(tmpdir):
def test_ensure_other_create_error(tmpdir):
subdir = tmpdir.join("subdir")
subdir = tmpdir.join('subdir"')
rcfile = subdir.join("rcfile")
# trigger an oserror that isn't "directory already exists"
# trigger an oserror that isn't "directory already exists"
os.chmod(str(tmpdir), stat.S_IREAD)
with pytest.raises(OSError):
ensure_dir_exists(str(rcfile))
@pytest.mark.parametrize(
"text, skipped_lines",
(
("abc\n", 1),
("#[section]\ndef\n[section]", 2),
("[section]", 0),
),
)
def test_skip_initial_comment(text, skipped_lines):
assert skip_initial_comment(io.StringIO(text)) == skipped_lines

View file

@ -1,6 +1,6 @@
import os
import platform
import mock
from unittest import mock
import pytest
@ -288,7 +288,12 @@ def test_pg_service_file(tmpdir):
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
service_conf.write(
"""[myservice]
"""File begins with a comment
that is not a comment
# or maybe a comment after all
because psql is crazy
[myservice]
host=a_host
user=a_user
port=5433

View file

@ -13,7 +13,7 @@ def completer():
@pytest.fixture
def complete_event():
from mock import Mock
from unittest.mock import Mock
return Mock()

View file

@ -2,7 +2,7 @@ from textwrap import dedent
import psycopg2
import pytest
from mock import patch, MagicMock
from unittest.mock import patch, MagicMock
from pgspecial.main import PGSpecial, NO_QUERY
from utils import run, dbtest, requires_json, requires_jsonb
@ -89,7 +89,7 @@ def test_expanded_slash_G(executor, pgspecial):
# Tests whether we reset the expanded output after a \G.
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
results = run(executor, """select * from test \G""", pgspecial=pgspecial)
results = run(executor, r"""select * from test \G""", pgspecial=pgspecial)
assert pgspecial.expanded_output == False
@ -105,31 +105,35 @@ def test_schemata_table_views_and_columns_query(executor):
# schemata
# don't enforce all members of the schemas since they may include postgres
# temporary schemas
assert set(executor.schemata()) >= set(
["public", "pg_catalog", "information_schema", "schema1", "schema2"]
)
assert set(executor.schemata()) >= {
"public",
"pg_catalog",
"information_schema",
"schema1",
"schema2",
}
assert executor.search_path() == ["pg_catalog", "public"]
# tables
assert set(executor.tables()) >= set(
[("public", "a"), ("public", "b"), ("schema1", "c")]
)
assert set(executor.tables()) >= {
("public", "a"),
("public", "b"),
("schema1", "c"),
}
assert set(executor.table_columns()) >= set(
[
("public", "a", "x", "text", False, None),
("public", "a", "y", "text", False, None),
("public", "b", "z", "text", False, None),
("schema1", "c", "w", "text", True, "'meow'::text"),
]
)
assert set(executor.table_columns()) >= {
("public", "a", "x", "text", False, None),
("public", "a", "y", "text", False, None),
("public", "b", "z", "text", False, None),
("schema1", "c", "w", "text", True, "'meow'::text"),
}
# views
assert set(executor.views()) >= set([("public", "d")])
assert set(executor.views()) >= {("public", "d")}
assert set(executor.view_columns()) >= set(
[("public", "d", "e", "integer", False, None)]
)
assert set(executor.view_columns()) >= {
("public", "d", "e", "integer", False, None)
}
@dbtest
@ -142,9 +146,9 @@ def test_foreign_key_query(executor):
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
)
assert set(executor.foreignkeys()) >= set(
[("schema1", "parent", "parentid", "schema2", "child", "motherid")]
)
assert set(executor.foreignkeys()) >= {
("schema1", "parent", "parentid", "schema2", "child", "motherid")
}
@dbtest
@ -175,30 +179,28 @@ def test_functions_query(executor):
)
funcs = set(executor.functions())
assert funcs >= set(
[
function_meta_data(func_name="func1", return_type="integer"),
function_meta_data(
func_name="func3",
arg_names=["x", "y"],
arg_types=["integer", "integer"],
arg_modes=["t", "t"],
return_type="record",
is_set_returning=True,
),
function_meta_data(
schema_name="public",
func_name="func4",
arg_names=("x",),
arg_types=("integer",),
return_type="integer",
is_set_returning=True,
),
function_meta_data(
schema_name="schema1", func_name="func2", return_type="integer"
),
]
)
assert funcs >= {
function_meta_data(func_name="func1", return_type="integer"),
function_meta_data(
func_name="func3",
arg_names=["x", "y"],
arg_types=["integer", "integer"],
arg_modes=["t", "t"],
return_type="record",
is_set_returning=True,
),
function_meta_data(
schema_name="public",
func_name="func4",
arg_names=("x",),
arg_types=("integer",),
return_type="integer",
is_set_returning=True,
),
function_meta_data(
schema_name="schema1", func_name="func2", return_type="integer"
),
}
@dbtest
@ -257,8 +259,8 @@ def test_not_is_special(executor, pgspecial):
@dbtest
def test_execute_from_file_no_arg(executor, pgspecial):
"""\i without a filename returns an error."""
result = list(executor.run("\i", pgspecial=pgspecial))
r"""\i without a filename returns an error."""
result = list(executor.run(r"\i", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert "missing required argument" in status
assert success == False
@ -268,12 +270,12 @@ def test_execute_from_file_no_arg(executor, pgspecial):
@dbtest
@patch("pgcli.main.os")
def test_execute_from_file_io_error(os, executor, pgspecial):
"""\i with an io_error returns an error."""
# Inject an IOError.
os.path.expanduser.side_effect = IOError("test")
r"""\i with an os_error returns an error."""
# Inject an OSError.
os.path.expanduser.side_effect = OSError("test")
# Check the result.
result = list(executor.run("\i test", pgspecial=pgspecial))
result = list(executor.run(r"\i test", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert status == "test"
assert success == False
@ -290,7 +292,7 @@ def test_multiple_queries_same_line(executor):
@dbtest
def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
result = run(executor, "select 'foo'; \d", pgspecial=pgspecial)
result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial)
assert len(result) == 11 # 2 * (output+status) * 3 lines
assert "foo" in result[3]
# This is a lame check. :(
@ -408,7 +410,7 @@ def test_date_time_types(executor):
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
def test_large_numbers_render_directly(executor, value):
run(executor, "create table numbertest(a numeric)")
run(executor, "insert into numbertest (a) values ({0})".format(value))
run(executor, f"insert into numbertest (a) values ({value})")
assert value in run(executor, "select * from numbertest", join=True)
@ -511,13 +513,28 @@ def test_short_host(executor):
assert executor.short_host == "localhost1"
class BrokenConnection(object):
class BrokenConnection:
"""Mock a connection that failed."""
def cursor(self):
raise psycopg2.InterfaceError("I'm broken!")
class VirtualCursor:
"""Mock a cursor to virtual database like pgbouncer."""
def __init__(self):
self.protocol_error = False
self.protocol_message = ""
self.description = None
self.status = None
self.statusmessage = "Error"
def execute(self, *args, **kwargs):
self.protocol_error = True
self.protocol_message = "Command not supported"
@dbtest
def test_exit_without_active_connection(executor):
quit_handler = MagicMock()
@ -540,3 +557,12 @@ def test_exit_without_active_connection(executor):
# an exception should be raised when running a query without active connection
with pytest.raises(psycopg2.InterfaceError):
run(executor, "select 1", pgspecial=pgspecial)
@dbtest
def test_virtual_database(executor):
virtual_connection = MagicMock()
virtual_connection.cursor.return_value = VirtualCursor()
with patch.object(executor, "conn", virtual_connection):
result = run(executor, "select 1")
assert "Command not supported" in result

View file

@ -13,12 +13,12 @@ from pgcli.packages.sqlcompletion import (
def test_slash_suggests_special():
suggestions = suggest_type("\\", "\\")
assert set(suggestions) == set([Special()])
assert set(suggestions) == {Special()}
def test_slash_d_suggests_special():
suggestions = suggest_type("\\d", "\\d")
assert set(suggestions) == set([Special()])
assert set(suggestions) == {Special()}
def test_dn_suggests_schemata():
@ -30,24 +30,24 @@ def test_dn_suggests_schemata():
def test_d_suggests_tables_views_and_schemas():
suggestions = suggest_type("\d ", "\d ")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
suggestions = suggest_type(r"\d ", r"\d ")
assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
suggestions = suggest_type("\d xxx", "\d xxx")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
suggestions = suggest_type(r"\d xxx", r"\d xxx")
assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
def test_d_dot_suggests_schema_qualified_tables_or_views():
suggestions = suggest_type("\d myschema.", "\d myschema.")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
suggestions = suggest_type(r"\d myschema.", r"\d myschema.")
assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx")
assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
def test_df_suggests_schema_or_function():
suggestions = suggest_type("\\df xxx", "\\df xxx")
assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()])
assert set(suggestions) == {Function(schema=None, usage="special"), Schema()}
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
assert suggestions == (Function(schema="myschema", usage="special"),)
@ -63,7 +63,7 @@ def test_leading_whitespace_ok():
def test_dT_suggests_schema_or_datatypes():
text = "\\dT "
suggestions = suggest_type(text, text)
assert set(suggestions) == set([Schema(), Datatype(schema=None)])
assert set(suggestions) == {Schema(), Datatype(schema=None)}
def test_schema_qualified_dT_suggests_datatypes():

View file

@ -7,4 +7,4 @@ def test_confirm_destructive_query_notty():
stdin = click.get_text_stream("stdin")
if not stdin.isatty():
sql = "drop database foo;"
assert confirm_destructive_query(sql) is None
assert confirm_destructive_query(sql, "all") is None

View file

@ -1,5 +1,5 @@
import pytest
from mock import Mock
from unittest.mock import Mock
from pgcli.main import PGCli

View file

@ -193,7 +193,7 @@ def test_suggested_joins(completer, query, tbl):
result = get_result(completer, query.format(tbl))
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
+ [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))]
+ [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
)
@ -350,6 +350,36 @@ def test_schema_qualified_function_name(completer):
)
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
def test_schema_qualified_function_name_after_from(completer):
text = "SELECT * FROM custom.set_r"
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[
function("set_returning_func()", -len("func")),
]
)
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
def test_unqualified_function_name_not_returned(completer):
text = "SELECT * FROM set_r"
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set([])
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
def test_unqualified_function_name_in_search_path(completer):
completer.search_path = ["public", "custom"]
text = "SELECT * FROM set_r"
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[
function("set_returning_func()", -len("func")),
]
)
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",

View file

@ -53,7 +53,7 @@ metadata = {
],
}
metadata = dict((k, {"public": v}) for k, v in metadata.items())
metadata = {k: {"public": v} for k, v in metadata.items()}
testdata = MetaData(metadata)
@ -296,7 +296,7 @@ def test_suggested_cased_always_qualified_column_names(completer):
def test_suggested_column_names_in_function(completer):
result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
assert completions_to_set(result) == completions_to_set(
(testdata.columns_functions_and_keywords("users"))
testdata.columns_functions_and_keywords("users")
)
@ -316,7 +316,7 @@ def test_suggested_column_names_with_alias(completer):
def test_suggested_multiple_column_names(completer):
result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
assert completions_to_set(result) == completions_to_set(
(testdata.columns_functions_and_keywords("users"))
testdata.columns_functions_and_keywords("users")
)

View file

@ -23,16 +23,14 @@ def cols_etc(
):
"""Returns the expected select-clause suggestions for a single-table
select."""
return set(
[
Column(
table_refs=(TableReference(schema, table, alias, is_function),),
qualifiable=True,
),
Function(schema=parent),
Keyword(last_keyword),
]
)
return {
Column(
table_refs=(TableReference(schema, table, alias, is_function),),
qualifiable=True,
),
Function(schema=parent),
Keyword(last_keyword),
}
def test_select_suggests_cols_with_visible_table_scope():
@ -103,24 +101,20 @@ def test_where_equals_any_suggests_columns_or_keywords():
def test_lparen_suggests_cols_and_funcs():
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
assert set(suggestion) == set(
[
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("("),
]
)
assert set(suggestion) == {
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("("),
}
def test_select_suggests_cols_and_funcs():
suggestions = suggest_type("SELECT ", "SELECT ")
assert set(suggestions) == set(
[
Column(table_refs=(), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=(), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
@pytest.mark.parametrize(
@ -128,13 +122,13 @@ def test_select_suggests_cols_and_funcs():
)
def test_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema=None), View(schema=None), Schema()])
assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()}
@pytest.mark.parametrize("expression", ["SELECT * FROM "])
def test_suggest_tables_views_schemas_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
@pytest.mark.parametrize(
@ -147,9 +141,11 @@ def test_suggest_tables_views_schemas_and_functions(expression):
def test_suggest_after_join_with_two_tables(expression):
suggestions = suggest_type(expression, expression)
tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
assert set(suggestions) == set(
[FromClauseItem(schema=None, table_refs=tables), Join(tables, None), Schema()]
)
assert set(suggestions) == {
FromClauseItem(schema=None, table_refs=tables),
Join(tables, None),
Schema(),
}
@pytest.mark.parametrize(
@ -158,13 +154,11 @@ def test_suggest_after_join_with_two_tables(expression):
def test_suggest_after_join_with_one_table(expression):
suggestions = suggest_type(expression, expression)
tables = ((None, "foo", None, False),)
assert set(suggestions) == set(
[
FromClauseItem(schema=None, table_refs=tables),
Join(((None, "foo", None, False),), None),
Schema(),
]
)
assert set(suggestions) == {
FromClauseItem(schema=None, table_refs=tables),
Join(((None, "foo", None, False),), None),
Schema(),
}
@pytest.mark.parametrize(
@ -172,13 +166,13 @@ def test_suggest_after_join_with_one_table(expression):
)
def test_suggest_qualified_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
@pytest.mark.parametrize("expression", ["UPDATE sch."])
def test_suggest_qualified_aliasable_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
@pytest.mark.parametrize(
@ -193,26 +187,27 @@ def test_suggest_qualified_aliasable_tables_and_views(expression):
)
def test_suggest_qualified_tables_views_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([FromClauseItem(schema="sch")])
assert set(suggestions) == {FromClauseItem(schema="sch")}
@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
def test_suggest_qualified_tables_views_functions_and_joins(expression):
suggestions = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
assert set(suggestions) == set(
[FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch")]
)
assert set(suggestions) == {
FromClauseItem(schema="sch", table_refs=tbls),
Join(tbls, "sch"),
}
def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
assert set(suggestions) == set([Table(schema=None), Schema()])
assert set(suggestions) == {Table(schema=None), Schema()}
def test_truncate_suggests_qualified_tables():
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
assert set(suggestions) == set([Table(schema="sch")])
assert set(suggestions) == {Table(schema="sch")}
@pytest.mark.parametrize(
@ -220,13 +215,11 @@ def test_truncate_suggests_qualified_tables():
)
def test_distinct_suggests_cols(text):
suggestions = suggest_type(text, text)
assert set(suggestions) == set(
[
Column(table_refs=(), local_tables=(), qualifiable=True),
Function(schema=None),
Keyword("DISTINCT"),
]
)
assert set(suggestions) == {
Column(table_refs=(), local_tables=(), qualifiable=True),
Function(schema=None),
Keyword("DISTINCT"),
}
@pytest.mark.parametrize(
@ -244,20 +237,18 @@ def test_distinct_and_order_by_suggestions_with_aliases(
text, text_before, last_keyword
):
suggestions = suggest_type(text, text_before)
assert set(suggestions) == set(
[
Column(
table_refs=(
TableReference(None, "tbl", "x", False),
TableReference(None, "tbl1", "y", False),
),
local_tables=(),
qualifiable=True,
assert set(suggestions) == {
Column(
table_refs=(
TableReference(None, "tbl", "x", False),
TableReference(None, "tbl1", "y", False),
),
Function(schema=None),
Keyword(last_keyword),
]
)
local_tables=(),
qualifiable=True,
),
Function(schema=None),
Keyword(last_keyword),
}
@pytest.mark.parametrize(
@ -272,56 +263,50 @@ def test_distinct_and_order_by_suggestions_with_aliases(
)
def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
suggestions = suggest_type(text, text_before)
assert set(suggestions) == set(
[
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
]
)
assert set(suggestions) == {
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
}
def test_function_arguments_with_alias_given():
suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
assert set(suggestions) == set(
[
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
]
)
assert set(suggestions) == {
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
}
def test_col_comma_suggests_cols():
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
assert set(suggestions) == set(
[
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
assert set(suggestion) == set([Table(schema=None), View(schema=None), Schema()])
assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()}
@pytest.mark.parametrize(
@ -357,14 +342,12 @@ def test_partially_typed_col_name_suggests_col_names():
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl", None, False),)),
Table(schema="tabl"),
View(schema="tabl"),
Function(schema="tabl"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tabl", None, False),)),
Table(schema="tabl"),
View(schema="tabl"),
Function(schema="tabl"),
}
@pytest.mark.parametrize(
@ -378,14 +361,12 @@ def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
)
def test_dot_suggests_cols_of_an_alias(sql):
suggestions = suggest_type(sql, "SELECT t1.")
assert set(suggestions) == set(
[
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
]
)
assert set(suggestions) == {
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
}
@pytest.mark.parametrize(
@ -399,28 +380,24 @@ def test_dot_suggests_cols_of_an_alias(sql):
)
def test_dot_suggests_cols_of_an_alias_where(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set(
[
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
]
)
assert set(suggestions) == {
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
}
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type(
"SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl2", "t2", False),)),
Table(schema="t2"),
View(schema="t2"),
Function(schema="t2"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tabl2", "t2", False),)),
Table(schema="t2"),
View(schema="t2"),
Function(schema="t2"),
}
@pytest.mark.parametrize(
@ -452,20 +429,18 @@ def test_sub_select_partial_text_suggests_keyword(expression):
def test_outer_table_reference_in_exists_subquery_suggests_columns():
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
suggestions = suggest_type(q, q)
assert set(suggestions) == set(
[
Column(table_refs=((None, "foo", "f", False),)),
Table(schema="f"),
View(schema="f"),
Function(schema="f"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "foo", "f", False),)),
Table(schema="f"),
View(schema="f"),
Function(schema="f"),
}
@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
assert set(suggestion) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestion) == {FromClauseItem(schema=None), Schema()}
@pytest.mark.parametrize(
@ -478,22 +453,18 @@ def test_sub_select_table_name_completion(expression):
def test_sub_select_table_name_completion_with_outer_table(expression):
suggestion = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
assert set(suggestion) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
)
assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
def test_sub_select_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "abc", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "abc", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
@pytest.mark.xfail
@ -508,25 +479,25 @@ def test_sub_select_dot_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl", "t", False),)),
Table(schema="t"),
View(schema="t"),
Function(schema="t"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tabl", "t", False),)),
Table(schema="t"),
View(schema="t"),
Function(schema="t"),
}
@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
@pytest.mark.parametrize("tbl_alias", ("", "foo"))
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN "
suggestion = suggest_type(text, text)
tbls = tuple([(None, "abc", tbl_alias or None, False)])
assert set(suggestion) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema(), Join(tbls, None)]
)
assert set(suggestion) == {
FromClauseItem(schema=None, table_refs=tbls),
Schema(),
Join(tbls, None),
}
def test_left_join_with_comma():
@ -535,9 +506,7 @@ def test_left_join_with_comma():
# tbls should also include (None, 'bar', 'b', False)
# but there's a bug with commas
tbls = tuple([(None, "foo", "f", False)])
assert set(suggestions) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
)
assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
@pytest.mark.parametrize(
@ -550,15 +519,13 @@ def test_left_join_with_comma():
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "def", "d", False))
assert set(suggestions) == set(
[
Column(table_refs=((None, "abc", "a", False),)),
Table(schema="a"),
View(schema="a"),
Function(schema="a"),
JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "abc", "a", False),)),
Table(schema="a"),
View(schema="a"),
Function(schema="a"),
JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
}
@pytest.mark.parametrize(
@ -570,14 +537,12 @@ def test_join_alias_dot_suggests_cols1(sql):
)
def test_join_alias_dot_suggests_cols2(sql):
suggestion = suggest_type(sql, sql)
assert set(suggestion) == set(
[
Column(table_refs=((None, "def", "d", False),)),
Table(schema="d"),
View(schema="d"),
Function(schema="d"),
]
)
assert set(suggestion) == {
Column(table_refs=((None, "def", "d", False),)),
Table(schema="d"),
View(schema="d"),
Function(schema="d"),
}
@pytest.mark.parametrize(
@ -598,9 +563,10 @@ on """,
def test_on_suggests_aliases_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("a", "b")))
)
assert set(suggestions) == {
JoinCondition(table_refs=tables, parent=None),
Alias(aliases=("a", "b")),
}
@pytest.mark.parametrize(
@ -613,9 +579,10 @@ def test_on_suggests_aliases_and_join_conditions(sql):
def test_on_suggests_tables_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
)
assert set(suggestions) == {
JoinCondition(table_refs=tables, parent=None),
Alias(aliases=("abc", "bcd")),
}
@pytest.mark.parametrize(
@ -640,9 +607,10 @@ def test_on_suggests_aliases_right_side(sql):
def test_on_suggests_tables_and_join_conditions_right_side(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
)
assert set(suggestions) == {
JoinCondition(table_refs=tables, parent=None),
Alias(aliases=("abc", "bcd")),
}
@pytest.mark.parametrize(
@ -659,9 +627,9 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql):
)
def test_join_using_suggests_common_columns(text):
tables = ((None, "abc", None, False), (None, "def", None, False))
assert set(suggest_type(text, text)) == set(
[Column(table_refs=tables, require_last_table=True)]
)
assert set(suggest_type(text, text)) == {
Column(table_refs=tables, require_last_table=True)
}
def test_suggest_columns_after_multiple_joins():
@ -678,29 +646,27 @@ def test_2_statements_2nd_current():
suggestions = suggest_type(
"select * from a; select * from ", "select * from a; select * from "
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type(
"select * from a; select from b", "select * from a; select "
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "b", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "b", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
# Should work even if first statement is invalid
suggestions = suggest_type(
"select * from; select * from ", "select * from; select * from "
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
def test_2_statements_1st_current():
suggestions = suggest_type("select * from ; select * from b", "select * from ")
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type("select from a; select * from b", "select ")
assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
@ -711,7 +677,7 @@ def test_3_statements_2nd_current():
"select * from a; select * from ; select * from c",
"select * from a; select * from ",
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type(
"select * from a; select from b; select * from c", "select * from a; select "
@ -768,13 +734,11 @@ SELECT * FROM qux;
)
def test_statements_in_function_body(text):
suggestions = suggest_type(text, text[: text.find(" ") + 1])
assert set(suggestions) == set(
[
Column(table_refs=((None, "foo", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "foo", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
functions = [
@ -799,13 +763,13 @@ SELECT 1 FROM foo;
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_after_function_body(text):
suggestions = suggest_type(text, text[: text.find("; ") + 1])
assert set(suggestions) == set([Keyword(), Special()])
assert set(suggestions) == {Keyword(), Special()}
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_before_function_body(text):
suggestions = suggest_type(text, "")
assert set(suggestions) == set([Keyword(), Special()])
assert set(suggestions) == {Keyword(), Special()}
def test_create_db_with_template():
@ -813,14 +777,14 @@ def test_create_db_with_template():
"create database foo with template ", "create database foo with template "
)
assert set(suggestions) == set((Database(),))
assert set(suggestions) == {Database()}
@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
assert set(suggestions) == set([Keyword(), Special()])
assert set(suggestions) == {Keyword(), Special()}
def test_drop_schema_qualified_table_suggests_only_tables():
@ -843,25 +807,30 @@ def test_drop_schema_suggests_schemas():
@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
def test_cast_operator_suggests_types(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
assert set(suggest_type(text, text)) == {
Datatype(schema=None),
Table(schema=None),
Schema(),
}
@pytest.mark.parametrize(
"text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
)
def test_cast_operator_suggests_schema_qualified_types(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema="bar"), Table(schema="bar")]
)
assert set(suggest_type(text, text)) == {
Datatype(schema="bar"),
Table(schema="bar"),
}
def test_alter_column_type_suggests_types():
q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
assert set(suggest_type(q, q)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
assert set(suggest_type(q, q)) == {
Datatype(schema=None),
Table(schema=None),
Schema(),
}
@pytest.mark.parametrize(
@ -880,9 +849,11 @@ def test_alter_column_type_suggests_types():
],
)
def test_identifier_suggests_types_in_parentheses(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
assert set(suggest_type(text, text)) == {
Datatype(schema=None),
Table(schema=None),
Schema(),
}
@pytest.mark.parametrize(
@ -977,7 +948,7 @@ def test_ignore_leading_double_quotes(sql):
)
def test_column_keyword_suggests_columns(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set([Column(table_refs=((None, "foo", None, False),))])
assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))}
def test_handle_unrecognized_kw_generously():

View file

@ -8,7 +8,7 @@ from os import getenv
POSTGRES_USER = getenv("PGUSER", "postgres")
POSTGRES_HOST = getenv("PGHOST", "localhost")
POSTGRES_PORT = getenv("PGPORT", 5432)
POSTGRES_PASSWORD = getenv("PGPASSWORD", "")
POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres")
def db_connection(dbname=None):
@ -73,7 +73,7 @@ def drop_tables(conn):
def run(
executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
):
" Return string output for the sql to be run "
"Return string output for the sql to be run"
results = executor.run(sql, pgspecial, exception_formatter)
formatted = []
@ -89,7 +89,7 @@ def run(
def completions_to_set(completions):
return set(
return {
(completion.display_text, completion.display_meta_text)
for completion in completions
)
}