Merging upstream version 3.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
a868bb3d29
commit
39b7cc8559
50 changed files with 952 additions and 634 deletions
66
.github/workflows/ci.yml
vendored
Normal file
66
.github/workflows/ci.yml
vendored
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
name: pgcli
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths-ignore:
|
||||||
|
- '**.rst'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||||
|
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:9.6
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd pg_isready
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
pip install -U pip setuptools
|
||||||
|
pip install --no-cache-dir .
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
pip install keyrings.alt>=3.1
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
|
run: coverage run --source pgcli -m py.test
|
||||||
|
|
||||||
|
- name: Run integration tests
|
||||||
|
env:
|
||||||
|
PGUSER: postgres
|
||||||
|
PGPASSWORD: postgres
|
||||||
|
|
||||||
|
run: behave tests/features --no-capture
|
||||||
|
|
||||||
|
- name: Check changelog for ReST compliance
|
||||||
|
run: rst2html.py --halt=warning changelog.rst >/dev/null
|
||||||
|
|
||||||
|
- name: Run Black
|
||||||
|
run: pip install black && black --check .
|
||||||
|
if: matrix.python-version == '3.6'
|
||||||
|
|
||||||
|
- name: Coverage
|
||||||
|
run: |
|
||||||
|
coverage combine
|
||||||
|
coverage report
|
||||||
|
codecov
|
|
@ -1,6 +1,6 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: stable
|
rev: 21.5b0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
language_version: python3.7
|
language_version: python3.7
|
||||||
|
|
51
.travis.yml
51
.travis.yml
|
@ -1,51 +0,0 @@
|
||||||
dist: xenial
|
|
||||||
|
|
||||||
sudo: required
|
|
||||||
|
|
||||||
language: python
|
|
||||||
|
|
||||||
python:
|
|
||||||
- "3.6"
|
|
||||||
- "3.7"
|
|
||||||
- "3.8"
|
|
||||||
- "3.9-dev"
|
|
||||||
|
|
||||||
before_install:
|
|
||||||
- which python
|
|
||||||
- which pip
|
|
||||||
- pip install -U setuptools
|
|
||||||
|
|
||||||
install:
|
|
||||||
- pip install --no-cache-dir .
|
|
||||||
- pip install -r requirements-dev.txt
|
|
||||||
- pip install keyrings.alt>=3.1
|
|
||||||
|
|
||||||
script:
|
|
||||||
- set -e
|
|
||||||
- coverage run --source pgcli -m py.test
|
|
||||||
- cd tests
|
|
||||||
- behave --no-capture
|
|
||||||
- cd ..
|
|
||||||
# check for changelog ReST compliance
|
|
||||||
- rst2html.py --halt=warning changelog.rst >/dev/null
|
|
||||||
# check for black code compliance, 3.6 only
|
|
||||||
- if [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then pip install black && black --check . ; else echo "Skipping black for $TRAVIS_PYTHON_VERSION"; fi
|
|
||||||
- set +e
|
|
||||||
|
|
||||||
after_success:
|
|
||||||
- coverage combine
|
|
||||||
- codecov
|
|
||||||
|
|
||||||
notifications:
|
|
||||||
webhooks:
|
|
||||||
urls:
|
|
||||||
- YOUR_WEBHOOK_URL
|
|
||||||
on_success: change # options: [always|never|change] default: always
|
|
||||||
on_failure: always # options: [always|never|change] default: always
|
|
||||||
on_start: false # default: false
|
|
||||||
|
|
||||||
services:
|
|
||||||
- postgresql
|
|
||||||
|
|
||||||
addons:
|
|
||||||
postgresql: "9.6"
|
|
4
AUTHORS
4
AUTHORS
|
@ -114,6 +114,10 @@ Contributors:
|
||||||
* Tom Caruso (tomplex)
|
* Tom Caruso (tomplex)
|
||||||
* Jan Brun Rasmussen (janbrunrasmussen)
|
* Jan Brun Rasmussen (janbrunrasmussen)
|
||||||
* Kevin Marsh (kevinmarsh)
|
* Kevin Marsh (kevinmarsh)
|
||||||
|
* Eero Ruohola (ruohola)
|
||||||
|
* Miroslav Šedivý (eumiro)
|
||||||
|
* Eric R Young (ERYoung11)
|
||||||
|
* Paweł Sacawa (psacawa)
|
||||||
|
|
||||||
Creator:
|
Creator:
|
||||||
--------
|
--------
|
||||||
|
|
|
@ -170,7 +170,7 @@ Troubleshooting the integration tests
|
||||||
- Make sure postgres instance on localhost is running
|
- Make sure postgres instance on localhost is running
|
||||||
- Check your ``pg_hba.conf`` file to verify local connections are enabled
|
- Check your ``pg_hba.conf`` file to verify local connections are enabled
|
||||||
- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information.
|
- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information.
|
||||||
- Contact us on `gitter <https://gitter.im/dbcli/pgcli/>`_ or `file an issue <https://github.com/dbcli/pgcli/issues/new>`_.
|
- `File an issue <https://github.com/dbcli/pgcli/issues/new>`_.
|
||||||
|
|
||||||
Coding Style
|
Coding Style
|
||||||
------------
|
------------
|
||||||
|
|
28
README.rst
28
README.rst
|
@ -1,7 +1,7 @@
|
||||||
A REPL for Postgres
|
A REPL for Postgres
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
|Build Status| |CodeCov| |PyPI| |Landscape| |Gitter|
|
|Build Status| |CodeCov| |PyPI| |Landscape|
|
||||||
|
|
||||||
This is a postgres client that does auto-completion and syntax highlighting.
|
This is a postgres client that does auto-completion and syntax highlighting.
|
||||||
|
|
||||||
|
@ -72,21 +72,21 @@ For more details:
|
||||||
--single-connection Do not use a separate connection for completions.
|
--single-connection Do not use a separate connection for completions.
|
||||||
-v, --version Version of pgcli.
|
-v, --version Version of pgcli.
|
||||||
-d, --dbname TEXT database name to connect to.
|
-d, --dbname TEXT database name to connect to.
|
||||||
--pgclirc PATH Location of pgclirc file.
|
--pgclirc FILE Location of pgclirc file.
|
||||||
-D, --dsn TEXT Use DSN configured into the [alias_dsn] section of
|
-D, --dsn TEXT Use DSN configured into the [alias_dsn] section
|
||||||
pgclirc file.
|
|
||||||
--list-dsn list of DSN configured into the [alias_dsn] section
|
|
||||||
of pgclirc file.
|
of pgclirc file.
|
||||||
--row-limit INTEGER Set threshold for row limit prompt. Use 0 to disable
|
--list-dsn list of DSN configured into the [alias_dsn]
|
||||||
prompt.
|
section of pgclirc file.
|
||||||
|
--row-limit INTEGER Set threshold for row limit prompt. Use 0 to
|
||||||
|
disable prompt.
|
||||||
--less-chatty Skip intro on startup and goodbye on exit.
|
--less-chatty Skip intro on startup and goodbye on exit.
|
||||||
--prompt TEXT Prompt format (Default: "\u@\h:\d> ").
|
--prompt TEXT Prompt format (Default: "\u@\h:\d> ").
|
||||||
--prompt-dsn TEXT Prompt format for connections using DSN aliases
|
--prompt-dsn TEXT Prompt format for connections using DSN aliases
|
||||||
(Default: "\u@\h:\d> ").
|
(Default: "\u@\h:\d> ").
|
||||||
-l, --list list available databases, then exit.
|
-l, --list list available databases, then exit.
|
||||||
--auto-vertical-output Automatically switch to vertical output mode if the
|
--auto-vertical-output Automatically switch to vertical output mode if
|
||||||
result is wider than the terminal width.
|
the result is wider than the terminal width.
|
||||||
--warn / --no-warn Warn before running a destructive query.
|
--warn [all|moderate|off] Warn before running a destructive query.
|
||||||
--help Show this message and exit.
|
--help Show this message and exit.
|
||||||
|
|
||||||
``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``).
|
``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``).
|
||||||
|
@ -352,8 +352,8 @@ interface to Postgres database.
|
||||||
Thanks to all the beta testers and contributors for your time and patience. :)
|
Thanks to all the beta testers and contributors for your time and patience. :)
|
||||||
|
|
||||||
|
|
||||||
.. |Build Status| image:: https://api.travis-ci.org/dbcli/pgcli.svg?branch=master
|
.. |Build Status| image:: https://github.com/dbcli/pgcli/workflows/pgcli/badge.svg
|
||||||
:target: https://travis-ci.org/dbcli/pgcli
|
:target: https://github.com/dbcli/pgcli/actions?query=workflow%3Apgcli
|
||||||
|
|
||||||
.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
|
.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
|
||||||
:target: https://codecov.io/gh/dbcli/pgcli
|
:target: https://codecov.io/gh/dbcli/pgcli
|
||||||
|
@ -366,7 +366,3 @@ Thanks to all the beta testers and contributors for your time and patience. :)
|
||||||
.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg
|
.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg
|
||||||
:target: https://pypi.python.org/pypi/pgcli/
|
:target: https://pypi.python.org/pypi/pgcli/
|
||||||
:alt: Latest Version
|
:alt: Latest Version
|
||||||
|
|
||||||
.. |Gitter| image:: https://badges.gitter.im/Join%20Chat.svg
|
|
||||||
:target: https://gitter.im/dbcli/pgcli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge
|
|
||||||
:alt: Gitter Chat
|
|
||||||
|
|
79
Vagrantfile
vendored
79
Vagrantfile
vendored
|
@ -1,5 +1,7 @@
|
||||||
# -*- mode: ruby -*-
|
# -*- mode: ruby -*-
|
||||||
# vi: set ft=ruby :
|
# vi: set ft=ruby :
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
Vagrant.configure(2) do |config|
|
Vagrant.configure(2) do |config|
|
||||||
|
|
||||||
|
@ -9,20 +11,23 @@ Vagrant.configure(2) do |config|
|
||||||
pgcli_description = "Postgres CLI with autocompletion and syntax highlighting"
|
pgcli_description = "Postgres CLI with autocompletion and syntax highlighting"
|
||||||
|
|
||||||
config.vm.define "debian" do |debian|
|
config.vm.define "debian" do |debian|
|
||||||
debian.vm.box = "chef/debian-7.8"
|
debian.vm.box = "bento/debian-10.8"
|
||||||
debian.vm.provision "shell", inline: <<-SHELL
|
debian.vm.provision "shell", inline: <<-SHELL
|
||||||
echo "-> Building DEB on `lsb_release -s`"
|
echo "-> Building DEB on `lsb_release -d`"
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems
|
sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems
|
||||||
sudo easy_install pip
|
sudo apt install -y python3-pip
|
||||||
sudo pip install virtualenv virtualenv-tools
|
sudo pip3 install --no-cache-dir virtualenv virtualenv-tools3
|
||||||
|
sudo apt-get install -y ruby-dev
|
||||||
|
sudo apt-get install -y git
|
||||||
|
sudo apt-get install -y rpm librpmbuild8
|
||||||
|
|
||||||
sudo gem install fpm
|
sudo gem install fpm
|
||||||
|
|
||||||
echo "-> Cleaning up old workspace"
|
echo "-> Cleaning up old workspace"
|
||||||
rm -rf build
|
sudo rm -rf build
|
||||||
mkdir -p build/usr/share
|
mkdir -p build/usr/share
|
||||||
virtualenv build/usr/share/pgcli
|
virtualenv build/usr/share/pgcli
|
||||||
build/usr/share/pgcli/bin/pip install -U pip distribute
|
|
||||||
build/usr/share/pgcli/bin/pip uninstall -y distribute
|
|
||||||
build/usr/share/pgcli/bin/pip install /pgcli
|
build/usr/share/pgcli/bin/pip install /pgcli
|
||||||
|
|
||||||
echo "-> Cleaning Virtualenv"
|
echo "-> Cleaning Virtualenv"
|
||||||
|
@ -45,24 +50,59 @@ Vagrant.configure(2) do |config|
|
||||||
--url https://github.com/dbcli/pgcli \
|
--url https://github.com/dbcli/pgcli \
|
||||||
--description "#{pgcli_description}" \
|
--description "#{pgcli_description}" \
|
||||||
--license 'BSD'
|
--license 'BSD'
|
||||||
|
|
||||||
SHELL
|
SHELL
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
# This is considerably more messy than the debian section. I had to go off-standard to update
|
||||||
|
# some packages to get this to work.
|
||||||
|
|
||||||
config.vm.define "centos" do |centos|
|
config.vm.define "centos" do |centos|
|
||||||
centos.vm.box = "chef/centos-7.0"
|
|
||||||
|
centos.vm.box = "bento/centos-7.9"
|
||||||
|
centos.vm.box_version = "202012.21.0"
|
||||||
centos.vm.provision "shell", inline: <<-SHELL
|
centos.vm.provision "shell", inline: <<-SHELL
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
echo "-> Building RPM on `lsb_release -s`"
|
echo "-> Building RPM on `hostnamectl | grep "Operating System"`"
|
||||||
sudo yum install -y rpm-build gcc ruby-devel postgresql-devel python-devel rubygems
|
export PATH=/usr/local/rvm/gems/ruby-2.6.3/bin:/usr/local/rvm/gems/ruby-2.6.3@global/bin:/usr/local/rvm/rubies/ruby-2.6.3/bin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/rvm/bin:/root/bin
|
||||||
sudo easy_install pip
|
echo "PATH -> " $PATH
|
||||||
sudo pip install virtualenv virtualenv-tools
|
|
||||||
sudo gem install fpm
|
#####
|
||||||
|
### get base updates
|
||||||
|
|
||||||
|
sudo yum install -y rpm-build gcc postgresql-devel python-devel python3-pip git python3-devel
|
||||||
|
|
||||||
|
######
|
||||||
|
### install FPM, which we need to install to get an up-to-date version of ruby, which we need for git
|
||||||
|
|
||||||
|
echo "-> Get FPM installed"
|
||||||
|
# import the necessary GPG keys
|
||||||
|
gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
|
||||||
|
sudo gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
|
||||||
|
# install RVM
|
||||||
|
sudo curl -sSL https://get.rvm.io | sudo bash -s stable
|
||||||
|
sudo usermod -aG rvm vagrant
|
||||||
|
sudo usermod -aG rvm root
|
||||||
|
sudo /usr/local/rvm/bin/rvm alias create default 2.6.3
|
||||||
|
source /etc/profile.d/rvm.sh
|
||||||
|
|
||||||
|
# install a newer version of ruby. centos7 only comes with ruby2.0.0, which isn't good enough for git.
|
||||||
|
sudo yum install -y ruby-devel
|
||||||
|
sudo /usr/local/rvm/bin/rvm install 2.6.3
|
||||||
|
|
||||||
|
#
|
||||||
|
# yes,this gives an error about generating doc but we don't need the doc.
|
||||||
|
|
||||||
|
/usr/local/rvm/gems/ruby-2.6.3/wrappers/gem install fpm
|
||||||
|
|
||||||
|
######
|
||||||
|
|
||||||
|
sudo pip3 install virtualenv virtualenv-tools3
|
||||||
echo "-> Cleaning up old workspace"
|
echo "-> Cleaning up old workspace"
|
||||||
rm -rf build
|
rm -rf build
|
||||||
mkdir -p build/usr/share
|
mkdir -p build/usr/share
|
||||||
virtualenv build/usr/share/pgcli
|
virtualenv build/usr/share/pgcli
|
||||||
build/usr/share/pgcli/bin/pip install -U pip distribute
|
|
||||||
build/usr/share/pgcli/bin/pip uninstall -y distribute
|
|
||||||
build/usr/share/pgcli/bin/pip install /pgcli
|
build/usr/share/pgcli/bin/pip install /pgcli
|
||||||
|
|
||||||
echo "-> Cleaning Virtualenv"
|
echo "-> Cleaning Virtualenv"
|
||||||
|
@ -74,9 +114,9 @@ Vagrant.configure(2) do |config|
|
||||||
find build -iname '*.pyc' -delete
|
find build -iname '*.pyc' -delete
|
||||||
find build -iname '*.pyo' -delete
|
find build -iname '*.pyo' -delete
|
||||||
|
|
||||||
|
cd /home/vagrant
|
||||||
echo "-> Creating PgCLI RPM"
|
echo "-> Creating PgCLI RPM"
|
||||||
echo $PATH
|
/usr/local/rvm/gems/ruby-2.6.3/gems/fpm-1.12.0/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
|
||||||
sudo /usr/local/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
|
|
||||||
-a all \
|
-a all \
|
||||||
-d postgresql-devel \
|
-d postgresql-devel \
|
||||||
-d python-devel \
|
-d python-devel \
|
||||||
|
@ -86,8 +126,13 @@ Vagrant.configure(2) do |config|
|
||||||
--url https://github.com/dbcli/pgcli \
|
--url https://github.com/dbcli/pgcli \
|
||||||
--description "#{pgcli_description}" \
|
--description "#{pgcli_description}" \
|
||||||
--license 'BSD'
|
--license 'BSD'
|
||||||
|
|
||||||
|
|
||||||
SHELL
|
SHELL
|
||||||
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,37 @@
|
||||||
|
TBD
|
||||||
|
=====
|
||||||
|
|
||||||
|
Features:
|
||||||
|
---------
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
----------
|
||||||
|
|
||||||
|
3.2.0
|
||||||
|
=====
|
||||||
|
|
||||||
|
Release date: 2021/08/23
|
||||||
|
|
||||||
|
Features:
|
||||||
|
---------
|
||||||
|
|
||||||
|
* Consider `update` queries destructive and issue a warning. Change
|
||||||
|
`destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239)
|
||||||
|
* Skip initial comment in .pg_session even if it doesn't start with '#'
|
||||||
|
* Include functions from schemas in search_path. (`Amjith Ramanujam`_)
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
----------
|
||||||
|
|
||||||
|
* Fix issue where `syntax_style` config value would not have any effect. (#1212)
|
||||||
|
* Fix crash because of not found `InputMode.REPLACE_SINGLE` with prompt-toolkit < 3.0.6
|
||||||
|
* Fix comments being lost in config when saving a named query. (#1240)
|
||||||
|
* Fix IPython magic for ipython-sql >= 0.4.0
|
||||||
|
* Fix pager not being used when output format is set to csv. (#1238)
|
||||||
|
* Add function literals random, generate_series, generate_subscripts
|
||||||
|
* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly
|
||||||
|
* Fix pgcli crashing with virtual `pgbouncer` database. (#1093)
|
||||||
|
|
||||||
3.1.0
|
3.1.0
|
||||||
=====
|
=====
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = "3.1.0"
|
__version__ = "3.2.0"
|
||||||
|
|
|
@ -3,10 +3,9 @@ import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .pgcompleter import PGCompleter
|
from .pgcompleter import PGCompleter
|
||||||
from .pgexecute import PGExecute
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionRefresher(object):
|
class CompletionRefresher:
|
||||||
|
|
||||||
refreshers = OrderedDict()
|
refreshers = OrderedDict()
|
||||||
|
|
||||||
|
@ -27,6 +26,10 @@ class CompletionRefresher(object):
|
||||||
has completed the refresh. The newly created completion
|
has completed the refresh. The newly created completion
|
||||||
object will be passed in as an argument to each callback.
|
object will be passed in as an argument to each callback.
|
||||||
"""
|
"""
|
||||||
|
if executor.is_virtual_database():
|
||||||
|
# do nothing
|
||||||
|
return [(None, None, None, "Auto-completion refresh can't be started.")]
|
||||||
|
|
||||||
if self.is_refreshing():
|
if self.is_refreshing():
|
||||||
self._restart_refresh.set()
|
self._restart_refresh.set()
|
||||||
return [(None, None, None, "Auto-completion refresh restarted.")]
|
return [(None, None, None, "Auto-completion refresh restarted.")]
|
||||||
|
@ -141,7 +144,7 @@ def refresh_casing(completer, executor):
|
||||||
with open(casing_file, "w") as f:
|
with open(casing_file, "w") as f:
|
||||||
f.write(casing_prefs)
|
f.write(casing_prefs)
|
||||||
if os.path.isfile(casing_file):
|
if os.path.isfile(casing_file):
|
||||||
with open(casing_file, "r") as f:
|
with open(casing_file) as f:
|
||||||
completer.extend_casing([line.strip() for line in f])
|
completer.extend_casing([line.strip() for line in f])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,8 @@ import shutil
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
from os.path import expanduser, exists, dirname
|
from os.path import expanduser, exists, dirname
|
||||||
|
import re
|
||||||
|
from typing import TextIO
|
||||||
from configobj import ConfigObj
|
from configobj import ConfigObj
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,11 +18,15 @@ def config_location():
|
||||||
|
|
||||||
|
|
||||||
def load_config(usr_cfg, def_cfg=None):
|
def load_config(usr_cfg, def_cfg=None):
|
||||||
|
# avoid config merges when possible. For writing, we need an umerged config instance.
|
||||||
|
# see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171
|
||||||
|
if def_cfg:
|
||||||
cfg = ConfigObj()
|
cfg = ConfigObj()
|
||||||
cfg.merge(ConfigObj(def_cfg, interpolation=False))
|
cfg.merge(ConfigObj(def_cfg, interpolation=False))
|
||||||
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
|
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
|
||||||
|
else:
|
||||||
|
cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")
|
||||||
cfg.filename = expanduser(usr_cfg)
|
cfg.filename = expanduser(usr_cfg)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,12 +50,16 @@ def upgrade_config(config, def_config):
|
||||||
cfg.write()
|
cfg.write()
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_filename(pgclirc_file=None):
|
||||||
|
return pgclirc_file or "%sconfig" % config_location()
|
||||||
|
|
||||||
|
|
||||||
def get_config(pgclirc_file=None):
|
def get_config(pgclirc_file=None):
|
||||||
from pgcli import __file__ as package_root
|
from pgcli import __file__ as package_root
|
||||||
|
|
||||||
package_root = os.path.dirname(package_root)
|
package_root = os.path.dirname(package_root)
|
||||||
|
|
||||||
pgclirc_file = pgclirc_file or "%sconfig" % config_location()
|
pgclirc_file = get_config_filename(pgclirc_file)
|
||||||
|
|
||||||
default_config = os.path.join(package_root, "pgclirc")
|
default_config = os.path.join(package_root, "pgclirc")
|
||||||
write_default_config(default_config, pgclirc_file)
|
write_default_config(default_config, pgclirc_file)
|
||||||
|
@ -62,3 +72,28 @@ def get_casing_file(config):
|
||||||
if casing_file == "default":
|
if casing_file == "default":
|
||||||
casing_file = config_location() + "casing"
|
casing_file = config_location() + "casing"
|
||||||
return casing_file
|
return casing_file
|
||||||
|
|
||||||
|
|
||||||
|
def skip_initial_comment(f_stream: TextIO) -> int:
|
||||||
|
"""
|
||||||
|
Initial comment in ~/.pg_service.conf is not always marked with '#'
|
||||||
|
which crashes the parser. This function takes a file object and
|
||||||
|
"rewinds" it to the beginning of the first section,
|
||||||
|
from where on it can be parsed safely
|
||||||
|
|
||||||
|
:return: number of skipped lines
|
||||||
|
"""
|
||||||
|
section_regex = r"\s*\["
|
||||||
|
pos = f_stream.tell()
|
||||||
|
lines_skipped = 0
|
||||||
|
while True:
|
||||||
|
line = f_stream.readline()
|
||||||
|
if line == "":
|
||||||
|
break
|
||||||
|
if re.match(section_regex, line) is not None:
|
||||||
|
f_stream.seek(pos)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pos += len(line)
|
||||||
|
lines_skipped += 1
|
||||||
|
return lines_skipped
|
||||||
|
|
|
@ -25,7 +25,11 @@ def pgcli_line_magic(line):
|
||||||
if hasattr(sql.connection.Connection, "get"):
|
if hasattr(sql.connection.Connection, "get"):
|
||||||
conn = sql.connection.Connection.get(parsed["connection"])
|
conn = sql.connection.Connection.get(parsed["connection"])
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
conn = sql.connection.Connection.set(parsed["connection"])
|
conn = sql.connection.Connection.set(parsed["connection"])
|
||||||
|
# a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql
|
||||||
|
except TypeError:
|
||||||
|
conn = sql.connection.Connection.set(parsed["connection"], False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# A corresponding pgcli object already exists
|
# A corresponding pgcli object already exists
|
||||||
|
@ -43,7 +47,7 @@ def pgcli_line_magic(line):
|
||||||
conn._pgcli = pgcli
|
conn._pgcli = pgcli
|
||||||
|
|
||||||
# For convenience, print the connection alias
|
# For convenience, print the connection alias
|
||||||
print("Connected: {}".format(conn.name))
|
print(f"Connected: {conn.name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pgcli.run_cli()
|
pgcli.run_cli()
|
||||||
|
|
|
@ -2,8 +2,9 @@ import platform
|
||||||
import warnings
|
import warnings
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
|
|
||||||
from configobj import ConfigObj
|
from configobj import ConfigObj, ParseError
|
||||||
from pgspecial.namedqueries import NamedQueries
|
from pgspecial.namedqueries import NamedQueries
|
||||||
|
from .config import skip_initial_comment
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
|
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
|
||||||
|
|
||||||
|
@ -20,12 +21,12 @@ import datetime as dt
|
||||||
import itertools
|
import itertools
|
||||||
import platform
|
import platform
|
||||||
from time import time, sleep
|
from time import time, sleep
|
||||||
from codecs import open
|
|
||||||
|
|
||||||
keyring = None # keyring will be loaded later
|
keyring = None # keyring will be loaded later
|
||||||
|
|
||||||
from cli_helpers.tabular_output import TabularOutputFormatter
|
from cli_helpers.tabular_output import TabularOutputFormatter
|
||||||
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
|
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
|
||||||
|
from cli_helpers.utils import strip_ansi
|
||||||
import click
|
import click
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -62,6 +63,7 @@ from .config import (
|
||||||
config_location,
|
config_location,
|
||||||
ensure_dir_exists,
|
ensure_dir_exists,
|
||||||
get_config,
|
get_config,
|
||||||
|
get_config_filename,
|
||||||
)
|
)
|
||||||
from .key_bindings import pgcli_bindings
|
from .key_bindings import pgcli_bindings
|
||||||
from .packages.prompt_utils import confirm_destructive_query
|
from .packages.prompt_utils import confirm_destructive_query
|
||||||
|
@ -122,7 +124,7 @@ class PgCliQuitError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PGCli(object):
|
class PGCli:
|
||||||
default_prompt = "\\u@\\h:\\d> "
|
default_prompt = "\\u@\\h:\\d> "
|
||||||
max_len_prompt = 30
|
max_len_prompt = 30
|
||||||
|
|
||||||
|
@ -175,7 +177,11 @@ class PGCli(object):
|
||||||
# Load config.
|
# Load config.
|
||||||
c = self.config = get_config(pgclirc_file)
|
c = self.config = get_config(pgclirc_file)
|
||||||
|
|
||||||
NamedQueries.instance = NamedQueries.from_config(self.config)
|
# at this point, config should be written to pgclirc_file if it did not exist. Read it.
|
||||||
|
self.config_writer = load_config(get_config_filename(pgclirc_file))
|
||||||
|
|
||||||
|
# make sure to use self.config_writer, not self.config
|
||||||
|
NamedQueries.instance = NamedQueries.from_config(self.config_writer)
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.initialize_logging()
|
self.initialize_logging()
|
||||||
|
@ -201,8 +207,11 @@ class PGCli(object):
|
||||||
self.syntax_style = c["main"]["syntax_style"]
|
self.syntax_style = c["main"]["syntax_style"]
|
||||||
self.cli_style = c["colors"]
|
self.cli_style = c["colors"]
|
||||||
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
|
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
|
||||||
c_dest_warning = c["main"].as_bool("destructive_warning")
|
self.destructive_warning = warn or c["main"]["destructive_warning"]
|
||||||
self.destructive_warning = c_dest_warning if warn is None else warn
|
# also handle boolean format of destructive warning
|
||||||
|
self.destructive_warning = {"true": "all", "false": "off"}.get(
|
||||||
|
self.destructive_warning.lower(), self.destructive_warning
|
||||||
|
)
|
||||||
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
|
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
|
||||||
self.null_string = c["main"].get("null_string", "<null>")
|
self.null_string = c["main"].get("null_string", "<null>")
|
||||||
self.prompt_format = (
|
self.prompt_format = (
|
||||||
|
@ -325,11 +334,11 @@ class PGCli(object):
|
||||||
if pattern not in TabularOutputFormatter().supported_formats:
|
if pattern not in TabularOutputFormatter().supported_formats:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
self.table_format = pattern
|
self.table_format = pattern
|
||||||
yield (None, None, None, "Changed table format to {}".format(pattern))
|
yield (None, None, None, f"Changed table format to {pattern}")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
msg = "Table format {} not recognized. Allowed formats:".format(pattern)
|
msg = f"Table format {pattern} not recognized. Allowed formats:"
|
||||||
for table_type in TabularOutputFormatter().supported_formats:
|
for table_type in TabularOutputFormatter().supported_formats:
|
||||||
msg += "\n\t{}".format(table_type)
|
msg += f"\n\t{table_type}"
|
||||||
msg += "\nCurrently set to: %s" % self.table_format
|
msg += "\nCurrently set to: %s" % self.table_format
|
||||||
yield (None, None, None, msg)
|
yield (None, None, None, msg)
|
||||||
|
|
||||||
|
@ -386,10 +395,13 @@ class PGCli(object):
|
||||||
try:
|
try:
|
||||||
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
|
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
|
||||||
query = f.read()
|
query = f.read()
|
||||||
except IOError as e:
|
except OSError as e:
|
||||||
return [(None, None, None, str(e), "", False, True)]
|
return [(None, None, None, str(e), "", False, True)]
|
||||||
|
|
||||||
if self.destructive_warning and confirm_destructive_query(query) is False:
|
if (
|
||||||
|
self.destructive_warning != "off"
|
||||||
|
and confirm_destructive_query(query, self.destructive_warning) is False
|
||||||
|
):
|
||||||
message = "Wise choice. Command execution stopped."
|
message = "Wise choice. Command execution stopped."
|
||||||
return [(None, None, None, message)]
|
return [(None, None, None, message)]
|
||||||
|
|
||||||
|
@ -407,7 +419,7 @@ class PGCli(object):
|
||||||
if not os.path.isfile(filename):
|
if not os.path.isfile(filename):
|
||||||
try:
|
try:
|
||||||
open(filename, "w").close()
|
open(filename, "w").close()
|
||||||
except IOError as e:
|
except OSError as e:
|
||||||
self.output_file = None
|
self.output_file = None
|
||||||
message = str(e) + "\nFile output disabled"
|
message = str(e) + "\nFile output disabled"
|
||||||
return [(None, None, None, message, "", False, True)]
|
return [(None, None, None, message, "", False, True)]
|
||||||
|
@ -479,7 +491,7 @@ class PGCli(object):
|
||||||
service_config, file = parse_service_info(service)
|
service_config, file = parse_service_info(service)
|
||||||
if service_config is None:
|
if service_config is None:
|
||||||
click.secho(
|
click.secho(
|
||||||
"service '%s' was not found in %s" % (service, file), err=True, fg="red"
|
f"service '{service}' was not found in {file}", err=True, fg="red"
|
||||||
)
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
self.connect(
|
self.connect(
|
||||||
|
@ -515,7 +527,7 @@ class PGCli(object):
|
||||||
passwd = os.environ.get("PGPASSWORD", "")
|
passwd = os.environ.get("PGPASSWORD", "")
|
||||||
|
|
||||||
# Find password from store
|
# Find password from store
|
||||||
key = "%s@%s" % (user, host)
|
key = f"{user}@{host}"
|
||||||
keyring_error_message = dedent(
|
keyring_error_message = dedent(
|
||||||
"""\
|
"""\
|
||||||
{}
|
{}
|
||||||
|
@ -644,8 +656,10 @@ class PGCli(object):
|
||||||
query = MetaQuery(query=text, successful=False)
|
query = MetaQuery(query=text, successful=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.destructive_warning:
|
if self.destructive_warning != "off":
|
||||||
destroy = confirm = confirm_destructive_query(text)
|
destroy = confirm = confirm_destructive_query(
|
||||||
|
text, self.destructive_warning
|
||||||
|
)
|
||||||
if destroy is False:
|
if destroy is False:
|
||||||
click.secho("Wise choice!")
|
click.secho("Wise choice!")
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
|
@ -677,7 +691,7 @@ class PGCli(object):
|
||||||
click.echo(text, file=f)
|
click.echo(text, file=f)
|
||||||
click.echo("\n".join(output), file=f)
|
click.echo("\n".join(output), file=f)
|
||||||
click.echo("", file=f) # extra newline
|
click.echo("", file=f) # extra newline
|
||||||
except IOError as e:
|
except OSError as e:
|
||||||
click.secho(str(e), err=True, fg="red")
|
click.secho(str(e), err=True, fg="red")
|
||||||
else:
|
else:
|
||||||
if output:
|
if output:
|
||||||
|
@ -729,7 +743,6 @@ class PGCli(object):
|
||||||
if not self.less_chatty:
|
if not self.less_chatty:
|
||||||
print("Server: PostgreSQL", self.pgexecute.server_version)
|
print("Server: PostgreSQL", self.pgexecute.server_version)
|
||||||
print("Version:", __version__)
|
print("Version:", __version__)
|
||||||
print("Chat: https://gitter.im/dbcli/pgcli")
|
|
||||||
print("Home: http://pgcli.com")
|
print("Home: http://pgcli.com")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -753,11 +766,7 @@ class PGCli(object):
|
||||||
while self.watch_command:
|
while self.watch_command:
|
||||||
try:
|
try:
|
||||||
query = self.execute_command(self.watch_command)
|
query = self.execute_command(self.watch_command)
|
||||||
click.echo(
|
click.echo(f"Waiting for {timing} seconds before repeating")
|
||||||
"Waiting for {0} seconds before repeating".format(
|
|
||||||
timing
|
|
||||||
)
|
|
||||||
)
|
|
||||||
sleep(timing)
|
sleep(timing)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.watch_command = None
|
self.watch_command = None
|
||||||
|
@ -979,16 +988,13 @@ class PGCli(object):
|
||||||
callback = functools.partial(
|
callback = functools.partial(
|
||||||
self._on_completions_refreshed, persist_priorities=persist_priorities
|
self._on_completions_refreshed, persist_priorities=persist_priorities
|
||||||
)
|
)
|
||||||
self.completion_refresher.refresh(
|
return self.completion_refresher.refresh(
|
||||||
self.pgexecute,
|
self.pgexecute,
|
||||||
self.pgspecial,
|
self.pgspecial,
|
||||||
callback,
|
callback,
|
||||||
history=history,
|
history=history,
|
||||||
settings=self.settings,
|
settings=self.settings,
|
||||||
)
|
)
|
||||||
return [
|
|
||||||
(None, None, None, "Auto-completion refresh started in the background.")
|
|
||||||
]
|
|
||||||
|
|
||||||
def _on_completions_refreshed(self, new_completer, persist_priorities):
|
def _on_completions_refreshed(self, new_completer, persist_priorities):
|
||||||
self._swap_completer_objects(new_completer, persist_priorities)
|
self._swap_completer_objects(new_completer, persist_priorities)
|
||||||
|
@ -1049,7 +1055,7 @@ class PGCli(object):
|
||||||
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
|
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
|
||||||
)
|
)
|
||||||
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
|
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
|
||||||
string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
|
string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
|
||||||
string = string.replace("\\n", "\n")
|
string = string.replace("\\n", "\n")
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
@ -1075,9 +1081,10 @@ class PGCli(object):
|
||||||
def echo_via_pager(self, text, color=None):
|
def echo_via_pager(self, text, color=None):
|
||||||
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
|
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
|
||||||
click.echo(text, color=color)
|
click.echo(text, color=color)
|
||||||
elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv":
|
elif (
|
||||||
click.echo_via_pager(text, color)
|
self.pgspecial.pager_config == PAGER_LONG_OUTPUT
|
||||||
elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
|
and self.table_format != "csv"
|
||||||
|
):
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
|
|
||||||
# The last 4 lines are reserved for the pgcli menu and padding
|
# The last 4 lines are reserved for the pgcli menu and padding
|
||||||
|
@ -1192,7 +1199,10 @@ class PGCli(object):
|
||||||
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
|
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--warn/--no-warn", default=None, help="Warn before running a destructive query."
|
"--warn",
|
||||||
|
default=None,
|
||||||
|
type=click.Choice(["all", "moderate", "off"]),
|
||||||
|
help="Warn before running a destructive query.",
|
||||||
)
|
)
|
||||||
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
|
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
|
||||||
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
|
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
|
||||||
|
@ -1384,7 +1394,7 @@ def is_mutating(status):
|
||||||
if not status:
|
if not status:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
mutating = set(["insert", "update", "delete"])
|
mutating = {"insert", "update", "delete"}
|
||||||
return status.split(None, 1)[0].lower() in mutating
|
return status.split(None, 1)[0].lower() in mutating
|
||||||
|
|
||||||
|
|
||||||
|
@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings):
|
||||||
formatted = iter(formatted.splitlines())
|
formatted = iter(formatted.splitlines())
|
||||||
first_line = next(formatted)
|
first_line = next(formatted)
|
||||||
formatted = itertools.chain([first_line], formatted)
|
formatted = itertools.chain([first_line], formatted)
|
||||||
if not expanded and max_width and len(first_line) > max_width and headers:
|
if (
|
||||||
|
not expanded
|
||||||
|
and max_width
|
||||||
|
and len(strip_ansi(first_line)) > max_width
|
||||||
|
and headers
|
||||||
|
):
|
||||||
formatted = formatter.format_output(
|
formatted = formatter.format_output(
|
||||||
cur, headers, format_name="vertical", column_types=None, **output_kwargs
|
cur, headers, format_name="vertical", column_types=None, **output_kwargs
|
||||||
)
|
)
|
||||||
|
@ -1502,10 +1517,16 @@ def parse_service_info(service):
|
||||||
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
|
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
|
||||||
else:
|
else:
|
||||||
service_file = expanduser("~/.pg_service.conf")
|
service_file = expanduser("~/.pg_service.conf")
|
||||||
if not service:
|
if not service or not os.path.exists(service_file):
|
||||||
# nothing to do
|
# nothing to do
|
||||||
return None, service_file
|
return None, service_file
|
||||||
service_file_config = ConfigObj(service_file)
|
with open(service_file, newline="") as f:
|
||||||
|
skipped_lines = skip_initial_comment(f)
|
||||||
|
try:
|
||||||
|
service_file_config = ConfigObj(f)
|
||||||
|
except ParseError as err:
|
||||||
|
err.line_number += skipped_lines
|
||||||
|
raise err
|
||||||
if service not in service_file_config:
|
if service not in service_file_config:
|
||||||
return None, service_file
|
return None, service_file
|
||||||
service_conf = service_file_config.get(service)
|
service_conf = service_file_config.get(service)
|
||||||
|
|
|
@ -1,22 +1,34 @@
|
||||||
import sqlparse
|
import sqlparse
|
||||||
|
|
||||||
|
|
||||||
def query_starts_with(query, prefixes):
|
def query_starts_with(formatted_sql, prefixes):
|
||||||
"""Check if the query starts with any item from *prefixes*."""
|
"""Check if the query starts with any item from *prefixes*."""
|
||||||
prefixes = [prefix.lower() for prefix in prefixes]
|
prefixes = [prefix.lower() for prefix in prefixes]
|
||||||
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
|
|
||||||
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
||||||
|
|
||||||
|
|
||||||
def queries_start_with(queries, prefixes):
|
def query_is_unconditional_update(formatted_sql):
|
||||||
"""Check if any queries start with any item from *prefixes*."""
|
"""Check if the query starts with UPDATE and contains no WHERE."""
|
||||||
for query in sqlparse.split(queries):
|
tokens = formatted_sql.split()
|
||||||
if query and query_starts_with(query, prefixes) is True:
|
return bool(tokens) and tokens[0] == "update" and "where" not in tokens
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_destructive(queries):
|
def query_is_simple_update(formatted_sql):
|
||||||
|
"""Check if the query starts with UPDATE."""
|
||||||
|
tokens = formatted_sql.split()
|
||||||
|
return bool(tokens) and tokens[0] == "update"
|
||||||
|
|
||||||
|
|
||||||
|
def is_destructive(queries, warning_level="all"):
|
||||||
"""Returns if any of the queries in *queries* is destructive."""
|
"""Returns if any of the queries in *queries* is destructive."""
|
||||||
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
|
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
|
||||||
return queries_start_with(queries, keywords)
|
for query in sqlparse.split(queries):
|
||||||
|
if query:
|
||||||
|
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
|
||||||
|
if query_starts_with(formatted_sql, keywords):
|
||||||
|
return True
|
||||||
|
if query_is_unconditional_update(formatted_sql):
|
||||||
|
return True
|
||||||
|
if warning_level == "all" and query_is_simple_update(formatted_sql):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
|
@ -50,7 +50,7 @@ def parse_defaults(defaults_string):
|
||||||
yield current
|
yield current
|
||||||
|
|
||||||
|
|
||||||
class FunctionMetadata(object):
|
class FunctionMetadata:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
schema_name,
|
schema_name,
|
||||||
|
|
|
@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
||||||
for item in parsed.tokens:
|
for item in parsed.tokens:
|
||||||
if tbl_prefix_seen:
|
if tbl_prefix_seen:
|
||||||
if is_subselect(item):
|
if is_subselect(item):
|
||||||
for x in extract_from_part(item, stop_at_punctuation):
|
yield from extract_from_part(item, stop_at_punctuation)
|
||||||
yield x
|
|
||||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||||
return
|
return
|
||||||
# An incomplete nested select won't be recognized correctly as a
|
# An incomplete nested select won't be recognized correctly as a
|
||||||
|
|
|
@ -392,6 +392,7 @@
|
||||||
"QUOTE_NULLABLE",
|
"QUOTE_NULLABLE",
|
||||||
"RADIANS",
|
"RADIANS",
|
||||||
"RADIUS",
|
"RADIUS",
|
||||||
|
"RANDOM",
|
||||||
"RANK",
|
"RANK",
|
||||||
"REGEXP_MATCH",
|
"REGEXP_MATCH",
|
||||||
"REGEXP_MATCHES",
|
"REGEXP_MATCHES",
|
||||||
|
|
|
@ -16,10 +16,10 @@ def _compile_regex(keyword):
|
||||||
|
|
||||||
|
|
||||||
keywords = get_literals("keywords")
|
keywords = get_literals("keywords")
|
||||||
keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
|
keyword_regexs = {kw: _compile_regex(kw) for kw in keywords}
|
||||||
|
|
||||||
|
|
||||||
class PrevalenceCounter(object):
|
class PrevalenceCounter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.keyword_counts = defaultdict(int)
|
self.keyword_counts = defaultdict(int)
|
||||||
self.name_counts = defaultdict(int)
|
self.name_counts = defaultdict(int)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import click
|
||||||
from .parseutils import is_destructive
|
from .parseutils import is_destructive
|
||||||
|
|
||||||
|
|
||||||
def confirm_destructive_query(queries):
|
def confirm_destructive_query(queries, warning_level):
|
||||||
"""Check if the query is destructive and prompts the user to confirm.
|
"""Check if the query is destructive and prompts the user to confirm.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -15,7 +15,7 @@ def confirm_destructive_query(queries):
|
||||||
prompt_text = (
|
prompt_text = (
|
||||||
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
|
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
|
||||||
)
|
)
|
||||||
if is_destructive(queries) and sys.stdin.isatty():
|
if is_destructive(queries, warning_level) and sys.stdin.isatty():
|
||||||
return prompt(prompt_text, type=bool)
|
return prompt(prompt_text, type=bool)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ Alias = namedtuple("Alias", ["aliases"])
|
||||||
Path = namedtuple("Path", [])
|
Path = namedtuple("Path", [])
|
||||||
|
|
||||||
|
|
||||||
class SqlStatement(object):
|
class SqlStatement:
|
||||||
def __init__(self, full_text, text_before_cursor):
|
def __init__(self, full_text, text_before_cursor):
|
||||||
self.identifier = None
|
self.identifier = None
|
||||||
self.word_before_cursor = word_before_cursor = last_word(
|
self.word_before_cursor = word_before_cursor = last_word(
|
||||||
|
|
|
@ -23,9 +23,13 @@ multi_line = False
|
||||||
multi_line_mode = psql
|
multi_line_mode = psql
|
||||||
|
|
||||||
# Destructive warning mode will alert you before executing a sql statement
|
# Destructive warning mode will alert you before executing a sql statement
|
||||||
# that may cause harm to the database such as "drop table", "drop database"
|
# that may cause harm to the database such as "drop table", "drop database",
|
||||||
# or "shutdown".
|
# "shutdown", "delete", or "update".
|
||||||
destructive_warning = True
|
# Possible values:
|
||||||
|
# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE
|
||||||
|
# "moderate" - skip warning on UPDATE statements, except for unconditional updates
|
||||||
|
# "off" - skip all warnings
|
||||||
|
destructive_warning = all
|
||||||
|
|
||||||
# Enables expand mode, which is similar to `\x` in psql.
|
# Enables expand mode, which is similar to `\x` in psql.
|
||||||
expand = False
|
expand = False
|
||||||
|
@ -170,9 +174,12 @@ arg-toolbar = 'noinherit bold'
|
||||||
arg-toolbar.text = 'nobold'
|
arg-toolbar.text = 'nobold'
|
||||||
bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
|
bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
|
||||||
bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
|
bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
|
||||||
literal.string = '#ba2121'
|
# These three values can be used to further refine the syntax highlighting.
|
||||||
literal.number = '#666666'
|
# They are commented out by default, since they have priority over the theme set
|
||||||
keyword = 'bold #008000'
|
# with the `syntax_style` setting and overriding its behavior can be confusing.
|
||||||
|
# literal.string = '#ba2121'
|
||||||
|
# literal.number = '#666666'
|
||||||
|
# keyword = 'bold #008000'
|
||||||
|
|
||||||
# style classes for colored table output
|
# style classes for colored table output
|
||||||
output.header = "#00ff5f bold"
|
output.header = "#00ff5f bold"
|
||||||
|
|
|
@ -83,7 +83,7 @@ class PGCompleter(Completer):
|
||||||
reserved_words = set(get_literals("reserved"))
|
reserved_words = set(get_literals("reserved"))
|
||||||
|
|
||||||
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
|
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
|
||||||
super(PGCompleter, self).__init__()
|
super().__init__()
|
||||||
self.smart_completion = smart_completion
|
self.smart_completion = smart_completion
|
||||||
self.pgspecial = pgspecial
|
self.pgspecial = pgspecial
|
||||||
self.prioritizer = PrevalenceCounter()
|
self.prioritizer = PrevalenceCounter()
|
||||||
|
@ -140,7 +140,7 @@ class PGCompleter(Completer):
|
||||||
return "'{}'".format(self.unescape_name(name))
|
return "'{}'".format(self.unescape_name(name))
|
||||||
|
|
||||||
def unescape_name(self, name):
|
def unescape_name(self, name):
|
||||||
""" Unquote a string."""
|
"""Unquote a string."""
|
||||||
if name and name[0] == '"' and name[-1] == '"':
|
if name and name[0] == '"' and name[-1] == '"':
|
||||||
name = name[1:-1]
|
name = name[1:-1]
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ class PGCompleter(Completer):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# casing should be a dict {lowercasename:PreferredCasingName}
|
# casing should be a dict {lowercasename:PreferredCasingName}
|
||||||
self.casing = dict((word.lower(), word) for word in words)
|
self.casing = {word.lower(): word for word in words}
|
||||||
|
|
||||||
def extend_relations(self, data, kind):
|
def extend_relations(self, data, kind):
|
||||||
"""extend metadata for tables or views.
|
"""extend metadata for tables or views.
|
||||||
|
@ -279,8 +279,8 @@ class PGCompleter(Completer):
|
||||||
fk = ForeignKey(
|
fk = ForeignKey(
|
||||||
parentschema, parenttable, parcol, childschema, childtable, childcol
|
parentschema, parenttable, parcol, childschema, childtable, childcol
|
||||||
)
|
)
|
||||||
childcolmeta.foreignkeys.append((fk))
|
childcolmeta.foreignkeys.append(fk)
|
||||||
parcolmeta.foreignkeys.append((fk))
|
parcolmeta.foreignkeys.append(fk)
|
||||||
|
|
||||||
def extend_datatypes(self, type_data):
|
def extend_datatypes(self, type_data):
|
||||||
|
|
||||||
|
@ -424,7 +424,7 @@ class PGCompleter(Completer):
|
||||||
# the same priority as unquoted names.
|
# the same priority as unquoted names.
|
||||||
lexical_priority = (
|
lexical_priority = (
|
||||||
tuple(
|
tuple(
|
||||||
0 if c in (" _") else -ord(c)
|
0 if c in " _" else -ord(c)
|
||||||
for c in self.unescape_name(item.lower())
|
for c in self.unescape_name(item.lower())
|
||||||
)
|
)
|
||||||
+ (1,)
|
+ (1,)
|
||||||
|
@ -517,9 +517,9 @@ class PGCompleter(Completer):
|
||||||
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
|
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
|
||||||
# suggest only columns that appear in the last table and one more
|
# suggest only columns that appear in the last table and one more
|
||||||
ltbl = tables[-1].ref
|
ltbl = tables[-1].ref
|
||||||
other_tbl_cols = set(
|
other_tbl_cols = {
|
||||||
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
|
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
|
||||||
)
|
}
|
||||||
scoped_cols = {
|
scoped_cols = {
|
||||||
t: [col for col in cols if col.name in other_tbl_cols]
|
t: [col for col in cols if col.name in other_tbl_cols]
|
||||||
for t, cols in scoped_cols.items()
|
for t, cols in scoped_cols.items()
|
||||||
|
@ -574,7 +574,7 @@ class PGCompleter(Completer):
|
||||||
tbls - TableReference iterable of tables already in query
|
tbls - TableReference iterable of tables already in query
|
||||||
"""
|
"""
|
||||||
tbl = self.case(tbl)
|
tbl = self.case(tbl)
|
||||||
tbls = set(normalize_ref(t.ref) for t in tbls)
|
tbls = {normalize_ref(t.ref) for t in tbls}
|
||||||
if self.generate_aliases:
|
if self.generate_aliases:
|
||||||
tbl = generate_alias(self.unescape_name(tbl))
|
tbl = generate_alias(self.unescape_name(tbl))
|
||||||
if normalize_ref(tbl) not in tbls:
|
if normalize_ref(tbl) not in tbls:
|
||||||
|
@ -589,10 +589,10 @@ class PGCompleter(Completer):
|
||||||
tbls = suggestion.table_refs
|
tbls = suggestion.table_refs
|
||||||
cols = self.populate_scoped_cols(tbls)
|
cols = self.populate_scoped_cols(tbls)
|
||||||
# Set up some data structures for efficient access
|
# Set up some data structures for efficient access
|
||||||
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
|
qualified = {normalize_ref(t.ref): t.schema for t in tbls}
|
||||||
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
|
ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
|
||||||
refs = set(normalize_ref(t.ref) for t in tbls)
|
refs = {normalize_ref(t.ref) for t in tbls}
|
||||||
other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
|
other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
|
||||||
joins = []
|
joins = []
|
||||||
# Iterate over FKs in existing tables to find potential joins
|
# Iterate over FKs in existing tables to find potential joins
|
||||||
fks = (
|
fks = (
|
||||||
|
@ -667,7 +667,7 @@ class PGCompleter(Completer):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
# Tables that are closer to the cursor get higher prio
|
# Tables that are closer to the cursor get higher prio
|
||||||
ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
|
ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
|
||||||
# Map (schema, table, col) to tables
|
# Map (schema, table, col) to tables
|
||||||
coldict = list_dict(
|
coldict = list_dict(
|
||||||
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
|
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
|
||||||
|
@ -703,7 +703,11 @@ class PGCompleter(Completer):
|
||||||
not f.is_aggregate
|
not f.is_aggregate
|
||||||
and not f.is_window
|
and not f.is_window
|
||||||
and not f.is_extension
|
and not f.is_extension
|
||||||
and (f.is_public or f.schema_name == suggestion.schema)
|
and (
|
||||||
|
f.is_public
|
||||||
|
or f.schema_name in self.search_path
|
||||||
|
or f.schema_name == suggestion.schema
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -721,9 +725,7 @@ class PGCompleter(Completer):
|
||||||
# Function overloading means we way have multiple functions of the same
|
# Function overloading means we way have multiple functions of the same
|
||||||
# name at this point, so keep unique names only
|
# name at this point, so keep unique names only
|
||||||
all_functions = self.populate_functions(suggestion.schema, filt)
|
all_functions = self.populate_functions(suggestion.schema, filt)
|
||||||
funcs = set(
|
funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
|
||||||
self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
|
|
||||||
)
|
|
||||||
|
|
||||||
matches = self.find_matches(word_before_cursor, funcs, meta="function")
|
matches = self.find_matches(word_before_cursor, funcs, meta="function")
|
||||||
|
|
||||||
|
@ -953,7 +955,7 @@ class PGCompleter(Completer):
|
||||||
:return: {TableReference:{colname:ColumnMetaData}}
|
:return: {TableReference:{colname:ColumnMetaData}}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
|
ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
|
||||||
columns = OrderedDict()
|
columns = OrderedDict()
|
||||||
meta = self.dbmetadata
|
meta = self.dbmetadata
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import traceback
|
|
||||||
import logging
|
import logging
|
||||||
|
import select
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import pgspecial as special
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import psycopg2.extras
|
|
||||||
import psycopg2.errorcodes
|
import psycopg2.errorcodes
|
||||||
import psycopg2.extensions as ext
|
import psycopg2.extensions as ext
|
||||||
|
import psycopg2.extras
|
||||||
import sqlparse
|
import sqlparse
|
||||||
import pgspecial as special
|
|
||||||
import select
|
|
||||||
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
|
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
|
||||||
|
|
||||||
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
|
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
|
||||||
|
|
||||||
# TODO: Get default timeout from pgclirc?
|
# TODO: Get default timeout from pgclirc?
|
||||||
_WAIT_SELECT_TIMEOUT = 1
|
_WAIT_SELECT_TIMEOUT = 1
|
||||||
|
_wait_callback_is_set = False
|
||||||
|
|
||||||
|
|
||||||
def _wait_select(conn):
|
def _wait_select(conn):
|
||||||
|
@ -34,6 +37,7 @@ def _wait_select(conn):
|
||||||
copy-pasted from psycopg2.extras.wait_select
|
copy-pasted from psycopg2.extras.wait_select
|
||||||
the default implementation doesn't define a timeout in the select calls
|
the default implementation doesn't define a timeout in the select calls
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
while 1:
|
while 1:
|
||||||
try:
|
try:
|
||||||
state = conn.poll()
|
state = conn.poll()
|
||||||
|
@ -49,16 +53,25 @@ def _wait_select(conn):
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
# the loop will be broken by a server error
|
# the loop will be broken by a server error
|
||||||
continue
|
continue
|
||||||
except select.error as e:
|
except OSError as e:
|
||||||
errno = e.args[0]
|
errno = e.args[0]
|
||||||
if errno != 4:
|
if errno != 4:
|
||||||
raise
|
raise
|
||||||
|
except psycopg2.OperationalError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
|
def _set_wait_callback(is_virtual_database):
|
||||||
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
|
global _wait_callback_is_set
|
||||||
# See also https://github.com/psycopg/psycopg2/issues/468
|
if _wait_callback_is_set:
|
||||||
ext.set_wait_callback(_wait_select)
|
return
|
||||||
|
_wait_callback_is_set = True
|
||||||
|
if is_virtual_database:
|
||||||
|
return
|
||||||
|
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
|
||||||
|
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
|
||||||
|
# See also https://github.com/psycopg/psycopg2/issues/468
|
||||||
|
ext.set_wait_callback(_wait_select)
|
||||||
|
|
||||||
|
|
||||||
def register_date_typecasters(connection):
|
def register_date_typecasters(connection):
|
||||||
|
@ -72,6 +85,8 @@ def register_date_typecasters(connection):
|
||||||
|
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
cursor.execute("SELECT NULL::date")
|
cursor.execute("SELECT NULL::date")
|
||||||
|
if cursor.description is None:
|
||||||
|
return
|
||||||
date_oid = cursor.description[0][1]
|
date_oid = cursor.description[0][1]
|
||||||
cursor.execute("SELECT NULL::timestamp")
|
cursor.execute("SELECT NULL::timestamp")
|
||||||
timestamp_oid = cursor.description[0][1]
|
timestamp_oid = cursor.description[0][1]
|
||||||
|
@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
|
||||||
try:
|
try:
|
||||||
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
|
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
|
||||||
available.add(name)
|
available.add(name)
|
||||||
except psycopg2.ProgrammingError:
|
except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return available
|
return available
|
||||||
|
@ -127,7 +142,39 @@ def register_hstore_typecaster(conn):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PGExecute(object):
|
class ProtocolSafeCursor(psycopg2.extensions.cursor):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.protocol_error = False
|
||||||
|
self.protocol_message = ""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.protocol_error:
|
||||||
|
raise StopIteration
|
||||||
|
return super().__iter__()
|
||||||
|
|
||||||
|
def fetchall(self):
|
||||||
|
if self.protocol_error:
|
||||||
|
return [(self.protocol_message,)]
|
||||||
|
return super().fetchall()
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
if self.protocol_error:
|
||||||
|
return (self.protocol_message,)
|
||||||
|
return super().fetchone()
|
||||||
|
|
||||||
|
def execute(self, sql, args=None):
|
||||||
|
try:
|
||||||
|
psycopg2.extensions.cursor.execute(self, sql, args)
|
||||||
|
self.protocol_error = False
|
||||||
|
self.protocol_message = ""
|
||||||
|
except psycopg2.errors.ProtocolViolation as ex:
|
||||||
|
self.protocol_error = True
|
||||||
|
self.protocol_message = ex.pgerror
|
||||||
|
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
|
||||||
|
|
||||||
|
|
||||||
|
class PGExecute:
|
||||||
|
|
||||||
# The boolean argument to the current_schemas function indicates whether
|
# The boolean argument to the current_schemas function indicates whether
|
||||||
# implicit schemas, e.g. pg_catalog
|
# implicit schemas, e.g. pg_catalog
|
||||||
|
@ -190,8 +237,6 @@ class PGExecute(object):
|
||||||
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
|
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
|
||||||
FROM f"""
|
FROM f"""
|
||||||
|
|
||||||
version_query = "SELECT version();"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database=None,
|
database=None,
|
||||||
|
@ -203,6 +248,7 @@ class PGExecute(object):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._conn_params = {}
|
self._conn_params = {}
|
||||||
|
self._is_virtual_database = None
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.dbname = None
|
self.dbname = None
|
||||||
self.user = None
|
self.user = None
|
||||||
|
@ -214,6 +260,11 @@ class PGExecute(object):
|
||||||
self.connect(database, user, password, host, port, dsn, **kwargs)
|
self.connect(database, user, password, host, port, dsn, **kwargs)
|
||||||
self.reset_expanded = None
|
self.reset_expanded = None
|
||||||
|
|
||||||
|
def is_virtual_database(self):
|
||||||
|
if self._is_virtual_database is None:
|
||||||
|
self._is_virtual_database = self.is_protocol_error()
|
||||||
|
return self._is_virtual_database
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
"""Returns a clone of the current executor."""
|
"""Returns a clone of the current executor."""
|
||||||
return self.__class__(**self._conn_params)
|
return self.__class__(**self._conn_params)
|
||||||
|
@ -250,9 +301,9 @@ class PGExecute(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
conn_params.update({k: v for k, v in new_params.items() if v})
|
conn_params.update({k: v for k, v in new_params.items() if v})
|
||||||
|
conn_params["cursor_factory"] = ProtocolSafeCursor
|
||||||
|
|
||||||
conn = psycopg2.connect(**conn_params)
|
conn = psycopg2.connect(**conn_params)
|
||||||
cursor = conn.cursor()
|
|
||||||
conn.set_client_encoding("utf8")
|
conn.set_client_encoding("utf8")
|
||||||
|
|
||||||
self._conn_params = conn_params
|
self._conn_params = conn_params
|
||||||
|
@ -293,13 +344,19 @@ class PGExecute(object):
|
||||||
self.extra_args = kwargs
|
self.extra_args = kwargs
|
||||||
|
|
||||||
if not self.host:
|
if not self.host:
|
||||||
self.host = self.get_socket_directory()
|
self.host = (
|
||||||
|
"pgbouncer"
|
||||||
|
if self.is_virtual_database()
|
||||||
|
else self.get_socket_directory()
|
||||||
|
)
|
||||||
|
|
||||||
pid = self._select_one(cursor, "select pg_backend_pid()")[0]
|
self.pid = conn.get_backend_pid()
|
||||||
self.pid = pid
|
|
||||||
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
|
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
|
||||||
self.server_version = conn.get_parameter_status("server_version")
|
self.server_version = conn.get_parameter_status("server_version") or ""
|
||||||
|
|
||||||
|
_set_wait_callback(self.is_virtual_database())
|
||||||
|
|
||||||
|
if not self.is_virtual_database():
|
||||||
register_date_typecasters(conn)
|
register_date_typecasters(conn)
|
||||||
register_json_typecasters(self.conn, self._json_typecaster)
|
register_json_typecasters(self.conn, self._json_typecaster)
|
||||||
register_hstore_typecaster(self.conn)
|
register_hstore_typecaster(self.conn)
|
||||||
|
@ -395,7 +452,13 @@ class PGExecute(object):
|
||||||
# See https://github.com/dbcli/pgcli/issues/1014.
|
# See https://github.com/dbcli/pgcli/issues/1014.
|
||||||
cur = None
|
cur = None
|
||||||
try:
|
try:
|
||||||
for result in pgspecial.execute(cur, sql):
|
response = pgspecial.execute(cur, sql)
|
||||||
|
if cur and cur.protocol_error:
|
||||||
|
yield None, None, None, cur.protocol_message, statement, False, False
|
||||||
|
# this would close connection. We should reconnect.
|
||||||
|
self.connect()
|
||||||
|
continue
|
||||||
|
for result in response:
|
||||||
# e.g. execute_from_file already appends these
|
# e.g. execute_from_file already appends these
|
||||||
if len(result) < 7:
|
if len(result) < 7:
|
||||||
yield result + (sql, True, True)
|
yield result + (sql, True, True)
|
||||||
|
@ -453,6 +516,9 @@ class PGExecute(object):
|
||||||
if cur.description:
|
if cur.description:
|
||||||
headers = [x[0] for x in cur.description]
|
headers = [x[0] for x in cur.description]
|
||||||
return title, cur, headers, cur.statusmessage
|
return title, cur, headers, cur.statusmessage
|
||||||
|
elif cur.protocol_error:
|
||||||
|
_logger.debug("Protocol error, unsupported command.")
|
||||||
|
return title, None, None, cur.protocol_message
|
||||||
else:
|
else:
|
||||||
_logger.debug("No rows in result.")
|
_logger.debug("No rows in result.")
|
||||||
return title, None, None, cur.statusmessage
|
return title, None, None, cur.statusmessage
|
||||||
|
@ -485,7 +551,7 @@ class PGExecute(object):
|
||||||
try:
|
try:
|
||||||
cur.execute(sql, (spec,))
|
cur.execute(sql, (spec,))
|
||||||
except psycopg2.ProgrammingError:
|
except psycopg2.ProgrammingError:
|
||||||
raise RuntimeError("View {} does not exist.".format(spec))
|
raise RuntimeError(f"View {spec} does not exist.")
|
||||||
result = cur.fetchone()
|
result = cur.fetchone()
|
||||||
view_type = "MATERIALIZED" if result[2] == "m" else ""
|
view_type = "MATERIALIZED" if result[2] == "m" else ""
|
||||||
return template.format(*result + (view_type,))
|
return template.format(*result + (view_type,))
|
||||||
|
@ -501,7 +567,7 @@ class PGExecute(object):
|
||||||
result = cur.fetchone()
|
result = cur.fetchone()
|
||||||
return result[0]
|
return result[0]
|
||||||
except psycopg2.ProgrammingError:
|
except psycopg2.ProgrammingError:
|
||||||
raise RuntimeError("Function {} does not exist.".format(spec))
|
raise RuntimeError(f"Function {spec} does not exist.")
|
||||||
|
|
||||||
def schemata(self):
|
def schemata(self):
|
||||||
"""Returns a list of schema names in the database"""
|
"""Returns a list of schema names in the database"""
|
||||||
|
@ -527,21 +593,18 @@ class PGExecute(object):
|
||||||
sql = cur.mogrify(self.tables_query, [kinds])
|
sql = cur.mogrify(self.tables_query, [kinds])
|
||||||
_logger.debug("Tables Query. sql: %r", sql)
|
_logger.debug("Tables Query. sql: %r", sql)
|
||||||
cur.execute(sql)
|
cur.execute(sql)
|
||||||
for row in cur:
|
yield from cur
|
||||||
yield row
|
|
||||||
|
|
||||||
def tables(self):
|
def tables(self):
|
||||||
"""Yields (schema_name, table_name) tuples"""
|
"""Yields (schema_name, table_name) tuples"""
|
||||||
for row in self._relations(kinds=["r", "p", "f"]):
|
yield from self._relations(kinds=["r", "p", "f"])
|
||||||
yield row
|
|
||||||
|
|
||||||
def views(self):
|
def views(self):
|
||||||
"""Yields (schema_name, view_name) tuples.
|
"""Yields (schema_name, view_name) tuples.
|
||||||
|
|
||||||
Includes both views and and materialized views
|
Includes both views and and materialized views
|
||||||
"""
|
"""
|
||||||
for row in self._relations(kinds=["v", "m"]):
|
yield from self._relations(kinds=["v", "m"])
|
||||||
yield row
|
|
||||||
|
|
||||||
def _columns(self, kinds=("r", "p", "f", "v", "m")):
|
def _columns(self, kinds=("r", "p", "f", "v", "m")):
|
||||||
"""Get column metadata for tables and views
|
"""Get column metadata for tables and views
|
||||||
|
@ -599,16 +662,13 @@ class PGExecute(object):
|
||||||
sql = cur.mogrify(columns_query, [kinds])
|
sql = cur.mogrify(columns_query, [kinds])
|
||||||
_logger.debug("Columns Query. sql: %r", sql)
|
_logger.debug("Columns Query. sql: %r", sql)
|
||||||
cur.execute(sql)
|
cur.execute(sql)
|
||||||
for row in cur:
|
yield from cur
|
||||||
yield row
|
|
||||||
|
|
||||||
def table_columns(self):
|
def table_columns(self):
|
||||||
for row in self._columns(kinds=["r", "p", "f"]):
|
yield from self._columns(kinds=["r", "p", "f"])
|
||||||
yield row
|
|
||||||
|
|
||||||
def view_columns(self):
|
def view_columns(self):
|
||||||
for row in self._columns(kinds=["v", "m"]):
|
yield from self._columns(kinds=["v", "m"])
|
||||||
yield row
|
|
||||||
|
|
||||||
def databases(self):
|
def databases(self):
|
||||||
with self.conn.cursor() as cur:
|
with self.conn.cursor() as cur:
|
||||||
|
@ -623,6 +683,13 @@ class PGExecute(object):
|
||||||
headers = [x[0] for x in cur.description]
|
headers = [x[0] for x in cur.description]
|
||||||
return cur.fetchall(), headers, cur.statusmessage
|
return cur.fetchall(), headers, cur.statusmessage
|
||||||
|
|
||||||
|
def is_protocol_error(self):
|
||||||
|
query = "SELECT 1"
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
_logger.debug("Simple Query. sql: %r", query)
|
||||||
|
cur.execute(query)
|
||||||
|
return bool(cur.protocol_error)
|
||||||
|
|
||||||
def get_socket_directory(self):
|
def get_socket_directory(self):
|
||||||
with self.conn.cursor() as cur:
|
with self.conn.cursor() as cur:
|
||||||
_logger.debug(
|
_logger.debug(
|
||||||
|
@ -804,8 +871,7 @@ class PGExecute(object):
|
||||||
"""
|
"""
|
||||||
_logger.debug("Datatypes Query. sql: %r", query)
|
_logger.debug("Datatypes Query. sql: %r", query)
|
||||||
cur.execute(query)
|
cur.execute(query)
|
||||||
for row in cur:
|
yield from cur
|
||||||
yield row
|
|
||||||
|
|
||||||
def casing(self):
|
def casing(self):
|
||||||
"""Yields the most common casing for names used in db functions"""
|
"""Yields the most common casing for names used in db functions"""
|
||||||
|
|
|
@ -1,15 +1,23 @@
|
||||||
|
from pkg_resources import packaging
|
||||||
|
|
||||||
|
import prompt_toolkit
|
||||||
from prompt_toolkit.key_binding.vi_state import InputMode
|
from prompt_toolkit.key_binding.vi_state import InputMode
|
||||||
from prompt_toolkit.application import get_app
|
from prompt_toolkit.application import get_app
|
||||||
|
|
||||||
|
parse_version = packaging.version.parse
|
||||||
|
|
||||||
def _get_vi_mode():
|
vi_modes = {
|
||||||
return {
|
|
||||||
InputMode.INSERT: "I",
|
InputMode.INSERT: "I",
|
||||||
InputMode.NAVIGATION: "N",
|
InputMode.NAVIGATION: "N",
|
||||||
InputMode.REPLACE: "R",
|
InputMode.REPLACE: "R",
|
||||||
InputMode.REPLACE_SINGLE: "R",
|
|
||||||
InputMode.INSERT_MULTIPLE: "M",
|
InputMode.INSERT_MULTIPLE: "M",
|
||||||
}[get_app().vi_state.input_mode]
|
}
|
||||||
|
if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"):
|
||||||
|
vi_modes[InputMode.REPLACE_SINGLE] = "R"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vi_mode():
|
||||||
|
return vi_modes[get_app().vi_state.input_mode]
|
||||||
|
|
||||||
|
|
||||||
def create_toolbar_tokens_func(pgcli):
|
def create_toolbar_tokens_func(pgcli):
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
pytest>=2.7.0
|
pytest>=2.7.0
|
||||||
mock>=1.0.1
|
|
||||||
tox>=1.9.2
|
tox>=1.9.2
|
||||||
behave>=1.2.4
|
behave>=1.2.4
|
||||||
pexpect==3.3
|
pexpect==3.3
|
||||||
|
|
|
@ -12,7 +12,7 @@ from utils import (
|
||||||
import pgcli.pgexecute
|
import pgcli.pgexecute
|
||||||
|
|
||||||
|
|
||||||
@pytest.yield_fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def connection():
|
def connection():
|
||||||
create_db("_test_db")
|
create_db("_test_db")
|
||||||
connection = db_connection("_test_db")
|
connection = db_connection("_test_db")
|
||||||
|
|
|
@ -44,7 +44,7 @@ def create_cn(hostname, password, username, dbname, port):
|
||||||
host=hostname, user=username, database=dbname, password=password, port=port
|
host=hostname, user=username, database=dbname, password=password, port=port
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Created connection: {0}.".format(cn.dsn))
|
print(f"Created connection: {cn.dsn}.")
|
||||||
return cn
|
return cn
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,4 +75,4 @@ def close_cn(cn=None):
|
||||||
"""
|
"""
|
||||||
if cn:
|
if cn:
|
||||||
cn.close()
|
cn.close()
|
||||||
print("Closed connection: {0}.".format(cn.dsn))
|
print(f"Closed connection: {cn.dsn}.")
|
||||||
|
|
|
@ -38,7 +38,7 @@ def before_all(context):
|
||||||
|
|
||||||
vi = "_".join([str(x) for x in sys.version_info[:3]])
|
vi = "_".join([str(x) for x in sys.version_info[:3]])
|
||||||
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
|
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
|
||||||
db_name_full = "{0}_{1}".format(db_name, vi)
|
db_name_full = f"{db_name}_{vi}"
|
||||||
|
|
||||||
# Store get params from config.
|
# Store get params from config.
|
||||||
context.conf = {
|
context.conf = {
|
||||||
|
@ -63,7 +63,7 @@ def before_all(context):
|
||||||
"import coverage",
|
"import coverage",
|
||||||
"coverage.process_startup()",
|
"coverage.process_startup()",
|
||||||
"import pgcli.main",
|
"import pgcli.main",
|
||||||
"pgcli.main.cli()",
|
"pgcli.main.cli(auto_envvar_prefix='BEHAVE')",
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -102,6 +102,7 @@ def before_all(context):
|
||||||
else:
|
else:
|
||||||
if "PGPASSWORD" in os.environ:
|
if "PGPASSWORD" in os.environ:
|
||||||
del os.environ["PGPASSWORD"]
|
del os.environ["PGPASSWORD"]
|
||||||
|
os.environ["BEHAVE_WARN"] = "moderate"
|
||||||
|
|
||||||
context.cn = dbutils.create_db(
|
context.cn = dbutils.create_db(
|
||||||
context.conf["host"],
|
context.conf["host"],
|
||||||
|
@ -122,12 +123,12 @@ def before_all(context):
|
||||||
def show_env_changes(env_old, env_new):
|
def show_env_changes(env_old, env_new):
|
||||||
"""Print out all test-specific env values."""
|
"""Print out all test-specific env values."""
|
||||||
print("--- os.environ changed values: ---")
|
print("--- os.environ changed values: ---")
|
||||||
all_keys = set(list(env_old.keys()) + list(env_new.keys()))
|
all_keys = env_old.keys() | env_new.keys()
|
||||||
for k in sorted(all_keys):
|
for k in sorted(all_keys):
|
||||||
old_value = env_old.get(k, "")
|
old_value = env_old.get(k, "")
|
||||||
new_value = env_new.get(k, "")
|
new_value = env_new.get(k, "")
|
||||||
if new_value and old_value != new_value:
|
if new_value and old_value != new_value:
|
||||||
print('{}="{}"'.format(k, new_value))
|
print(f'{k}="{new_value}"')
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,13 +174,13 @@ def after_scenario(context, scenario):
|
||||||
# Quit nicely.
|
# Quit nicely.
|
||||||
if not context.atprompt:
|
if not context.atprompt:
|
||||||
dbname = context.currentdb
|
dbname = context.currentdb
|
||||||
context.cli.expect_exact("{0}> ".format(dbname), timeout=15)
|
context.cli.expect_exact(f"{dbname}> ", timeout=15)
|
||||||
context.cli.sendcontrol("c")
|
context.cli.sendcontrol("c")
|
||||||
context.cli.sendcontrol("d")
|
context.cli.sendcontrol("d")
|
||||||
try:
|
try:
|
||||||
context.cli.expect_exact(pexpect.EOF, timeout=15)
|
context.cli.expect_exact(pexpect.EOF, timeout=15)
|
||||||
except pexpect.TIMEOUT:
|
except pexpect.TIMEOUT:
|
||||||
print("--- after_scenario {}: kill cli".format(scenario.name))
|
print(f"--- after_scenario {scenario.name}: kill cli")
|
||||||
context.cli.kill(signal.SIGKILL)
|
context.cli.kill(signal.SIGKILL)
|
||||||
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
|
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
|
||||||
context.tmpfile_sql_help.close()
|
context.tmpfile_sql_help.close()
|
||||||
|
|
|
@ -18,7 +18,7 @@ def read_fixture_files():
|
||||||
"""Read all files inside fixture_data directory."""
|
"""Read all files inside fixture_data directory."""
|
||||||
current_dir = os.path.dirname(__file__)
|
current_dir = os.path.dirname(__file__)
|
||||||
fixture_dir = os.path.join(current_dir, "fixture_data/")
|
fixture_dir = os.path.join(current_dir, "fixture_data/")
|
||||||
print("reading fixture data: {}".format(fixture_dir))
|
print(f"reading fixture data: {fixture_dir}")
|
||||||
fixture_dict = {}
|
fixture_dict = {}
|
||||||
for filename in os.listdir(fixture_dir):
|
for filename in os.listdir(fixture_dir):
|
||||||
if filename not in [".", ".."]:
|
if filename not in [".", ".."]:
|
||||||
|
|
|
@ -65,19 +65,20 @@ def step_ctrl_d(context):
|
||||||
Send Ctrl + D to hopefully exit.
|
Send Ctrl + D to hopefully exit.
|
||||||
"""
|
"""
|
||||||
# turn off pager before exiting
|
# turn off pager before exiting
|
||||||
context.cli.sendline("\pset pager off")
|
context.cli.sendcontrol("c")
|
||||||
|
context.cli.sendline(r"\pset pager off")
|
||||||
wrappers.wait_prompt(context)
|
wrappers.wait_prompt(context)
|
||||||
context.cli.sendcontrol("d")
|
context.cli.sendcontrol("d")
|
||||||
context.cli.expect(pexpect.EOF, timeout=15)
|
context.cli.expect(pexpect.EOF, timeout=15)
|
||||||
context.exit_sent = True
|
context.exit_sent = True
|
||||||
|
|
||||||
|
|
||||||
@when('we send "\?" command')
|
@when(r'we send "\?" command')
|
||||||
def step_send_help(context):
|
def step_send_help(context):
|
||||||
"""
|
r"""
|
||||||
Send \? to see help.
|
Send \? to see help.
|
||||||
"""
|
"""
|
||||||
context.cli.sendline("\?")
|
context.cli.sendline(r"\?")
|
||||||
|
|
||||||
|
|
||||||
@when("we send partial select command")
|
@when("we send partial select command")
|
||||||
|
@ -96,9 +97,9 @@ def step_see_error_message(context):
|
||||||
@when("we send source command")
|
@when("we send source command")
|
||||||
def step_send_source_command(context):
|
def step_send_source_command(context):
|
||||||
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
|
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
|
||||||
context.tmpfile_sql_help.write(b"\?")
|
context.tmpfile_sql_help.write(br"\?")
|
||||||
context.tmpfile_sql_help.flush()
|
context.tmpfile_sql_help.flush()
|
||||||
context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name))
|
context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
|
||||||
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ def step_db_create(context):
|
||||||
"""
|
"""
|
||||||
Send create database.
|
Send create database.
|
||||||
"""
|
"""
|
||||||
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
|
context.cli.sendline("create database {};".format(context.conf["dbname_tmp"]))
|
||||||
|
|
||||||
context.response = {"database_name": context.conf["dbname_tmp"]}
|
context.response = {"database_name": context.conf["dbname_tmp"]}
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ def step_db_drop(context):
|
||||||
"""
|
"""
|
||||||
Send drop database.
|
Send drop database.
|
||||||
"""
|
"""
|
||||||
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
|
context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"]))
|
||||||
|
|
||||||
|
|
||||||
@when("we connect to test database")
|
@when("we connect to test database")
|
||||||
|
@ -33,7 +33,7 @@ def step_db_connect_test(context):
|
||||||
Send connect to database.
|
Send connect to database.
|
||||||
"""
|
"""
|
||||||
db_name = context.conf["dbname"]
|
db_name = context.conf["dbname"]
|
||||||
context.cli.sendline("\\connect {0}".format(db_name))
|
context.cli.sendline(f"\\connect {db_name}")
|
||||||
|
|
||||||
|
|
||||||
@when("we connect to dbserver")
|
@when("we connect to dbserver")
|
||||||
|
@ -59,7 +59,7 @@ def step_see_prompt(context):
|
||||||
Wait to see the prompt.
|
Wait to see the prompt.
|
||||||
"""
|
"""
|
||||||
db_name = getattr(context, "currentdb", context.conf["dbname"])
|
db_name = getattr(context, "currentdb", context.conf["dbname"])
|
||||||
wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5)
|
wrappers.expect_exact(context, f"{db_name}> ", timeout=5)
|
||||||
context.atprompt = True
|
context.atprompt = True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ def step_prepare_data(context):
|
||||||
@when("we set expanded {mode}")
|
@when("we set expanded {mode}")
|
||||||
def step_set_expanded(context, mode):
|
def step_set_expanded(context, mode):
|
||||||
"""Set expanded to mode."""
|
"""Set expanded to mode."""
|
||||||
context.cli.sendline("\\" + "x {}".format(mode))
|
context.cli.sendline("\\" + f"x {mode}")
|
||||||
wrappers.expect_exact(context, "Expanded display is", timeout=2)
|
wrappers.expect_exact(context, "Expanded display is", timeout=2)
|
||||||
wrappers.wait_prompt(context)
|
wrappers.wait_prompt(context)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ def step_edit_file(context):
|
||||||
)
|
)
|
||||||
if os.path.exists(context.editor_file_name):
|
if os.path.exists(context.editor_file_name):
|
||||||
os.remove(context.editor_file_name)
|
os.remove(context.editor_file_name)
|
||||||
context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name)))
|
context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name)))
|
||||||
wrappers.expect_exact(
|
wrappers.expect_exact(
|
||||||
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
|
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
|
||||||
)
|
)
|
||||||
|
@ -53,7 +53,7 @@ def step_tee_ouptut(context):
|
||||||
)
|
)
|
||||||
if os.path.exists(context.tee_file_name):
|
if os.path.exists(context.tee_file_name):
|
||||||
os.remove(context.tee_file_name)
|
os.remove(context.tee_file_name)
|
||||||
context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name)))
|
context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name)))
|
||||||
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
||||||
wrappers.expect_exact(context, "Writing to file", timeout=5)
|
wrappers.expect_exact(context, "Writing to file", timeout=5)
|
||||||
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
||||||
|
@ -67,7 +67,7 @@ def step_query_select_123456(context):
|
||||||
|
|
||||||
@when("we stop teeing output")
|
@when("we stop teeing output")
|
||||||
def step_notee_output(context):
|
def step_notee_output(context):
|
||||||
context.cli.sendline("\o")
|
context.cli.sendline(r"\o")
|
||||||
wrappers.expect_exact(context, "Time", timeout=5)
|
wrappers.expect_exact(context, "Time", timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,5 +22,10 @@ def step_see_refresh_started(context):
|
||||||
Wait to see refresh output.
|
Wait to see refresh output.
|
||||||
"""
|
"""
|
||||||
wrappers.expect_pager(
|
wrappers.expect_pager(
|
||||||
context, "Auto-completion refresh started in the background.\r\n", timeout=2
|
context,
|
||||||
|
[
|
||||||
|
"Auto-completion refresh started in the background.\r\n",
|
||||||
|
"Auto-completion refresh restarted.\r\n",
|
||||||
|
],
|
||||||
|
timeout=2,
|
||||||
)
|
)
|
||||||
|
|
|
@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout):
|
||||||
|
|
||||||
|
|
||||||
def expect_pager(context, expected, timeout):
|
def expect_pager(context, expected, timeout):
|
||||||
|
formatted = expected if isinstance(expected, list) else [expected]
|
||||||
|
formatted = [
|
||||||
|
f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
|
||||||
|
for t in formatted
|
||||||
|
]
|
||||||
|
|
||||||
expect_exact(
|
expect_exact(
|
||||||
context,
|
context,
|
||||||
"{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
|
formatted,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,7 +63,7 @@ def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
|
||||||
context.cli.logfile = context.logfile
|
context.cli.logfile = context.logfile
|
||||||
context.exit_sent = False
|
context.exit_sent = False
|
||||||
context.currentdb = currentdb or context.conf["dbname"]
|
context.currentdb = currentdb or context.conf["dbname"]
|
||||||
context.cli.sendline("\pset pager always")
|
context.cli.sendline(r"\pset pager always")
|
||||||
if prompt_check:
|
if prompt_check:
|
||||||
wait_prompt(context)
|
wait_prompt(context)
|
||||||
|
|
||||||
|
|
0
tests/features/wrappager.py
Executable file → Normal file
0
tests/features/wrappager.py
Executable file → Normal file
|
@ -3,7 +3,7 @@ from itertools import product
|
||||||
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
|
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
|
||||||
from prompt_toolkit.completion import Completion
|
from prompt_toolkit.completion import Completion
|
||||||
from prompt_toolkit.document import Document
|
from prompt_toolkit.document import Document
|
||||||
from mock import Mock
|
from unittest.mock import Mock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
parametrize = pytest.mark.parametrize
|
parametrize = pytest.mark.parametrize
|
||||||
|
@ -59,7 +59,7 @@ def wildcard_expansion(cols, pos=-1):
|
||||||
return Completion(cols, start_position=pos, display_meta="columns", display="*")
|
return Completion(cols, start_position=pos, display_meta="columns", display="*")
|
||||||
|
|
||||||
|
|
||||||
class MetaData(object):
|
class MetaData:
|
||||||
def __init__(self, metadata):
|
def __init__(self, metadata):
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ class MetaData(object):
|
||||||
]
|
]
|
||||||
|
|
||||||
def schemas(self, pos=0):
|
def schemas(self, pos=0):
|
||||||
schemas = set(sch for schs in self.metadata.values() for sch in schs)
|
schemas = {sch for schs in self.metadata.values() for sch in schs}
|
||||||
return [schema(escape(s), pos=pos) for s in schemas]
|
return [schema(escape(s), pos=pos) for s in schemas]
|
||||||
|
|
||||||
def functions_and_keywords(self, parent="public", pos=0):
|
def functions_and_keywords(self, parent="public", pos=0):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
from pgcli.packages.parseutils import is_destructive
|
||||||
from pgcli.packages.parseutils.tables import extract_tables
|
from pgcli.packages.parseutils.tables import extract_tables
|
||||||
from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
|
from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
|
||||||
|
|
||||||
|
@ -34,12 +35,12 @@ def test_simple_select_single_table_double_quoted():
|
||||||
|
|
||||||
def test_simple_select_multiple_tables():
|
def test_simple_select_multiple_tables():
|
||||||
tables = extract_tables("select * from abc, def")
|
tables = extract_tables("select * from abc, def")
|
||||||
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
|
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
|
||||||
|
|
||||||
|
|
||||||
def test_simple_select_multiple_tables_double_quoted():
|
def test_simple_select_multiple_tables_double_quoted():
|
||||||
tables = extract_tables('select * from "Abc", "Def"')
|
tables = extract_tables('select * from "Abc", "Def"')
|
||||||
assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)])
|
assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)}
|
||||||
|
|
||||||
|
|
||||||
def test_simple_select_single_table_deouble_quoted_aliased():
|
def test_simple_select_single_table_deouble_quoted_aliased():
|
||||||
|
@ -49,14 +50,12 @@ def test_simple_select_single_table_deouble_quoted_aliased():
|
||||||
|
|
||||||
def test_simple_select_multiple_tables_deouble_quoted_aliased():
|
def test_simple_select_multiple_tables_deouble_quoted_aliased():
|
||||||
tables = extract_tables('select * from "Abc" a, "Def" d')
|
tables = extract_tables('select * from "Abc" a, "Def" d')
|
||||||
assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)])
|
assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)}
|
||||||
|
|
||||||
|
|
||||||
def test_simple_select_multiple_tables_schema_qualified():
|
def test_simple_select_multiple_tables_schema_qualified():
|
||||||
tables = extract_tables("select * from abc.def, ghi.jkl")
|
tables = extract_tables("select * from abc.def, ghi.jkl")
|
||||||
assert set(tables) == set(
|
assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)}
|
||||||
[("abc", "def", None, False), ("ghi", "jkl", None, False)]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple_select_with_cols_single_table():
|
def test_simple_select_with_cols_single_table():
|
||||||
|
@ -71,14 +70,12 @@ def test_simple_select_with_cols_single_table_schema_qualified():
|
||||||
|
|
||||||
def test_simple_select_with_cols_multiple_tables():
|
def test_simple_select_with_cols_multiple_tables():
|
||||||
tables = extract_tables("select a,b from abc, def")
|
tables = extract_tables("select a,b from abc, def")
|
||||||
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
|
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
|
||||||
|
|
||||||
|
|
||||||
def test_simple_select_with_cols_multiple_qualified_tables():
|
def test_simple_select_with_cols_multiple_qualified_tables():
|
||||||
tables = extract_tables("select a,b from abc.def, def.ghi")
|
tables = extract_tables("select a,b from abc.def, def.ghi")
|
||||||
assert set(tables) == set(
|
assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)}
|
||||||
[("abc", "def", None, False), ("def", "ghi", None, False)]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_with_hanging_comma_single_table():
|
def test_select_with_hanging_comma_single_table():
|
||||||
|
@ -88,14 +85,12 @@ def test_select_with_hanging_comma_single_table():
|
||||||
|
|
||||||
def test_select_with_hanging_comma_multiple_tables():
|
def test_select_with_hanging_comma_multiple_tables():
|
||||||
tables = extract_tables("select a, from abc, def")
|
tables = extract_tables("select a, from abc, def")
|
||||||
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
|
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
|
||||||
|
|
||||||
|
|
||||||
def test_select_with_hanging_period_multiple_tables():
|
def test_select_with_hanging_period_multiple_tables():
|
||||||
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
|
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
|
||||||
assert set(tables) == set(
|
assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)}
|
||||||
[(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple_insert_single_table():
|
def test_simple_insert_single_table():
|
||||||
|
@ -126,14 +121,14 @@ def test_simple_update_table_with_schema():
|
||||||
|
|
||||||
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
|
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
|
||||||
def test_join_table(join_type):
|
def test_join_table(join_type):
|
||||||
sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type)
|
sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num"
|
||||||
tables = extract_tables(sql)
|
tables = extract_tables(sql)
|
||||||
assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)])
|
assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)}
|
||||||
|
|
||||||
|
|
||||||
def test_join_table_schema_qualified():
|
def test_join_table_schema_qualified():
|
||||||
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
|
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
|
||||||
assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)])
|
assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)}
|
||||||
|
|
||||||
|
|
||||||
def test_incomplete_join_clause():
|
def test_incomplete_join_clause():
|
||||||
|
@ -177,25 +172,25 @@ def test_extract_no_tables(text):
|
||||||
|
|
||||||
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
|
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
|
||||||
def test_simple_function_as_table(arg_list):
|
def test_simple_function_as_table(arg_list):
|
||||||
tables = extract_tables("SELECT * FROM foo({0})".format(arg_list))
|
tables = extract_tables(f"SELECT * FROM foo({arg_list})")
|
||||||
assert tables == ((None, "foo", None, True),)
|
assert tables == ((None, "foo", None, True),)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
|
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
|
||||||
def test_simple_schema_qualified_function_as_table(arg_list):
|
def test_simple_schema_qualified_function_as_table(arg_list):
|
||||||
tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list))
|
tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})")
|
||||||
assert tables == (("foo", "bar", None, True),)
|
assert tables == (("foo", "bar", None, True),)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
|
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
|
||||||
def test_simple_aliased_function_as_table(arg_list):
|
def test_simple_aliased_function_as_table(arg_list):
|
||||||
tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list))
|
tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar")
|
||||||
assert tables == ((None, "foo", "bar", True),)
|
assert tables == ((None, "foo", "bar", True),)
|
||||||
|
|
||||||
|
|
||||||
def test_simple_table_and_function():
|
def test_simple_table_and_function():
|
||||||
tables = extract_tables("SELECT * FROM foo JOIN bar()")
|
tables = extract_tables("SELECT * FROM foo JOIN bar()")
|
||||||
assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)])
|
assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)}
|
||||||
|
|
||||||
|
|
||||||
def test_complex_table_and_function():
|
def test_complex_table_and_function():
|
||||||
|
@ -203,9 +198,7 @@ def test_complex_table_and_function():
|
||||||
"""SELECT * FROM foo.bar baz
|
"""SELECT * FROM foo.bar baz
|
||||||
JOIN bar.qux(x, y, z) quux"""
|
JOIN bar.qux(x, y, z) quux"""
|
||||||
)
|
)
|
||||||
assert set(tables) == set(
|
assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)}
|
||||||
[("foo", "bar", "baz", False), ("bar", "qux", "quux", True)]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_find_prev_keyword_using():
|
def test_find_prev_keyword_using():
|
||||||
|
@ -267,3 +260,21 @@ def test_is_open_quote__closed(sql):
|
||||||
)
|
)
|
||||||
def test_is_open_quote__open(sql):
|
def test_is_open_quote__open(sql):
|
||||||
assert is_open_quote(sql)
|
assert is_open_quote(sql)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("sql", "warning_level", "expected"),
|
||||||
|
[
|
||||||
|
("update abc set x = 1", "all", True),
|
||||||
|
("update abc set x = 1 where y = 2", "all", True),
|
||||||
|
("update abc set x = 1", "moderate", True),
|
||||||
|
("update abc set x = 1 where y = 2", "moderate", False),
|
||||||
|
("select x, y, z from abc", "all", False),
|
||||||
|
("drop abc", "all", True),
|
||||||
|
("alter abc", "all", True),
|
||||||
|
("delete abc", "all", True),
|
||||||
|
("truncate abc", "all", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_destructive(sql, warning_level, expected):
|
||||||
|
assert is_destructive(sql, warning_level=warning_level) == expected
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import time
|
import time
|
||||||
import pytest
|
import pytest
|
||||||
from mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -37,7 +37,7 @@ def test_refresh_called_once(refresher):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
callbacks = Mock()
|
callbacks = Mock()
|
||||||
pgexecute = Mock()
|
pgexecute = Mock(**{"is_virtual_database.return_value": False})
|
||||||
special = Mock()
|
special = Mock()
|
||||||
|
|
||||||
with patch.object(refresher, "_bg_refresh") as bg_refresh:
|
with patch.object(refresher, "_bg_refresh") as bg_refresh:
|
||||||
|
@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher):
|
||||||
"""
|
"""
|
||||||
callbacks = Mock()
|
callbacks = Mock()
|
||||||
|
|
||||||
pgexecute = Mock()
|
pgexecute = Mock(**{"is_virtual_database.return_value": False})
|
||||||
special = Mock()
|
special = Mock()
|
||||||
|
|
||||||
def dummy_bg_refresh(*args):
|
def dummy_bg_refresh(*args):
|
||||||
|
@ -84,12 +84,10 @@ def test_refresh_with_callbacks(refresher):
|
||||||
:param refresher:
|
:param refresher:
|
||||||
"""
|
"""
|
||||||
callbacks = [Mock()]
|
callbacks = [Mock()]
|
||||||
pgexecute_class = Mock()
|
pgexecute = Mock(**{"is_virtual_database.return_value": False})
|
||||||
pgexecute = Mock()
|
|
||||||
pgexecute.extra_args = {}
|
pgexecute.extra_args = {}
|
||||||
special = Mock()
|
special = Mock()
|
||||||
|
|
||||||
with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
|
|
||||||
# Set refreshers to 0: we're not testing refresh logic here
|
# Set refreshers to 0: we're not testing refresh logic here
|
||||||
refresher.refreshers = {}
|
refresher.refreshers = {}
|
||||||
refresher.refresh(pgexecute, special, callbacks)
|
refresher.refresh(pgexecute, special, callbacks)
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pgcli.config import ensure_dir_exists
|
from pgcli.config import ensure_dir_exists, skip_initial_comment
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_file_parent(tmpdir):
|
def test_ensure_file_parent(tmpdir):
|
||||||
|
@ -20,7 +21,7 @@ def test_ensure_existing_dir(tmpdir):
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_other_create_error(tmpdir):
|
def test_ensure_other_create_error(tmpdir):
|
||||||
subdir = tmpdir.join("subdir")
|
subdir = tmpdir.join('subdir"')
|
||||||
rcfile = subdir.join("rcfile")
|
rcfile = subdir.join("rcfile")
|
||||||
|
|
||||||
# trigger an oserror that isn't "directory already exists"
|
# trigger an oserror that isn't "directory already exists"
|
||||||
|
@ -28,3 +29,15 @@ def test_ensure_other_create_error(tmpdir):
|
||||||
|
|
||||||
with pytest.raises(OSError):
|
with pytest.raises(OSError):
|
||||||
ensure_dir_exists(str(rcfile))
|
ensure_dir_exists(str(rcfile))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text, skipped_lines",
|
||||||
|
(
|
||||||
|
("abc\n", 1),
|
||||||
|
("#[section]\ndef\n[section]", 2),
|
||||||
|
("[section]", 0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_skip_initial_comment(text, skipped_lines):
|
||||||
|
assert skip_initial_comment(io.StringIO(text)) == skipped_lines
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -288,7 +288,12 @@ def test_pg_service_file(tmpdir):
|
||||||
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
|
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
|
||||||
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
|
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
|
||||||
service_conf.write(
|
service_conf.write(
|
||||||
"""[myservice]
|
"""File begins with a comment
|
||||||
|
that is not a comment
|
||||||
|
# or maybe a comment after all
|
||||||
|
because psql is crazy
|
||||||
|
|
||||||
|
[myservice]
|
||||||
host=a_host
|
host=a_host
|
||||||
user=a_user
|
user=a_user
|
||||||
port=5433
|
port=5433
|
||||||
|
|
|
@ -13,7 +13,7 @@ def completer():
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def complete_event():
|
def complete_event():
|
||||||
from mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
return Mock()
|
return Mock()
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import pytest
|
import pytest
|
||||||
from mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
from pgspecial.main import PGSpecial, NO_QUERY
|
from pgspecial.main import PGSpecial, NO_QUERY
|
||||||
from utils import run, dbtest, requires_json, requires_jsonb
|
from utils import run, dbtest, requires_json, requires_jsonb
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def test_expanded_slash_G(executor, pgspecial):
|
||||||
# Tests whether we reset the expanded output after a \G.
|
# Tests whether we reset the expanded output after a \G.
|
||||||
run(executor, """create table test(a boolean)""")
|
run(executor, """create table test(a boolean)""")
|
||||||
run(executor, """insert into test values(True)""")
|
run(executor, """insert into test values(True)""")
|
||||||
results = run(executor, """select * from test \G""", pgspecial=pgspecial)
|
results = run(executor, r"""select * from test \G""", pgspecial=pgspecial)
|
||||||
assert pgspecial.expanded_output == False
|
assert pgspecial.expanded_output == False
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,31 +105,35 @@ def test_schemata_table_views_and_columns_query(executor):
|
||||||
# schemata
|
# schemata
|
||||||
# don't enforce all members of the schemas since they may include postgres
|
# don't enforce all members of the schemas since they may include postgres
|
||||||
# temporary schemas
|
# temporary schemas
|
||||||
assert set(executor.schemata()) >= set(
|
assert set(executor.schemata()) >= {
|
||||||
["public", "pg_catalog", "information_schema", "schema1", "schema2"]
|
"public",
|
||||||
)
|
"pg_catalog",
|
||||||
|
"information_schema",
|
||||||
|
"schema1",
|
||||||
|
"schema2",
|
||||||
|
}
|
||||||
assert executor.search_path() == ["pg_catalog", "public"]
|
assert executor.search_path() == ["pg_catalog", "public"]
|
||||||
|
|
||||||
# tables
|
# tables
|
||||||
assert set(executor.tables()) >= set(
|
assert set(executor.tables()) >= {
|
||||||
[("public", "a"), ("public", "b"), ("schema1", "c")]
|
("public", "a"),
|
||||||
)
|
("public", "b"),
|
||||||
|
("schema1", "c"),
|
||||||
|
}
|
||||||
|
|
||||||
assert set(executor.table_columns()) >= set(
|
assert set(executor.table_columns()) >= {
|
||||||
[
|
|
||||||
("public", "a", "x", "text", False, None),
|
("public", "a", "x", "text", False, None),
|
||||||
("public", "a", "y", "text", False, None),
|
("public", "a", "y", "text", False, None),
|
||||||
("public", "b", "z", "text", False, None),
|
("public", "b", "z", "text", False, None),
|
||||||
("schema1", "c", "w", "text", True, "'meow'::text"),
|
("schema1", "c", "w", "text", True, "'meow'::text"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
# views
|
# views
|
||||||
assert set(executor.views()) >= set([("public", "d")])
|
assert set(executor.views()) >= {("public", "d")}
|
||||||
|
|
||||||
assert set(executor.view_columns()) >= set(
|
assert set(executor.view_columns()) >= {
|
||||||
[("public", "d", "e", "integer", False, None)]
|
("public", "d", "e", "integer", False, None)
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@dbtest
|
@dbtest
|
||||||
|
@ -142,9 +146,9 @@ def test_foreign_key_query(executor):
|
||||||
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
|
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert set(executor.foreignkeys()) >= set(
|
assert set(executor.foreignkeys()) >= {
|
||||||
[("schema1", "parent", "parentid", "schema2", "child", "motherid")]
|
("schema1", "parent", "parentid", "schema2", "child", "motherid")
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@dbtest
|
@dbtest
|
||||||
|
@ -175,8 +179,7 @@ def test_functions_query(executor):
|
||||||
)
|
)
|
||||||
|
|
||||||
funcs = set(executor.functions())
|
funcs = set(executor.functions())
|
||||||
assert funcs >= set(
|
assert funcs >= {
|
||||||
[
|
|
||||||
function_meta_data(func_name="func1", return_type="integer"),
|
function_meta_data(func_name="func1", return_type="integer"),
|
||||||
function_meta_data(
|
function_meta_data(
|
||||||
func_name="func3",
|
func_name="func3",
|
||||||
|
@ -197,8 +200,7 @@ def test_functions_query(executor):
|
||||||
function_meta_data(
|
function_meta_data(
|
||||||
schema_name="schema1", func_name="func2", return_type="integer"
|
schema_name="schema1", func_name="func2", return_type="integer"
|
||||||
),
|
),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dbtest
|
@dbtest
|
||||||
|
@ -257,8 +259,8 @@ def test_not_is_special(executor, pgspecial):
|
||||||
|
|
||||||
@dbtest
|
@dbtest
|
||||||
def test_execute_from_file_no_arg(executor, pgspecial):
|
def test_execute_from_file_no_arg(executor, pgspecial):
|
||||||
"""\i without a filename returns an error."""
|
r"""\i without a filename returns an error."""
|
||||||
result = list(executor.run("\i", pgspecial=pgspecial))
|
result = list(executor.run(r"\i", pgspecial=pgspecial))
|
||||||
status, sql, success, is_special = result[0][3:]
|
status, sql, success, is_special = result[0][3:]
|
||||||
assert "missing required argument" in status
|
assert "missing required argument" in status
|
||||||
assert success == False
|
assert success == False
|
||||||
|
@ -268,12 +270,12 @@ def test_execute_from_file_no_arg(executor, pgspecial):
|
||||||
@dbtest
|
@dbtest
|
||||||
@patch("pgcli.main.os")
|
@patch("pgcli.main.os")
|
||||||
def test_execute_from_file_io_error(os, executor, pgspecial):
|
def test_execute_from_file_io_error(os, executor, pgspecial):
|
||||||
"""\i with an io_error returns an error."""
|
r"""\i with an os_error returns an error."""
|
||||||
# Inject an IOError.
|
# Inject an OSError.
|
||||||
os.path.expanduser.side_effect = IOError("test")
|
os.path.expanduser.side_effect = OSError("test")
|
||||||
|
|
||||||
# Check the result.
|
# Check the result.
|
||||||
result = list(executor.run("\i test", pgspecial=pgspecial))
|
result = list(executor.run(r"\i test", pgspecial=pgspecial))
|
||||||
status, sql, success, is_special = result[0][3:]
|
status, sql, success, is_special = result[0][3:]
|
||||||
assert status == "test"
|
assert status == "test"
|
||||||
assert success == False
|
assert success == False
|
||||||
|
@ -290,7 +292,7 @@ def test_multiple_queries_same_line(executor):
|
||||||
|
|
||||||
@dbtest
|
@dbtest
|
||||||
def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
|
def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
|
||||||
result = run(executor, "select 'foo'; \d", pgspecial=pgspecial)
|
result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial)
|
||||||
assert len(result) == 11 # 2 * (output+status) * 3 lines
|
assert len(result) == 11 # 2 * (output+status) * 3 lines
|
||||||
assert "foo" in result[3]
|
assert "foo" in result[3]
|
||||||
# This is a lame check. :(
|
# This is a lame check. :(
|
||||||
|
@ -408,7 +410,7 @@ def test_date_time_types(executor):
|
||||||
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
|
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
|
||||||
def test_large_numbers_render_directly(executor, value):
|
def test_large_numbers_render_directly(executor, value):
|
||||||
run(executor, "create table numbertest(a numeric)")
|
run(executor, "create table numbertest(a numeric)")
|
||||||
run(executor, "insert into numbertest (a) values ({0})".format(value))
|
run(executor, f"insert into numbertest (a) values ({value})")
|
||||||
assert value in run(executor, "select * from numbertest", join=True)
|
assert value in run(executor, "select * from numbertest", join=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -511,13 +513,28 @@ def test_short_host(executor):
|
||||||
assert executor.short_host == "localhost1"
|
assert executor.short_host == "localhost1"
|
||||||
|
|
||||||
|
|
||||||
class BrokenConnection(object):
|
class BrokenConnection:
|
||||||
"""Mock a connection that failed."""
|
"""Mock a connection that failed."""
|
||||||
|
|
||||||
def cursor(self):
|
def cursor(self):
|
||||||
raise psycopg2.InterfaceError("I'm broken!")
|
raise psycopg2.InterfaceError("I'm broken!")
|
||||||
|
|
||||||
|
|
||||||
|
class VirtualCursor:
|
||||||
|
"""Mock a cursor to virtual database like pgbouncer."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.protocol_error = False
|
||||||
|
self.protocol_message = ""
|
||||||
|
self.description = None
|
||||||
|
self.status = None
|
||||||
|
self.statusmessage = "Error"
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
self.protocol_error = True
|
||||||
|
self.protocol_message = "Command not supported"
|
||||||
|
|
||||||
|
|
||||||
@dbtest
|
@dbtest
|
||||||
def test_exit_without_active_connection(executor):
|
def test_exit_without_active_connection(executor):
|
||||||
quit_handler = MagicMock()
|
quit_handler = MagicMock()
|
||||||
|
@ -540,3 +557,12 @@ def test_exit_without_active_connection(executor):
|
||||||
# an exception should be raised when running a query without active connection
|
# an exception should be raised when running a query without active connection
|
||||||
with pytest.raises(psycopg2.InterfaceError):
|
with pytest.raises(psycopg2.InterfaceError):
|
||||||
run(executor, "select 1", pgspecial=pgspecial)
|
run(executor, "select 1", pgspecial=pgspecial)
|
||||||
|
|
||||||
|
|
||||||
|
@dbtest
|
||||||
|
def test_virtual_database(executor):
|
||||||
|
virtual_connection = MagicMock()
|
||||||
|
virtual_connection.cursor.return_value = VirtualCursor()
|
||||||
|
with patch.object(executor, "conn", virtual_connection):
|
||||||
|
result = run(executor, "select 1")
|
||||||
|
assert "Command not supported" in result
|
||||||
|
|
|
@ -13,12 +13,12 @@ from pgcli.packages.sqlcompletion import (
|
||||||
|
|
||||||
def test_slash_suggests_special():
|
def test_slash_suggests_special():
|
||||||
suggestions = suggest_type("\\", "\\")
|
suggestions = suggest_type("\\", "\\")
|
||||||
assert set(suggestions) == set([Special()])
|
assert set(suggestions) == {Special()}
|
||||||
|
|
||||||
|
|
||||||
def test_slash_d_suggests_special():
|
def test_slash_d_suggests_special():
|
||||||
suggestions = suggest_type("\\d", "\\d")
|
suggestions = suggest_type("\\d", "\\d")
|
||||||
assert set(suggestions) == set([Special()])
|
assert set(suggestions) == {Special()}
|
||||||
|
|
||||||
|
|
||||||
def test_dn_suggests_schemata():
|
def test_dn_suggests_schemata():
|
||||||
|
@ -30,24 +30,24 @@ def test_dn_suggests_schemata():
|
||||||
|
|
||||||
|
|
||||||
def test_d_suggests_tables_views_and_schemas():
|
def test_d_suggests_tables_views_and_schemas():
|
||||||
suggestions = suggest_type("\d ", "\d ")
|
suggestions = suggest_type(r"\d ", r"\d ")
|
||||||
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
|
assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
|
||||||
|
|
||||||
suggestions = suggest_type("\d xxx", "\d xxx")
|
suggestions = suggest_type(r"\d xxx", r"\d xxx")
|
||||||
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
|
assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
|
||||||
|
|
||||||
|
|
||||||
def test_d_dot_suggests_schema_qualified_tables_or_views():
|
def test_d_dot_suggests_schema_qualified_tables_or_views():
|
||||||
suggestions = suggest_type("\d myschema.", "\d myschema.")
|
suggestions = suggest_type(r"\d myschema.", r"\d myschema.")
|
||||||
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
|
assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
|
||||||
|
|
||||||
suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx")
|
suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx")
|
||||||
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
|
assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
|
||||||
|
|
||||||
|
|
||||||
def test_df_suggests_schema_or_function():
|
def test_df_suggests_schema_or_function():
|
||||||
suggestions = suggest_type("\\df xxx", "\\df xxx")
|
suggestions = suggest_type("\\df xxx", "\\df xxx")
|
||||||
assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()])
|
assert set(suggestions) == {Function(schema=None, usage="special"), Schema()}
|
||||||
|
|
||||||
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
|
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
|
||||||
assert suggestions == (Function(schema="myschema", usage="special"),)
|
assert suggestions == (Function(schema="myschema", usage="special"),)
|
||||||
|
@ -63,7 +63,7 @@ def test_leading_whitespace_ok():
|
||||||
def test_dT_suggests_schema_or_datatypes():
|
def test_dT_suggests_schema_or_datatypes():
|
||||||
text = "\\dT "
|
text = "\\dT "
|
||||||
suggestions = suggest_type(text, text)
|
suggestions = suggest_type(text, text)
|
||||||
assert set(suggestions) == set([Schema(), Datatype(schema=None)])
|
assert set(suggestions) == {Schema(), Datatype(schema=None)}
|
||||||
|
|
||||||
|
|
||||||
def test_schema_qualified_dT_suggests_datatypes():
|
def test_schema_qualified_dT_suggests_datatypes():
|
||||||
|
|
|
@ -7,4 +7,4 @@ def test_confirm_destructive_query_notty():
|
||||||
stdin = click.get_text_stream("stdin")
|
stdin = click.get_text_stream("stdin")
|
||||||
if not stdin.isatty():
|
if not stdin.isatty():
|
||||||
sql = "drop database foo;"
|
sql = "drop database foo;"
|
||||||
assert confirm_destructive_query(sql) is None
|
assert confirm_destructive_query(sql, "all") is None
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from pgcli.main import PGCli
|
from pgcli.main import PGCli
|
||||||
|
|
||||||
|
|
|
@ -193,7 +193,7 @@ def test_suggested_joins(completer, query, tbl):
|
||||||
result = get_result(completer, query.format(tbl))
|
result = get_result(completer, query.format(tbl))
|
||||||
assert completions_to_set(result) == completions_to_set(
|
assert completions_to_set(result) == completions_to_set(
|
||||||
testdata.schemas_and_from_clause_items()
|
testdata.schemas_and_from_clause_items()
|
||||||
+ [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))]
|
+ [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -350,6 +350,36 @@ def test_schema_qualified_function_name(completer):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
|
||||||
|
def test_schema_qualified_function_name_after_from(completer):
|
||||||
|
text = "SELECT * FROM custom.set_r"
|
||||||
|
result = get_result(completer, text)
|
||||||
|
assert completions_to_set(result) == completions_to_set(
|
||||||
|
[
|
||||||
|
function("set_returning_func()", -len("func")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
|
||||||
|
def test_unqualified_function_name_not_returned(completer):
|
||||||
|
text = "SELECT * FROM set_r"
|
||||||
|
result = get_result(completer, text)
|
||||||
|
assert completions_to_set(result) == completions_to_set([])
|
||||||
|
|
||||||
|
|
||||||
|
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
|
||||||
|
def test_unqualified_function_name_in_search_path(completer):
|
||||||
|
completer.search_path = ["public", "custom"]
|
||||||
|
text = "SELECT * FROM set_r"
|
||||||
|
result = get_result(completer, text)
|
||||||
|
assert completions_to_set(result) == completions_to_set(
|
||||||
|
[
|
||||||
|
function("set_returning_func()", -len("func")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@parametrize("completer", completers(filtr=True, casing=False))
|
@parametrize("completer", completers(filtr=True, casing=False))
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"text",
|
"text",
|
||||||
|
|
|
@ -53,7 +53,7 @@ metadata = {
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata = dict((k, {"public": v}) for k, v in metadata.items())
|
metadata = {k: {"public": v} for k, v in metadata.items()}
|
||||||
|
|
||||||
testdata = MetaData(metadata)
|
testdata = MetaData(metadata)
|
||||||
|
|
||||||
|
@ -296,7 +296,7 @@ def test_suggested_cased_always_qualified_column_names(completer):
|
||||||
def test_suggested_column_names_in_function(completer):
|
def test_suggested_column_names_in_function(completer):
|
||||||
result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
|
result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
|
||||||
assert completions_to_set(result) == completions_to_set(
|
assert completions_to_set(result) == completions_to_set(
|
||||||
(testdata.columns_functions_and_keywords("users"))
|
testdata.columns_functions_and_keywords("users")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -316,7 +316,7 @@ def test_suggested_column_names_with_alias(completer):
|
||||||
def test_suggested_multiple_column_names(completer):
|
def test_suggested_multiple_column_names(completer):
|
||||||
result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
|
result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
|
||||||
assert completions_to_set(result) == completions_to_set(
|
assert completions_to_set(result) == completions_to_set(
|
||||||
(testdata.columns_functions_and_keywords("users"))
|
testdata.columns_functions_and_keywords("users")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,16 +23,14 @@ def cols_etc(
|
||||||
):
|
):
|
||||||
"""Returns the expected select-clause suggestions for a single-table
|
"""Returns the expected select-clause suggestions for a single-table
|
||||||
select."""
|
select."""
|
||||||
return set(
|
return {
|
||||||
[
|
|
||||||
Column(
|
Column(
|
||||||
table_refs=(TableReference(schema, table, alias, is_function),),
|
table_refs=(TableReference(schema, table, alias, is_function),),
|
||||||
qualifiable=True,
|
qualifiable=True,
|
||||||
),
|
),
|
||||||
Function(schema=parent),
|
Function(schema=parent),
|
||||||
Keyword(last_keyword),
|
Keyword(last_keyword),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_suggests_cols_with_visible_table_scope():
|
def test_select_suggests_cols_with_visible_table_scope():
|
||||||
|
@ -103,24 +101,20 @@ def test_where_equals_any_suggests_columns_or_keywords():
|
||||||
|
|
||||||
def test_lparen_suggests_cols_and_funcs():
|
def test_lparen_suggests_cols_and_funcs():
|
||||||
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
|
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
|
||||||
assert set(suggestion) == set(
|
assert set(suggestion) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
|
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword("("),
|
Keyword("("),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_suggests_cols_and_funcs():
|
def test_select_suggests_cols_and_funcs():
|
||||||
suggestions = suggest_type("SELECT ", "SELECT ")
|
suggestions = suggest_type("SELECT ", "SELECT ")
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=(), qualifiable=True),
|
Column(table_refs=(), qualifiable=True),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword("SELECT"),
|
Keyword("SELECT"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -128,13 +122,13 @@ def test_select_suggests_cols_and_funcs():
|
||||||
)
|
)
|
||||||
def test_suggests_tables_views_and_schemas(expression):
|
def test_suggests_tables_views_and_schemas(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
assert set(suggestions) == set([Table(schema=None), View(schema=None), Schema()])
|
assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expression", ["SELECT * FROM "])
|
@pytest.mark.parametrize("expression", ["SELECT * FROM "])
|
||||||
def test_suggest_tables_views_schemas_and_functions(expression):
|
def test_suggest_tables_views_schemas_and_functions(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
|
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -147,9 +141,11 @@ def test_suggest_tables_views_schemas_and_functions(expression):
|
||||||
def test_suggest_after_join_with_two_tables(expression):
|
def test_suggest_after_join_with_two_tables(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
|
tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[FromClauseItem(schema=None, table_refs=tables), Join(tables, None), Schema()]
|
FromClauseItem(schema=None, table_refs=tables),
|
||||||
)
|
Join(tables, None),
|
||||||
|
Schema(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -158,13 +154,11 @@ def test_suggest_after_join_with_two_tables(expression):
|
||||||
def test_suggest_after_join_with_one_table(expression):
|
def test_suggest_after_join_with_one_table(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
tables = ((None, "foo", None, False),)
|
tables = ((None, "foo", None, False),)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
FromClauseItem(schema=None, table_refs=tables),
|
FromClauseItem(schema=None, table_refs=tables),
|
||||||
Join(((None, "foo", None, False),), None),
|
Join(((None, "foo", None, False),), None),
|
||||||
Schema(),
|
Schema(),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -172,13 +166,13 @@ def test_suggest_after_join_with_one_table(expression):
|
||||||
)
|
)
|
||||||
def test_suggest_qualified_tables_and_views(expression):
|
def test_suggest_qualified_tables_and_views(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
|
assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expression", ["UPDATE sch."])
|
@pytest.mark.parametrize("expression", ["UPDATE sch."])
|
||||||
def test_suggest_qualified_aliasable_tables_and_views(expression):
|
def test_suggest_qualified_aliasable_tables_and_views(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
|
assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -193,26 +187,27 @@ def test_suggest_qualified_aliasable_tables_and_views(expression):
|
||||||
)
|
)
|
||||||
def test_suggest_qualified_tables_views_and_functions(expression):
|
def test_suggest_qualified_tables_views_and_functions(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
assert set(suggestions) == set([FromClauseItem(schema="sch")])
|
assert set(suggestions) == {FromClauseItem(schema="sch")}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
|
@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
|
||||||
def test_suggest_qualified_tables_views_functions_and_joins(expression):
|
def test_suggest_qualified_tables_views_functions_and_joins(expression):
|
||||||
suggestions = suggest_type(expression, expression)
|
suggestions = suggest_type(expression, expression)
|
||||||
tbls = tuple([(None, "foo", None, False)])
|
tbls = tuple([(None, "foo", None, False)])
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch")]
|
FromClauseItem(schema="sch", table_refs=tbls),
|
||||||
)
|
Join(tbls, "sch"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_truncate_suggests_tables_and_schemas():
|
def test_truncate_suggests_tables_and_schemas():
|
||||||
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
|
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
|
||||||
assert set(suggestions) == set([Table(schema=None), Schema()])
|
assert set(suggestions) == {Table(schema=None), Schema()}
|
||||||
|
|
||||||
|
|
||||||
def test_truncate_suggests_qualified_tables():
|
def test_truncate_suggests_qualified_tables():
|
||||||
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
|
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
|
||||||
assert set(suggestions) == set([Table(schema="sch")])
|
assert set(suggestions) == {Table(schema="sch")}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -220,13 +215,11 @@ def test_truncate_suggests_qualified_tables():
|
||||||
)
|
)
|
||||||
def test_distinct_suggests_cols(text):
|
def test_distinct_suggests_cols(text):
|
||||||
suggestions = suggest_type(text, text)
|
suggestions = suggest_type(text, text)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=(), local_tables=(), qualifiable=True),
|
Column(table_refs=(), local_tables=(), qualifiable=True),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword("DISTINCT"),
|
Keyword("DISTINCT"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -244,8 +237,7 @@ def test_distinct_and_order_by_suggestions_with_aliases(
|
||||||
text, text_before, last_keyword
|
text, text_before, last_keyword
|
||||||
):
|
):
|
||||||
suggestions = suggest_type(text, text_before)
|
suggestions = suggest_type(text, text_before)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(
|
Column(
|
||||||
table_refs=(
|
table_refs=(
|
||||||
TableReference(None, "tbl", "x", False),
|
TableReference(None, "tbl", "x", False),
|
||||||
|
@ -256,8 +248,7 @@ def test_distinct_and_order_by_suggestions_with_aliases(
|
||||||
),
|
),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword(last_keyword),
|
Keyword(last_keyword),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -272,8 +263,7 @@ def test_distinct_and_order_by_suggestions_with_aliases(
|
||||||
)
|
)
|
||||||
def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
|
def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
|
||||||
suggestions = suggest_type(text, text_before)
|
suggestions = suggest_type(text, text_before)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(
|
Column(
|
||||||
table_refs=(TableReference(None, "tbl", "x", False),),
|
table_refs=(TableReference(None, "tbl", "x", False),),
|
||||||
local_tables=(),
|
local_tables=(),
|
||||||
|
@ -282,15 +272,13 @@ def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
|
||||||
Table(schema="x"),
|
Table(schema="x"),
|
||||||
View(schema="x"),
|
View(schema="x"),
|
||||||
Function(schema="x"),
|
Function(schema="x"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_function_arguments_with_alias_given():
|
def test_function_arguments_with_alias_given():
|
||||||
suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
|
suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
|
||||||
|
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(
|
Column(
|
||||||
table_refs=(TableReference(None, "tbl", "x", False),),
|
table_refs=(TableReference(None, "tbl", "x", False),),
|
||||||
local_tables=(),
|
local_tables=(),
|
||||||
|
@ -299,29 +287,26 @@ def test_function_arguments_with_alias_given():
|
||||||
Table(schema="x"),
|
Table(schema="x"),
|
||||||
View(schema="x"),
|
View(schema="x"),
|
||||||
Function(schema="x"),
|
Function(schema="x"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_col_comma_suggests_cols():
|
def test_col_comma_suggests_cols():
|
||||||
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
|
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
|
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword("SELECT"),
|
Keyword("SELECT"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_table_comma_suggests_tables_and_schemas():
|
def test_table_comma_suggests_tables_and_schemas():
|
||||||
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
|
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
|
||||||
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
|
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
|
||||||
|
|
||||||
|
|
||||||
def test_into_suggests_tables_and_schemas():
|
def test_into_suggests_tables_and_schemas():
|
||||||
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
|
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
|
||||||
assert set(suggestion) == set([Table(schema=None), View(schema=None), Schema()])
|
assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -357,14 +342,12 @@ def test_partially_typed_col_name_suggests_col_names():
|
||||||
|
|
||||||
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
|
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
|
||||||
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
|
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "tabl", None, False),)),
|
Column(table_refs=((None, "tabl", None, False),)),
|
||||||
Table(schema="tabl"),
|
Table(schema="tabl"),
|
||||||
View(schema="tabl"),
|
View(schema="tabl"),
|
||||||
Function(schema="tabl"),
|
Function(schema="tabl"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -378,14 +361,12 @@ def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
|
||||||
)
|
)
|
||||||
def test_dot_suggests_cols_of_an_alias(sql):
|
def test_dot_suggests_cols_of_an_alias(sql):
|
||||||
suggestions = suggest_type(sql, "SELECT t1.")
|
suggestions = suggest_type(sql, "SELECT t1.")
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Table(schema="t1"),
|
Table(schema="t1"),
|
||||||
View(schema="t1"),
|
View(schema="t1"),
|
||||||
Column(table_refs=((None, "tabl1", "t1", False),)),
|
Column(table_refs=((None, "tabl1", "t1", False),)),
|
||||||
Function(schema="t1"),
|
Function(schema="t1"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -399,28 +380,24 @@ def test_dot_suggests_cols_of_an_alias(sql):
|
||||||
)
|
)
|
||||||
def test_dot_suggests_cols_of_an_alias_where(sql):
|
def test_dot_suggests_cols_of_an_alias_where(sql):
|
||||||
suggestions = suggest_type(sql, sql)
|
suggestions = suggest_type(sql, sql)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Table(schema="t1"),
|
Table(schema="t1"),
|
||||||
View(schema="t1"),
|
View(schema="t1"),
|
||||||
Column(table_refs=((None, "tabl1", "t1", False),)),
|
Column(table_refs=((None, "tabl1", "t1", False),)),
|
||||||
Function(schema="t1"),
|
Function(schema="t1"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
|
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
|
||||||
suggestions = suggest_type(
|
suggestions = suggest_type(
|
||||||
"SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
|
"SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
|
||||||
)
|
)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "tabl2", "t2", False),)),
|
Column(table_refs=((None, "tabl2", "t2", False),)),
|
||||||
Table(schema="t2"),
|
Table(schema="t2"),
|
||||||
View(schema="t2"),
|
View(schema="t2"),
|
||||||
Function(schema="t2"),
|
Function(schema="t2"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -452,20 +429,18 @@ def test_sub_select_partial_text_suggests_keyword(expression):
|
||||||
def test_outer_table_reference_in_exists_subquery_suggests_columns():
|
def test_outer_table_reference_in_exists_subquery_suggests_columns():
|
||||||
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
|
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
|
||||||
suggestions = suggest_type(q, q)
|
suggestions = suggest_type(q, q)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "foo", "f", False),)),
|
Column(table_refs=((None, "foo", "f", False),)),
|
||||||
Table(schema="f"),
|
Table(schema="f"),
|
||||||
View(schema="f"),
|
View(schema="f"),
|
||||||
Function(schema="f"),
|
Function(schema="f"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
|
@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
|
||||||
def test_sub_select_table_name_completion(expression):
|
def test_sub_select_table_name_completion(expression):
|
||||||
suggestion = suggest_type(expression, expression)
|
suggestion = suggest_type(expression, expression)
|
||||||
assert set(suggestion) == set([FromClauseItem(schema=None), Schema()])
|
assert set(suggestion) == {FromClauseItem(schema=None), Schema()}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -478,22 +453,18 @@ def test_sub_select_table_name_completion(expression):
|
||||||
def test_sub_select_table_name_completion_with_outer_table(expression):
|
def test_sub_select_table_name_completion_with_outer_table(expression):
|
||||||
suggestion = suggest_type(expression, expression)
|
suggestion = suggest_type(expression, expression)
|
||||||
tbls = tuple([(None, "foo", None, False)])
|
tbls = tuple([(None, "foo", None, False)])
|
||||||
assert set(suggestion) == set(
|
assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
|
||||||
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sub_select_col_name_completion():
|
def test_sub_select_col_name_completion():
|
||||||
suggestions = suggest_type(
|
suggestions = suggest_type(
|
||||||
"SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
|
"SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
|
||||||
)
|
)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "abc", None, False),), qualifiable=True),
|
Column(table_refs=((None, "abc", None, False),), qualifiable=True),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword("SELECT"),
|
Keyword("SELECT"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
@pytest.mark.xfail
|
||||||
|
@ -508,25 +479,25 @@ def test_sub_select_dot_col_name_completion():
|
||||||
suggestions = suggest_type(
|
suggestions = suggest_type(
|
||||||
"SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
|
"SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
|
||||||
)
|
)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "tabl", "t", False),)),
|
Column(table_refs=((None, "tabl", "t", False),)),
|
||||||
Table(schema="t"),
|
Table(schema="t"),
|
||||||
View(schema="t"),
|
View(schema="t"),
|
||||||
Function(schema="t"),
|
Function(schema="t"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
|
@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
|
||||||
@pytest.mark.parametrize("tbl_alias", ("", "foo"))
|
@pytest.mark.parametrize("tbl_alias", ("", "foo"))
|
||||||
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
|
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
|
||||||
text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
|
text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN "
|
||||||
suggestion = suggest_type(text, text)
|
suggestion = suggest_type(text, text)
|
||||||
tbls = tuple([(None, "abc", tbl_alias or None, False)])
|
tbls = tuple([(None, "abc", tbl_alias or None, False)])
|
||||||
assert set(suggestion) == set(
|
assert set(suggestion) == {
|
||||||
[FromClauseItem(schema=None, table_refs=tbls), Schema(), Join(tbls, None)]
|
FromClauseItem(schema=None, table_refs=tbls),
|
||||||
)
|
Schema(),
|
||||||
|
Join(tbls, None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_left_join_with_comma():
|
def test_left_join_with_comma():
|
||||||
|
@ -535,9 +506,7 @@ def test_left_join_with_comma():
|
||||||
# tbls should also include (None, 'bar', 'b', False)
|
# tbls should also include (None, 'bar', 'b', False)
|
||||||
# but there's a bug with commas
|
# but there's a bug with commas
|
||||||
tbls = tuple([(None, "foo", "f", False)])
|
tbls = tuple([(None, "foo", "f", False)])
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
|
||||||
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -550,15 +519,13 @@ def test_left_join_with_comma():
|
||||||
def test_join_alias_dot_suggests_cols1(sql):
|
def test_join_alias_dot_suggests_cols1(sql):
|
||||||
suggestions = suggest_type(sql, sql)
|
suggestions = suggest_type(sql, sql)
|
||||||
tables = ((None, "abc", "a", False), (None, "def", "d", False))
|
tables = ((None, "abc", "a", False), (None, "def", "d", False))
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "abc", "a", False),)),
|
Column(table_refs=((None, "abc", "a", False),)),
|
||||||
Table(schema="a"),
|
Table(schema="a"),
|
||||||
View(schema="a"),
|
View(schema="a"),
|
||||||
Function(schema="a"),
|
Function(schema="a"),
|
||||||
JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
|
JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -570,14 +537,12 @@ def test_join_alias_dot_suggests_cols1(sql):
|
||||||
)
|
)
|
||||||
def test_join_alias_dot_suggests_cols2(sql):
|
def test_join_alias_dot_suggests_cols2(sql):
|
||||||
suggestion = suggest_type(sql, sql)
|
suggestion = suggest_type(sql, sql)
|
||||||
assert set(suggestion) == set(
|
assert set(suggestion) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "def", "d", False),)),
|
Column(table_refs=((None, "def", "d", False),)),
|
||||||
Table(schema="d"),
|
Table(schema="d"),
|
||||||
View(schema="d"),
|
View(schema="d"),
|
||||||
Function(schema="d"),
|
Function(schema="d"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -598,9 +563,10 @@ on """,
|
||||||
def test_on_suggests_aliases_and_join_conditions(sql):
|
def test_on_suggests_aliases_and_join_conditions(sql):
|
||||||
suggestions = suggest_type(sql, sql)
|
suggestions = suggest_type(sql, sql)
|
||||||
tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
|
tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("a", "b")))
|
JoinCondition(table_refs=tables, parent=None),
|
||||||
)
|
Alias(aliases=("a", "b")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -613,9 +579,10 @@ def test_on_suggests_aliases_and_join_conditions(sql):
|
||||||
def test_on_suggests_tables_and_join_conditions(sql):
|
def test_on_suggests_tables_and_join_conditions(sql):
|
||||||
suggestions = suggest_type(sql, sql)
|
suggestions = suggest_type(sql, sql)
|
||||||
tables = ((None, "abc", None, False), (None, "bcd", None, False))
|
tables = ((None, "abc", None, False), (None, "bcd", None, False))
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
|
JoinCondition(table_refs=tables, parent=None),
|
||||||
)
|
Alias(aliases=("abc", "bcd")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -640,9 +607,10 @@ def test_on_suggests_aliases_right_side(sql):
|
||||||
def test_on_suggests_tables_and_join_conditions_right_side(sql):
|
def test_on_suggests_tables_and_join_conditions_right_side(sql):
|
||||||
suggestions = suggest_type(sql, sql)
|
suggestions = suggest_type(sql, sql)
|
||||||
tables = ((None, "abc", None, False), (None, "bcd", None, False))
|
tables = ((None, "abc", None, False), (None, "bcd", None, False))
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
|
JoinCondition(table_refs=tables, parent=None),
|
||||||
)
|
Alias(aliases=("abc", "bcd")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -659,9 +627,9 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql):
|
||||||
)
|
)
|
||||||
def test_join_using_suggests_common_columns(text):
|
def test_join_using_suggests_common_columns(text):
|
||||||
tables = ((None, "abc", None, False), (None, "def", None, False))
|
tables = ((None, "abc", None, False), (None, "def", None, False))
|
||||||
assert set(suggest_type(text, text)) == set(
|
assert set(suggest_type(text, text)) == {
|
||||||
[Column(table_refs=tables, require_last_table=True)]
|
Column(table_refs=tables, require_last_table=True)
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_suggest_columns_after_multiple_joins():
|
def test_suggest_columns_after_multiple_joins():
|
||||||
|
@ -678,29 +646,27 @@ def test_2_statements_2nd_current():
|
||||||
suggestions = suggest_type(
|
suggestions = suggest_type(
|
||||||
"select * from a; select * from ", "select * from a; select * from "
|
"select * from a; select * from ", "select * from a; select * from "
|
||||||
)
|
)
|
||||||
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
|
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
|
||||||
|
|
||||||
suggestions = suggest_type(
|
suggestions = suggest_type(
|
||||||
"select * from a; select from b", "select * from a; select "
|
"select * from a; select from b", "select * from a; select "
|
||||||
)
|
)
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "b", None, False),), qualifiable=True),
|
Column(table_refs=((None, "b", None, False),), qualifiable=True),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword("SELECT"),
|
Keyword("SELECT"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
# Should work even if first statement is invalid
|
# Should work even if first statement is invalid
|
||||||
suggestions = suggest_type(
|
suggestions = suggest_type(
|
||||||
"select * from; select * from ", "select * from; select * from "
|
"select * from; select * from ", "select * from; select * from "
|
||||||
)
|
)
|
||||||
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
|
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
|
||||||
|
|
||||||
|
|
||||||
def test_2_statements_1st_current():
|
def test_2_statements_1st_current():
|
||||||
suggestions = suggest_type("select * from ; select * from b", "select * from ")
|
suggestions = suggest_type("select * from ; select * from b", "select * from ")
|
||||||
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
|
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
|
||||||
|
|
||||||
suggestions = suggest_type("select from a; select * from b", "select ")
|
suggestions = suggest_type("select from a; select * from b", "select ")
|
||||||
assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
|
assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
|
||||||
|
@ -711,7 +677,7 @@ def test_3_statements_2nd_current():
|
||||||
"select * from a; select * from ; select * from c",
|
"select * from a; select * from ; select * from c",
|
||||||
"select * from a; select * from ",
|
"select * from a; select * from ",
|
||||||
)
|
)
|
||||||
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
|
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
|
||||||
|
|
||||||
suggestions = suggest_type(
|
suggestions = suggest_type(
|
||||||
"select * from a; select from b; select * from c", "select * from a; select "
|
"select * from a; select from b; select * from c", "select * from a; select "
|
||||||
|
@ -768,13 +734,11 @@ SELECT * FROM qux;
|
||||||
)
|
)
|
||||||
def test_statements_in_function_body(text):
|
def test_statements_in_function_body(text):
|
||||||
suggestions = suggest_type(text, text[: text.find(" ") + 1])
|
suggestions = suggest_type(text, text[: text.find(" ") + 1])
|
||||||
assert set(suggestions) == set(
|
assert set(suggestions) == {
|
||||||
[
|
|
||||||
Column(table_refs=((None, "foo", None, False),), qualifiable=True),
|
Column(table_refs=((None, "foo", None, False),), qualifiable=True),
|
||||||
Function(schema=None),
|
Function(schema=None),
|
||||||
Keyword("SELECT"),
|
Keyword("SELECT"),
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
functions = [
|
functions = [
|
||||||
|
@ -799,13 +763,13 @@ SELECT 1 FROM foo;
|
||||||
@pytest.mark.parametrize("text", functions)
|
@pytest.mark.parametrize("text", functions)
|
||||||
def test_statements_with_cursor_after_function_body(text):
|
def test_statements_with_cursor_after_function_body(text):
|
||||||
suggestions = suggest_type(text, text[: text.find("; ") + 1])
|
suggestions = suggest_type(text, text[: text.find("; ") + 1])
|
||||||
assert set(suggestions) == set([Keyword(), Special()])
|
assert set(suggestions) == {Keyword(), Special()}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text", functions)
|
@pytest.mark.parametrize("text", functions)
|
||||||
def test_statements_with_cursor_before_function_body(text):
|
def test_statements_with_cursor_before_function_body(text):
|
||||||
suggestions = suggest_type(text, "")
|
suggestions = suggest_type(text, "")
|
||||||
assert set(suggestions) == set([Keyword(), Special()])
|
assert set(suggestions) == {Keyword(), Special()}
|
||||||
|
|
||||||
|
|
||||||
def test_create_db_with_template():
|
def test_create_db_with_template():
|
||||||
|
@ -813,14 +777,14 @@ def test_create_db_with_template():
|
||||||
"create database foo with template ", "create database foo with template "
|
"create database foo with template ", "create database foo with template "
|
||||||
)
|
)
|
||||||
|
|
||||||
assert set(suggestions) == set((Database(),))
|
assert set(suggestions) == {Database()}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
|
@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
|
||||||
def test_specials_included_for_initial_completion(initial_text):
|
def test_specials_included_for_initial_completion(initial_text):
|
||||||
suggestions = suggest_type(initial_text, initial_text)
|
suggestions = suggest_type(initial_text, initial_text)
|
||||||
|
|
||||||
assert set(suggestions) == set([Keyword(), Special()])
|
assert set(suggestions) == {Keyword(), Special()}
|
||||||
|
|
||||||
|
|
||||||
def test_drop_schema_qualified_table_suggests_only_tables():
|
def test_drop_schema_qualified_table_suggests_only_tables():
|
||||||
|
@ -843,25 +807,30 @@ def test_drop_schema_suggests_schemas():
|
||||||
|
|
||||||
@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
|
@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
|
||||||
def test_cast_operator_suggests_types(text):
|
def test_cast_operator_suggests_types(text):
|
||||||
assert set(suggest_type(text, text)) == set(
|
assert set(suggest_type(text, text)) == {
|
||||||
[Datatype(schema=None), Table(schema=None), Schema()]
|
Datatype(schema=None),
|
||||||
)
|
Table(schema=None),
|
||||||
|
Schema(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
|
"text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
|
||||||
)
|
)
|
||||||
def test_cast_operator_suggests_schema_qualified_types(text):
|
def test_cast_operator_suggests_schema_qualified_types(text):
|
||||||
assert set(suggest_type(text, text)) == set(
|
assert set(suggest_type(text, text)) == {
|
||||||
[Datatype(schema="bar"), Table(schema="bar")]
|
Datatype(schema="bar"),
|
||||||
)
|
Table(schema="bar"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_alter_column_type_suggests_types():
|
def test_alter_column_type_suggests_types():
|
||||||
q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
|
q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
|
||||||
assert set(suggest_type(q, q)) == set(
|
assert set(suggest_type(q, q)) == {
|
||||||
[Datatype(schema=None), Table(schema=None), Schema()]
|
Datatype(schema=None),
|
||||||
)
|
Table(schema=None),
|
||||||
|
Schema(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -880,9 +849,11 @@ def test_alter_column_type_suggests_types():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_identifier_suggests_types_in_parentheses(text):
|
def test_identifier_suggests_types_in_parentheses(text):
|
||||||
assert set(suggest_type(text, text)) == set(
|
assert set(suggest_type(text, text)) == {
|
||||||
[Datatype(schema=None), Table(schema=None), Schema()]
|
Datatype(schema=None),
|
||||||
)
|
Table(schema=None),
|
||||||
|
Schema(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -977,7 +948,7 @@ def test_ignore_leading_double_quotes(sql):
|
||||||
)
|
)
|
||||||
def test_column_keyword_suggests_columns(sql):
|
def test_column_keyword_suggests_columns(sql):
|
||||||
suggestions = suggest_type(sql, sql)
|
suggestions = suggest_type(sql, sql)
|
||||||
assert set(suggestions) == set([Column(table_refs=((None, "foo", None, False),))])
|
assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))}
|
||||||
|
|
||||||
|
|
||||||
def test_handle_unrecognized_kw_generously():
|
def test_handle_unrecognized_kw_generously():
|
||||||
|
|
|
@ -8,7 +8,7 @@ from os import getenv
|
||||||
POSTGRES_USER = getenv("PGUSER", "postgres")
|
POSTGRES_USER = getenv("PGUSER", "postgres")
|
||||||
POSTGRES_HOST = getenv("PGHOST", "localhost")
|
POSTGRES_HOST = getenv("PGHOST", "localhost")
|
||||||
POSTGRES_PORT = getenv("PGPORT", 5432)
|
POSTGRES_PORT = getenv("PGPORT", 5432)
|
||||||
POSTGRES_PASSWORD = getenv("PGPASSWORD", "")
|
POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres")
|
||||||
|
|
||||||
|
|
||||||
def db_connection(dbname=None):
|
def db_connection(dbname=None):
|
||||||
|
@ -73,7 +73,7 @@ def drop_tables(conn):
|
||||||
def run(
|
def run(
|
||||||
executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
|
executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
|
||||||
):
|
):
|
||||||
" Return string output for the sql to be run "
|
"Return string output for the sql to be run"
|
||||||
|
|
||||||
results = executor.run(sql, pgspecial, exception_formatter)
|
results = executor.run(sql, pgspecial, exception_formatter)
|
||||||
formatted = []
|
formatted = []
|
||||||
|
@ -89,7 +89,7 @@ def run(
|
||||||
|
|
||||||
|
|
||||||
def completions_to_set(completions):
|
def completions_to_set(completions):
|
||||||
return set(
|
return {
|
||||||
(completion.display_text, completion.display_meta_text)
|
(completion.display_text, completion.display_meta_text)
|
||||||
for completion in completions
|
for completion in completions
|
||||||
)
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue