1
0
Fork 0

Adding upstream version 3.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 19:48:22 +01:00
parent f2184ff4ed
commit ec5391b244
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
104 changed files with 15144 additions and 0 deletions

3
.coveragerc Normal file
View file

@ -0,0 +1,3 @@
[run]
parallel=True
source=pgcli

15
.editorconfig Normal file
View file

@ -0,0 +1,15 @@
# editorconfig.org
# Get your text editor plugin at:
# http://editorconfig.org/#download
root = true
[*]
charset = utf-8
end_of_line = lf
indent_size = 4
indent_style = space
insert_final_newline = true
trim_trailing_whitespace = true
[travis.yml]
indent_size = 2

0
.git-blame-ignore-revs Normal file
View file

9
.github/ISSUE_TEMPLATE.md vendored Normal file
View file

@ -0,0 +1,9 @@
## Description
<!--- Describe your problem as fully as you can. -->
## Your environment
<!-- This gives us some more context to work with. -->
- [ ] Please provide your OS and version information.
- [ ] Please provide your CLI version.
- [ ] What is the output of ``pip freeze`` command.

12
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View file

@ -0,0 +1,12 @@
## Description
<!--- Describe your changes in detail. -->
## Checklist
<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. -->
- [ ] I've added this contribution to the `changelog.rst`.
- [ ] I've added my name to the `AUTHORS` file (or it's already there).
<!-- We would appreciate if you comply with our code style guidelines. -->
- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code.
- [x] Please squash merge this pull request (uncheck if you'd like us to merge as multiple commits)

71
.gitignore vendored Normal file
View file

@ -0,0 +1,71 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.Python
env/
pyvenv/
build/
develop-eggs/
dist/
downloads/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
.pytest_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# PyCharm
.idea/
*.iml
# Vagrant
.vagrant/
# Generated Packages
*.deb
*.rpm
.vscode/
venv/

7
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,7 @@
repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
language_version: python3.7

51
.travis.yml Normal file
View file

@ -0,0 +1,51 @@
dist: xenial
sudo: required
language: python
python:
- "3.6"
- "3.7"
- "3.8"
- "3.9-dev"
before_install:
- which python
- which pip
- pip install -U setuptools
install:
- pip install --no-cache-dir .
- pip install -r requirements-dev.txt
- pip install keyrings.alt>=3.1
script:
- set -e
- coverage run --source pgcli -m py.test
- cd tests
- behave --no-capture
- cd ..
# check for changelog ReST compliance
- rst2html.py --halt=warning changelog.rst >/dev/null
# check for black code compliance, 3.6 only
- if [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then pip install black && black --check . ; else echo "Skipping black for $TRAVIS_PYTHON_VERSION"; fi
- set +e
after_success:
- coverage combine
- codecov
notifications:
webhooks:
urls:
- YOUR_WEBHOOK_URL
on_success: change # options: [always|never|change] default: always
on_failure: always # options: [always|never|change] default: always
on_start: false # default: false
services:
- postgresql
addons:
postgresql: "9.6"

120
AUTHORS Normal file
View file

@ -0,0 +1,120 @@
Many thanks to the following contributors.
Project Lead:
-------------
* Irina Truong
Core Devs:
----------
* Amjith Ramanujam
* Darik Gamble
* Stuart Quin
* Joakim Koljonen
* Daniel Rocco
* Karl-Aksel Puulmann
* Dick Marinus
Contributors:
-------------
* Brett
* Étienne BERSAC (bersace)
* Daniel Schwarz
* inkn
* Jonathan Slenders
* xalley
* TamasNo1
* François Pietka
* Michael Kaminsky
* Alexander Kukushkin
* Ludovic Gasc (GMLudo)
* Marc Abramowitz
* Nick Hahner
* Jay Zeng
* Dimitar Roustchev
* Dhaivat Pandit
* Matheus Rosa
* Ali Kargın
* Nathan Jhaveri
* David Celis
* Sven-Hendrik Haase
* Çağatay Yüksel
* Tiago Ribeiro
* Vignesh Anand
* Charlie Arnold
* dwalmsley
* Artur Dryomov
* rrampage
* while0pass
* Eric Workman
* xa
* Hans Roman
* Guewen Baconnier
* Dionysis Grigoropoulos
* Jacob Magnusson
* Johannes Hoff
* vinotheassassin
* Jacek Wielemborek
* Fabien Meghazi
* Manuel Barkhau
* Sergii V
* Emanuele Gaifas
* Owen Stephens
* Russell Davies
* AlexTes
* Hraban Luyat
* Jackson Popkin
* Gustavo Castro
* Alexander Schmolck
* Donnell Muse
* Andrew Speed
* Dmitry B
* Isank
* Marcin Sztolcman
* Bojan Delić
* Chris Vaughn
* Frederic Aoustin
* Pierre Giraud
* Andrew Kuchling
* Dan Clark
* Catherine Devlin
* Jason Ribeiro
* Rishi Ramraj
* Matthieu Guilbert
* Alexandr Korsak
* Saif Hakim
* Artur Balabanov
* Kenny Do
* Max Rothman
* Daniel Egger
* Ignacio Campabadal
* Mikhail Elovskikh (wronglink)
* Marcin Cieślak (saper)
* easteregg (verfriemelt-dot-org)
* Scott Brenstuhl (808sAndBR)
* Nathan Verzemnieks
* raylu
* Zhaolong Zhu
* Zane C. Bowers-Hadley
* Telmo "Trooper" (telmotrooper)
* Alexander Zawadzki
* Pablo A. Bianchi (pabloab)
* Sebastian Janko (sebojanko)
* Pedro Ferrari (petobens)
* Martin Matejek (mmtj)
* Jonas Jelten
* BrownShibaDog
* George Thomas(thegeorgeous)
* Yoni Nakache(lazydba247)
* Gantsev Denis
* Stephano Paraskeva
* Panos Mavrogiorgos (pmav99)
* Igor Kim (igorkim)
* Anthony DeBarros (anthonydb)
* Seungyong Kwak (GUIEEN)
* Tom Caruso (tomplex)
* Jan Brun Rasmussen (janbrunrasmussen)
* Kevin Marsh (kevinmarsh)
Creator:
--------
Amjith Ramanujam

178
DEVELOP.rst Normal file
View file

@ -0,0 +1,178 @@
Development Guide
-----------------
This is a guide for developers who would like to contribute to this project.
GitHub Workflow
---------------
If you're interested in contributing to pgcli, first of all my heart felt
thanks. `Fork the project <https://github.com/dbcli/pgcli>`_ on github. Then
clone your fork into your computer (``git clone <url-for-your-fork>``). Make
the changes and create the commits in your local machine. Then push those
changes to your fork. Then click on the pull request icon on github and create
a new pull request. Add a description about the change and send it along. I
promise to review the pull request in a reasonable window of time and get back
to you.
In order to keep your fork up to date with any changes from mainline, add a new
git remote to your local copy called 'upstream' and point it to the main pgcli
repo.
::
$ git remote add upstream git@github.com:dbcli/pgcli.git
Once the 'upstream' end point is added you can then periodically do a ``git
pull upstream master`` to update your local copy and then do a ``git push
origin master`` to keep your own fork up to date.
Check Github's `Understanding the GitHub flow guide
<https://guides.github.com/introduction/flow/>`_ for a more detailed
explanation of this process.
Local Setup
-----------
The installation instructions in the README file are intended for users of
pgcli. If you're developing pgcli, you'll need to install it in a slightly
different way so you can see the effects of your changes right away without
having to go through the install cycle every time you change the code.
It is highly recommended to use virtualenv for development. If you don't know
what a virtualenv is, `this guide <http://docs.python-guide.org/en/latest/dev/virtualenvs/#virtual-environments>`_
will help you get started.
Create a virtualenv (let's call it pgcli-dev). Activate it:
::
source ./pgcli-dev/bin/activate
Once the virtualenv is activated, `cd` into the local clone of pgcli folder
and install pgcli using pip as follows:
::
$ pip install --editable .
or
$ pip install -e .
This will install the necessary dependencies as well as install pgcli from the
working folder into the virtualenv. By installing it using `pip install -e`
we've linked the pgcli installation with the working copy. Any changes made
to the code are immediately available in the installed version of pgcli. This
makes it easy to change something in the code, launch pgcli and check the
effects of your changes.
Adding PostgreSQL Special (Meta) Commands
-----------------------------------------
If you want to work on adding new meta-commands (such as `\dp`, `\ds`, `dy`),
you need to contribute to `pgspecial <https://github.com/dbcli/pgspecial/>`_
project.
Building RPM and DEB packages
-----------------------------
You will need Vagrant 1.7.2 or higher. In the project root there is a
Vagrantfile that is setup to do multi-vm provisioning. If you're setting things
up for the first time, then do:
::
$ version=x.y.z vagrant up debian
$ version=x.y.z vagrant up centos
If you already have those VMs setup and you're merely creating a new version of
DEB or RPM package, then you can do:
::
$ version=x.y.z vagrant provision
That will create a .deb file and a .rpm file.
The deb package can be installed as follows:
::
$ sudo dpkg -i pgcli*.deb # if dependencies are available.
or
$ sudo apt-get install -f pgcli*.deb # if dependencies are not available.
The rpm package can be installed as follows:
::
$ sudo yum install pgcli*.rpm
Running the integration tests
-----------------------------
Integration tests use `behave package <https://behave.readthedocs.io/>`_ and
pytest.
Configuration settings for this package are provided via a ``behave.ini`` file
in the ``tests`` directory. An example::
[behave]
stderr_capture = false
[behave.userdata]
pg_test_user = dbuser
pg_test_host = db.example.com
pg_test_port = 30000
First, install the requirements for testing:
::
$ pip install -r requirements-dev.txt
Ensure that the database user has permissions to create and drop test databases
by checking your ``pg_hba.conf`` file. The default user should be ``postgres``
at ``localhost``. Make sure the authentication method is set to ``trust``. If
you made any changes to your ``pg_hba.conf`` make sure to restart the postgres
service for the changes to take effect.
::
# ONLY IF YOU MADE CHANGES TO YOUR pg_hba.conf FILE
$ sudo service postgresql restart
After that, tests in the ``/pgcli/tests`` directory can be run with:
::
# on directory /pgcli/tests
$ behave
And on the ``/pgcli`` directory:
::
# on directory /pgcli
$ py.test
To see stdout/stderr, use the following command:
::
$ behave --no-capture
Troubleshooting the integration tests
-------------------------------------
- Make sure postgres instance on localhost is running
- Check your ``pg_hba.conf`` file to verify local connections are enabled
- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information.
- Contact us on `gitter <https://gitter.im/dbcli/pgcli/>`_ or `file an issue <https://github.com/dbcli/pgcli/issues/new>`_.
Coding Style
------------
``pgcli`` uses `black <https://github.com/ambv/black>`_ to format the source code. Make sure to install black.

6
Dockerfile Normal file
View file

@ -0,0 +1,6 @@
FROM python:3.8
COPY . /app
RUN cd /app && pip install -e .
CMD pgcli

26
LICENSE.txt Normal file
View file

@ -0,0 +1,26 @@
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or
other materials provided with the distribution.
* Neither the name of the {organization} nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

2
MANIFEST.in Normal file
View file

@ -0,0 +1,2 @@
include LICENSE.txt AUTHORS changelog.rst
recursive-include tests *.py *.txt *.feature *.ini

372
README.rst Normal file
View file

@ -0,0 +1,372 @@
A REPL for Postgres
-------------------
|Build Status| |CodeCov| |PyPI| |Landscape| |Gitter|
This is a postgres client that does auto-completion and syntax highlighting.
Home Page: http://pgcli.com
MySQL Equivalent: http://mycli.net
.. image:: screenshots/pgcli.gif
.. image:: screenshots/image01.png
Quick Start
-----------
If you already know how to install python packages, then you can simply do:
::
$ pip install -U pgcli
or
$ sudo apt-get install pgcli # Only on Debian based Linux (e.g. Ubuntu, Mint, etc)
$ brew install pgcli # Only on macOS
If you don't know how to install python packages, please check the
`detailed instructions`_.
If you are restricted to using psycopg2 2.7.x then pip will try to install it from a binary. There are some known issues with the psycopg2 2.7 binary - see the `psycopg docs`_ for more information about this and how to force installation from source. psycopg2 2.8 has fixed these problems, and will build from source.
.. _`detailed instructions`: https://github.com/dbcli/pgcli#detailed-installation-instructions
.. _`psycopg docs`: http://initd.org/psycopg/docs/install.html#change-in-binary-packages-between-psycopg-2-7-and-2-8
Usage
-----
::
$ pgcli [database_name]
or
$ pgcli postgresql://[user[:password]@][netloc][:port][/dbname][?extra=value[&other=other-value]]
Examples:
::
$ pgcli local_database
$ pgcli postgres://amjith:pa$$w0rd@example.com:5432/app_db?sslmode=verify-ca&sslrootcert=/myrootcert
For more details:
::
$ pgcli --help
Usage: pgcli [OPTIONS] [DBNAME] [USERNAME]
Options:
-h, --host TEXT Host address of the postgres database.
-p, --port INTEGER Port number at which the postgres instance is
listening.
-U, --username TEXT Username to connect to the postgres database.
-u, --user TEXT Username to connect to the postgres database.
-W, --password Force password prompt.
-w, --no-password Never prompt for password.
--single-connection Do not use a separate connection for completions.
-v, --version Version of pgcli.
-d, --dbname TEXT database name to connect to.
--pgclirc PATH Location of pgclirc file.
-D, --dsn TEXT Use DSN configured into the [alias_dsn] section of
pgclirc file.
--list-dsn list of DSN configured into the [alias_dsn] section
of pgclirc file.
--row-limit INTEGER Set threshold for row limit prompt. Use 0 to disable
prompt.
--less-chatty Skip intro on startup and goodbye on exit.
--prompt TEXT Prompt format (Default: "\u@\h:\d> ").
--prompt-dsn TEXT Prompt format for connections using DSN aliases
(Default: "\u@\h:\d> ").
-l, --list list available databases, then exit.
--auto-vertical-output Automatically switch to vertical output mode if the
result is wider than the terminal width.
--warn / --no-warn Warn before running a destructive query.
--help Show this message and exit.
``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``).
The SSL-related environment variables are also supported, so if you need to connect a postgres database via ssl connection, you can set set environment like this:
::
export PGSSLMODE="verify-full"
export PGSSLCERT="/your-path-to-certs/client.crt"
export PGSSLKEY="/your-path-to-keys/client.key"
export PGSSLROOTCERT="/your-path-to-ca/ca.crt"
pgcli -h localhost -p 5432 -U username postgres
.. _environment variables: https://www.postgresql.org/docs/current/libpq-envars.html
Features
--------
The `pgcli` is written using prompt_toolkit_.
* Auto-completes as you type for SQL keywords as well as tables and
columns in the database.
* Syntax highlighting using Pygments.
* Smart-completion (enabled by default) will suggest context-sensitive
completion.
- ``SELECT * FROM <tab>`` will only show table names.
- ``SELECT * FROM users WHERE <tab>`` will only show column names.
* Primitive support for ``psql`` back-slash commands.
* Pretty prints tabular data.
.. _prompt_toolkit: https://github.com/jonathanslenders/python-prompt-toolkit
.. _tabulate: https://pypi.python.org/pypi/tabulate
Config
------
A config file is automatically created at ``~/.config/pgcli/config`` at first launch.
See the file itself for a description of all available options.
Contributions:
--------------
If you're interested in contributing to this project, first of all I would like
to extend my heartfelt gratitude. I've written a small doc to describe how to
get this running in a development setup.
https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst
Please feel free to reach out to me if you need help.
My email: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_
Detailed Installation Instructions:
-----------------------------------
macOS:
======
The easiest way to install pgcli is using Homebrew.
::
$ brew install pgcli
Done!
Alternatively, you can install ``pgcli`` as a python package using a package
manager called called ``pip``. You will need postgres installed on your system
for this to work.
In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html.
::
$ which pip
If it is installed then you can do:
::
$ pip install pgcli
If that fails due to permission issues, you might need to run the command with
sudo permissions.
::
$ sudo pip install pgcli
If pip is not installed check if easy_install is available on the system.
::
$ which easy_install
$ sudo easy_install pgcli
Linux:
======
In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html.
Check if pip is already available in your system.
::
$ which pip
If it doesn't exist, use your linux package manager to install `pip`. This
might look something like:
::
$ sudo apt-get install python-pip # Debian, Ubuntu, Mint etc
or
$ sudo yum install python-pip # RHEL, Centos, Fedora etc
``pgcli`` requires python-dev, libpq-dev and libevent-dev packages. You can
install these via your operating system package manager.
::
$ sudo apt-get install python-dev libpq-dev libevent-dev
or
$ sudo yum install python-devel postgresql-devel
Then you can install pgcli:
::
$ sudo pip install pgcli
Docker
======
Pgcli can be run from within Docker. This can be useful to try pgcli without
installing it, or any dependencies, system-wide.
To build the image:
::
$ docker build -t pgcli .
To create a container from the image:
::
$ docker run --rm -ti pgcli pgcli <ARGS>
To access postgresql databases listening on localhost, make sure to run the
docker in "host net mode". E.g. to access a database called "foo" on the
postgresql server running on localhost:5432 (the standard port):
::
$ docker run --rm -ti --net host pgcli pgcli -h localhost foo
To connect to a locally running instance over a unix socket, bind the socket to
the docker container:
::
$ docker run --rm -ti -v /var/run/postgres:/var/run/postgres pgcli pgcli foo
IPython
=======
Pgcli can be run from within `IPython <https://ipython.org>`_ console. When working on a query,
it may be useful to drop into a pgcli session without leaving the IPython console, iterate on a
query, then quit pgcli to find the query results in your IPython workspace.
Assuming you have IPython installed:
::
$ pip install ipython-sql
After that, run ipython and load the ``pgcli.magic`` extension:
::
$ ipython
In [1]: %load_ext pgcli.magic
Connect to a database and construct a query:
::
In [2]: %pgcli postgres://someone@localhost:5432/world
Connected: someone@world
someone@localhost:world> select * from city c where countrycode = 'USA' and population > 1000000;
+------+--------------+---------------+--------------+--------------+
| id | name | countrycode | district | population |
|------+--------------+---------------+--------------+--------------|
| 3793 | New York | USA | New York | 8008278 |
| 3794 | Los Angeles | USA | California | 3694820 |
| 3795 | Chicago | USA | Illinois | 2896016 |
| 3796 | Houston | USA | Texas | 1953631 |
| 3797 | Philadelphia | USA | Pennsylvania | 1517550 |
| 3798 | Phoenix | USA | Arizona | 1321045 |
| 3799 | San Diego | USA | California | 1223400 |
| 3800 | Dallas | USA | Texas | 1188580 |
| 3801 | San Antonio | USA | Texas | 1144646 |
+------+--------------+---------------+--------------+--------------+
SELECT 9
Time: 0.003s
Exit out of pgcli session with ``Ctrl + D`` and find the query results:
::
someone@localhost:world>
Goodbye!
9 rows affected.
Out[2]:
[(3793, u'New York', u'USA', u'New York', 8008278),
(3794, u'Los Angeles', u'USA', u'California', 3694820),
(3795, u'Chicago', u'USA', u'Illinois', 2896016),
(3796, u'Houston', u'USA', u'Texas', 1953631),
(3797, u'Philadelphia', u'USA', u'Pennsylvania', 1517550),
(3798, u'Phoenix', u'USA', u'Arizona', 1321045),
(3799, u'San Diego', u'USA', u'California', 1223400),
(3800, u'Dallas', u'USA', u'Texas', 1188580),
(3801, u'San Antonio', u'USA', u'Texas', 1144646)]
The results are available in special local variable ``_``, and can be assigned to a variable of your
choice:
::
In [3]: my_result = _
Pgcli only runs on Python3.6+ since 2.2.0, if you use an old version of Python,
you should use install ``pgcli <= 2.2.0``.
Thanks:
-------
A special thanks to `Jonathan Slenders <https://twitter.com/jonathan_s>`_ for
creating `Python Prompt Toolkit <http://github.com/jonathanslenders/python-prompt-toolkit>`_,
which is quite literally the backbone library, that made this app possible.
Jonathan has also provided valuable feedback and support during the development
of this app.
`Click <http://click.pocoo.org/>`_ is used for command line option parsing
and printing error messages.
Thanks to `psycopg <http://initd.org/psycopg/>`_ for providing a rock solid
interface to Postgres database.
Thanks to all the beta testers and contributors for your time and patience. :)
.. |Build Status| image:: https://api.travis-ci.org/dbcli/pgcli.svg?branch=master
:target: https://travis-ci.org/dbcli/pgcli
.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
:target: https://codecov.io/gh/dbcli/pgcli
:alt: Code coverage report
.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/master/landscape.svg?style=flat
:target: https://landscape.io/github/dbcli/pgcli/master
:alt: Code Health
.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg
:target: https://pypi.python.org/pypi/pgcli/
:alt: Latest Version
.. |Gitter| image:: https://badges.gitter.im/Join%20Chat.svg
:target: https://gitter.im/dbcli/pgcli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge
:alt: Gitter Chat

12
TODO Normal file
View file

@ -0,0 +1,12 @@
# vi: ft=vimwiki
* [ ] Add coverage.
* [ ] Refactor to sqlcompletion to consume the text from left to right and use a state machine to suggest cols or tables instead of relying on hacks.
* [ ] Add a few more special commands. (\l pattern, \dp, \ds, \dy, \z etc)
* [ ] Refactor pgspecial.py to a class.
* [ ] Show/hide docs for a statement using a keybinding.
* [ ] Check how to add the name of the table before printing the table.
* [ ] Add a new trigger for M-/ that does naive completion.
* [ ] New Feature List - Write the current version to config file. At launch if the version has changed, display the changelog between the two versions.
* [ ] Add a test for 'select * from custom.abc where custom.abc.' should suggest columns from abc.
* [ ] pgexecute columns(), tables() etc can be just cursors instead of fetchall()
* [ ] Add colorschemes in config file.

93
Vagrantfile vendored Normal file
View file

@ -0,0 +1,93 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure(2) do |config|
config.vm.synced_folder ".", "/pgcli"
pgcli_version = ENV['version']
pgcli_description = "Postgres CLI with autocompletion and syntax highlighting"
config.vm.define "debian" do |debian|
debian.vm.box = "chef/debian-7.8"
debian.vm.provision "shell", inline: <<-SHELL
echo "-> Building DEB on `lsb_release -s`"
sudo apt-get update
sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems
sudo easy_install pip
sudo pip install virtualenv virtualenv-tools
sudo gem install fpm
echo "-> Cleaning up old workspace"
rm -rf build
mkdir -p build/usr/share
virtualenv build/usr/share/pgcli
build/usr/share/pgcli/bin/pip install -U pip distribute
build/usr/share/pgcli/bin/pip uninstall -y distribute
build/usr/share/pgcli/bin/pip install /pgcli
echo "-> Cleaning Virtualenv"
cd build/usr/share/pgcli
virtualenv-tools --update-path /usr/share/pgcli > /dev/null
cd /home/vagrant/
echo "-> Removing compiled files"
find build -iname '*.pyc' -delete
find build -iname '*.pyo' -delete
echo "-> Creating PgCLI deb"
sudo fpm -t deb -s dir -C build -n pgcli -v #{pgcli_version} \
-a all \
-d libpq-dev \
-d python-dev \
-p /pgcli/ \
--after-install /pgcli/post-install \
--after-remove /pgcli/post-remove \
--url https://github.com/dbcli/pgcli \
--description "#{pgcli_description}" \
--license 'BSD'
SHELL
end
config.vm.define "centos" do |centos|
centos.vm.box = "chef/centos-7.0"
centos.vm.provision "shell", inline: <<-SHELL
#!/bin/bash
echo "-> Building RPM on `lsb_release -s`"
sudo yum install -y rpm-build gcc ruby-devel postgresql-devel python-devel rubygems
sudo easy_install pip
sudo pip install virtualenv virtualenv-tools
sudo gem install fpm
echo "-> Cleaning up old workspace"
rm -rf build
mkdir -p build/usr/share
virtualenv build/usr/share/pgcli
build/usr/share/pgcli/bin/pip install -U pip distribute
build/usr/share/pgcli/bin/pip uninstall -y distribute
build/usr/share/pgcli/bin/pip install /pgcli
echo "-> Cleaning Virtualenv"
cd build/usr/share/pgcli
virtualenv-tools --update-path /usr/share/pgcli > /dev/null
cd /home/vagrant/
echo "-> Removing compiled files"
find build -iname '*.pyc' -delete
find build -iname '*.pyo' -delete
echo "-> Creating PgCLI RPM"
echo $PATH
sudo /usr/local/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
-a all \
-d postgresql-devel \
-d python-devel \
-p /pgcli/ \
--after-install /pgcli/post-install \
--after-remove /pgcli/post-remove \
--url https://github.com/dbcli/pgcli \
--description "#{pgcli_description}" \
--license 'BSD'
SHELL
end
end

1064
changelog.rst Normal file

File diff suppressed because it is too large Load diff

61
pgcli-completion.bash Normal file
View file

@ -0,0 +1,61 @@
_pg_databases()
{
# -w was introduced in 8.4, https://launchpad.net/bugs/164772
# "Access privileges" in output may contain linefeeds, hence the NF > 1
COMPREPLY=( $( compgen -W "$( psql -AtqwlF $'\t' 2>/dev/null | \
awk 'NF > 1 { print $1 }' )" -- "$cur" ) )
}
_pg_users()
{
# -w was introduced in 8.4, https://launchpad.net/bugs/164772
COMPREPLY=( $( compgen -W "$( psql -Atqwc 'select usename from pg_user' \
template1 2>/dev/null )" -- "$cur" ) )
[[ ${#COMPREPLY[@]} -eq 0 ]] && COMPREPLY=( $( compgen -u -- "$cur" ) )
}
_pgcli()
{
local cur prev words cword
_init_completion -s || return
case $prev in
-h|--host)
_known_hosts_real "$cur"
return 0
;;
-U|--user)
_pg_users
return 0
;;
-d|--dbname)
_pg_databases
return 0
;;
--help|-v|--version|-p|--port|-R|--row-limit)
# all other arguments are noop with these
return 0
;;
esac
case "$cur" in
--*)
# return list of available options
COMPREPLY=( $( compgen -W '--host --port --user --password --no-password
--single-connection --version --dbname --pgclirc --dsn
--row-limit --help' -- "$cur" ) )
[[ $COMPREPLY == *= ]] && compopt -o nospace
return 0
;;
-)
# only complete long options
compopt -o nospace
COMPREPLY=( -- )
return 0
;;
*)
# return list of available databases
_pg_databases
esac
} &&
complete -F _pgcli pgcli

1
pgcli/__init__.py Normal file
View file

@ -0,0 +1 @@
__version__ = "3.1.0"

9
pgcli/__main__.py Normal file
View file

@ -0,0 +1,9 @@
"""
pgcli package main entry point
"""
from .main import cli
if __name__ == "__main__":
cli()

View file

@ -0,0 +1,150 @@
import threading
import os
from collections import OrderedDict
from .pgcompleter import PGCompleter
from .pgexecute import PGExecute
class CompletionRefresher(object):
refreshers = OrderedDict()
def __init__(self):
self._completer_thread = None
self._restart_refresh = threading.Event()
def refresh(self, executor, special, callbacks, history=None, settings=None):
"""
Creates a PGCompleter object and populates it with the relevant
completion suggestions in a background thread.
executor - PGExecute object, used to extract the credentials to connect
to the database.
special - PGSpecial object used for creating a new completion object.
settings - dict of settings for completer object
callbacks - A function or a list of functions to call after the thread
has completed the refresh. The newly created completion
object will be passed in as an argument to each callback.
"""
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]
else:
self._completer_thread = threading.Thread(
target=self._bg_refresh,
args=(executor, special, callbacks, history, settings),
name="completion_refresh",
)
self._completer_thread.setDaemon(True)
self._completer_thread.start()
return [
(None, None, None, "Auto-completion refresh started in the background.")
]
def is_refreshing(self):
return self._completer_thread and self._completer_thread.is_alive()
def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None):
settings = settings or {}
completer = PGCompleter(
smart_completion=True, pgspecial=special, settings=settings
)
if settings.get("single_connection"):
executor = pgexecute
else:
# Create a new pgexecute method to populate the completions.
executor = pgexecute.copy()
# If callbacks is a single function then push it into a list.
if callable(callbacks):
callbacks = [callbacks]
while 1:
for refresher in self.refreshers.values():
refresher(completer, executor)
if self._restart_refresh.is_set():
self._restart_refresh.clear()
break
else:
# Break out of while loop if the for loop finishes natually
# without hitting the break statement.
break
# Start over the refresh from the beginning if the for loop hit the
# break statement.
continue
# Load history into pgcompleter so it can learn user preferences
n_recent = 100
if history:
for recent in history.get_strings()[-n_recent:]:
completer.extend_query_history(recent, is_init=True)
for callback in callbacks:
callback(completer)
if not settings.get("single_connection") and executor.conn:
# close connection established with pgexecute.copy()
executor.conn.close()
def refresher(name, refreshers=CompletionRefresher.refreshers):
"""Decorator to populate the dictionary of refreshers with the current
function.
"""
def wrapper(wrapped):
refreshers[name] = wrapped
return wrapped
return wrapper
@refresher("schemata")
def refresh_schemata(completer, executor):
completer.set_search_path(executor.search_path())
completer.extend_schemata(executor.schemata())
@refresher("tables")
def refresh_tables(completer, executor):
completer.extend_relations(executor.tables(), kind="tables")
completer.extend_columns(executor.table_columns(), kind="tables")
completer.extend_foreignkeys(executor.foreignkeys())
@refresher("views")
def refresh_views(completer, executor):
completer.extend_relations(executor.views(), kind="views")
completer.extend_columns(executor.view_columns(), kind="views")
@refresher("types")
def refresh_types(completer, executor):
completer.extend_datatypes(executor.datatypes())
@refresher("databases")
def refresh_databases(completer, executor):
completer.extend_database_names(executor.databases())
@refresher("casing")
def refresh_casing(completer, executor):
casing_file = completer.casing_file
if not casing_file:
return
generate_casing_file = completer.generate_casing_file
if generate_casing_file and not os.path.isfile(casing_file):
casing_prefs = "\n".join(executor.casing())
with open(casing_file, "w") as f:
f.write(casing_prefs)
if os.path.isfile(casing_file):
with open(casing_file, "r") as f:
completer.extend_casing([line.strip() for line in f])
@refresher("functions")
def refresh_functions(completer, executor):
completer.extend_functions(executor.functions())

64
pgcli/config.py Normal file
View file

@ -0,0 +1,64 @@
import errno
import shutil
import os
import platform
from os.path import expanduser, exists, dirname
from configobj import ConfigObj
def config_location():
if "XDG_CONFIG_HOME" in os.environ:
return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"])
elif platform.system() == "Windows":
return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\"
else:
return expanduser("~/.config/pgcli/")
def load_config(usr_cfg, def_cfg=None):
cfg = ConfigObj()
cfg.merge(ConfigObj(def_cfg, interpolation=False))
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
cfg.filename = expanduser(usr_cfg)
return cfg
def ensure_dir_exists(path):
parent_dir = expanduser(dirname(path))
os.makedirs(parent_dir, exist_ok=True)
def write_default_config(source, destination, overwrite=False):
destination = expanduser(destination)
if not overwrite and exists(destination):
return
ensure_dir_exists(destination)
shutil.copyfile(source, destination)
def upgrade_config(config, def_config):
cfg = load_config(config, def_config)
cfg.write()
def get_config(pgclirc_file=None):
from pgcli import __file__ as package_root
package_root = os.path.dirname(package_root)
pgclirc_file = pgclirc_file or "%sconfig" % config_location()
default_config = os.path.join(package_root, "pgclirc")
write_default_config(default_config, pgclirc_file)
return load_config(pgclirc_file, default_config)
def get_casing_file(config):
casing_file = config["main"]["casing_file"]
if casing_file == "default":
casing_file = config_location() + "casing"
return casing_file

127
pgcli/key_bindings.py Normal file
View file

@ -0,0 +1,127 @@
import logging
from prompt_toolkit.enums import EditingMode
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.filters import (
completion_is_selected,
is_searching,
has_completions,
has_selection,
vi_mode,
)
from .pgbuffer import buffer_should_be_handled
_logger = logging.getLogger(__name__)
def pgcli_bindings(pgcli):
"""Custom key bindings for pgcli."""
kb = KeyBindings()
tab_insert_text = " " * 4
@kb.add("f2")
def _(event):
"""Enable/Disable SmartCompletion Mode."""
_logger.debug("Detected F2 key.")
pgcli.completer.smart_completion = not pgcli.completer.smart_completion
@kb.add("f3")
def _(event):
"""Enable/Disable Multiline Mode."""
_logger.debug("Detected F3 key.")
pgcli.multi_line = not pgcli.multi_line
@kb.add("f4")
def _(event):
"""Toggle between Vi and Emacs mode."""
_logger.debug("Detected F4 key.")
pgcli.vi_mode = not pgcli.vi_mode
event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS
@kb.add("tab")
def _(event):
"""Force autocompletion at cursor on non-empty lines."""
_logger.debug("Detected <Tab> key.")
buff = event.app.current_buffer
doc = buff.document
if doc.on_first_line or doc.current_line.strip():
if buff.complete_state:
buff.complete_next()
else:
buff.start_completion(select_first=True)
else:
buff.insert_text(tab_insert_text, fire_event=False)
@kb.add("escape", filter=has_completions)
def _(event):
"""Force closing of autocompletion."""
_logger.debug("Detected <Esc> key.")
event.current_buffer.complete_state = None
event.app.current_buffer.complete_state = None
@kb.add("c-space")
def _(event):
"""
Initialize autocompletion at cursor.
If the autocompletion menu is not showing, display it with the
appropriate completions for the context.
If the menu is showing, select the next completion.
"""
_logger.debug("Detected <C-Space> key.")
b = event.app.current_buffer
if b.complete_state:
b.complete_next()
else:
b.start_completion(select_first=False)
@kb.add("enter", filter=completion_is_selected)
def _(event):
"""Makes the enter key work as the tab key only when showing the menu.
In other words, don't execute query when enter is pressed in
the completion dropdown menu, instead close the dropdown menu
(accept current selection).
"""
_logger.debug("Detected enter key during completion selection.")
event.current_buffer.complete_state = None
event.app.current_buffer.complete_state = None
# When using multi_line input mode the buffer is not handled on Enter (a new line is
# inserted instead), so we force the handling if we're not in a completion or
# history search, and one of several conditions are True
@kb.add(
"enter",
filter=~(completion_is_selected | is_searching)
& buffer_should_be_handled(pgcli),
)
def _(event):
_logger.debug("Detected enter key.")
event.current_buffer.validate_and_handle()
@kb.add("escape", "enter", filter=~vi_mode)
def _(event):
"""Introduces a line break regardless of multi-line mode or not."""
_logger.debug("Detected alt-enter key.")
event.app.current_buffer.insert_text("\n")
@kb.add("c-p", filter=~has_selection)
def _(event):
"""Move up in history."""
event.current_buffer.history_backward(count=event.arg)
@kb.add("c-n", filter=~has_selection)
def _(event):
"""Move down in history."""
event.current_buffer.history_forward(count=event.arg)
return kb

67
pgcli/magic.py Normal file
View file

@ -0,0 +1,67 @@
from .main import PGCli
import sql.parse
import sql.connection
import logging
_logger = logging.getLogger(__name__)
def load_ipython_extension(ipython):
"""This is called via the ipython command '%load_ext pgcli.magic'"""
# first, load the sql magic if it isn't already loaded
if not ipython.find_line_magic("sql"):
ipython.run_line_magic("load_ext", "sql")
# register our own magic
ipython.register_magic_function(pgcli_line_magic, "line", "pgcli")
def pgcli_line_magic(line):
_logger.debug("pgcli magic called: %r", line)
parsed = sql.parse.parse(line, {})
# "get" was renamed to "set" in ipython-sql:
# https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43
if hasattr(sql.connection.Connection, "get"):
conn = sql.connection.Connection.get(parsed["connection"])
else:
conn = sql.connection.Connection.set(parsed["connection"])
try:
# A corresponding pgcli object already exists
pgcli = conn._pgcli
_logger.debug("Reusing existing pgcli")
except AttributeError:
# I can't figure out how to get the underylying psycopg2 connection
# from the sqlalchemy connection, so just grab the url and make a
# new connection
pgcli = PGCli()
u = conn.session.engine.url
_logger.debug("New pgcli: %r", str(u))
pgcli.connect(u.database, u.host, u.username, u.port, u.password)
conn._pgcli = pgcli
# For convenience, print the connection alias
print("Connected: {}".format(conn.name))
try:
pgcli.run_cli()
except SystemExit:
pass
if not pgcli.query_history:
return
q = pgcli.query_history[-1]
if not q.successful:
_logger.debug("Unsuccessful query - ignoring")
return
if q.meta_changed or q.db_changed or q.path_changed:
_logger.debug("Dangerous query detected -- ignoring")
return
ipython = get_ipython()
return ipython.run_cell_magic("sql", line, q.query)

1516
pgcli/main.py Normal file

File diff suppressed because it is too large Load diff

View file

View file

@ -0,0 +1,22 @@
import sqlparse
def query_starts_with(query, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
def queries_start_with(queries, prefixes):
"""Check if any queries start with any item from *prefixes*."""
for query in sqlparse.split(queries):
if query and query_starts_with(query, prefixes) is True:
return True
return False
def is_destructive(queries):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
return queries_start_with(queries, keywords)

View file

@ -0,0 +1,141 @@
from sqlparse import parse
from sqlparse.tokens import Keyword, CTE, DML
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
from collections import namedtuple
from .meta import TableMetadata, ColumnMetadata
# TableExpression is a namedtuple representing a CTE, used internally
# name: cte alias assigned in the query
# columns: list of column names
# start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple("TableExpression", "name columns start stop")
def isolate_query_ctes(full_text, text_before_cursor):
"""Simplify a query by converting CTEs into table metadata objects"""
if not full_text or not full_text.strip():
return full_text, text_before_cursor, tuple()
ctes, remainder = extract_ctes(full_text)
if not ctes:
return full_text, text_before_cursor, ()
current_position = len(text_before_cursor)
meta = []
for cte in ctes:
if cte.start < current_position < cte.stop:
# Currently editing a cte - treat its body as the current full_text
text_before_cursor = full_text[cte.start : current_position]
full_text = full_text[cte.start : cte.stop]
return full_text, text_before_cursor, meta
# Append this cte to the list of available table metadata
cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
meta.append(TableMetadata(cte.name, cols))
# Editing past the last cte (ie the main body of the query)
full_text = full_text[ctes[-1].stop :]
text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]
return full_text, text_before_cursor, tuple(meta)
def extract_ctes(sql):
"""Extract constant table expresseions from a query
Returns tuple (ctes, remainder_sql)
ctes is a list of TableExpression namedtuples
remainder_sql is the text from the original query after the CTEs have
been stripped.
"""
p = parse(sql)[0]
# Make sure the first meaningful token is "WITH" which is necessary to
# define CTEs
idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
if not (tok and tok.ttype == CTE):
return [], sql
# Get the next (meaningful) token, which should be the first CTE
idx, tok = p.token_next(idx)
if not tok:
return ([], "")
start_pos = token_start_pos(p.tokens, idx)
ctes = []
if isinstance(tok, IdentifierList):
# Multiple ctes
for t in tok.get_identifiers():
cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
cte = get_cte_from_token(t, start_pos + cte_start_offset)
if not cte:
continue
ctes.append(cte)
elif isinstance(tok, Identifier):
# A single CTE
cte = get_cte_from_token(tok, start_pos)
if cte:
ctes.append(cte)
idx = p.token_index(tok) + 1
# Collapse everything after the ctes into a remainder query
remainder = "".join(str(tok) for tok in p.tokens[idx:])
return ctes, remainder
def get_cte_from_token(tok, pos0):
cte_name = tok.get_real_name()
if not cte_name:
return None
# Find the start position of the opening parens enclosing the cte body
idx, parens = tok.token_next_by(Parenthesis)
if not parens:
return None
start_pos = pos0 + token_start_pos(tok.tokens, idx)
cte_len = len(str(parens)) # includes parens
stop_pos = start_pos + cte_len
column_names = extract_column_names(parens)
return TableExpression(cte_name, column_names, start_pos, stop_pos)
def extract_column_names(parsed):
# Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
idx, tok = parsed.token_next_by(t=DML)
tok_val = tok and tok.value.lower()
if tok_val in ("insert", "update", "delete"):
# Jump ahead to the RETURNING clause where the list of column names is
idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
elif not tok_val == "select":
# Must be invalid CTE
return ()
# The next token should be either a column name, or a list of column names
idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
return tuple(t.get_name() for t in _identifiers(tok))
def token_start_pos(tokens, idx):
return sum(len(str(t)) for t in tokens[:idx])
def _identifiers(tok):
if isinstance(tok, IdentifierList):
for t in tok.get_identifiers():
# NB: IdentifierList.get_identifiers() can return non-identifiers!
if isinstance(t, Identifier):
yield t
elif isinstance(tok, Identifier):
yield tok

View file

@ -0,0 +1,170 @@
from collections import namedtuple
_ColumnMetadata = namedtuple(
"ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
)
def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
ForeignKey = namedtuple(
"ForeignKey",
[
"parentschema",
"parenttable",
"parentcolumn",
"childschema",
"childtable",
"childcolumn",
],
)
TableMetadata = namedtuple("TableMetadata", "name columns")
def parse_defaults(defaults_string):
"""Yields default values for a function, given the string provided by
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
if not defaults_string:
return
current = ""
in_quote = None
for char in defaults_string:
if current == "" and char == " ":
# Skip space after comma separating default expressions
continue
if char == '"' or char == "'":
if in_quote and char == in_quote:
# End quote
in_quote = None
elif not in_quote:
# Begin quote
in_quote = char
elif char == "," and not in_quote:
# End of expression
yield current
current = ""
continue
current += char
yield current
class FunctionMetadata(object):
def __init__(
self,
schema_name,
func_name,
arg_names,
arg_types,
arg_modes,
return_type,
is_aggregate,
is_window,
is_set_returning,
is_extension,
arg_defaults,
):
"""Class for describing a postgresql function"""
self.schema_name = schema_name
self.func_name = func_name
self.arg_modes = tuple(arg_modes) if arg_modes else None
self.arg_names = tuple(arg_names) if arg_names else None
# Be flexible in not requiring arg_types -- use None as a placeholder
# for each arg. (Used for compatibility with old versions of postgresql
# where such info is hard to get.
if arg_types:
self.arg_types = tuple(arg_types)
elif arg_modes:
self.arg_types = tuple([None] * len(arg_modes))
elif arg_names:
self.arg_types = tuple([None] * len(arg_names))
else:
self.arg_types = None
self.arg_defaults = tuple(parse_defaults(arg_defaults))
self.return_type = return_type.strip()
self.is_aggregate = is_aggregate
self.is_window = is_window
self.is_set_returning = is_set_returning
self.is_extension = bool(is_extension)
self.is_public = self.schema_name and self.schema_name == "public"
def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
def _signature(self):
return (
self.schema_name,
self.func_name,
self.arg_names,
self.arg_types,
self.arg_modes,
self.return_type,
self.is_aggregate,
self.is_window,
self.is_set_returning,
self.is_extension,
self.arg_defaults,
)
def __hash__(self):
return hash(self._signature())
def __repr__(self):
return (
"%s(schema_name=%r, func_name=%r, arg_names=%r, "
"arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
"is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
) % ((self.__class__.__name__,) + self._signature())
def has_variadic(self):
return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
def args(self):
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
if not self.arg_names:
return []
modes = self.arg_modes or ["i"] * len(self.arg_names)
args = [
(name, typ)
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
]
def arg(name, typ, num):
num_args = len(args)
num_defaults = len(self.arg_defaults)
has_default = num + num_defaults >= num_args
default = (
self.arg_defaults[num - num_args + num_defaults]
if has_default
else None
)
return ColumnMetadata(name, typ, [], default, has_default)
return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
def fields(self):
"""Returns a list of output-field ColumnMetadata namedtuples"""
if self.return_type.lower() == "void":
return []
elif not self.arg_modes:
# For functions without output parameters, the function name
# is used as the name of the output column.
# E.g. 'SELECT unnest FROM unnest(...);'
return [ColumnMetadata(self.func_name, self.return_type, [])]
return [
ColumnMetadata(name, typ, [])
for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
if mode in ("o", "b", "t")
] # OUT, INOUT, TABLE

View file

@ -0,0 +1,170 @@
import sqlparse
from collections import namedtuple
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation
TableReference = namedtuple(
"TableReference", ["schema", "name", "alias", "is_function"]
)
TableReference.ref = property(
lambda self: self.alias
or (
self.name
if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'
)
)
# This code is borrowed from sqlparse example script.
# <url>
def is_subselect(parsed):
if not parsed.is_group:
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in (
"SELECT",
"INSERT",
"UPDATE",
"CREATE",
"DELETE",
):
return True
return False
def _identifier_is_function(identifier):
return any(isinstance(t, Function) for t in identifier.tokens)
def extract_from_part(parsed, stop_at_punctuation=True):
tbl_prefix_seen = False
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
elif stop_at_punctuation and item.ttype is Punctuation:
return
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
# the second FROM to trigger this elif condition resulting in a
# `return`. So we need to ignore the keyword if the keyword
# FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
elif (
item.ttype is Keyword
and (not item.value.upper() == "FROM")
and (not item.value.upper().endswith("JOIN"))
):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (
item_val
in (
"COPY",
"FROM",
"INTO",
"UPDATE",
"TABLE",
)
or item_val.endswith("JOIN")
):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
tbl_prefix_seen = True
break
def extract_table_identifiers(token_stream, allow_functions=True):
"""yields tuples of TableReference namedtuples"""
# We need to do some massaging of the names because postgres is case-
# insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
def parse_identifier(item):
name = item.get_real_name()
schema_name = item.get_parent_name()
alias = item.get_alias()
if not name:
schema_name = None
name = item.get_name()
alias = alias or name
schema_quoted = schema_name and item.value[0] == '"'
if schema_name and not schema_quoted:
schema_name = schema_name.lower()
quote_count = item.value.count('"')
name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
alias_quoted = alias and item.value[-1] == '"'
if alias_quoted or name_quoted and not alias and name.islower():
alias = '"' + (alias or name) + '"'
if name and not name_quoted and not name.islower():
if not alias:
alias = name
name = name.lower()
return schema_name, name, alias
try:
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = allow_functions and _identifier_is_function(
identifier
)
except AttributeError:
continue
if real_name:
yield TableReference(
schema_name, real_name, identifier.get_alias(), is_function
)
elif isinstance(item, Identifier):
schema_name, real_name, alias = parse_identifier(item)
is_function = allow_functions and _identifier_is_function(item)
yield TableReference(schema_name, real_name, alias, is_function)
elif isinstance(item, Function):
schema_name, real_name, alias = parse_identifier(item)
yield TableReference(None, real_name, alias, allow_functions)
except StopIteration:
return
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql):
"""Extract the table names from an SQL statment.
Returns a list of TableReference namedtuples
"""
parsed = sqlparse.parse(sql)
if not parsed:
return ()
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == "insert"
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as
# function calls due to the parenthesized column list, e.g. interprets
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt)
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
return tuple(i for i in identifiers if i.name)

View file

@ -0,0 +1,140 @@
import re
import sqlparse
from sqlparse.sql import Identifier
from sqlparse.tokens import Token, Error
cleanup_regex = {
# This matches only alphanumerics and underscores.
"alphanum_underscore": re.compile(r"(\w+)$"),
# This matches everything except spaces, parens, colon, and comma
"many_punctuations": re.compile(r"([^():,\s]+)$"),
# This matches everything except spaces, parens, colon, comma, and period
"most_punctuations": re.compile(r"([^\.():,\s]+)$"),
# This matches everything except a space.
"all_punctuations": re.compile(r"([^\s]+)$"),
}
def last_word(text, include="alphanum_underscore"):
r"""
Find the last word in a sentence.
>>> last_word('abc')
'abc'
>>> last_word(' abc')
'abc'
>>> last_word('')
''
>>> last_word(' ')
''
>>> last_word('abc ')
''
>>> last_word('abc def')
'def'
>>> last_word('abc def ')
''
>>> last_word('abc def;')
''
>>> last_word('bac $def')
'def'
>>> last_word('bac $def', include='most_punctuations')
'$def'
>>> last_word('bac \def', include='most_punctuations')
'\\\\def'
>>> last_word('bac \def;', include='most_punctuations')
'\\\\def;'
>>> last_word('bac::def', include='most_punctuations')
'def'
>>> last_word('"foo*bar', include='most_punctuations')
'"foo*bar'
"""
if not text: # Empty string
return ""
if text[-1].isspace():
return ""
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ""
def find_prev_keyword(sql, n_skip=0):
"""Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped
"""
if not sql.strip():
return None, ""
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
flattened = flattened[: len(flattened) - n_skip]
logical_operators = ("AND", "OR", "NOT", "BETWEEN")
for t in reversed(flattened):
if t.value == "(" or (
t.is_keyword and (t.value.upper() not in logical_operators)
):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index throws an error
# Minimal example:
# p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token
# p.token_index(t) # Throws ValueError: not in list
idx = flattened.index(t)
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = "".join(tok.value for tok in flattened[: idx + 1])
return t, text
return None, ""
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote"""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
def _parsed_is_open_quote(parsed):
# Look for unmatched single quotes, or unmatched dollar sign quotes
return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
def parse_partial_identifier(word):
"""Attempt to parse a (partially typed) word as an identifier
word may include a schema qualification, like `schema_name.partial_name`
or `schema_name.` There may also be unclosed quotation marks, like
`"schema`, or `schema."partial_name`
:param word: string representing a (partially complete) identifier
:return: sqlparse.sql.Identifier, or None
"""
p = sqlparse.parse(word)[0]
n_tok = len(p.tokens)
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
return p.tokens[0]
elif p.token_next_by(m=(Error, '"'))[1]:
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
# Close the double quote, then reparse
return parse_partial_identifier(word + '"')
else:
return None

View file

View file

@ -0,0 +1,15 @@
import os
import json
root = os.path.dirname(__file__)
literal_file = os.path.join(root, "pgliterals.json")
with open(literal_file) as f:
literals = json.load(f)
def get_literals(literal_type, type_=tuple):
# Where `literal_type` is one of 'keywords', 'functions', 'datatypes',
# returns a tuple of literal values of that type.
return type_(literals[literal_type])

View file

@ -0,0 +1,629 @@
{
"keywords": {
"ACCESS": [],
"ADD": [],
"ALL": [],
"ALTER": [
"AGGREGATE",
"COLLATION",
"COLUMN",
"CONVERSION",
"DATABASE",
"DEFAULT",
"DOMAIN",
"EVENT TRIGGER",
"EXTENSION",
"FOREIGN",
"FUNCTION",
"GROUP",
"INDEX",
"LANGUAGE",
"LARGE OBJECT",
"MATERIALIZED VIEW",
"OPERATOR",
"POLICY",
"ROLE",
"RULE",
"SCHEMA",
"SEQUENCE",
"SERVER",
"SYSTEM",
"TABLE",
"TABLESPACE",
"TEXT SEARCH",
"TRIGGER",
"TYPE",
"USER",
"VIEW"
],
"AND": [],
"ANY": [],
"AS": [],
"ASC": [],
"AUDIT": [],
"BEGIN": [],
"BETWEEN": [],
"BY": [],
"CASE": [],
"CHAR": [],
"CHECK": [],
"CLUSTER": [],
"COLUMN": [],
"COMMENT": [],
"COMMIT": [],
"COMPRESS": [],
"CONCURRENTLY": [],
"CONNECT": [],
"COPY": [],
"CREATE": [
"ACCESS METHOD",
"AGGREGATE",
"CAST",
"COLLATION",
"CONVERSION",
"DATABASE",
"DOMAIN",
"EVENT TRIGGER",
"EXTENSION",
"FOREIGN DATA WRAPPER",
"FOREIGN EXTENSION",
"FUNCTION",
"GLOBAL",
"GROUP",
"IF NOT EXISTS",
"INDEX",
"LANGUAGE",
"LOCAL",
"MATERIALIZED VIEW",
"OPERATOR",
"OR REPLACE",
"POLICY",
"ROLE",
"RULE",
"SCHEMA",
"SEQUENCE",
"SERVER",
"TABLE",
"TABLESPACE",
"TEMPORARY",
"TEXT SEARCH",
"TRIGGER",
"TYPE",
"UNIQUE",
"UNLOGGED",
"USER",
"USER MAPPING",
"VIEW"
],
"CURRENT": [],
"DATABASE": [],
"DATE": [],
"DECIMAL": [],
"DEFAULT": [],
"DELETE FROM": [],
"DELIMITER": [],
"DESC": [],
"DESCRIBE": [],
"DISTINCT": [],
"DROP": [
"ACCESS METHOD",
"AGGREGATE",
"CAST",
"COLLATION",
"COLUMN",
"CONVERSION",
"DATABASE",
"DOMAIN",
"EVENT TRIGGER",
"EXTENSION",
"FOREIGN DATA WRAPPER",
"FOREIGN TABLE",
"FUNCTION",
"GROUP",
"INDEX",
"LANGUAGE",
"MATERIALIZED VIEW",
"OPERATOR",
"OWNED",
"POLICY",
"ROLE",
"RULE",
"SCHEMA",
"SEQUENCE",
"SERVER",
"TABLE",
"TABLESPACE",
"TEXT SEARCH",
"TRANSFORM",
"TRIGGER",
"TYPE",
"USER",
"USER MAPPING",
"VIEW"
],
"EXPLAIN": [],
"ELSE": [],
"ENCODING": [],
"ESCAPE": [],
"EXCLUSIVE": [],
"EXISTS": [],
"EXTENSION": [],
"FILE": [],
"FLOAT": [],
"FOR": [],
"FORMAT": [],
"FORCE_QUOTE": [],
"FORCE_NOT_NULL": [],
"FREEZE": [],
"FROM": [],
"FULL": [],
"FUNCTION": [],
"GRANT": [],
"GROUP BY": [],
"HAVING": [],
"HEADER": [],
"IDENTIFIED": [],
"IMMEDIATE": [],
"IN": [],
"INCREMENT": [],
"INDEX": [],
"INITIAL": [],
"INSERT INTO": [],
"INTEGER": [],
"INTERSECT": [],
"INTERVAL": [],
"INTO": [],
"IS": [],
"JOIN": [],
"LANGUAGE": [],
"LEFT": [],
"LEVEL": [],
"LIKE": [],
"LIMIT": [],
"LOCK": [],
"LONG": [],
"MATERIALIZED VIEW": [],
"MAXEXTENTS": [],
"MINUS": [],
"MLSLABEL": [],
"MODE": [],
"MODIFY": [],
"NOT": [],
"NOAUDIT": [],
"NOTICE": [],
"NOCOMPRESS": [],
"NOWAIT": [],
"NULL": [],
"NUMBER": [],
"OIDS": [],
"OF": [],
"OFFLINE": [],
"ON": [],
"ONLINE": [],
"OPTION": [],
"OR": [],
"ORDER BY": [],
"OUTER": [],
"OWNER": [],
"PCTFREE": [],
"PRIMARY": [],
"PRIOR": [],
"PRIVILEGES": [],
"QUOTE": [],
"RAISE": [],
"RENAME": [],
"REPLACE": [],
"RESET": ["ALL"],
"RAW": [],
"REFRESH MATERIALIZED VIEW": [],
"RESOURCE": [],
"RETURNS": [],
"REVOKE": [],
"RIGHT": [],
"ROLLBACK": [],
"ROW": [],
"ROWID": [],
"ROWNUM": [],
"ROWS": [],
"SELECT": [],
"SESSION": [],
"SET": [],
"SHARE": [],
"SHOW": [],
"SIZE": [],
"SMALLINT": [],
"START": [],
"SUCCESSFUL": [],
"SYNONYM": [],
"SYSDATE": [],
"TABLE": [],
"TEMPLATE": [],
"THEN": [],
"TO": [],
"TRIGGER": [],
"TRUNCATE": [],
"UID": [],
"UNION": [],
"UNIQUE": [],
"UPDATE": [],
"USE": [],
"USER": [],
"USING": [],
"VALIDATE": [],
"VALUES": [],
"VARCHAR": [],
"VARCHAR2": [],
"VIEW": [],
"WHEN": [],
"WHENEVER": [],
"WHERE": [],
"WITH": []
},
"functions": [
"ABBREV",
"ABS",
"AGE",
"AREA",
"ARRAY_AGG",
"ARRAY_APPEND",
"ARRAY_CAT",
"ARRAY_DIMS",
"ARRAY_FILL",
"ARRAY_LENGTH",
"ARRAY_LOWER",
"ARRAY_NDIMS",
"ARRAY_POSITION",
"ARRAY_POSITIONS",
"ARRAY_PREPEND",
"ARRAY_REMOVE",
"ARRAY_REPLACE",
"ARRAY_TO_STRING",
"ARRAY_UPPER",
"ASCII",
"AVG",
"BIT_AND",
"BIT_LENGTH",
"BIT_OR",
"BOOL_AND",
"BOOL_OR",
"BOUND_BOX",
"BOX",
"BROADCAST",
"BTRIM",
"CARDINALITY",
"CBRT",
"CEIL",
"CEILING",
"CENTER",
"CHAR_LENGTH",
"CHR",
"CIRCLE",
"CLOCK_TIMESTAMP",
"CONCAT",
"CONCAT_WS",
"CONVERT",
"CONVERT_FROM",
"CONVERT_TO",
"COUNT",
"CUME_DIST",
"CURRENT_DATE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"DATE_PART",
"DATE_TRUNC",
"DECODE",
"DEGREES",
"DENSE_RANK",
"DIAMETER",
"DIV",
"ENCODE",
"ENUM_FIRST",
"ENUM_LAST",
"ENUM_RANGE",
"EVERY",
"EXP",
"EXTRACT",
"FAMILY",
"FIRST_VALUE",
"FLOOR",
"FORMAT",
"GET_BIT",
"GET_BYTE",
"HEIGHT",
"HOST",
"HOSTMASK",
"INET_MERGE",
"INET_SAME_FAMILY",
"INITCAP",
"ISCLOSED",
"ISFINITE",
"ISOPEN",
"JUSTIFY_DAYS",
"JUSTIFY_HOURS",
"JUSTIFY_INTERVAL",
"LAG",
"LAST_VALUE",
"LEAD",
"LEFT",
"LENGTH",
"LINE",
"LN",
"LOCALTIME",
"LOCALTIMESTAMP",
"LOG",
"LOG10",
"LOWER",
"LPAD",
"LSEG",
"LTRIM",
"MAKE_DATE",
"MAKE_INTERVAL",
"MAKE_TIME",
"MAKE_TIMESTAMP",
"MAKE_TIMESTAMPTZ",
"MASKLEN",
"MAX",
"MD5",
"MIN",
"MOD",
"NETMASK",
"NETWORK",
"NOW",
"NPOINTS",
"NTH_VALUE",
"NTILE",
"NUM_NONNULLS",
"NUM_NULLS",
"OCTET_LENGTH",
"OVERLAY",
"PARSE_IDENT",
"PATH",
"PCLOSE",
"PERCENT_RANK",
"PG_CLIENT_ENCODING",
"PI",
"POINT",
"POLYGON",
"POPEN",
"POSITION",
"POWER",
"QUOTE_IDENT",
"QUOTE_LITERAL",
"QUOTE_NULLABLE",
"RADIANS",
"RADIUS",
"RANK",
"REGEXP_MATCH",
"REGEXP_MATCHES",
"REGEXP_REPLACE",
"REGEXP_SPLIT_TO_ARRAY",
"REGEXP_SPLIT_TO_TABLE",
"REPEAT",
"REPLACE",
"REVERSE",
"RIGHT",
"ROUND",
"ROW_NUMBER",
"RPAD",
"RTRIM",
"SCALE",
"SET_BIT",
"SET_BYTE",
"SET_MASKLEN",
"SHA224",
"SHA256",
"SHA384",
"SHA512",
"SIGN",
"SPLIT_PART",
"SQRT",
"STARTS_WITH",
"STATEMENT_TIMESTAMP",
"STRING_TO_ARRAY",
"STRPOS",
"SUBSTR",
"SUBSTRING",
"SUM",
"TEXT",
"TIMEOFDAY",
"TO_ASCII",
"TO_CHAR",
"TO_DATE",
"TO_HEX",
"TO_NUMBER",
"TO_TIMESTAMP",
"TRANSACTION_TIMESTAMP",
"TRANSLATE",
"TRIM",
"TRUNC",
"UNNEST",
"UPPER",
"WIDTH",
"WIDTH_BUCKET",
"XMLAGG"
],
"datatypes": [
"ANY",
"ANYARRAY",
"ANYELEMENT",
"ANYENUM",
"ANYNONARRAY",
"ANYRANGE",
"BIGINT",
"BIGSERIAL",
"BIT",
"BIT VARYING",
"BOOL",
"BOOLEAN",
"BOX",
"BYTEA",
"CHAR",
"CHARACTER",
"CHARACTER VARYING",
"CIDR",
"CIRCLE",
"CSTRING",
"DATE",
"DECIMAL",
"DOUBLE PRECISION",
"EVENT_TRIGGER",
"FDW_HANDLER",
"FLOAT4",
"FLOAT8",
"INET",
"INT",
"INT2",
"INT4",
"INT8",
"INTEGER",
"INTERNAL",
"INTERVAL",
"JSON",
"JSONB",
"LANGUAGE_HANDLER",
"LINE",
"LSEG",
"MACADDR",
"MACADDR8",
"MONEY",
"NUMERIC",
"OID",
"OPAQUE",
"PATH",
"PG_LSN",
"POINT",
"POLYGON",
"REAL",
"RECORD",
"REGCLASS",
"REGCONFIG",
"REGDICTIONARY",
"REGNAMESPACE",
"REGOPER",
"REGOPERATOR",
"REGPROC",
"REGPROCEDURE",
"REGROLE",
"REGTYPE",
"SERIAL",
"SERIAL2",
"SERIAL4",
"SERIAL8",
"SMALLINT",
"SMALLSERIAL",
"TEXT",
"TIME",
"TIMESTAMP",
"TRIGGER",
"TSQUERY",
"TSVECTOR",
"TXID_SNAPSHOT",
"UUID",
"VARBIT",
"VARCHAR",
"VOID",
"XML"
],
"reserved": [
"ALL",
"ANALYSE",
"ANALYZE",
"AND",
"ANY",
"ARRAY",
"AS",
"ASC",
"ASYMMETRIC",
"BOTH",
"CASE",
"CAST",
"CHECK",
"COLLATE",
"COLUMN",
"CONSTRAINT",
"CREATE",
"CURRENT_CATALOG",
"CURRENT_DATE",
"CURRENT_ROLE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_USER",
"DEFAULT",
"DEFERRABLE",
"DESC",
"DISTINCT",
"DO",
"ELSE",
"END",
"EXCEPT",
"FALSE",
"FETCH",
"FOR",
"FOREIGN",
"FROM",
"GRANT",
"GROUP",
"HAVING",
"IN",
"INITIALLY",
"INTERSECT",
"INTO",
"LATERAL",
"LEADING",
"LIMIT",
"LOCALTIME",
"LOCALTIMESTAMP",
"NOT",
"NULL",
"OFFSET",
"ON",
"ONLY",
"OR",
"ORDER",
"PLACING",
"PRIMARY",
"REFERENCES",
"RETURNING",
"SELECT",
"SESSION_USER",
"SOME",
"SYMMETRIC",
"TABLE",
"THEN",
"TO",
"TRAILING",
"TRUE",
"UNION",
"UNIQUE",
"USER",
"USING",
"VARIADIC",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
"AUTHORIZATION",
"BINARY",
"COLLATION",
"CONCURRENTLY",
"CROSS",
"CURRENT_SCHEMA",
"FREEZE",
"FULL",
"ILIKE",
"INNER",
"IS",
"ISNULL",
"JOIN",
"LEFT",
"LIKE",
"NATURAL",
"NOTNULL",
"OUTER",
"OVERLAPS",
"RIGHT",
"SIMILAR",
"TABLESAMPLE",
"VERBOSE"
]
}

View file

@ -0,0 +1,51 @@
import re
import sqlparse
from sqlparse.tokens import Name
from collections import defaultdict
from .pgliterals.main import get_literals
white_space_regex = re.compile("\\s+", re.MULTILINE)
def _compile_regex(keyword):
# Surround the keyword with word boundaries and replace interior whitespace
# with whitespace wildcards
pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
keywords = get_literals("keywords")
keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
class PrevalenceCounter(object):
def __init__(self):
self.keyword_counts = defaultdict(int)
self.name_counts = defaultdict(int)
def update(self, text):
self.update_keywords(text)
self.update_names(text)
def update_names(self, text):
for parsed in sqlparse.parse(text):
for token in parsed.flatten():
if token.ttype in Name:
self.name_counts[token.value] += 1
def clear_names(self):
self.name_counts = defaultdict(int)
def update_keywords(self, text):
# Count keywords. Can't rely for sqlparse for this, because it's
# database agnostic
for keyword, regex in keyword_regexs.items():
for _ in regex.finditer(text):
self.keyword_counts[keyword] += 1
def keyword_count(self, keyword):
return self.keyword_counts[keyword]
def name_count(self, name):
return self.name_counts[name]

View file

@ -0,0 +1,35 @@
import sys
import click
from .parseutils import is_destructive
def confirm_destructive_query(queries):
"""Check if the query is destructive and prompts the user to confirm.
Returns:
* None if the query is non-destructive or we can't prompt the user.
* True if the query is destructive and the user wants to proceed.
* False if the query is destructive and the user doesn't want to proceed.
"""
prompt_text = (
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
)
if is_destructive(queries) and sys.stdin.isatty():
return prompt(prompt_text, type=bool)
def confirm(*args, **kwargs):
"""Prompt for confirmation (yes/no) and handle any abort exceptions."""
try:
return click.confirm(*args, **kwargs)
except click.Abort:
return False
def prompt(*args, **kwargs):
"""Prompt the user for input and handle any abort exceptions."""
try:
return click.prompt(*args, **kwargs)
except click.Abort:
return False

View file

@ -0,0 +1,608 @@
import sys
import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where
from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier
from .parseutils.tables import extract_tables
from .parseutils.ctes import isolate_query_ctes
from pgspecial.main import parse_special_command
Special = namedtuple("Special", [])
Database = namedtuple("Database", [])
Schema = namedtuple("Schema", ["quoted"])
Schema.__new__.__defaults__ = (False,)
# FromClauseItem is a table/view/function used in the FROM clause
# `table_refs` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique
FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
TableFormat = namedtuple("TableFormat", [])
View = namedtuple("View", ["schema", "table_refs"])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
Join = namedtuple("Join", ["table_refs", "schema"])
Function = namedtuple("Function", ["schema", "table_refs", "usage"])
# For convenience, don't require the `usage` argument in Function constructor
Function.__new__.__defaults__ = (None, tuple(), None)
Table.__new__.__defaults__ = (None, tuple(), tuple())
View.__new__.__defaults__ = (None, tuple())
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
Column = namedtuple(
"Column",
["table_refs", "require_last_table", "local_tables", "qualifiable", "context"],
)
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
Keyword = namedtuple("Keyword", ["last_token"])
Keyword.__new__.__defaults__ = (None,)
NamedQuery = namedtuple("NamedQuery", [])
Datatype = namedtuple("Datatype", ["schema"])
Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple("Path", [])
class SqlStatement(object):
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(
text_before_cursor, include="many_punctuations"
)
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
full_text, text_before_cursor
)
self.text_before_cursor_including_last_word = text_before_cursor
# If we've partially typed a word then word_before_cursor won't be an
# empty string. In that case we want to remove the partially typed
# string before sending it to the sqlparser. Otherwise the last token
# will always be the partially typed string which renders the smart
# completion useless because it will always return the list of
# keywords as completion.
if self.word_before_cursor:
if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
parsed = sqlparse.parse(text_before_cursor)
else:
text_before_cursor = text_before_cursor[: -len(word_before_cursor)]
parsed = sqlparse.parse(text_before_cursor)
self.identifier = parse_partial_identifier(word_before_cursor)
else:
parsed = sqlparse.parse(text_before_cursor)
full_text, text_before_cursor, parsed = _split_multiple_statements(
full_text, text_before_cursor, parsed
)
self.full_text = full_text
self.text_before_cursor = text_before_cursor
self.parsed = parsed
self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or ""
def is_insert(self):
return self.parsed.token_first().value.lower() == "insert"
def get_tables(self, scope="full"):
"""Gets the tables available in the statement.
param `scope:` possible values: 'full', 'insert', 'before'
If 'insert', only the first table is returned.
If 'before', only tables before the cursor are returned.
If not 'insert' and the stmt is an insert, the first table is skipped.
"""
tables = extract_tables(
self.full_text if scope == "full" else self.text_before_cursor
)
if scope == "insert":
tables = tables[:1]
elif self.is_insert():
tables = tables[1:]
return tables
def get_previous_token(self, token):
return self.parsed.token_prev(self.parsed.token_index(token))[1]
def get_identifier_schema(self):
schema = (self.identifier and self.identifier.get_parent_name()) or None
# If schema name is unquoted, lower-case it
if schema and self.identifier.value[0] != '"':
schema = schema.lower()
return schema
def reduce_to_prev_keyword(self, n_skip=0):
prev_keyword, self.text_before_cursor = find_prev_keyword(
self.text_before_cursor, n_skip=n_skip
)
return prev_keyword
def suggest_type(full_text, text_before_cursor):
"""Takes the full_text that is typed so far and also the text before the
cursor to suggest completion type and scope.
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
A scope for a column category will be a list of tables.
"""
if full_text.startswith("\\i "):
return (Path(),)
# This is a temporary hack; the exception handling
# here should be removed once sqlparse has been fixed
try:
stmt = SqlStatement(full_text, text_before_cursor)
except (TypeError, AttributeError):
return []
# Check for special commands and handle those separately
if stmt.parsed:
# Be careful here because trivial whitespace is parsed as a
# statement, but the statement won't have a first token
tok1 = stmt.parsed.token_first()
if tok1 and tok1.value.startswith("\\"):
text = stmt.text_before_cursor + stmt.word_before_cursor
return suggest_special(text)
return suggest_based_on_last_token(stmt.last_token, stmt)
named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
def _strip_named_query(txt):
"""
This will strip "save named query" command in the beginning of the line:
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
"""
if named_query_regex.match(txt):
txt = named_query_regex.sub("", txt)
return txt
function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
def _find_function_body(text):
split = function_body_pattern.search(text)
return (split.start(2), split.end(2)) if split else (None, None)
def _statement_from_function(full_text, text_before_cursor, statement):
current_pos = len(text_before_cursor)
body_start, body_end = _find_function_body(full_text)
if body_start is None:
return full_text, text_before_cursor, statement
if not body_start <= current_pos < body_end:
return full_text, text_before_cursor, statement
full_text = full_text[body_start:body_end]
text_before_cursor = text_before_cursor[body_start:]
parsed = sqlparse.parse(text_before_cursor)
return _split_multiple_statements(full_text, text_before_cursor, parsed)
def _split_multiple_statements(full_text, text_before_cursor, parsed):
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
# cumulatively summing statement lengths to find the one that bounds
# the current position
current_pos = len(text_before_cursor)
stmt_start, stmt_end = 0, 0
for statement in parsed:
stmt_len = len(str(statement))
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
if stmt_end >= current_pos:
text_before_cursor = full_text[stmt_start:current_pos]
full_text = full_text[stmt_start:]
break
elif parsed:
# A single statement
statement = parsed[0]
else:
# The empty string
return full_text, text_before_cursor, None
token2 = None
if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
token1 = statement.token_first()
if token1:
token1_idx = statement.token_index(token1)
token2 = statement.token_next(token1_idx)[1]
if token2 and token2.value.upper() == "FUNCTION":
full_text, text_before_cursor, statement = _statement_from_function(
full_text, text_before_cursor, statement
)
return full_text, text_before_cursor, statement
SPECIALS_SUGGESTION = {
"dT": Datatype,
"df": Function,
"dt": Table,
"dv": View,
"sf": Function,
}
def suggest_special(text):
text = text.lstrip()
cmd, _, arg = parse_special_command(text)
if cmd == text:
# Trying to complete the special command itself
return (Special(),)
if cmd in ("\\c", "\\connect"):
return (Database(),)
if cmd == "\\T":
return (TableFormat(),)
if cmd == "\\dn":
return (Schema(),)
if arg:
# Try to distinguish "\d name" from "\d schema.name"
# Note that this will fail to obtain a schema name if wildcards are
# used, e.g. "\d schema???.name"
parsed = sqlparse.parse(arg)[0].tokens[0]
try:
schema = parsed.get_parent_name()
except AttributeError:
schema = None
else:
schema = None
if cmd[1:] == "d":
# \d can describe tables or views
if schema:
return (Table(schema=schema), View(schema=schema))
else:
return (Schema(), Table(schema=None), View(schema=None))
elif cmd[1:] in SPECIALS_SUGGESTION:
rel_type = SPECIALS_SUGGESTION[cmd[1:]]
if schema:
if rel_type == Function:
return (Function(schema=schema, usage="special"),)
return (rel_type(schema=schema),)
else:
if rel_type == Function:
return (Schema(), Function(schema=None, usage="special"))
return (Schema(), rel_type(schema=None))
if cmd in ["\\n", "\\ns", "\\nd"]:
return (NamedQuery(),)
return (Keyword(), Special())
def suggest_based_on_last_token(token, stmt):
if isinstance(token, str):
token_v = token.lower()
elif isinstance(token, Comparison):
# If 'token' is a Comparison type such as
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
# token.value on the comparison type will only return the lhs of the
# comparison. In this case a.id. So we need to do token.tokens to get
# both sides of the comparison and pick the last token out of that
# list.
token_v = token.tokens[-1].value.lower()
elif isinstance(token, Where):
# sqlparse groups all tokens from the where clause into a single token
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
prev_keyword = stmt.reduce_to_prev_keyword()
return suggest_based_on_last_token(prev_keyword, stmt)
elif isinstance(token, Identifier):
# If the previous token is an identifier, we can suggest datatypes if
# we're in a parenthesized column/field list, e.g.:
# CREATE TABLE foo (Identifier <CURSOR>
# CREATE FUNCTION foo (Identifier <CURSOR>
# If we're not in a parenthesized list, the most likely scenario is the
# user is about to specify an alias, e.g.:
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
if prev_keyword and prev_keyword.value == "(":
# Suggest datatypes
return suggest_based_on_last_token("type", stmt)
else:
return (Keyword(),)
else:
token_v = token.value.lower()
if not token:
return (Keyword(), Special())
elif token_v.endswith("("):
p = sqlparse.parse(stmt.text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
# Four possibilities:
# 1 - Parenthesized clause like "WHERE foo AND ("
# Suggest columns/functions
# 2 - Function call like "WHERE foo("
# Suggest columns/functions
# 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token("where", stmt)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
prev_tok = where.token_prev(len(where.tokens) - 1)[1]
if isinstance(prev_tok, Comparison):
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == "exists":
return (Keyword(),)
else:
return column_suggestions
# Get the token before the parens
prev_tok = p.token_prev(len(p.tokens) - 1)[1]
if (
prev_tok
and prev_tok.value
and prev_tok.value.lower().split(" ")[-1] == "using"
):
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = stmt.get_tables("before")
# suggest columns that are present in more than one table
return (
Column(
table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables,
),
)
elif p.token_first().value.lower() == "select":
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("):
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == "INTO":
return (Column(table_refs=stmt.get_tables("insert"), context="insert"),)
# We're probably in a function argument list
return _suggest_expression(token_v, stmt)
elif token_v == "set":
return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),)
elif token_v in ("select", "where", "having", "order by", "distinct"):
return _suggest_expression(token_v, stmt)
elif token_v == "as":
# Don't suggest anything for aliases
return ()
elif (token_v.endswith("join") and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate")
):
schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith("join") and token.is_keyword
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
if token_v == "from" or is_join:
suggest.append(
FromClauseItem(
schema=schema, table_refs=tables, local_tables=stmt.local_tables
)
)
elif token_v == "truncate":
suggest.append(Table(schema))
else:
suggest.extend((Table(schema), View(schema)))
if is_join and _allow_join(stmt.parsed):
tables = stmt.get_tables("before")
suggest.append(Join(table_refs=tables, schema=schema))
return tuple(suggest)
elif token_v == "function":
schema = stmt.get_identifier_schema()
# stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:`
try:
prev = stmt.get_previous_token(token).value.lower()
if prev in ("drop", "alter", "create", "create or replace"):
# Suggest functions from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
suggest.append(Function(schema=schema, usage="signature"))
return tuple(suggest)
except ValueError:
pass
return tuple()
elif token_v in ("table", "view"):
# E.g. 'ALTER TABLE <tablname>'
rel_type = {"table": Table, "view": View, "function": Function}[token_v]
schema = stmt.get_identifier_schema()
if schema:
return (rel_type(schema=schema),)
else:
return (Schema(), rel_type(schema=schema))
elif token_v == "column":
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
return (Column(table_refs=stmt.get_tables()),)
elif token_v == "on":
tables = stmt.get_tables("before")
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t))
sugs = [
Column(table_refs=filteredtables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),
]
if filteredtables and _allow_join_condition(stmt.parsed):
sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1]))
return tuple(sugs)
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.ref for t in tables)
if _allow_join_condition(stmt.parsed):
return (
Alias(aliases=aliases),
JoinCondition(table_refs=tables, parent=None),
)
else:
return (Alias(aliases=aliases),)
elif token_v in ("c", "use", "database", "template"):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return (Database(),)
elif token_v == "schema":
# DROP SCHEMA schema_name, SET SCHEMA schema name
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
quoted = prev_keyword and prev_keyword.value.lower() == "set"
return (Schema(quoted),)
elif token_v.endswith(",") or token_v in ("=", "and", "or"):
prev_keyword = stmt.reduce_to_prev_keyword()
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
else:
return ()
elif token_v in ("type", "::"):
# ALTER TABLE foo SET DATA TYPE bar
# SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so
# they're suggested here as well
schema = stmt.get_identifier_schema()
suggestions = [Datatype(schema=schema), Table(schema=schema)]
if not schema:
suggestions.append(Schema())
return tuple(suggestions)
elif token_v in {"alter", "create", "drop"}:
return (Keyword(token_v.upper()),)
elif token.is_keyword:
# token is a keyword we haven't implemented any special handling for
# go backwards in the query until we find one we do recognize
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
else:
return (Keyword(token_v.upper()),)
else:
return (Keyword(),)
def _suggest_expression(token_v, stmt):
"""
Return suggestions for an expression, taking account of any partially-typed
identifier's parent, which may be a table alias or schema name.
"""
parent = stmt.identifier.get_parent_name() if stmt.identifier else []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (
Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),
)
return (
Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True),
Function(schema=None),
Keyword(token_v.upper()),
)
def identifies(id, ref):
"""Returns true if string `id` matches TableReference `ref`"""
return (
id == ref.alias
or id == ref.name
or (ref.schema and (id == ref.schema + "." + ref.name))
)
def _allow_join_condition(statement):
"""
Tests if a join condition should be suggested
We need this to avoid bad suggestions when entering e.g.
select * from tbl1 a join tbl2 b on a.id = <cursor>
So check that the preceding token is a ON, AND, or OR keyword, instead of
e.g. an equals sign.
:param statement: an sqlparse.sql.Statement
:return: boolean
"""
if not statement or not statement.tokens:
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower() in ("on", "and", "or")
def _allow_join(statement):
"""
Tests if a join should be suggested
We need this to avoid bad suggestions when entering e.g.
select * from tbl1 a join tbl2 b <cursor>
So check that the preceding token is a JOIN keyword
:param statement: an sqlparse.sql.Statement
:return: boolean
"""
if not statement or not statement.tokens:
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in (
"cross join",
"natural join",
)

50
pgcli/pgbuffer.py Normal file
View file

@ -0,0 +1,50 @@
import logging
from prompt_toolkit.enums import DEFAULT_BUFFER
from prompt_toolkit.filters import Condition
from prompt_toolkit.application import get_app
from .packages.parseutils.utils import is_open_quote
_logger = logging.getLogger(__name__)
def _is_complete(sql):
# A complete command is an sql statement that ends with a semicolon, unless
# there's an open quote surrounding it, as is common when writing a
# CREATE FUNCTION command
return sql.endswith(";") and not is_open_quote(sql)
"""
Returns True if the buffer contents should be handled (i.e. the query/command
executed) immediately. This is necessary as we use prompt_toolkit in multiline
mode, which by default will insert new lines on Enter.
"""
def buffer_should_be_handled(pgcli):
@Condition
def cond():
if not pgcli.multi_line:
_logger.debug("Not in multi-line mode. Handle the buffer.")
return True
if pgcli.multiline_mode == "safe":
_logger.debug("Multi-line mode is set to 'safe'. Do NOT handle the buffer.")
return False
doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
text = doc.text.strip()
return (
text.startswith("\\") # Special Command
or text.endswith(r"\e") # Special Command
or text.endswith(r"\G") # Ended with \e which should launch the editor
or _is_complete(text) # A complete SQL command
or (text == "exit") # Exit doesn't need semi-colon
or (text == "quit") # Quit doesn't need semi-colon
or (text == ":q") # To all the vim fans out there
or (text == "") # Just a plain enter without any text
)
return cond

195
pgcli/pgclirc Normal file
View file

@ -0,0 +1,195 @@
# vi: ft=dosini
[main]
# Enables context sensitive auto-completion. If this is disabled, all
# possible completions will be listed.
smart_completion = True
# Display the completions in several columns. (More completions will be
# visible.)
wider_completion_menu = False
# Multi-line mode allows breaking up the sql statements into multiple lines. If
# this is set to True, then the end of the statements must have a semi-colon.
# If this is set to False then sql statements can't be split into multiple
# lines. End of line (return) is considered as the end of the statement.
multi_line = False
# If multi_line_mode is set to "psql", in multi-line mode, [Enter] will execute
# the current input if the input ends in a semicolon.
# If multi_line_mode is set to "safe", in multi-line mode, [Enter] will always
# insert a newline, and [Esc] [Enter] or [Alt]-[Enter] must be used to execute
# a command.
multi_line_mode = psql
# Destructive warning mode will alert you before executing a sql statement
# that may cause harm to the database such as "drop table", "drop database"
# or "shutdown".
destructive_warning = True
# Enables expand mode, which is similar to `\x` in psql.
expand = False
# Enables auto expand mode, which is similar to `\x auto` in psql.
auto_expand = False
# If set to True, table suggestions will include a table alias
generate_aliases = False
# log_file location.
# In Unix/Linux: ~/.config/pgcli/log
# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log
# %USERPROFILE% is typically C:\Users\{username}
log_file = default
# keyword casing preference. Possible values: "lower", "upper", "auto"
keyword_casing = auto
# casing_file location.
# In Unix/Linux: ~/.config/pgcli/casing
# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\casing
# %USERPROFILE% is typically C:\Users\{username}
casing_file = default
# If generate_casing_file is set to True and there is no file in the above
# location, one will be generated based on usage in SQL/PLPGSQL functions.
generate_casing_file = False
# Casing of column headers based on the casing_file described above
case_column_headers = True
# history_file location.
# In Unix/Linux: ~/.config/pgcli/history
# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\history
# %USERPROFILE% is typically C:\Users\{username}
history_file = default
# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
# and "DEBUG". "NONE" disables logging.
log_level = INFO
# Order of columns when expanding * to column list
# Possible values: "table_order" and "alphabetic"
asterisk_column_order = table_order
# Whether to qualify with table alias/name when suggesting columns
# Possible values: "always", "never" and "if_more_than_one_table"
qualify_columns = if_more_than_one_table
# When no schema is entered, only suggest objects in search_path
search_path_filter = False
# Default pager.
# By default 'PAGER' environment variable is used
# pager = less -SRXF
# Timing of sql statements and table rendering.
timing = True
# Show/hide the informational toolbar with function keymap at the footer.
show_bottom_toolbar = True
# Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe,
# ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs,
# textile, moinmoin, jira, vertical, tsv, csv.
# Recommended: psql, fancy_grid and grid.
table_format = psql
# Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt,
# native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark,
# colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity
syntax_style = default
# Keybindings:
# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
# When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E
# for end are available in the REPL.
vi = False
# Error handling
# When one of multiple SQL statements causes an error, choose to either
# continue executing the remaining statements, or stopping
# Possible values "STOP" or "RESUME"
on_error = STOP
# Set threshold for row limit. Use 0 to disable limiting.
row_limit = 1000
# Skip intro on startup and goodbye on exit
less_chatty = False
# Postgres prompt
# \t - Current date and time
# \u - Username
# \h - Short hostname of the server (up to first '.')
# \H - Hostname of the server
# \d - Database name
# \p - Database port
# \i - Postgres PID
# \# - "@" sign if logged in as superuser, '>' in other case
# \n - Newline
# \dsn_alias - name of dsn alias if -D option is used (empty otherwise)
# \x1b[...m - insert ANSI escape sequence
# eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>'
prompt = '\u@\h:\d> '
# Number of lines to reserve for the suggestion menu
min_num_menu_lines = 4
# Character used to left pad multi-line queries to match the prompt size.
multiline_continuation_char = ''
# The string used in place of a null value.
null_string = '<null>'
# manage pager on startup
enable_pager = True
# Use keyring to automatically save and load password in a secure manner
keyring = True
# Custom colors for the completion menu, toolbar, etc.
[colors]
completion-menu.completion.current = 'bg:#ffffff #000000'
completion-menu.completion = 'bg:#008888 #ffffff'
completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
completion-menu.meta.completion = 'bg:#448888 #ffffff'
completion-menu.multi-column-meta = 'bg:#aaffff #000000'
scrollbar.arrow = 'bg:#003333'
scrollbar = 'bg:#00aaaa'
selected = '#ffffff bg:#6666aa'
search = '#ffffff bg:#4444aa'
search.current = '#ffffff bg:#44aa44'
bottom-toolbar = 'bg:#222222 #aaaaaa'
bottom-toolbar.off = 'bg:#222222 #888888'
bottom-toolbar.on = 'bg:#222222 #ffffff'
search-toolbar = 'noinherit bold'
search-toolbar.text = 'nobold'
system-toolbar = 'noinherit bold'
arg-toolbar = 'noinherit bold'
arg-toolbar.text = 'nobold'
bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
literal.string = '#ba2121'
literal.number = '#666666'
keyword = 'bold #008000'
# style classes for colored table output
output.header = "#00ff5f bold"
output.odd-row = ""
output.even-row = ""
output.null = "#808080"
# Named queries are queries you can execute by name.
[named queries]
# DSN to call by -D option
[alias_dsn]
# example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname]
# Format for number representation
# for decimal "d" - 12345678, ",d" - 12,345,678
# for float "g" - 123456.78, ",g" - 123,456.78
[data_formats]
decimal = ""
float = ""

1046
pgcli/pgcompleter.py Normal file

File diff suppressed because it is too large Load diff

857
pgcli/pgexecute.py Normal file
View file

@ -0,0 +1,857 @@
import traceback
import logging
import psycopg2
import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
import sqlparse
import pgspecial as special
import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
# Cast all database input to unicode automatically.
# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info.
ext.register_type(ext.UNICODE)
ext.register_type(ext.UNICODEARRAY)
ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE))
# See https://github.com/dbcli/pgcli/issues/426 for more details.
# This registers a unicode type caster for datatype 'RECORD'.
ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
# Cast bytea fields to text. By default, this will render as hex strings with
# Postgres 9+ and as escaped binary in earlier versions.
ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
def _wait_select(conn):
"""
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except select.error as e:
errno = e.args[0]
if errno != 4:
raise
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
"""
Casts date and timestamp values to string, resolves issues with out of
range dates (e.g. BC) which psycopg2 can't handle
"""
def cast_date(value, cursor):
return value
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp with time zone")
timestamptz_oid = cursor.description[0][1]
oids = (date_oid, timestamp_oid, timestamptz_oid)
new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date)
psycopg2.extensions.register_type(new_type)
def register_json_typecasters(conn, loads_fn):
"""Set the function for converting JSON data for a connection.
Use the supplied function to decode JSON data returned from the database
via the given connection. The function should accept a single argument of
the data as a string encoded in the database's character encoding.
psycopg2's default handler for JSON data is json.loads.
http://initd.org/psycopg/docs/extras.html#json-adaptation
This function attempts to register the typecaster for both JSON and JSONB
types.
Returns a set that is a subset of {'json', 'jsonb'} indicating which types
(if any) were successfully registered.
"""
available = set()
for name in ["json", "jsonb"]:
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
except psycopg2.ProgrammingError:
pass
return available
def register_hstore_typecaster(conn):
"""
Instead of using register_hstore() which converts hstore into a python
dict, we query the 'oid' of hstore which will be different for each
database and register a type caster that converts it to unicode.
http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore
"""
with conn.cursor() as cur:
try:
cur.execute(
"select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined"
)
oid = cur.fetchone()[0]
ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
except Exception:
pass
class PGExecute(object):
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
search_path_query = """
SELECT * FROM unnest(current_schemas(true))"""
schemata_query = """
SELECT nspname
FROM pg_catalog.pg_namespace
ORDER BY 1 """
tables_query = """
SELECT n.nspname schema_name,
c.relname table_name
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
WHERE c.relkind = ANY(%s)
ORDER BY 1,2;"""
databases_query = """
SELECT d.datname
FROM pg_catalog.pg_database d
ORDER BY 1"""
full_databases_query = """
SELECT d.datname as "Name",
pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding",
d.datcollate as "Collate",
d.datctype as "Ctype",
pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges"
FROM pg_catalog.pg_database d
ORDER BY 1"""
socket_directory_query = """
SELECT setting
FROM pg_settings
WHERE name = 'unix_socket_directories'
"""
view_definition_query = """
WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid)
SELECT nspname, relname, relkind,
pg_catalog.pg_get_viewdef(c.oid, true),
array_remove(array_remove(c.reloptions,'check_option=local'),
'check_option=cascaded') AS reloptions,
CASE
WHEN 'check_option=local' = ANY (c.reloptions) THEN 'LOCAL'::text
WHEN 'check_option=cascaded' = ANY (c.reloptions) THEN 'CASCADED'::text
ELSE NULL
END AS checkoption
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid)
JOIN v ON (c.oid = v.v_oid)"""
function_definition_query = """
WITH f AS
(SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid)
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
version_query = "SELECT version();"
def __init__(
self,
database=None,
user=None,
password=None,
host=None,
port=None,
dsn=None,
**kwargs,
):
self._conn_params = {}
self.conn = None
self.dbname = None
self.user = None
self.password = None
self.host = None
self.port = None
self.server_version = None
self.extra_args = None
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
def connect(
self,
database=None,
user=None,
password=None,
host=None,
port=None,
dsn=None,
**kwargs,
):
conn_params = self._conn_params.copy()
new_params = {
"database": database,
"user": user,
"password": password,
"host": host,
"port": port,
"dsn": dsn,
}
new_params.update(kwargs)
if new_params["dsn"]:
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
if new_params["password"]:
new_params["dsn"] = make_dsn(
new_params["dsn"], password=new_params.pop("password")
)
conn_params.update({k: v for k, v in new_params.items() if v})
conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
if self.conn:
self.conn.close()
self.conn = conn
self.conn.autocommit = True
# When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it.
# Note: moved this after setting autocommit because of #664.
libpq_version = psycopg2.__libpq_version__
dsn_parameters = {}
if libpq_version >= 93000:
# use actual connection info from psycopg2.extensions.Connection.info
# as libpq_version > 9.3 is available and required dependency
dsn_parameters = conn.info.dsn_parameters
else:
try:
dsn_parameters = conn.get_dsn_parameters()
except Exception as x:
# https://github.com/dbcli/pgcli/issues/1110
# PQconninfo not available in libpq < 9.3
_logger.info("Exception in get_dsn_parameters: %r", x)
if dsn_parameters:
self.dbname = dsn_parameters.get("dbname")
self.user = dsn_parameters.get("user")
self.host = dsn_parameters.get("host")
self.port = dsn_parameters.get("port")
else:
self.dbname = conn_params.get("database")
self.user = conn_params.get("user")
self.host = conn_params.get("host")
self.port = conn_params.get("port")
self.password = password
self.extra_args = kwargs
if not self.host:
self.host = self.get_socket_directory()
pid = self._select_one(cursor, "select pg_backend_pid()")[0]
self.pid = pid
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
self.server_version = conn.get_parameter_status("server_version")
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
@property
def short_host(self):
if "," in self.host:
host, _, _ = self.host.partition(",")
else:
host = self.host
short_host, _, _ = host.partition(".")
return short_host
def _select_one(self, cur, sql):
"""
Helper method to run a select and retrieve a single field value
:param cur: cursor
:param sql: string
:return: string
"""
cur.execute(sql)
return cur.fetchone()
def _json_typecaster(self, json_data):
"""Interpret incoming JSON data as a string.
The raw data is decoded using the connection's encoding, which defaults
to the database's encoding.
See http://initd.org/psycopg/docs/connection.html#connection.encoding
"""
return json_data
def failed_transaction(self):
status = self.conn.get_transaction_status()
return status == ext.TRANSACTION_STATUS_INERROR
def valid_transaction(self):
status = self.conn.get_transaction_status()
return (
status == ext.TRANSACTION_STATUS_ACTIVE
or status == ext.TRANSACTION_STATUS_INTRANS
)
def run(
self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False
):
"""Execute the sql in the database and return the results.
:param statement: A string containing one or more sql statements
:param pgspecial: PGSpecial object
:param exception_formatter: A callable that accepts an Exception and
returns a formatted (title, rows, headers, status) tuple that can
act as a query result. If an exception_formatter is not supplied,
psycopg2 exceptions are always raised.
:param on_error_resume: Bool. If true, queries following an exception
(assuming exception_formatter has been supplied) continue to
execute.
:return: Generator yielding tuples containing
(title, rows, headers, status, query, success, is_special)
"""
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
yield (None, None, None, None, statement, False, False)
# Split the sql into separate queries and run each one.
for sql in sqlparse.split(statement):
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(";")
sql = sqlparse.format(sql, strip_comments=True).strip()
if not sql:
continue
try:
if pgspecial:
# \G is treated specially since we have to set the expanded output.
if sql.endswith("\\G"):
if not pgspecial.expanded_output:
pgspecial.expanded_output = True
self.reset_expanded = True
sql = sql[:-2].strip()
# First try to run each query as special
_logger.debug("Trying a pgspecial command. sql: %r", sql)
try:
cur = self.conn.cursor()
except psycopg2.InterfaceError:
# edge case when connection is already closed, but we
# don't need cursor for special_cmd.arg_type == NO_QUERY.
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
for result in pgspecial.execute(cur, sql):
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
else:
yield result
continue
except special.CommandNotFound:
pass
# Not a special command, so execute as normal sql
yield self.execute_normal_sql(sql) + (sql, True, False)
except psycopg2.DatabaseError as e:
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
if self._must_raise(e) or not exception_formatter:
raise
yield None, None, None, exception_formatter(e), sql, False, False
if not on_error_resume:
break
finally:
if self.reset_expanded:
pgspecial.expanded_output = False
self.reset_expanded = None
def _must_raise(self, e):
"""Return true if e is an error that should not be caught in ``run``.
An uncaught error will prompt the user to reconnect; as long as we
detect that the connection is stil open, we catch the error, as
reconnecting won't solve that problem.
:param e: DatabaseError. An exception raised while executing a query.
:return: Bool. True if ``run`` must raise this exception.
"""
return self.conn.closed != 0
def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
_logger.debug("Regular sql statement. sql: %r", split_sql)
cur = self.conn.cursor()
cur.execute(split_sql)
# conn.notices persist between queies, we use pop to clear out the list
title = ""
while len(self.conn.notices) > 0:
title = self.conn.notices.pop() + title
# cur.description will be None for operations that do not return
# rows.
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
def search_path(self):
"""Returns the current search path as a list of schema names"""
try:
with self.conn.cursor() as cur:
_logger.debug("Search path query. sql: %r", self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
except psycopg2.ProgrammingError:
fallback = "SELECT * FROM current_schemas(true)"
with self.conn.cursor() as cur:
_logger.debug("Search path query. sql: %r", fallback)
cur.execute(fallback)
return cur.fetchone()[0]
def view_definition(self, spec):
"""Returns the SQL defining views described by `spec`"""
template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}"
# 2: relkind, v or m (materialized)
# 4: reloptions, null
# 5: checkoption: local or cascaded
with self.conn.cursor() as cur:
sql = self.view_definition_query
_logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
raise RuntimeError("View {} does not exist.".format(spec))
result = cur.fetchone()
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
def function_definition(self, spec):
"""Returns the SQL defining functions described by `spec`"""
with self.conn.cursor() as cur:
sql = self.function_definition_query
_logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec,))
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
raise RuntimeError("Function {} does not exist.".format(spec))
def schemata(self):
"""Returns a list of schema names in the database"""
with self.conn.cursor() as cur:
_logger.debug("Schemata Query. sql: %r", self.schemata_query)
cur.execute(self.schemata_query)
return [x[0] for x in cur.fetchall()]
def _relations(self, kinds=("r", "p", "f", "v", "m")):
"""Get table or view name metadata
:param kinds: list of postgres relkind filters:
'r' - table
'p' - partitioned table
'f' - foreign table
'v' - view
'm' - materialized view
:return: (schema_name, rel_name) tuples
"""
with self.conn.cursor() as cur:
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
def tables(self):
"""Yields (schema_name, table_name) tuples"""
for row in self._relations(kinds=["r", "p", "f"]):
yield row
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
for row in self._relations(kinds=["v", "m"]):
yield row
def _columns(self, kinds=("r", "p", "f", "v", "m")):
"""Get column metadata for tables and views
:param kinds: kinds: list of postgres relkind filters:
'r' - table
'p' - partitioned table
'f' - foreign table
'v' - view
'm' - materialized view
:return: list of (schema_name, relation_name, column_name, column_type) tuples
"""
if self.conn.server_version >= 80400:
columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
att.atttypid::regtype::text type_name,
att.atthasdef AS has_default,
pg_catalog.pg_get_expr(def.adbin, def.adrelid, true) as default
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
LEFT OUTER JOIN pg_attrdef def
ON def.adrelid = att.attrelid
AND def.adnum = att.attnum
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, att.attnum"""
else:
columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
typ.typname type_name,
NULL AS has_default,
NULL AS default
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
INNER JOIN pg_catalog.pg_type typ
ON typ.oid = att.atttypid
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, att.attnum"""
with self.conn.cursor() as cur:
sql = cur.mogrify(columns_query, [kinds])
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
def table_columns(self):
for row in self._columns(kinds=["r", "p", "f"]):
yield row
def view_columns(self):
for row in self._columns(kinds=["v", "m"]):
yield row
def databases(self):
with self.conn.cursor() as cur:
_logger.debug("Databases Query. sql: %r", self.databases_query)
cur.execute(self.databases_query)
return [x[0] for x in cur.fetchall()]
def full_databases(self):
with self.conn.cursor() as cur:
_logger.debug("Databases Query. sql: %r", self.full_databases_query)
cur.execute(self.full_databases_query)
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(
"Socket directory Query. sql: %r", self.socket_directory_query
)
cur.execute(self.socket_directory_query)
result = cur.fetchone()
return result[0] if result else ""
def foreignkeys(self):
"""Yields ForeignKey named tuples"""
if self.conn.server_version < 90000:
return
with self.conn.cursor() as cur:
query = """
SELECT s_p.nspname AS parentschema,
t_p.relname AS parenttable,
unnest((
select
array_agg(attname ORDER BY i)
from
(select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
JOIN pg_catalog.pg_attribute c USING(attnum)
WHERE c.attrelid = fk.confrelid
)) AS parentcolumn,
s_c.nspname AS childschema,
t_c.relname AS childtable,
unnest((
select
array_agg(attname ORDER BY i)
from
(select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
JOIN pg_catalog.pg_attribute c USING(attnum)
WHERE c.attrelid = fk.conrelid
)) AS childcolumn
FROM pg_catalog.pg_constraint fk
JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
WHERE fk.contype = 'f';
"""
_logger.debug("Functions Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield ForeignKey(*row)
def functions(self):
"""Yields FunctionMetadata named tuples"""
if self.conn.server_version >= 110000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
p.proargmodes,
prorettype::regtype::text return_type,
p.prokind = 'a' is_aggregate,
p.prokind = 'w' is_window,
p.proretset is_set_returning,
d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
elif self.conn.server_version > 90000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
p.proargmodes,
prorettype::regtype::text return_type,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning,
d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
elif self.conn.server_version >= 80400:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
p.proargmodes,
prorettype::regtype::text,
p.proisagg is_aggregate,
false is_window,
p.proretset is_set_returning,
d.deptype = 'e' is_extension,
NULL AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
else:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
NULL arg_types,
NULL arg_modes,
'' ret_type,
p.proisagg is_aggregate,
false is_window,
p.proretset is_set_returning,
d.deptype = 'e' is_extension,
NULL AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
with self.conn.cursor() as cur:
_logger.debug("Functions Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield FunctionMetadata(*row)
def datatypes(self):
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
if self.conn.server_version > 90000:
query = """
SELECT n.nspname schema_name,
t.typname type_name
FROM pg_catalog.pg_type t
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = t.typnamespace
WHERE ( t.typrelid = 0 -- non-composite types
OR ( -- composite type, but not a table
SELECT c.relkind = 'c'
FROM pg_catalog.pg_class c
WHERE c.oid = t.typrelid
)
)
AND NOT EXISTS( -- ignore array types
SELECT 1
FROM pg_catalog.pg_type el
WHERE el.oid = t.typelem AND el.typarray = t.oid
)
AND n.nspname <> 'pg_catalog'
AND n.nspname <> 'information_schema'
ORDER BY 1, 2;
"""
else:
query = """
SELECT n.nspname schema_name,
pg_catalog.format_type(t.oid, NULL) type_name
FROM pg_catalog.pg_type t
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
AND t.typname !~ '^_'
AND n.nspname <> 'pg_catalog'
AND n.nspname <> 'information_schema'
AND pg_catalog.pg_type_is_visible(t.oid)
ORDER BY 1, 2;
"""
_logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row
def casing(self):
"""Yields the most common casing for names used in db functions"""
with self.conn.cursor() as cur:
query = r"""
WITH Words AS (
SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1)
FROM pg_catalog.pg_proc P
JOIN pg_catalog.pg_namespace N ON N.oid = P.pronamespace
JOIN pg_catalog.pg_language L ON L.oid = P.prolang
WHERE L.lanname IN ('sql', 'plpgsql')
AND N.nspname NOT IN ('pg_catalog', 'information_schema')
GROUP BY Word
),
OrderWords AS (
SELECT Word,
ROW_NUMBER() OVER(PARTITION BY LOWER(Word) ORDER BY Count DESC)
FROM Words
WHERE Word ~* '.*[a-z].*'
),
Names AS (
--Column names
SELECT attname AS Name
FROM pg_catalog.pg_attribute
UNION -- Table/view names
SELECT relname
FROM pg_catalog.pg_class
UNION -- Function names
SELECT proname
FROM pg_catalog.pg_proc
UNION -- Type names
SELECT typname
FROM pg_catalog.pg_type
UNION -- Schema names
SELECT nspname
FROM pg_catalog.pg_namespace
UNION -- Parameter names
SELECT unnest(proargnames)
FROM pg_proc
)
SELECT Word
FROM OrderWords
WHERE LOWER(Word) IN (SELECT Name FROM Names)
AND Row_Number = 1;
"""
_logger.debug("Casing Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row[0]

116
pgcli/pgstyle.py Normal file
View file

@ -0,0 +1,116 @@
import logging
import pygments.styles
from pygments.token import string_to_tokentype, Token
from pygments.style import Style as PygmentsStyle
from pygments.util import ClassNotFound
from prompt_toolkit.styles.pygments import style_from_pygments_cls
from prompt_toolkit.styles import merge_styles, Style
logger = logging.getLogger(__name__)
# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
TOKEN_TO_PROMPT_STYLE = {
Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
Token.Menu.Completions.Completion: "completion-menu.completion",
Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
Token.Menu.Completions.Meta: "completion-menu.meta.completion",
Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
Token.SelectedText: "selected",
Token.SearchMatch: "search",
Token.SearchMatch.Current: "search.current",
Token.Toolbar: "bottom-toolbar",
Token.Toolbar.Off: "bottom-toolbar.off",
Token.Toolbar.On: "bottom-toolbar.on",
Token.Toolbar.Search: "search-toolbar",
Token.Toolbar.Search.Text: "search-toolbar.text",
Token.Toolbar.System: "system-toolbar",
Token.Toolbar.Arg: "arg-toolbar",
Token.Toolbar.Arg.Text: "arg-toolbar.text",
Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
Token.Output.Header: "output.header",
Token.Output.OddRow: "output.odd-row",
Token.Output.EvenRow: "output.even-row",
Token.Output.Null: "output.null",
Token.Literal.String: "literal.string",
Token.Literal.Number: "literal.number",
Token.Keyword: "keyword",
Token.Prompt: "prompt",
Token.Continuation: "continuation",
}
# reverse dict for cli_helpers, because they still expect Pygments tokens.
PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
def parse_pygments_style(token_name, style_object, style_dict):
"""Parse token type and style string.
:param token_name: str name of Pygments token. Example: "Token.String"
:param style_object: pygments.style.Style instance to use as base
:param style_dict: dict of token names and their styles, customized to this cli
"""
token_type = string_to_tokentype(token_name)
try:
other_token_type = string_to_tokentype(style_dict[token_name])
return token_type, style_object.styles[other_token_type]
except AttributeError:
return token_type, style_dict[token_name]
def style_factory(name, cli_style):
try:
style = pygments.styles.get_style_by_name(name)
except ClassNotFound:
style = pygments.styles.get_style_by_name("native")
prompt_styles = []
# prompt-toolkit used pygments tokens for styling before, switched to style
# names in 2.0. Convert old token types to new style names, for backwards compatibility.
for token in cli_style:
if token.startswith("Token."):
# treat as pygments token (1.0)
token_type, style_value = parse_pygments_style(token, style, cli_style)
if token_type in TOKEN_TO_PROMPT_STYLE:
prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
prompt_styles.append((prompt_style, style_value))
else:
# we don't want to support tokens anymore
logger.error("Unhandled style / class name: %s", token)
else:
# treat as prompt style name (2.0). See default style names here:
# https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
prompt_styles.append((token, cli_style[token]))
override_style = Style([("bottom-toolbar", "noreverse")])
return merge_styles(
[style_from_pygments_cls(style), override_style, Style(prompt_styles)]
)
def style_factory_output(name, cli_style):
try:
style = pygments.styles.get_style_by_name(name).styles
except ClassNotFound:
style = pygments.styles.get_style_by_name("native").styles
for token in cli_style:
if token.startswith("Token."):
token_type, style_value = parse_pygments_style(token, style, cli_style)
style.update({token_type: style_value})
elif token in PROMPT_STYLE_TO_TOKEN:
token_type = PROMPT_STYLE_TO_TOKEN[token]
style.update({token_type: cli_style[token]})
else:
# TODO: cli helpers will have to switch to ptk.Style
logger.error("Unhandled style / class name: %s", token)
class OutputStyle(PygmentsStyle):
default_style = ""
styles = style
return OutputStyle

62
pgcli/pgtoolbar.py Normal file
View file

@ -0,0 +1,62 @@
from prompt_toolkit.key_binding.vi_state import InputMode
from prompt_toolkit.application import get_app
def _get_vi_mode():
return {
InputMode.INSERT: "I",
InputMode.NAVIGATION: "N",
InputMode.REPLACE: "R",
InputMode.REPLACE_SINGLE: "R",
InputMode.INSERT_MULTIPLE: "M",
}[get_app().vi_state.input_mode]
def create_toolbar_tokens_func(pgcli):
"""Return a function that generates the toolbar tokens."""
def get_toolbar_tokens():
result = []
result.append(("class:bottom-toolbar", " "))
if pgcli.completer.smart_completion:
result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON "))
else:
result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF "))
if pgcli.multi_line:
result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
else:
result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
if pgcli.multi_line:
if pgcli.multiline_mode == "safe":
result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) "))
else:
result.append(
("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")
)
if pgcli.vi_mode:
result.append(
("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ")")
)
else:
result.append(("class:bottom-toolbar", "[F4] Emacs-mode"))
if pgcli.pgexecute.failed_transaction():
result.append(
("class:bottom-toolbar.transaction.failed", " Failed transaction")
)
if pgcli.pgexecute.valid_transaction():
result.append(
("class:bottom-toolbar.transaction.valid", " Transaction")
)
if pgcli.completion_refresher.is_refreshing():
result.append(("class:bottom-toolbar", " Refreshing completions..."))
return result
return get_toolbar_tokens

4
post-install Normal file
View file

@ -0,0 +1,4 @@
#!/bin/bash
echo "Setting up symlink to pgcli"
ln -sf /usr/share/pgcli/bin/pgcli /usr/local/bin/pgcli

4
post-remove Normal file
View file

@ -0,0 +1,4 @@
#!/bin/bash
echo "Removing symlink to pgcli"
rm /usr/local/bin/pgcli

2
pylintrc Normal file
View file

@ -0,0 +1,2 @@
[MESSAGES CONTROL]
disable=missing-docstring,invalid-name

22
pyproject.toml Normal file
View file

@ -0,0 +1,22 @@
[tool.black]
line-length = 88
target-version = ['py36']
include = '\.pyi?$'
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| \.cache
| \.pytest_cache
| _build
| buck-out
| build
| dist
| tests/data
)/
'''

135
release.py Normal file
View file

@ -0,0 +1,135 @@
#!/usr/bin/env python
"""A script to publish a release of pgcli to PyPI."""
import io
from optparse import OptionParser
import re
import subprocess
import sys
import click
DEBUG = False
CONFIRM_STEPS = False
DRY_RUN = False
def skip_step():
"""
Asks for user's response whether to run a step. Default is yes.
:return: boolean
"""
global CONFIRM_STEPS
if CONFIRM_STEPS:
return not click.confirm("--- Run this step?", default=True)
return False
def run_step(*args):
"""
Prints out the command and asks if it should be run.
If yes (default), runs it.
:param args: list of strings (command and args)
"""
global DRY_RUN
cmd = args
print(" ".join(cmd))
if skip_step():
print("--- Skipping...")
elif DRY_RUN:
print("--- Pretending to run...")
else:
subprocess.check_output(cmd)
def version(version_file):
_version_re = re.compile(
r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)'
)
with io.open(version_file, encoding="utf-8") as f:
ver = _version_re.search(f.read()).group("version")
return ver
def commit_for_release(version_file, ver):
run_step("git", "reset")
run_step("git", "add", version_file)
run_step("git", "commit", "--message", "Releasing version {}".format(ver))
def create_git_tag(tag_name):
run_step("git", "tag", tag_name)
def create_distribution_files():
run_step("python", "setup.py", "clean", "--all", "sdist", "bdist_wheel")
def upload_distribution_files():
run_step("twine", "upload", "dist/*")
def push_to_github():
run_step("git", "push", "origin", "master")
def push_tags_to_github():
run_step("git", "push", "--tags", "origin")
def checklist(questions):
for question in questions:
if not click.confirm("--- {}".format(question), default=False):
sys.exit(1)
if __name__ == "__main__":
if DEBUG:
subprocess.check_output = lambda x: x
checks = [
"Have you updated the AUTHORS file?",
"Have you updated the `Usage` section of the README?",
]
checklist(checks)
ver = version("pgcli/__init__.py")
print("Releasing Version:", ver)
parser = OptionParser()
parser.add_option(
"-c",
"--confirm-steps",
action="store_true",
dest="confirm_steps",
default=False,
help=(
"Confirm every step. If the step is not " "confirmed, it will be skipped."
),
)
parser.add_option(
"-d",
"--dry-run",
action="store_true",
dest="dry_run",
default=False,
help="Print out, but not actually run any steps.",
)
popts, pargs = parser.parse_args()
CONFIRM_STEPS = popts.confirm_steps
DRY_RUN = popts.dry_run
if not click.confirm("Are you sure?", default=False):
sys.exit(1)
commit_for_release("pgcli/__init__.py", ver)
create_git_tag("v{}".format(ver))
create_distribution_files()
push_to_github()
push_tags_to_github()
upload_distribution_files()

13
release_procedure.txt Normal file
View file

@ -0,0 +1,13 @@
# vi: ft=vimwiki
* Bump the version number in pgcli/__init__.py
* Commit with message: 'Releasing version X.X.X.'
* Create a tag: git tag vX.X.X
* Fix the image url in PyPI to point to github raw content. https://raw.githubusercontent.com/dbcli/pgcli/master/screenshots/image01.png
* Create source dist tar ball: python setup.py sdist
* Test this by installing it in a fresh new virtualenv. Run SanityChecks [./sanity_checks.txt].
* Upload the source dist to PyPI: https://pypi.python.org/pypi/pgcli
* pip install pgcli
* Run SanityChecks.
* Push the version back to github: git push --tags origin master
* Done!

14
requirements-dev.txt Normal file
View file

@ -0,0 +1,14 @@
pytest>=2.7.0
mock>=1.0.1
tox>=1.9.2
behave>=1.2.4
pexpect==3.3
pre-commit>=1.16.0
coverage==5.0.4
codecov>=1.5.1
docutils>=0.13.1
autopep8==1.3.3
click==6.7
twine==1.11.0
wheel==0.33.6
prompt_toolkit==3.0.5

37
sanity_checks.txt Normal file
View file

@ -0,0 +1,37 @@
# vi: ft=vimwiki
* Launch pgcli with different inputs.
* pgcli test_db
* pgcli postgres://localhost/test_db
* pgcli postgres://localhost:5432/test_db
* pgcli postgres://amjith@localhost:5432/test_db
* pgcli postgres://amjith:password@localhost:5432/test_db
* pgcli non-existent-db
* Test special command
* \d
* \d table_name
* \dt
* \l
* \c amjith
* \q
* Simple execution:
1 Execute a simple 'select * from users;' test that will pass.
2 Execute a syntax error: 'insert into users ( ;'
3 Execute a simple test from step 1 again to see if it still passes.
* Change the database and try steps 1 - 3.
* Test smart-completion
* Sele - Must auto-complete to SELECT
* SELECT * FROM - Must list the table names.
* INSERT INTO - Must list table names.
* \d <tab> - Must list table names.
* \c <tab> - Database names.
* SELECT * FROM table_name WHERE <tab> - column names (all of it).
* Test naive-completion - turn off smart completion (using F2 key after launch)
* Sele - autocomplete to select.
* SELECT * FROM - autocomplete list should have everything.

BIN
screenshots/image01.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

BIN
screenshots/image02.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

BIN
screenshots/pgcli.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 233 KiB

64
setup.py Normal file
View file

@ -0,0 +1,64 @@
import platform
from setuptools import setup, find_packages
from pgcli import __version__
description = "CLI for Postgres Database. With auto-completion and syntax highlighting."
install_requirements = [
"pgspecial>=1.11.8",
"click >= 4.1",
"Pygments >= 2.0", # Pygments has to be Capitalcased. WTF?
# We still need to use pt-2 unless pt-3 released on Fedora32
# see: https://github.com/dbcli/pgcli/pull/1197
"prompt_toolkit>=2.0.6,<4.0.0",
"psycopg2 >= 2.8",
"sqlparse >=0.3.0,<0.5",
"configobj >= 5.0.6",
"pendulum>=2.1.0",
"cli_helpers[styles] >= 2.0.0",
]
# setproctitle is used to mask the password when running `ps` in command line.
# But this is not necessary in Windows since the password is never shown in the
# task manager. Also setproctitle is a hard dependency to install in Windows,
# so we'll only install it if we're not in Windows.
if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"):
install_requirements.append("setproctitle >= 1.1.9")
setup(
name="pgcli",
author="Pgcli Core Team",
author_email="pgcli-dev@googlegroups.com",
version=__version__,
license="BSD",
url="http://pgcli.com",
packages=find_packages(),
package_data={"pgcli": ["pgclirc", "packages/pgliterals/pgliterals.json"]},
description=description,
long_description=open("README.rst").read(),
install_requires=install_requirements,
extras_require={"keyring": ["keyring >= 12.2.0"]},
python_requires=">=3.6",
entry_points="""
[console_scripts]
pgcli=pgcli.main:cli
""",
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: SQL",
"Topic :: Database",
"Topic :: Database :: Front-Ends",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)

52
tests/conftest.py Normal file
View file

@ -0,0 +1,52 @@
import os
import pytest
from utils import (
POSTGRES_HOST,
POSTGRES_PORT,
POSTGRES_USER,
POSTGRES_PASSWORD,
create_db,
db_connection,
drop_tables,
)
import pgcli.pgexecute
@pytest.yield_fixture(scope="function")
def connection():
create_db("_test_db")
connection = db_connection("_test_db")
yield connection
drop_tables(connection)
connection.close()
@pytest.fixture
def cursor(connection):
with connection.cursor() as cur:
return cur
@pytest.fixture
def executor(connection):
return pgcli.pgexecute.PGExecute(
database="_test_db",
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
dsn=None,
)
@pytest.fixture
def exception_formatter():
return lambda e: str(e)
@pytest.fixture(scope="session", autouse=True)
def temp_config(tmpdir_factory):
# this function runs on start of test session.
# use temporary directory for config home so user config will not be used
os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data"))

View file

View file

@ -0,0 +1,12 @@
Feature: auto_vertical mode:
on, off
Scenario: auto_vertical on with small query
When we run dbcli with --auto-vertical-output
and we execute a small query
then we see small results in horizontal format
Scenario: auto_vertical on with large query
When we run dbcli with --auto-vertical-output
and we execute a large query
then we see large results in vertical format

View file

@ -0,0 +1,58 @@
Feature: run the cli,
call the help command,
exit the cli
Scenario: run "\?" command
When we send "\?" command
then we see help output
Scenario: run source command
When we send source command
then we see help output
Scenario: run partial select command
When we send partial select command
then we see error message
then we see dbcli prompt
Scenario: check our application_name
When we run query to check application_name
then we see found
Scenario: run the cli and exit
When we send "ctrl + d"
then dbcli exits
Scenario: list databases
When we list databases
then we see list of databases
Scenario: run the cli with --username
When we launch dbcli using --username
and we send "\?" command
then we see help output
Scenario: run the cli with --user
When we launch dbcli using --user
and we send "\?" command
then we see help output
Scenario: run the cli with --port
When we launch dbcli using --port
and we send "\?" command
then we see help output
Scenario: run the cli with --password
When we launch dbcli using --password
then we send password
and we see dbcli prompt
when we send "\?" command
then we see help output
@wip
Scenario: run the cli with dsn and password
When we launch dbcli using dsn_password
then we send password
and we see dbcli prompt
when we send "\?" command
then we see help output

View file

@ -0,0 +1,17 @@
Feature: manipulate databases:
create, drop, connect, disconnect
Scenario: create and drop temporary database
When we create database
then we see database created
when we drop database
then we confirm the destructive warning
then we see database dropped
when we connect to dbserver
then we see database connected
Scenario: connect and disconnect from test database
When we connect to test database
then we see database connected
when we connect to dbserver
then we see database connected

View file

@ -0,0 +1,22 @@
Feature: manipulate tables:
create, insert, update, select, delete from, drop
Scenario: create, insert, select from, update, drop table
When we connect to test database
then we see database connected
when we create table
then we see table created
when we insert into table
then we see record inserted
when we update table
then we see record updated
when we select from table
then we see data selected
when we delete from table
then we confirm the destructive warning
then we see record deleted
when we drop table
then we confirm the destructive warning
then we see table dropped
when we connect to dbserver
then we see database connected

View file

@ -0,0 +1,78 @@
from psycopg2 import connect
from psycopg2.extensions import AsIs
def create_db(
hostname="localhost", username=None, password=None, dbname=None, port=None
):
"""Create test database.
:param hostname: string
:param username: string
:param password: string
:param dbname: string
:param port: int
:return:
"""
cn = create_cn(hostname, password, username, "postgres", port)
# ISOLATION_LEVEL_AUTOCOMMIT = 0
# Needed for DB creation.
cn.set_isolation_level(0)
with cn.cursor() as cr:
cr.execute("drop database if exists %s", (AsIs(dbname),))
cr.execute("create database %s", (AsIs(dbname),))
cn.close()
cn = create_cn(hostname, password, username, dbname, port)
return cn
def create_cn(hostname, password, username, dbname, port):
"""
Open connection to database.
:param hostname:
:param password:
:param username:
:param dbname: string
:return: psycopg2.connection
"""
cn = connect(
host=hostname, user=username, database=dbname, password=password, port=port
)
print("Created connection: {0}.".format(cn.dsn))
return cn
def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None):
"""
Drop database.
:param hostname: string
:param username: string
:param password: string
:param dbname: string
"""
cn = create_cn(hostname, password, username, "postgres", port)
# ISOLATION_LEVEL_AUTOCOMMIT = 0
# Needed for DB drop.
cn.set_isolation_level(0)
with cn.cursor() as cr:
cr.execute("drop database if exists %s", (AsIs(dbname),))
close_cn(cn)
def close_cn(cn=None):
"""
Close connection.
:param connection: psycopg2.connection
"""
if cn:
cn.close()
print("Closed connection: {0}.".format(cn.dsn))

View file

@ -0,0 +1,192 @@
import copy
import os
import sys
import db_utils as dbutils
import fixture_utils as fixutils
import pexpect
import tempfile
import shutil
import signal
from steps import wrappers
def before_all(context):
"""Set env parameters."""
env_old = copy.deepcopy(dict(os.environ))
os.environ["LINES"] = "100"
os.environ["COLUMNS"] = "100"
os.environ["PAGER"] = "cat"
os.environ["EDITOR"] = "ex"
os.environ["VISUAL"] = "ex"
os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
)
fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data")
print("package root:", context.package_root)
print("fixture dir:", fixture_dir)
os.environ["COVERAGE_PROCESS_START"] = os.path.join(
context.package_root, ".coveragerc"
)
context.exit_sent = False
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
db_name_full = "{0}_{1}".format(db_name, vi)
# Store get params from config.
context.conf = {
"host": context.config.userdata.get(
"pg_test_host", os.getenv("PGHOST", "localhost")
),
"user": context.config.userdata.get(
"pg_test_user", os.getenv("PGUSER", "postgres")
),
"pass": context.config.userdata.get(
"pg_test_pass", os.getenv("PGPASSWORD", None)
),
"port": context.config.userdata.get(
"pg_test_port", os.getenv("PGPORT", "5432")
),
"cli_command": (
context.config.userdata.get("pg_cli_command", None)
or '{python} -c "{startup}"'.format(
python=sys.executable,
startup="; ".join(
[
"import coverage",
"coverage.process_startup()",
"import pgcli.main",
"pgcli.main.cli()",
]
),
)
),
"dbname": db_name_full,
"dbname_tmp": db_name_full + "_tmp",
"vi": vi,
"pager_boundary": "---boundary---",
}
os.environ["PAGER"] = "{0} {1} {2}".format(
sys.executable,
os.path.join(context.package_root, "tests/features/wrappager.py"),
context.conf["pager_boundary"],
)
# Store old env vars.
context.pgenv = {
"PGDATABASE": os.environ.get("PGDATABASE", None),
"PGUSER": os.environ.get("PGUSER", None),
"PGHOST": os.environ.get("PGHOST", None),
"PGPASSWORD": os.environ.get("PGPASSWORD", None),
"PGPORT": os.environ.get("PGPORT", None),
"XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None),
"PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None),
}
# Set new env vars.
os.environ["PGDATABASE"] = context.conf["dbname"]
os.environ["PGUSER"] = context.conf["user"]
os.environ["PGHOST"] = context.conf["host"]
os.environ["PGPORT"] = context.conf["port"]
os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf")
if context.conf["pass"]:
os.environ["PGPASSWORD"] = context.conf["pass"]
else:
if "PGPASSWORD" in os.environ:
del os.environ["PGPASSWORD"]
context.cn = dbutils.create_db(
context.conf["host"],
context.conf["user"],
context.conf["pass"],
context.conf["dbname"],
context.conf["port"],
)
context.fixture_data = fixutils.read_fixture_files()
# use temporary directory as config home
context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_")
os.environ["XDG_CONFIG_HOME"] = context.env_config_home
show_env_changes(env_old, dict(os.environ))
def show_env_changes(env_old, env_new):
"""Print out all test-specific env values."""
print("--- os.environ changed values: ---")
all_keys = set(list(env_old.keys()) + list(env_new.keys()))
for k in sorted(all_keys):
old_value = env_old.get(k, "")
new_value = env_new.get(k, "")
if new_value and old_value != new_value:
print('{}="{}"'.format(k, new_value))
print("-" * 20)
def after_all(context):
"""
Unset env parameters.
"""
dbutils.close_cn(context.cn)
dbutils.drop_db(
context.conf["host"],
context.conf["user"],
context.conf["pass"],
context.conf["dbname"],
context.conf["port"],
)
# Remove temp config direcotry
shutil.rmtree(context.env_config_home)
# Restore env vars.
for k, v in context.pgenv.items():
if k in os.environ and v is None:
del os.environ[k]
elif v:
os.environ[k] = v
def before_step(context, _):
context.atprompt = False
def before_scenario(context, scenario):
if scenario.name == "list databases":
# not using the cli for that
return
wrappers.run_cli(context)
wrappers.wait_prompt(context)
def after_scenario(context, scenario):
"""Cleans up after each scenario completes."""
if hasattr(context, "cli") and context.cli and not context.exit_sent:
# Quit nicely.
if not context.atprompt:
dbname = context.currentdb
context.cli.expect_exact("{0}> ".format(dbname), timeout=15)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
try:
context.cli.expect_exact(pexpect.EOF, timeout=15)
except pexpect.TIMEOUT:
print("--- after_scenario {}: kill cli".format(scenario.name))
context.cli.kill(signal.SIGKILL)
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
context.tmpfile_sql_help.close()
context.tmpfile_sql_help = None
# # TODO: uncomment to debug a failure
# def after_step(context, step):
# if step.status == "failed":
# import pdb; pdb.set_trace()

View file

@ -0,0 +1,29 @@
Feature: expanded mode:
on, off, auto
Scenario: expanded on
When we prepare the test data
and we set expanded on
and we select from table
then we see expanded data selected
when we drop table
then we confirm the destructive warning
then we see table dropped
Scenario: expanded off
When we prepare the test data
and we set expanded off
and we select from table
then we see nonexpanded data selected
when we drop table
then we confirm the destructive warning
then we see table dropped
Scenario: expanded auto
When we prepare the test data
and we set expanded auto
and we select from table
then we see auto data selected
when we drop table
then we confirm the destructive warning
then we see table dropped

View file

@ -0,0 +1,25 @@
+--------------------------+------------------------------------------------+
| Command | Description |
|--------------------------+------------------------------------------------|
| \# | Refresh auto-completions. |
| \? | Show Help. |
| \T [format] | Change the table format used to output results |
| \c[onnect] database_name | Change to a new database. |
| \d [pattern] | List or describe tables, views and sequences. |
| \dT[S+] [pattern] | List data types |
| \df[+] [pattern] | List functions. |
| \di[+] [pattern] | List indexes. |
| \dn[+] [pattern] | List schemas. |
| \ds[+] [pattern] | List sequences. |
| \dt[+] [pattern] | List tables. |
| \du[+] [pattern] | List roles. |
| \dv[+] [pattern] | List views. |
| \e [file] | Edit the query with external editor. |
| \l | List databases. |
| \n[+] [name] | List or execute named queries. |
| \nd [name [query]] | Delete a named query. |
| \ns name query | Save a named query. |
| \refresh | Refresh auto-completions. |
| \timing | Toggle timing of commands. |
| \x | Toggle expanded output. |
+--------------------------+------------------------------------------------+

View file

@ -0,0 +1,64 @@
Command
Description
\#
Refresh auto-completions.
\?
Show Commands.
\T [format]
Change the table format used to output results
\c[onnect] database_name
Change to a new database.
\copy [tablename] to/from [filename]
Copy data between a file and a table.
\d[+] [pattern]
List or describe tables, views and sequences.
\dT[S+] [pattern]
List data types
\db[+] [pattern]
List tablespaces.
\df[+] [pattern]
List functions.
\di[+] [pattern]
List indexes.
\dm[+] [pattern]
List materialized views.
\dn[+] [pattern]
List schemas.
\ds[+] [pattern]
List sequences.
\dt[+] [pattern]
List tables.
\du[+] [pattern]
List roles.
\dv[+] [pattern]
List views.
\dx[+] [pattern]
List extensions.
\e [file]
Edit the query with external editor.
\h
Show SQL syntax and help.
\i filename
Execute commands from file.
\l
List databases.
\n[+] [name] [param1 param2 ...]
List or execute named queries.
\nd [name]
Delete a named query.
\ns name query
Save a named query.
\o [filename]
Send all query results to file.
\pager [command]
Set PAGER. Print the query results via PAGER.
\pset [key] [value]
A limited version of traditional \pset
\refresh
Refresh auto-completions.
\sf[+] FUNCNAME
Show a function's definition.
\timing
Toggle timing of commands.
\x
Toggle expanded output.

View file

@ -0,0 +1,4 @@
[mock_postgres]
dbname=postgres
host=localhost
user=postgres

View file

@ -0,0 +1,28 @@
import os
import codecs
def read_fixture_lines(filename):
"""
Read lines of text from file.
:param filename: string name
:return: list of strings
"""
lines = []
for line in codecs.open(filename, "rb", encoding="utf-8"):
lines.append(line.strip())
return lines
def read_fixture_files():
"""Read all files inside fixture_data directory."""
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, "fixture_data/")
print("reading fixture data: {}".format(fixture_dir))
fixture_dict = {}
for filename in os.listdir(fixture_dir):
if filename not in [".", ".."]:
fullname = os.path.join(fixture_dir, filename)
fixture_dict[filename] = read_fixture_lines(fullname)
return fixture_dict

View file

@ -0,0 +1,17 @@
Feature: I/O commands
Scenario: edit sql in file with external editor
When we start external editor providing a file name
and we type sql in the editor
and we exit the editor
then we see dbcli prompt
and we see the sql in prompt
Scenario: tee output from query
When we tee output
and we wait for prompt
and we query "select 123456"
and we wait for prompt
and we stop teeing output
and we wait for prompt
then we see 123456 in tee output

View file

@ -0,0 +1,10 @@
Feature: named queries:
save, use and delete named queries
Scenario: save, use and delete named queries
When we connect to test database
then we see database connected
when we save a named query
then we see the named query saved
when we delete a named query
then we see the named query deleted

View file

@ -0,0 +1,6 @@
Feature: Special commands
Scenario: run refresh command
When we refresh completions
and we wait for prompt
then we see completions refresh started

View file

View file

@ -0,0 +1,99 @@
from textwrap import dedent
from behave import then, when
import wrappers
@when("we run dbcli with {arg}")
def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=arg.split("="))
@when("we execute a small query")
def step_execute_small_query(context):
context.cli.sendline("select 1")
@when("we execute a large query")
def step_execute_large_query(context):
context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
@then("we see small results in horizontal format")
def step_see_small_results(context):
wrappers.expect_pager(
context,
dedent(
"""\
+------------+\r
| ?column? |\r
|------------|\r
| 1 |\r
+------------+\r
SELECT 1\r
"""
),
timeout=5,
)
@then("we see large results in vertical format")
def step_see_large_results(context):
wrappers.expect_pager(
context,
dedent(
"""\
-[ RECORD 1 ]-------------------------\r
?column? | 1\r
?column? | 2\r
?column? | 3\r
?column? | 4\r
?column? | 5\r
?column? | 6\r
?column? | 7\r
?column? | 8\r
?column? | 9\r
?column? | 10\r
?column? | 11\r
?column? | 12\r
?column? | 13\r
?column? | 14\r
?column? | 15\r
?column? | 16\r
?column? | 17\r
?column? | 18\r
?column? | 19\r
?column? | 20\r
?column? | 21\r
?column? | 22\r
?column? | 23\r
?column? | 24\r
?column? | 25\r
?column? | 26\r
?column? | 27\r
?column? | 28\r
?column? | 29\r
?column? | 30\r
?column? | 31\r
?column? | 32\r
?column? | 33\r
?column? | 34\r
?column? | 35\r
?column? | 36\r
?column? | 37\r
?column? | 38\r
?column? | 39\r
?column? | 40\r
?column? | 41\r
?column? | 42\r
?column? | 43\r
?column? | 44\r
?column? | 45\r
?column? | 46\r
?column? | 47\r
?column? | 48\r
?column? | 49\r
SELECT 1\r
"""
),
timeout=5,
)

View file

@ -0,0 +1,147 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
import pexpect
import subprocess
import tempfile
from behave import when, then
from textwrap import dedent
import wrappers
@when("we list databases")
def step_list_databases(context):
cmd = ["pgcli", "--list"]
context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root)
@then("we see list of databases")
def step_see_list_databases(context):
assert b"List of databases" in context.cmd_output
assert b"postgres" in context.cmd_output
context.cmd_output = None
@when("we run dbcli")
def step_run_cli(context):
wrappers.run_cli(context)
@when("we launch dbcli using {arg}")
def step_run_cli_using_arg(context, arg):
prompt_check = False
currentdb = None
if arg == "--username":
arg = "--username={}".format(context.conf["user"])
if arg == "--user":
arg = "--user={}".format(context.conf["user"])
if arg == "--port":
arg = "--port={}".format(context.conf["port"])
if arg == "--password":
arg = "--password"
prompt_check = False
# This uses the mock_pg_service.conf file in fixtures folder.
if arg == "dsn_password":
arg = "service=mock_postgres --password"
prompt_check = False
currentdb = "postgres"
wrappers.run_cli(
context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb
)
@when("we wait for prompt")
def step_wait_prompt(context):
wrappers.wait_prompt(context)
@when('we send "ctrl + d"')
def step_ctrl_d(context):
"""
Send Ctrl + D to hopefully exit.
"""
# turn off pager before exiting
context.cli.sendline("\pset pager off")
wrappers.wait_prompt(context)
context.cli.sendcontrol("d")
context.cli.expect(pexpect.EOF, timeout=15)
context.exit_sent = True
@when('we send "\?" command')
def step_send_help(context):
"""
Send \? to see help.
"""
context.cli.sendline("\?")
@when("we send partial select command")
def step_send_partial_select_command(context):
"""
Send `SELECT a` to see completion.
"""
context.cli.sendline("SELECT a")
@then("we see error message")
def step_see_error_message(context):
wrappers.expect_exact(context, 'column "a" does not exist', timeout=2)
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
context.tmpfile_sql_help.write(b"\?")
context.tmpfile_sql_help.flush()
context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@when("we run query to check application_name")
def step_check_application_name(context):
context.cli.sendline(
"SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;"
)
@then("we see found")
def step_see_found(context):
wrappers.expect_exact(
context,
context.conf["pager_boundary"]
+ "\r"
+ dedent(
"""
+------------+\r
| ?column? |\r
|------------|\r
| found |\r
+------------+\r
SELECT 1\r
"""
)
+ context.conf["pager_boundary"],
timeout=5,
)
@then("we confirm the destructive warning")
def step_confirm_destructive_command(context):
"""Confirm destructive command."""
wrappers.expect_exact(
context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):",
timeout=2,
)
context.cli.sendline("y")
@then("we send password")
def step_send_password(context):
wrappers.expect_exact(context, "Password for", timeout=5)
context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER")

View file

@ -0,0 +1,93 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
import pexpect
from behave import when, then
import wrappers
@when("we create database")
def step_db_create(context):
"""
Send create database.
"""
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
context.response = {"database_name": context.conf["dbname_tmp"]}
@when("we drop database")
def step_db_drop(context):
"""
Send drop database.
"""
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
@when("we connect to test database")
def step_db_connect_test(context):
"""
Send connect to database.
"""
db_name = context.conf["dbname"]
context.cli.sendline("\\connect {0}".format(db_name))
@when("we connect to dbserver")
def step_db_connect_dbserver(context):
"""
Send connect to database.
"""
context.cli.sendline("\\connect postgres")
context.currentdb = "postgres"
@then("dbcli exits")
def step_wait_exit(context):
"""
Make sure the cli exits.
"""
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
@then("we see dbcli prompt")
def step_see_prompt(context):
"""
Wait to see the prompt.
"""
db_name = getattr(context, "currentdb", context.conf["dbname"])
wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5)
context.atprompt = True
@then("we see help output")
def step_see_help(context):
for expected_line in context.fixture_data["help_commands.txt"]:
wrappers.expect_exact(context, expected_line, timeout=2)
@then("we see database created")
def step_see_db_created(context):
"""
Wait to see create database output.
"""
wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5)
@then("we see database dropped")
def step_see_db_dropped(context):
"""
Wait to see drop database output.
"""
wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2)
@then("we see database connected")
def step_see_db_connected(context):
"""
Wait to see drop database output.
"""
wrappers.expect_exact(context, "You are now connected to database", timeout=2)

View file

@ -0,0 +1,118 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
from behave import when, then
from textwrap import dedent
import wrappers
@when("we create table")
def step_create_table(context):
"""
Send create table.
"""
context.cli.sendline("create table a(x text);")
@when("we insert into table")
def step_insert_into_table(context):
"""
Send insert into table.
"""
context.cli.sendline("""insert into a(x) values('xxx');""")
@when("we update table")
def step_update_table(context):
"""
Send insert into table.
"""
context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""")
@when("we select from table")
def step_select_from_table(context):
"""
Send select from table.
"""
context.cli.sendline("select * from a;")
@when("we delete from table")
def step_delete_from_table(context):
"""
Send deete from table.
"""
context.cli.sendline("""delete from a where x = 'yyy';""")
@when("we drop table")
def step_drop_table(context):
"""
Send drop table.
"""
context.cli.sendline("drop table a;")
@then("we see table created")
def step_see_table_created(context):
"""
Wait to see create table output.
"""
wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
@then("we see record inserted")
def step_see_record_inserted(context):
"""
Wait to see insert output.
"""
wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
@then("we see record updated")
def step_see_record_updated(context):
"""
Wait to see update output.
"""
wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2)
@then("we see data selected")
def step_see_data_selected(context):
"""
Wait to see select output.
"""
wrappers.expect_pager(
context,
dedent(
"""\
+-----+\r
| x |\r
|-----|\r
| yyy |\r
+-----+\r
SELECT 1\r
"""
),
timeout=1,
)
@then("we see record deleted")
def step_see_data_deleted(context):
"""
Wait to see delete output.
"""
wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2)
@then("we see table dropped")
def step_see_table_dropped(context):
"""
Wait to see drop output.
"""
wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2)

View file

@ -0,0 +1,70 @@
"""Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it. This string is used
to call the step in "*.feature" file.
"""
from behave import when, then
from textwrap import dedent
import wrappers
@when("we prepare the test data")
def step_prepare_data(context):
"""Create table, insert a record."""
context.cli.sendline("drop table if exists a;")
wrappers.expect_exact(
context,
"You're about to run a destructive command.\r\nDo you want to proceed? (y/n):",
timeout=2,
)
context.cli.sendline("y")
wrappers.wait_prompt(context)
context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));")
wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""")
wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
@when("we set expanded {mode}")
def step_set_expanded(context, mode):
"""Set expanded to mode."""
context.cli.sendline("\\" + "x {}".format(mode))
wrappers.expect_exact(context, "Expanded display is", timeout=2)
wrappers.wait_prompt(context)
@then("we see {which} data selected")
def step_see_data(context, which):
"""Select data from expanded test table."""
if which == "expanded":
wrappers.expect_pager(
context,
dedent(
"""\
-[ RECORD 1 ]-------------------------\r
x | 1\r
y | 1.0\r
z | 1.0000\r
SELECT 1\r
"""
),
timeout=1,
)
else:
wrappers.expect_pager(
context,
dedent(
"""\
+-----+-----+--------+\r
| x | y | z |\r
|-----+-----+--------|\r
| 1 | 1.0 | 1.0000 |\r
+-----+-----+--------+\r
SELECT 1\r
"""
),
timeout=1,
)

View file

@ -0,0 +1,80 @@
import os
import os.path
from behave import when, then
import wrappers
@when("we start external editor providing a file name")
def step_edit_file(context):
"""Edit file with external editor."""
context.editor_file_name = os.path.join(
context.package_root, "test_file_{0}.sql".format(context.conf["vi"])
)
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
)
wrappers.expect_exact(context, ":", timeout=2)
@when("we type sql in the editor")
def step_edit_type_sql(context):
context.cli.sendline("i")
context.cli.sendline("select * from abc")
context.cli.sendline(".")
wrappers.expect_exact(context, ":", timeout=2)
@when("we exit the editor")
def step_edit_quit(context):
context.cli.sendline("x")
wrappers.expect_exact(context, "written", timeout=2)
@then("we see the sql in prompt")
def step_edit_done_sql(context):
for match in "select * from abc".split(" "):
wrappers.expect_exact(context, match, timeout=1)
# Cleanup the command line.
context.cli.sendcontrol("c")
# Cleanup the edited file.
if context.editor_file_name and os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.atprompt = True
@when("we tee output")
def step_tee_ouptut(context):
context.tee_file_name = os.path.join(
context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])
)
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name)))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Writing to file", timeout=5)
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Time", timeout=5)
@when('we query "select 123456"')
def step_query_select_123456(context):
context.cli.sendline("select 123456")
@when("we stop teeing output")
def step_notee_output(context):
context.cli.sendline("\o")
wrappers.expect_exact(context, "Time", timeout=5)
@then("we see 123456 in tee output")
def step_see_123456_in_ouput(context):
with open(context.tee_file_name) as f:
assert "123456" in f.read()
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.atprompt = True

View file

@ -0,0 +1,57 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
from behave import when, then
import wrappers
@when("we save a named query")
def step_save_named_query(context):
"""
Send \ns command
"""
context.cli.sendline("\\ns foo SELECT 12345")
@when("we use a named query")
def step_use_named_query(context):
"""
Send \n command
"""
context.cli.sendline("\\n foo")
@when("we delete a named query")
def step_delete_named_query(context):
"""
Send \nd command
"""
context.cli.sendline("\\nd foo")
@then("we see the named query saved")
def step_see_named_query_saved(context):
"""
Wait to see query saved.
"""
wrappers.expect_exact(context, "Saved.", timeout=2)
@then("we see the named query executed")
def step_see_named_query_executed(context):
"""
Wait to see select output.
"""
wrappers.expect_exact(context, "12345", timeout=1)
wrappers.expect_exact(context, "SELECT 1", timeout=1)
@then("we see the named query deleted")
def step_see_named_query_deleted(context):
"""
Wait to see query deleted.
"""
wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1)

View file

@ -0,0 +1,26 @@
"""
Steps for behavioral style tests are defined in this module.
Each step is defined by the string decorating it.
This string is used to call the step in "*.feature" file.
"""
from behave import when, then
import wrappers
@when("we refresh completions")
def step_refresh_completions(context):
"""
Send refresh command.
"""
context.cli.sendline("\\refresh")
@then("we see completions refresh started")
def step_see_refresh_started(context):
"""
Wait to see refresh output.
"""
wrappers.expect_pager(
context, "Auto-completion refresh started in the background.\r\n", timeout=2
)

View file

@ -0,0 +1,67 @@
import re
import pexpect
from pgcli.main import COLOR_CODE_REGEX
import textwrap
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def expect_exact(context, expected, timeout):
timedout = False
try:
context.cli.expect_exact(expected, timeout=timeout)
except pexpect.TIMEOUT:
timedout = True
if timedout:
# Strip color codes out of the output.
actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
raise Exception(
textwrap.dedent(
"""\
Expected:
---
{0!r}
---
Actual:
---
{1!r}
---
Full log:
---
{2!r}
---
"""
).format(expected, actual, context.logfile.getvalue())
)
def expect_pager(context, expected, timeout):
expect_exact(
context,
"{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
timeout=timeout,
)
def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
"""Run the process using pexpect."""
run_args = run_args or []
cli_cmd = context.conf.get("cli_command")
cmd_parts = [cli_cmd] + run_args
cmd = " ".join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = currentdb or context.conf["dbname"]
context.cli.sendline("\pset pager always")
if prompt_check:
wait_prompt(context)
def wait_prompt(context):
"""Make sure prompt is displayed."""
expect_exact(context, "{0}> ".format(context.conf["dbname"]), timeout=5)

16
tests/features/wrappager.py Executable file
View file

@ -0,0 +1,16 @@
#!/usr/bin/env python
import sys
def wrappager(boundary):
print(boundary)
while 1:
buf = sys.stdin.read(2048)
if not buf:
break
sys.stdout.write(buf)
print(boundary)
if __name__ == "__main__":
wrappager(sys.argv[1])

255
tests/metadata.py Normal file
View file

@ -0,0 +1,255 @@
from functools import partial
from itertools import product
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from mock import Mock
import pytest
parametrize = pytest.mark.parametrize
qual = ["if_more_than_one_table", "always"]
no_qual = ["if_more_than_one_table", "never"]
def escape(name):
if not name.islower() or name in ("select", "localtimestamp"):
return '"' + name + '"'
return name
def completion(display_meta, text, pos=0):
return Completion(text, start_position=pos, display_meta=display_meta)
def function(text, pos=0, display=None):
return Completion(
text, display=display or text, start_position=pos, display_meta="function"
)
def get_result(completer, text, position=None):
position = len(text) if position is None else position
return completer.get_completions(
Document(text=text, cursor_position=position), Mock()
)
def result_set(completer, text, position=None):
return set(get_result(completer, text, position))
# The code below is quivalent to
# def schema(text, pos=0):
# return completion('schema', text, pos)
# and so on
schema = partial(completion, "schema")
table = partial(completion, "table")
view = partial(completion, "view")
column = partial(completion, "column")
keyword = partial(completion, "keyword")
datatype = partial(completion, "datatype")
alias = partial(completion, "table alias")
name_join = partial(completion, "name join")
fk_join = partial(completion, "fk join")
join = partial(completion, "join")
def wildcard_expansion(cols, pos=-1):
return Completion(cols, start_position=pos, display_meta="columns", display="*")
class MetaData(object):
def __init__(self, metadata):
self.metadata = metadata
def builtin_functions(self, pos=0):
return [function(f, pos) for f in self.completer.functions]
def builtin_datatypes(self, pos=0):
return [datatype(dt, pos) for dt in self.completer.datatypes]
def keywords(self, pos=0):
return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()]
def specials(self, pos=0):
return [
Completion(text=k, start_position=pos, display_meta=v.description)
for k, v in self.completer.pgspecial.commands.items()
]
def columns(self, tbl, parent="public", typ="tables", pos=0):
if typ == "functions":
fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0]
cols = fun[1]
else:
cols = self.metadata[typ][parent][tbl]
return [column(escape(col), pos) for col in cols]
def datatypes(self, parent="public", pos=0):
return [
datatype(escape(x), pos)
for x in self.metadata.get("datatypes", {}).get(parent, [])
]
def tables(self, parent="public", pos=0):
return [
table(escape(x), pos)
for x in self.metadata.get("tables", {}).get(parent, [])
]
def views(self, parent="public", pos=0):
return [
view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, [])
]
def functions(self, parent="public", pos=0):
return [
function(
escape(x[0])
+ "("
+ ", ".join(
arg_name + " := "
for (arg_name, arg_mode) in zip(x[1], x[3])
if arg_mode in ("b", "i")
)
+ ")",
pos,
escape(x[0])
+ "("
+ ", ".join(
arg_name
for (arg_name, arg_mode) in zip(x[1], x[3])
if arg_mode in ("b", "i")
)
+ ")",
)
for x in self.metadata.get("functions", {}).get(parent, [])
]
def schemas(self, pos=0):
schemas = set(sch for schs in self.metadata.values() for sch in schs)
return [schema(escape(s), pos=pos) for s in schemas]
def functions_and_keywords(self, parent="public", pos=0):
return (
self.functions(parent, pos)
+ self.builtin_functions(pos)
+ self.keywords(pos)
)
# Note that the filtering parameters here only apply to the columns
def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0):
return self.functions_and_keywords(pos=pos) + self.columns(
tbl, parent, typ, pos
)
def from_clause_items(self, parent="public", pos=0):
return (
self.functions(parent, pos)
+ self.views(parent, pos)
+ self.tables(parent, pos)
)
def schemas_and_from_clause_items(self, parent="public", pos=0):
return self.from_clause_items(parent, pos) + self.schemas(pos)
def types(self, parent="public", pos=0):
return self.datatypes(parent, pos) + self.tables(parent, pos)
@property
def completer(self):
return self.get_completer()
def get_completers(self, casing):
"""
Returns a function taking three bools `casing`, `filtr`, `aliasing` and
the list `qualify`, all defaulting to None.
Returns a list of completers.
These parameters specify the allowed values for the corresponding
completer parameters, `None` meaning any, i.e. (None, None, None, None)
results in all 24 possible completers, whereas e.g.
(True, False, True, ['never']) results in the one completer with
casing, without `search_path` filtering of objects, with table
aliasing, and without column qualification.
"""
def _cfg(_casing, filtr, aliasing, qualify):
cfg = {"settings": {}}
if _casing:
cfg["casing"] = casing
cfg["settings"]["search_path_filter"] = filtr
cfg["settings"]["generate_aliases"] = aliasing
cfg["settings"]["qualify_columns"] = qualify
return cfg
def _cfgs(casing, filtr, aliasing, qualify):
casings = [True, False] if casing is None else [casing]
filtrs = [True, False] if filtr is None else [filtr]
aliases = [True, False] if aliasing is None else [aliasing]
qualifys = qualify or ["always", "if_more_than_one_table", "never"]
return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)]
def completers(casing=None, filtr=None, aliasing=None, qualify=None):
get_comp = self.get_completer
return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)]
return completers
def _make_col(self, sch, tbl, col):
defaults = self.metadata.get("defaults", {}).get(sch, {})
return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col)))
def get_completer(self, settings=None, casing=None):
metadata = self.metadata
from pgcli.pgcompleter import PGCompleter
from pgspecial import PGSpecial
comp = PGCompleter(
smart_completion=True, settings=settings, pgspecial=PGSpecial()
)
schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []
for sch, tbls in metadata["tables"].items():
schemata.append(sch)
for tbl, cols in tbls.items():
tables.append((sch, tbl))
# Let all columns be text columns
tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols])
for sch, tbls in metadata.get("views", {}).items():
for tbl, cols in tbls.items():
views.append((sch, tbl))
# Let all columns be text columns
view_cols.extend([self._make_col(sch, tbl, col) for col in cols])
functions = [
FunctionMetadata(sch, *func_meta, arg_defaults=None)
for sch, funcs in metadata["functions"].items()
for func_meta in funcs
]
datatypes = [
(sch, typ)
for sch, datatypes in metadata["datatypes"].items()
for typ in datatypes
]
foreignkeys = [
ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks
]
comp.extend_schemata(schemata)
comp.extend_relations(tables, kind="tables")
comp.extend_relations(views, kind="views")
comp.extend_columns(tbl_cols, kind="tables")
comp.extend_columns(view_cols, kind="views")
comp.extend_functions(functions)
comp.extend_datatypes(datatypes)
comp.extend_foreignkeys(foreignkeys)
comp.set_search_path(["public"])
comp.extend_casing(casing or [])
return comp

View file

@ -0,0 +1,137 @@
import pytest
from sqlparse import parse
from pgcli.packages.parseutils.ctes import (
token_start_pos,
extract_ctes,
extract_column_names as _extract_column_names,
)
def extract_column_names(sql):
p = parse(sql)[0]
return _extract_column_names(p)
def test_token_str_pos():
sql = "SELECT * FROM xxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")
sql = "SELECT * FROM \nxxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")
def test_single_column_name_extraction():
sql = "SELECT abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_aliased_single_column_name_extraction():
sql = "SELECT abc def FROM xxx"
assert extract_column_names(sql) == ("def",)
def test_aliased_expression_name_extraction():
sql = "SELECT 99 abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_multiple_column_name_extraction():
sql = "SELECT abc, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_missing_column_name_handled_gracefully():
sql = "SELECT abc, 99 FROM xxx"
assert extract_column_names(sql) == ("abc",)
sql = "SELECT abc, 99, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_aliased_multiple_column_name_extraction():
sql = "SELECT abc def, ghi jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
def test_table_qualified_column_name_extraction():
sql = "SELECT abc.def, ghi.jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
@pytest.mark.parametrize(
"sql",
[
"INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
"DELETE FROM foo WHERE x > y RETURNING x, y",
"UPDATE foo SET x = 9 RETURNING x, y",
],
)
def test_extract_column_names_from_returning_clause(sql):
assert extract_column_names(sql) == ("x", "y")
def test_simple_cte_extraction():
sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
start_pos = len("WITH a AS ")
stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_cte_extraction_around_comments():
sql = """--blah blah blah
WITH a AS (SELECT abc def FROM x)
SELECT * FROM a"""
start_pos = len(
"""--blah blah blah
WITH a AS """
)
stop_pos = len(
"""--blah blah blah
WITH a AS (SELECT abc def FROM x)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_multiple_cte_extraction():
sql = """WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)
SELECT * FROM a, b"""
start1 = len(
"""WITH
x AS """
)
stop1 = len(
"""WITH
x AS (SELECT abc, def FROM x)"""
)
start2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS """
)
stop2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (
("x", ("abc", "def"), start1, stop1),
("y", ("ghi", "jkl"), start2, stop2),
)

View file

@ -0,0 +1,19 @@
from pgcli.packages.parseutils.meta import FunctionMetadata
def test_function_metadata_eq():
f1 = FunctionMetadata(
"s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
)
f2 = FunctionMetadata(
"s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
)
f3 = FunctionMetadata(
"s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None
)
assert f1 == f2
assert f1 != f3
assert not (f1 != f2)
assert not (f1 == f3)
assert hash(f1) == hash(f2)
assert hash(f1) != hash(f3)

View file

@ -0,0 +1,269 @@
import pytest
from pgcli.packages.parseutils.tables import extract_tables
from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
def test_empty_string():
tables = extract_tables("")
assert tables == ()
def test_simple_select_single_table():
tables = extract_tables("select * from abc")
assert tables == ((None, "abc", None, False),)
@pytest.mark.parametrize(
"sql", ['select * from "abc"."def"', 'select * from abc."def"']
)
def test_simple_select_single_table_schema_qualified_quoted_table(sql):
tables = extract_tables(sql)
assert tables == (("abc", "def", '"def"', False),)
@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def'])
def test_simple_select_single_table_schema_qualified(sql):
tables = extract_tables(sql)
assert tables == (("abc", "def", None, False),)
def test_simple_select_single_table_double_quoted():
tables = extract_tables('select * from "Abc"')
assert tables == ((None, "Abc", None, False),)
def test_simple_select_multiple_tables():
tables = extract_tables("select * from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_simple_select_multiple_tables_double_quoted():
tables = extract_tables('select * from "Abc", "Def"')
assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)])
def test_simple_select_single_table_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a')
assert tables == ((None, "Abc", "a", False),)
def test_simple_select_multiple_tables_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a, "Def" d')
assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)])
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables("select * from abc.def, ghi.jkl")
assert set(tables) == set(
[("abc", "def", None, False), ("ghi", "jkl", None, False)]
)
def test_simple_select_with_cols_single_table():
tables = extract_tables("select a,b from abc")
assert tables == ((None, "abc", None, False),)
def test_simple_select_with_cols_single_table_schema_qualified():
tables = extract_tables("select a,b from abc.def")
assert tables == (("abc", "def", None, False),)
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables("select a,b from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_simple_select_with_cols_multiple_qualified_tables():
tables = extract_tables("select a,b from abc.def, def.ghi")
assert set(tables) == set(
[("abc", "def", None, False), ("def", "ghi", None, False)]
)
def test_select_with_hanging_comma_single_table():
tables = extract_tables("select a, from abc")
assert tables == ((None, "abc", None, False),)
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables("select a, from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
assert set(tables) == set(
[(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)]
)
def test_simple_insert_single_table():
tables = extract_tables('insert into abc (id, name) values (1, "def")')
# sqlparse mistakenly assigns an alias to the table
# AND mistakenly identifies the field list as
# assert tables == ((None, 'abc', 'abc', False),)
assert tables == ((None, "abc", "abc", False),)
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
assert tables == (("abc", "def", None, False),)
def test_simple_update_table_no_schema():
tables = extract_tables("update abc set id = 1")
assert tables == ((None, "abc", None, False),)
def test_simple_update_table_with_schema():
tables = extract_tables("update abc.def set id = 1")
assert tables == (("abc", "def", None, False),)
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
def test_join_table(join_type):
sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type)
tables = extract_tables(sql)
assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)])
def test_join_table_schema_qualified():
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)])
def test_incomplete_join_clause():
sql = """select a.x, b.y
from abc a join bcd b
on a.id = """
tables = extract_tables(sql)
assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False))
def test_join_as_table():
tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
assert tables == ((None, "my_table", "m", False),)
def test_multiple_joins():
sql = """select * from t1
inner join t2 ON
t1.id = t2.t1_id
inner join t3 ON
t2.id = t3."""
tables = extract_tables(sql)
assert tables == (
(None, "t1", None, False),
(None, "t2", None, False),
(None, "t3", None, False),
)
def test_subselect_tables():
sql = "SELECT * FROM (SELECT FROM abc"
tables = extract_tables(sql)
assert tables == ((None, "abc", None, False),)
@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"])
def test_extract_no_tables(text):
tables = extract_tables(text)
assert tables == tuple()
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0})".format(arg_list))
assert tables == ((None, "foo", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_schema_qualified_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list))
assert tables == (("foo", "bar", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_aliased_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list))
assert tables == ((None, "foo", "bar", True),)
def test_simple_table_and_function():
tables = extract_tables("SELECT * FROM foo JOIN bar()")
assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)])
def test_complex_table_and_function():
tables = extract_tables(
"""SELECT * FROM foo.bar baz
JOIN bar.qux(x, y, z) quux"""
)
assert set(tables) == set(
[("foo", "bar", "baz", False), ("bar", "qux", "quux", True)]
)
def test_find_prev_keyword_using():
q = "select * from tbl1 inner join tbl2 using (col1, "
kw, q2 = find_prev_keyword(q)
assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using ("
@pytest.mark.parametrize(
"sql",
[
"select * from foo where bar",
"select * from foo where bar = 1 and baz or ",
"select * from foo where bar = 1 and baz between qux and ",
],
)
def test_find_prev_keyword_where(sql):
kw, stripped = find_prev_keyword(sql)
assert kw.value == "where" and stripped == "select * from foo where"
@pytest.mark.parametrize(
"sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "]
)
def test_find_prev_keyword_open_parens(sql):
kw, _ = find_prev_keyword(sql)
assert kw.value == "("
@pytest.mark.parametrize(
"sql",
[
"",
"$$ foo $$",
"$$ 'foo' $$",
'$$ "foo" $$',
"$$ $a$ $$",
"$a$ $$ $a$",
"foo bar $$ baz $$",
],
)
def test_is_open_quote__closed(sql):
assert not is_open_quote(sql)
@pytest.mark.parametrize(
"sql",
[
"$$",
";;;$$",
"foo $$ bar $$; foo $$",
"$$ foo $a$",
"foo 'bar baz",
"$a$ foo ",
'$$ "foo" ',
"$$ $a$ ",
"foo bar $$ baz",
],
)
def test_is_open_quote__open(sql):
assert is_open_quote(sql)

2
tests/pytest.ini Normal file
View file

@ -0,0 +1,2 @@
[pytest]
addopts=--capture=sys --showlocals

View file

@ -0,0 +1,97 @@
import time
import pytest
from mock import Mock, patch
@pytest.fixture
def refresher():
from pgcli.completion_refresher import CompletionRefresher
return CompletionRefresher()
def test_ctor(refresher):
"""
Refresher object should contain a few handlers
:param refresher:
:return:
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = [
"schemata",
"tables",
"views",
"types",
"databases",
"casing",
"functions",
]
assert expected_handlers == actual_handlers
def test_refresh_called_once(refresher):
"""
:param refresher:
:return:
"""
callbacks = Mock()
pgexecute = Mock()
special = Mock()
with patch.object(refresher, "_bg_refresh") as bg_refresh:
actual = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual) == 1
assert len(actual[0]) == 4
assert actual[0][3] == "Auto-completion refresh started in the background."
bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None)
def test_refresh_called_twice(refresher):
"""
If refresh is called a second time, it should be restarted
:param refresher:
:return:
"""
callbacks = Mock()
pgexecute = Mock()
special = Mock()
def dummy_bg_refresh(*args):
time.sleep(3) # seconds
refresher._bg_refresh = dummy_bg_refresh
actual1 = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual1) == 1
assert len(actual1[0]) == 4
assert actual1[0][3] == "Auto-completion refresh started in the background."
actual2 = refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual2) == 1
assert len(actual2[0]) == 4
assert actual2[0][3] == "Auto-completion refresh restarted."
def test_refresh_with_callbacks(refresher):
"""
Callbacks must be called
:param refresher:
"""
callbacks = [Mock()]
pgexecute_class = Mock()
pgexecute = Mock()
pgexecute.extra_args = {}
special = Mock()
with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert callbacks[0].call_count == 1

30
tests/test_config.py Normal file
View file

@ -0,0 +1,30 @@
import os
import stat
import pytest
from pgcli.config import ensure_dir_exists
def test_ensure_file_parent(tmpdir):
subdir = tmpdir.join("subdir")
rcfile = subdir.join("rcfile")
ensure_dir_exists(str(rcfile))
def test_ensure_existing_dir(tmpdir):
rcfile = str(tmpdir.mkdir("subdir").join("rcfile"))
# should just not raise
ensure_dir_exists(rcfile)
def test_ensure_other_create_error(tmpdir):
subdir = tmpdir.join("subdir")
rcfile = subdir.join("rcfile")
# trigger an oserror that isn't "directory already exists"
os.chmod(str(tmpdir), stat.S_IREAD)
with pytest.raises(OSError):
ensure_dir_exists(str(rcfile))

View file

View file

@ -0,0 +1,87 @@
import pytest
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
return pgcompleter.PGCompleter()
def test_ranking_ignores_identifier_quotes(completer):
"""When calculating result rank, identifier quotes should be ignored.
The result ranking algorithm ignores identifier quotes. Without this
correction, the match "user", which Postgres requires to be quoted
since it is also a reserved word, would incorrectly fall below the
match user_action because the literal quotation marks in "user"
alter the position of the match.
This test checks that the fuzzy ranking algorithm correctly ignores
quotation marks when computing match ranks.
"""
text = "user"
collection = ["user_action", '"user"']
matches = completer.find_matches(text, collection)
assert len(matches) == 2
def test_ranking_based_on_shortest_match(completer):
"""Fuzzy result rank should be based on shortest match.
Result ranking in fuzzy searching is partially based on the length
of matches: shorter matches are considered more relevant than
longer ones. When searching for the text 'user', the length
component of the match 'user_group' could be either 4 ('user') or
7 ('user_gr').
This test checks that the fuzzy ranking algorithm uses the shorter
match when calculating result rank.
"""
text = "user"
collection = ["api_user", "user_group"]
matches = completer.find_matches(text, collection)
assert matches[1].priority > matches[0].priority
@pytest.mark.parametrize(
"collection",
[["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]],
)
def test_should_break_ties_using_lexical_order(completer, collection):
"""Fuzzy result rank should use lexical order to break ties.
When fuzzy matching, if multiple matches have the same match length and
start position, present them in lexical (rather than arbitrary) order. For
example, if we have tables 'user', 'user_action', and 'user_group', a
search for the text 'user' should present these tables in this order.
The input collections to this test are out of order; each run checks that
the search text 'user' results in the input tables being reordered
lexically.
"""
text = "user"
matches = completer.find_matches(text, collection)
assert matches[1].priority > matches[0].priority
def test_matching_should_be_case_insensitive(completer):
"""Fuzzy matching should keep matches even if letter casing doesn't match.
This test checks that variations of the text which have different casing
are still matched.
"""
text = "foo"
collection = ["Foo", "FOO", "fOO"]
matches = completer.find_matches(text, collection)
assert len(matches) == 3

383
tests/test_main.py Normal file
View file

@ -0,0 +1,383 @@
import os
import platform
import mock
import pytest
try:
import setproctitle
except ImportError:
setproctitle = None
from pgcli.main import (
obfuscate_process_password,
format_output,
PGCli,
OutputSettings,
COLOR_CODE_REGEX,
)
from pgcli.pgexecute import PGExecute
from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS
from utils import dbtest, run
from collections import namedtuple
@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows")
@pytest.mark.skipif(not setproctitle, reason="setproctitle not available")
def test_obfuscate_process_password():
original_title = setproctitle.getproctitle()
setproctitle.setproctitle("pgcli user=root password=secret host=localhost")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli user=root password=xxxx host=localhost"
assert title == expected
setproctitle.setproctitle("pgcli user=root password=top secret host=localhost")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli user=root password=xxxx host=localhost"
assert title == expected
setproctitle.setproctitle("pgcli user=root password=top secret")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli user=root password=xxxx"
assert title == expected
setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db")
obfuscate_process_password()
title = setproctitle.getproctitle()
expected = "pgcli postgres://root:xxxx@localhost/db"
assert title == expected
setproctitle.setproctitle(original_title)
def test_format_output():
settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g")
results = format_output(
"Title", [("abc", "def")], ["head1", "head2"], "test status", settings
)
expected = [
"Title",
"+---------+---------+",
"| head1 | head2 |",
"|---------+---------|",
"| abc | def |",
"+---------+---------+",
"test status",
]
assert list(results) == expected
@dbtest
def test_format_array_output(executor):
statement = """
SELECT
array[1, 2, 3]::bigint[] as bigint_array,
'{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
'{å,魚,текст}'::text[] as 配列
UNION ALL
SELECT '{}', NULL, array[NULL]
"""
results = run(executor, statement)
expected = [
"+----------------+------------------------+--------------+",
"| bigint_array | nested_numeric_array | 配列 |",
"|----------------+------------------------+--------------|",
"| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |",
"| {} | <null> | {<null>} |",
"+----------------+------------------------+--------------+",
"SELECT 2",
]
assert list(results) == expected
@dbtest
def test_format_array_output_expanded(executor):
statement = """
SELECT
array[1, 2, 3]::bigint[] as bigint_array,
'{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
'{å,魚,текст}'::text[] as 配列
UNION ALL
SELECT '{}', NULL, array[NULL]
"""
results = run(executor, statement, expanded=True)
expected = [
"-[ RECORD 1 ]-------------------------",
"bigint_array | {1,2,3}",
"nested_numeric_array | {{1,2},{3,4}}",
"配列 | {å,魚,текст}",
"-[ RECORD 2 ]-------------------------",
"bigint_array | {}",
"nested_numeric_array | <null>",
"配列 | {<null>}",
"SELECT 2",
]
assert "\n".join(results) == "\n".join(expected)
def test_format_output_auto_expand():
settings = OutputSettings(
table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100
)
table_results = format_output(
"Title", [("abc", "def")], ["head1", "head2"], "test status", settings
)
table = [
"Title",
"+---------+---------+",
"| head1 | head2 |",
"|---------+---------|",
"| abc | def |",
"+---------+---------+",
"test status",
]
assert list(table_results) == table
expanded_results = format_output(
"Title",
[("abc", "def")],
["head1", "head2"],
"test status",
settings._replace(max_width=1),
)
expanded = [
"Title",
"-[ RECORD 1 ]-------------------------",
"head1 | abc",
"head2 | def",
"test status",
]
assert "\n".join(expanded_results) == "\n".join(expanded)
termsize = namedtuple("termsize", ["rows", "columns"])
test_line = "-" * 10
test_data = [
(10, 10, "\n".join([test_line] * 7)),
(10, 10, "\n".join([test_line] * 6)),
(10, 10, "\n".join([test_line] * 5)),
(10, 10, "-" * 11),
(10, 10, "-" * 10),
(10, 10, "-" * 9),
]
# 4 lines are reserved at the bottom of the terminal for pgcli's prompt
use_pager_when_on = [True, True, False, True, False, False]
# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL
test_ids = [
"Output longer than terminal height",
"Output equal to terminal height",
"Output shorter than terminal height",
"Output longer than terminal width",
"Output equal to terminal width",
"Output shorter than terminal width",
]
@pytest.fixture
def pset_pager_mocks():
cli = PGCli()
cli.watch_command = None
with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch(
"pgcli.main.click.echo_via_pager"
) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app:
yield cli, mock_echo, mock_echo_via_pager, mock_app
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF):
cli.echo_via_pager(text)
mock_echo.assert_called()
mock_echo_via_pager.assert_not_called()
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS):
cli.echo_via_pager(text)
mock_echo.assert_not_called()
mock_echo_via_pager.assert_called()
pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)]
@pytest.mark.parametrize(
"term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids
)
def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
mock_cli.output.get_size.return_value = termsize(
rows=term_height, columns=term_width
)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT):
cli.echo_via_pager(text)
if use_pager:
mock_echo.assert_not_called()
mock_echo_via_pager.assert_called()
else:
mock_echo_via_pager.assert_not_called()
mock_echo.assert_called()
@pytest.mark.parametrize(
"text,expected_length",
[
(
"22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s",
78,
),
("=\u001b[m=", 2),
("-\u001b]23\u0007-", 2),
],
)
def test_color_pattern(text, expected_length, pset_pager_mocks):
cli = pset_pager_mocks[0]
assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length
@dbtest
def test_i_works(tmpdir, executor):
sqlfile = tmpdir.join("test.sql")
sqlfile.write("SELECT NOW()")
rcfile = str(tmpdir.join("rcfile"))
cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
statement = r"\i {0}".format(sqlfile)
run(executor, statement, pgspecial=cli.pgspecial)
def test_missing_rc_dir(tmpdir):
rcfile = str(tmpdir.join("subdir").join("rcfile"))
PGCli(pgclirc_file=rcfile)
assert os.path.exists(rcfile)
def test_quoted_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B")
mock_connect.assert_called_with(
database="testdb[", host="baz.com", user="bar^", passwd="]foo"
)
def test_pg_service_file(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
service_conf.write(
"""[myservice]
host=a_host
user=a_user
port=5433
password=much_secure
dbname=a_dbname
[my_other_service]
host=b_host
user=b_user
port=5435
dbname=b_dbname
"""
)
os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath
cli.connect_service("myservice", "another_user")
mock_connect.assert_called_with(
database="a_dbname",
host="a_host",
user="another_user",
port="5433",
passwd="much_secure",
)
with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
mock_pgexecute.return_value = None
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
os.environ["PGPASSWORD"] = "very_secure"
cli.connect_service("my_other_service", None)
mock_pgexecute.assert_called_with(
"b_dbname",
"b_user",
"very_secure",
"b_host",
"5435",
"",
application_name="pgcli",
)
del os.environ["PGPASSWORD"]
del os.environ["PGSERVICEFILE"]
def test_ssl_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
"postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?"
"sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem"
)
mock_connect.assert_called_with(
database="testdb[",
host="baz.com",
user="bar^",
passwd="]foo",
sslmode="verify-full",
sslcert="my.pem",
sslkey="my-key.pem",
sslrootcert="ca.pem",
)
def test_port_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb")
mock_connect.assert_called_with(
database="testdb", host="baz.com", user="bar", passwd="foo", port="2543"
)
def test_multihost_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
"postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb"
)
mock_connect.assert_called_with(
database="testdb",
host="baz1.com,baz2.com,baz3.com",
user="bar",
passwd="foo",
port="2543,2543,2543",
)
def test_application_name_db_uri(tmpdir):
with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
mock_pgexecute.return_value = None
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar@baz.com/?application_name=cow")
mock_pgexecute.assert_called_with(
"bar", "bar", "", "baz.com", "", "", application_name="cow"
)

View file

@ -0,0 +1,133 @@
import pytest
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from utils import completions_to_set
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
return pgcompleter.PGCompleter(smart_completion=False)
@pytest.fixture
def complete_event():
from mock import Mock
return Mock()
def test_empty_string_completion(completer, complete_event):
text = ""
position = 0
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(map(Completion, completer.all_completions))
def test_select_keyword_completion(completer, complete_event):
text = "SEL"
position = len("SEL")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set([Completion(text="SELECT", start_position=-3)])
def test_function_name_completion(completer, complete_event):
text = "SELECT MA"
position = len("SELECT MA")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(
[
Completion(text="MATERIALIZED VIEW", start_position=-2),
Completion(text="MAX", start_position=-2),
Completion(text="MAXEXTENTS", start_position=-2),
Completion(text="MAKE_DATE", start_position=-2),
Completion(text="MAKE_TIME", start_position=-2),
Completion(text="MAKE_TIMESTAMPTZ", start_position=-2),
Completion(text="MAKE_INTERVAL", start_position=-2),
Completion(text="MASKLEN", start_position=-2),
Completion(text="MAKE_TIMESTAMP", start_position=-2),
]
)
def test_column_name_completion(completer, complete_event):
text = "SELECT FROM users"
position = len("SELECT ")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
assert result == completions_to_set(map(Completion, completer.all_completions))
def test_alter_well_known_keywords_completion(completer, complete_event):
text = "ALTER "
position = len(text)
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position),
complete_event,
smart_completion=True,
)
)
assert result > completions_to_set(
[
Completion(text="DATABASE", display_meta="keyword"),
Completion(text="TABLE", display_meta="keyword"),
Completion(text="SYSTEM", display_meta="keyword"),
]
)
assert (
completions_to_set([Completion(text="CREATE", display_meta="keyword")])
not in result
)
def test_special_name_completion(completer, complete_event):
text = "\\"
position = len("\\")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position), complete_event
)
)
# Special commands will NOT be suggested during naive completion mode.
assert result == completions_to_set([])
def test_datatype_name_completion(completer, complete_event):
text = "SELECT price::IN"
position = len("SELECT price::IN")
result = completions_to_set(
completer.get_completions(
Document(text=text, cursor_position=position),
complete_event,
smart_completion=True,
)
)
assert result == completions_to_set(
[
Completion(text="INET", display_meta="datatype"),
Completion(text="INT", display_meta="datatype"),
Completion(text="INT2", display_meta="datatype"),
Completion(text="INT4", display_meta="datatype"),
Completion(text="INT8", display_meta="datatype"),
Completion(text="INTEGER", display_meta="datatype"),
Completion(text="INTERNAL", display_meta="datatype"),
Completion(text="INTERVAL", display_meta="datatype"),
]
)

542
tests/test_pgexecute.py Normal file
View file

@ -0,0 +1,542 @@
from textwrap import dedent
import psycopg2
import pytest
from mock import patch, MagicMock
from pgspecial.main import PGSpecial, NO_QUERY
from utils import run, dbtest, requires_json, requires_jsonb
from pgcli.main import PGCli
from pgcli.packages.parseutils.meta import FunctionMetadata
def function_meta_data(
func_name,
schema_name="public",
arg_names=None,
arg_types=None,
arg_modes=None,
return_type=None,
is_aggregate=False,
is_window=False,
is_set_returning=False,
is_extension=False,
arg_defaults=None,
):
return FunctionMetadata(
schema_name,
func_name,
arg_names,
arg_types,
arg_modes,
return_type,
is_aggregate,
is_window,
is_set_returning,
is_extension,
arg_defaults,
)
@dbtest
def test_conn(executor):
run(executor, """create table test(a text)""")
run(executor, """insert into test values('abc')""")
assert run(executor, """select * from test""", join=True) == dedent(
"""\
+-----+
| a |
|-----|
| abc |
+-----+
SELECT 1"""
)
@dbtest
def test_copy(executor):
executor_copy = executor.copy()
run(executor_copy, """create table test(a text)""")
run(executor_copy, """insert into test values('abc')""")
assert run(executor_copy, """select * from test""", join=True) == dedent(
"""\
+-----+
| a |
|-----|
| abc |
+-----+
SELECT 1"""
)
@dbtest
def test_bools_are_treated_as_strings(executor):
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
assert run(executor, """select * from test""", join=True) == dedent(
"""\
+------+
| a |
|------|
| True |
+------+
SELECT 1"""
)
@dbtest
def test_expanded_slash_G(executor, pgspecial):
# Tests whether we reset the expanded output after a \G.
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
results = run(executor, """select * from test \G""", pgspecial=pgspecial)
assert pgspecial.expanded_output == False
@dbtest
def test_schemata_table_views_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
run(executor, "create view d as select 1 as e")
run(executor, "create schema schema1")
run(executor, "create table schema1.c (w text DEFAULT 'meow')")
run(executor, "create schema schema2")
# schemata
# don't enforce all members of the schemas since they may include postgres
# temporary schemas
assert set(executor.schemata()) >= set(
["public", "pg_catalog", "information_schema", "schema1", "schema2"]
)
assert executor.search_path() == ["pg_catalog", "public"]
# tables
assert set(executor.tables()) >= set(
[("public", "a"), ("public", "b"), ("schema1", "c")]
)
assert set(executor.table_columns()) >= set(
[
("public", "a", "x", "text", False, None),
("public", "a", "y", "text", False, None),
("public", "b", "z", "text", False, None),
("schema1", "c", "w", "text", True, "'meow'::text"),
]
)
# views
assert set(executor.views()) >= set([("public", "d")])
assert set(executor.view_columns()) >= set(
[("public", "d", "e", "integer", False, None)]
)
@dbtest
def test_foreign_key_query(executor):
run(executor, "create schema schema1")
run(executor, "create schema schema2")
run(executor, "create table schema1.parent(parentid int PRIMARY KEY)")
run(
executor,
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
)
assert set(executor.foreignkeys()) >= set(
[("schema1", "parent", "parentid", "schema2", "child", "motherid")]
)
@dbtest
def test_functions_query(executor):
run(
executor,
"""create function func1() returns int
language sql as $$select 1$$""",
)
run(executor, "create schema schema1")
run(
executor,
"""create function schema1.func2() returns int
language sql as $$select 2$$""",
)
run(
executor,
"""create function func3()
returns table(x int, y int) language sql
as $$select 1, 2 from generate_series(1,5)$$;""",
)
run(
executor,
"""create function func4(x int) returns setof int language sql
as $$select generate_series(1,5)$$;""",
)
funcs = set(executor.functions())
assert funcs >= set(
[
function_meta_data(func_name="func1", return_type="integer"),
function_meta_data(
func_name="func3",
arg_names=["x", "y"],
arg_types=["integer", "integer"],
arg_modes=["t", "t"],
return_type="record",
is_set_returning=True,
),
function_meta_data(
schema_name="public",
func_name="func4",
arg_names=("x",),
arg_types=("integer",),
return_type="integer",
is_set_returning=True,
),
function_meta_data(
schema_name="schema1", func_name="func2", return_type="integer"
),
]
)
@dbtest
def test_datatypes_query(executor):
run(executor, "create type foo AS (a int, b text)")
types = list(executor.datatypes())
assert types == [("public", "foo")]
@dbtest
def test_database_list(executor):
databases = executor.databases()
assert "_test_db" in databases
@dbtest
def test_invalid_syntax(executor, exception_formatter):
result = run(executor, "invalid syntax!", exception_formatter=exception_formatter)
assert 'syntax error at or near "invalid"' in result[0]
@dbtest
def test_invalid_column_name(executor, exception_formatter):
result = run(
executor, "select invalid command", exception_formatter=exception_formatter
)
assert 'column "invalid" does not exist' in result[0]
@pytest.fixture(params=[True, False])
def expanded(request):
return request.param
@dbtest
def test_unicode_support_in_output(executor, expanded):
run(executor, "create table unicodechars(t text)")
run(executor, "insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
assert "é" in run(
executor, "select * from unicodechars", join=True, expanded=expanded
)
@dbtest
def test_not_is_special(executor, pgspecial):
"""is_special is set to false for database queries."""
query = "select 1"
result = list(executor.run(query, pgspecial=pgspecial))
success, is_special = result[0][5:]
assert success == True
assert is_special == False
@dbtest
def test_execute_from_file_no_arg(executor, pgspecial):
"""\i without a filename returns an error."""
result = list(executor.run("\i", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert "missing required argument" in status
assert success == False
assert is_special == True
@dbtest
@patch("pgcli.main.os")
def test_execute_from_file_io_error(os, executor, pgspecial):
"""\i with an io_error returns an error."""
# Inject an IOError.
os.path.expanduser.side_effect = IOError("test")
# Check the result.
result = list(executor.run("\i test", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert status == "test"
assert success == False
assert is_special == True
@dbtest
def test_multiple_queries_same_line(executor):
result = run(executor, "select 'foo'; select 'bar'")
assert len(result) == 12 # 2 * (output+status) * 3 lines
assert "foo" in result[3]
assert "bar" in result[9]
@dbtest
def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
result = run(executor, "select 'foo'; \d", pgspecial=pgspecial)
assert len(result) == 11 # 2 * (output+status) * 3 lines
assert "foo" in result[3]
# This is a lame check. :(
assert "Schema" in result[7]
@dbtest
def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter):
result = run(
executor,
"select 'fooé'; invalid syntax é",
exception_formatter=exception_formatter,
)
assert "fooé" in result[3]
assert 'syntax error at or near "invalid"' in result[-1]
@pytest.fixture
def pgspecial():
return PGCli().pgspecial
@dbtest
def test_special_command_help(executor, pgspecial):
result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|")
assert "Command" in result[1]
assert "Description" in result[2]
@dbtest
def test_bytea_field_support_in_output(executor):
run(executor, "create table binarydata(c bytea)")
run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))")
assert "\\xdeadbeef" in run(executor, "select * from binarydata", join=True)
@dbtest
def test_unicode_support_in_unknown_type(executor):
assert "日本語" in run(executor, "SELECT '日本語' AS japanese;", join=True)
@dbtest
def test_unicode_support_in_enum_type(executor):
run(executor, "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')")
run(executor, "CREATE TABLE person (name TEXT, current_mood mood)")
run(executor, "INSERT INTO person VALUES ('Moe', '日本語')")
assert "日本語" in run(executor, "SELECT * FROM person", join=True)
@requires_json
def test_json_renders_without_u_prefix(executor, expanded):
run(executor, "create table jsontest(d json)")
run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""")
result = run(
executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded
)
assert '{"name": "Éowyn"}' in result
@requires_jsonb
def test_jsonb_renders_without_u_prefix(executor, expanded):
run(executor, "create table jsonbtest(d jsonb)")
run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""")
result = run(
executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded
)
assert '{"name": "Éowyn"}' in result
@dbtest
def test_date_time_types(executor):
run(executor, "SET TIME ZONE UTC")
assert (
run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3]
== "| 00:00:00 |"
)
assert (
run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split(
"\n"
)[3]
== "| 00:00:00+14:59 |"
)
assert (
run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[
3
]
== "| 4713-01-01 BC |"
)
assert (
run(
executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True
).split("\n")[3]
== "| 4713-01-01 00:00:00 BC |"
)
assert (
run(
executor,
"SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))",
join=True,
).split("\n")[3]
== "| 4713-01-01 00:00:00+00 BC |"
)
assert (
run(
executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True
).split("\n")[3]
== "| -123456789 days, 12:23:56 |"
)
@dbtest
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
def test_large_numbers_render_directly(executor, value):
run(executor, "create table numbertest(a numeric)")
run(executor, "insert into numbertest (a) values ({0})".format(value))
assert value in run(executor, "select * from numbertest", join=True)
@dbtest
@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"])
@pytest.mark.parametrize("verbose", ["", "+"])
@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"])
def test_describe_special(executor, command, verbose, pattern, pgspecial):
# We don't have any tests for the output of any of the special commands,
# but we can at least make sure they run without error
sql = r"\{command}{verbose} {pattern}".format(**locals())
list(executor.run(sql, pgspecial=pgspecial))
@dbtest
@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"])
def test_raises_with_no_formatter(executor, sql):
with pytest.raises(psycopg2.ProgrammingError):
list(executor.run(sql))
@dbtest
def test_on_error_resume(executor, exception_formatter):
sql = "select 1; error; select 1;"
result = list(
executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter)
)
assert len(result) == 3
@dbtest
def test_on_error_stop(executor, exception_formatter):
sql = "select 1; error; select 1;"
result = list(
executor.run(
sql, on_error_resume=False, exception_formatter=exception_formatter
)
)
assert len(result) == 2
# @dbtest
# def test_unicode_notices(executor):
# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;"
# result = list(executor.run(sql))
# assert result[0][0] == u'NOTICE: 有人更改\n'
@dbtest
def test_nonexistent_function_definition(executor):
with pytest.raises(RuntimeError):
result = executor.view_definition("there_is_no_such_function")
@dbtest
def test_function_definition(executor):
run(
executor,
"""
CREATE OR REPLACE FUNCTION public.the_number_three()
RETURNS int
LANGUAGE sql
AS $function$
select 3;
$function$
""",
)
result = executor.function_definition("the_number_three")
@dbtest
def test_view_definition(executor):
run(executor, "create table tbl1 (a text, b numeric)")
run(executor, "create view vw1 AS SELECT * FROM tbl1")
run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1")
result = executor.view_definition("vw1")
assert "FROM tbl1" in result
# import pytest; pytest.set_trace()
result = executor.view_definition("mvw1")
assert "MATERIALIZED VIEW" in result
@dbtest
def test_nonexistent_view_definition(executor):
with pytest.raises(RuntimeError):
result = executor.view_definition("there_is_no_such_view")
with pytest.raises(RuntimeError):
result = executor.view_definition("mvw1")
@dbtest
def test_short_host(executor):
with patch.object(executor, "host", "localhost"):
assert executor.short_host == "localhost"
with patch.object(executor, "host", "localhost.example.org"):
assert executor.short_host == "localhost"
with patch.object(
executor, "host", "localhost1.example.org,localhost2.example.org"
):
assert executor.short_host == "localhost1"
class BrokenConnection(object):
"""Mock a connection that failed."""
def cursor(self):
raise psycopg2.InterfaceError("I'm broken!")
@dbtest
def test_exit_without_active_connection(executor):
quit_handler = MagicMock()
pgspecial = PGSpecial()
pgspecial.register(
quit_handler,
"\\q",
"\\q",
"Quit pgcli.",
arg_type=NO_QUERY,
case_sensitive=True,
aliases=(":q",),
)
with patch.object(executor, "conn", BrokenConnection()):
# we should be able to quit the app, even without active connection
run(executor, "\\q", pgspecial=pgspecial)
quit_handler.assert_called_once()
# an exception should be raised when running a query without active connection
with pytest.raises(psycopg2.InterfaceError):
run(executor, "select 1", pgspecial=pgspecial)

78
tests/test_pgspecial.py Normal file
View file

@ -0,0 +1,78 @@
import pytest
from pgcli.packages.sqlcompletion import (
suggest_type,
Special,
Database,
Schema,
Table,
View,
Function,
Datatype,
)
def test_slash_suggests_special():
suggestions = suggest_type("\\", "\\")
assert set(suggestions) == set([Special()])
def test_slash_d_suggests_special():
suggestions = suggest_type("\\d", "\\d")
assert set(suggestions) == set([Special()])
def test_dn_suggests_schemata():
suggestions = suggest_type("\\dn ", "\\dn ")
assert suggestions == (Schema(),)
suggestions = suggest_type("\\dn xxx", "\\dn xxx")
assert suggestions == (Schema(),)
def test_d_suggests_tables_views_and_schemas():
suggestions = suggest_type("\d ", "\d ")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
suggestions = suggest_type("\d xxx", "\d xxx")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
def test_d_dot_suggests_schema_qualified_tables_or_views():
suggestions = suggest_type("\d myschema.", "\d myschema.")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
def test_df_suggests_schema_or_function():
suggestions = suggest_type("\\df xxx", "\\df xxx")
assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()])
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
assert suggestions == (Function(schema="myschema", usage="special"),)
def test_leading_whitespace_ok():
cmd = "\\dn "
whitespace = " "
suggestions = suggest_type(whitespace + cmd, whitespace + cmd)
assert suggestions == suggest_type(cmd, cmd)
def test_dT_suggests_schema_or_datatypes():
text = "\\dT "
suggestions = suggest_type(text, text)
assert set(suggestions) == set([Schema(), Datatype(schema=None)])
def test_schema_qualified_dT_suggests_datatypes():
text = "\\dT foo."
suggestions = suggest_type(text, text)
assert suggestions == (Datatype(schema="foo"),)
@pytest.mark.parametrize("command", ["\\c ", "\\connect "])
def test_c_suggests_databases(command):
suggestions = suggest_type(command, command)
assert suggestions == (Database(),)

38
tests/test_plan.wiki Normal file
View file

@ -0,0 +1,38 @@
= Gross Checks =
* [ ] Check connecting to a local database.
* [ ] Check connecting to a remote database.
* [ ] Check connecting to a database with a user/password.
* [ ] Check connecting to a non-existent database.
* [ ] Test changing the database.
== PGExecute ==
* [ ] Test successful execution given a cursor.
* [ ] Test unsuccessful execution with a syntax error.
* [ ] Test a series of executions with the same cursor without failure.
* [ ] Test a series of executions with the same cursor with failure.
* [ ] Test passing in a special command.
== Naive Autocompletion ==
* [ ] Input empty string, ask for completions - Everything.
* [ ] Input partial prefix, ask for completions - Stars with prefix.
* [ ] Input fully autocompleted string, ask for completions - Only full match
* [ ] Input non-existent prefix, ask for completions - nothing
* [ ] Input lowercase prefix - case insensitive completions
== Smart Autocompletion ==
* [ ] Input empty string and check if only keywords are returned.
* [ ] Input SELECT prefix and check if only columns are returned.
* [ ] Input SELECT blah - only keywords are returned.
* [ ] Input SELECT * FROM - Table names only
== PGSpecial ==
* [ ] Test \d
* [ ] Test \d tablename
* [ ] Test \d tablena*
* [ ] Test \d non-existent-tablename
* [ ] Test \d index
* [ ] Test \d sequence
* [ ] Test \d view
== Exceptionals ==
* [ ] Test the 'use' command to change db.

View file

@ -0,0 +1,20 @@
from pgcli.packages.prioritization import PrevalenceCounter
def test_prevalence_counter():
counter = PrevalenceCounter()
sql = """SELECT * FROM foo WHERE bar GROUP BY baz;
select * from foo;
SELECT * FROM foo WHERE bar GROUP
BY baz"""
counter.update(sql)
keywords = ["SELECT", "FROM", "GROUP BY"]
expected = [3, 3, 2]
kw_counts = [counter.keyword_count(x) for x in keywords]
assert kw_counts == expected
assert counter.keyword_count("NOSUCHKEYWORD") == 0
names = ["foo", "bar", "baz"]
name_counts = [counter.name_count(x) for x in names]
assert name_counts == [3, 2, 2]

View file

@ -0,0 +1,10 @@
import click
from pgcli.packages.prompt_utils import confirm_destructive_query
def test_confirm_destructive_query_notty():
stdin = click.get_text_stream("stdin")
if not stdin.isatty():
sql = "drop database foo;"
assert confirm_destructive_query(sql) is None

79
tests/test_rowlimit.py Normal file
View file

@ -0,0 +1,79 @@
import pytest
from mock import Mock
from pgcli.main import PGCli
# We need this fixtures beacause we need PGCli object to be created
# after test collection so it has config loaded from temp directory
@pytest.fixture(scope="module")
def default_pgcli_obj():
return PGCli()
@pytest.fixture(scope="module")
def DEFAULT(default_pgcli_obj):
return default_pgcli_obj.row_limit
@pytest.fixture(scope="module")
def LIMIT(DEFAULT):
return DEFAULT + 1000
@pytest.fixture(scope="module")
def over_default(DEFAULT):
over_default_cursor = Mock()
over_default_cursor.configure_mock(rowcount=DEFAULT + 10)
return over_default_cursor
@pytest.fixture(scope="module")
def over_limit(LIMIT):
over_limit_cursor = Mock()
over_limit_cursor.configure_mock(rowcount=LIMIT + 10)
return over_limit_cursor
@pytest.fixture(scope="module")
def low_count():
low_count_cursor = Mock()
low_count_cursor.configure_mock(rowcount=1)
return low_count_cursor
def test_row_limit_with_LIMIT_clause(LIMIT, over_limit):
cli = PGCli(row_limit=LIMIT)
stmt = "SELECT * FROM students LIMIT 1000"
result = cli._should_limit_output(stmt, over_limit)
assert result is False
cli = PGCli(row_limit=0)
result = cli._should_limit_output(stmt, over_limit)
assert result is False
def test_row_limit_without_LIMIT_clause(LIMIT, over_limit):
cli = PGCli(row_limit=LIMIT)
stmt = "SELECT * FROM students"
result = cli._should_limit_output(stmt, over_limit)
assert result is True
cli = PGCli(row_limit=0)
result = cli._should_limit_output(stmt, over_limit)
assert result is False
def test_row_limit_on_non_select(over_limit):
cli = PGCli()
stmt = "UPDATE students SET name='Boby'"
result = cli._should_limit_output(stmt, over_limit)
assert result is False
cli = PGCli(row_limit=0)
result = cli._should_limit_output(stmt, over_limit)
assert result is False

View file

@ -0,0 +1,727 @@
import itertools
from metadata import (
MetaData,
alias,
name_join,
fk_join,
join,
schema,
table,
function,
wildcard_expansion,
column,
get_result,
result_set,
qual,
no_qual,
parametrize,
)
from utils import completions_to_set
metadata = {
"tables": {
"public": {
"users": ["id", "email", "first_name", "last_name"],
"orders": ["id", "ordered_date", "status", "datestamp"],
"select": ["id", "localtime", "ABC"],
},
"custom": {
"users": ["id", "phone_number"],
"Users": ["userid", "username"],
"products": ["id", "product_name", "price"],
"shipments": ["id", "address", "user_id"],
},
"Custom": {"projects": ["projectid", "name"]},
"blog": {
"entries": ["entryid", "entrytitle", "entrytext"],
"tags": ["tagid", "name"],
"entrytags": ["entryid", "tagid"],
"entacclog": ["entryid", "username", "datestamp"],
},
},
"functions": {
"public": [
["func1", [], [], [], "", False, False, False, False],
["func2", [], [], [], "", False, False, False, False],
],
"custom": [
["func3", [], [], [], "", False, False, False, False],
[
"set_returning_func",
["x"],
["integer"],
["o"],
"integer",
False,
False,
True,
False,
],
],
"Custom": [["func4", [], [], [], "", False, False, False, False]],
"blog": [
[
"extract_entry_symbols",
["_entryid", "symbol"],
["integer", "text"],
["i", "o"],
"",
False,
False,
True,
False,
],
[
"enter_entry",
["_title", "_text", "entryid"],
["text", "text", "integer"],
["i", "i", "o"],
"",
False,
False,
False,
False,
],
],
},
"datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]},
"foreignkeys": {
"custom": [("public", "users", "id", "custom", "shipments", "user_id")],
"blog": [
("blog", "entries", "entryid", "blog", "entacclog", "entryid"),
("blog", "entries", "entryid", "blog", "entrytags", "entryid"),
("blog", "tags", "tagid", "blog", "entrytags", "tagid"),
],
},
"defaults": {
"public": {
("orders", "id"): "nextval('orders_id_seq'::regclass)",
("orders", "datestamp"): "now()",
("orders", "status"): "'PENDING'::text",
}
},
}
testdata = MetaData(metadata)
cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')]
casing = (
"SELECT",
"Orders",
"User_Emails",
"CUSTOM",
"Func1",
"Entries",
"Tags",
"EntryTags",
"EntAccLog",
"EntryID",
"EntryTitle",
"EntryText",
)
completers = testdata.get_completers(casing)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize("table", ["users", '"users"'])
def test_suggested_column_names_from_shadowed_visible_table(completer, table):
result = get_result(completer, "SELECT FROM " + table, len("SELECT "))
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("users")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize(
"text",
[
"SELECT from custom.users",
"WITH users as (SELECT 1 AS foo) SELECT from custom.users",
],
)
def test_suggested_column_names_from_qualified_shadowed_table(completer, text):
result = get_result(completer, text, position=text.find(" ") + 1)
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("users", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"])
def test_suggested_column_names_from_cte(completer, text):
result = completions_to_set(get_result(completer, text, text.find(" ") + 1))
assert result == completions_to_set(
[column("foo")] + testdata.functions_and_keywords()
)
@parametrize("completer", completers(casing=False))
@parametrize(
"text",
[
"SELECT * FROM users JOIN custom.shipments ON ",
"""SELECT *
FROM public.users
JOIN custom.shipments ON """,
],
)
def test_suggested_join_conditions(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[
alias("users"),
alias("shipments"),
name_join("shipments.id = users.id"),
fk_join("shipments.user_id = users.id"),
]
)
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
@parametrize(
("query", "tbl"),
itertools.product(
(
"SELECT * FROM public.{0} RIGHT OUTER JOIN ",
"""SELECT *
FROM {0}
JOIN """,
),
("users", '"users"', "Users"),
),
)
def test_suggested_joins(completer, query, tbl):
result = get_result(completer, query.format(tbl))
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
+ [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))]
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_column_names_from_schema_qualifed_table(completer):
result = get_result(completer, "SELECT from custom.products", len("SELECT "))
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("products", "custom")
)
@parametrize(
"text",
[
"INSERT INTO orders(",
"INSERT INTO orders (",
"INSERT INTO public.orders(",
"INSERT INTO public.orders (",
],
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_columns_with_insert(completer, text):
assert completions_to_set(get_result(completer, text)) == completions_to_set(
testdata.columns("orders")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_column_names_in_function(completer):
result = get_result(
completer, "SELECT MAX( from custom.products", len("SELECT MAX(")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("products", "custom")
)
@parametrize("completer", completers(casing=False, aliasing=False))
@parametrize(
"text",
["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'],
)
@parametrize("use_leading_double_quote", [False, True])
def test_suggested_table_names_with_schema_dot(
completer, text, use_leading_double_quote
):
if use_leading_double_quote:
text += '"'
start_position = -1
else:
start_position = 0
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.from_clause_items("custom", start_position)
)
@parametrize("completer", completers(casing=False, aliasing=False))
@parametrize("text", ['SELECT * FROM "Custom".'])
@parametrize("use_leading_double_quote", [False, True])
def test_suggested_table_names_with_schema_dot2(
completer, text, use_leading_double_quote
):
if use_leading_double_quote:
text += '"'
start_position = -1
else:
start_position = 0
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.from_clause_items("Custom", start_position)
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_column_names_with_qualified_alias(completer):
result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p."))
assert completions_to_set(result) == completions_to_set(
testdata.columns("products", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_multiple_column_names(completer):
result = get_result(
completer, "SELECT id, from custom.products", len("SELECT id, ")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns_functions_and_keywords("products", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_multiple_column_names_with_alias(completer):
result = get_result(
completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns("products", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ",
"SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id",
],
)
def test_suggestions_after_on(completer, text):
position = len(
"SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON "
)
result = get_result(completer, text, position)
assert completions_to_set(result) == completions_to_set(
[
alias("x"),
alias("y"),
name_join("y.price = x.price"),
name_join("y.product_name = x.product_name"),
name_join("y.id = x.id"),
]
)
@parametrize("completer", completers())
def test_suggested_aliases_after_on_right_side(completer):
text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = "
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")])
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
def test_table_names_after_from(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_schema_qualified_function_name(completer):
text = "SELECT custom.func"
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[
function("func3()", -len("func")),
function("set_returning_func()", -len("func")),
]
)
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"SELECT 1::custom.",
"CREATE TABLE foo (bar custom.",
"CREATE FUNCTION foo (bar INT, baz custom.",
"ALTER TABLE foo ALTER COLUMN bar TYPE custom.",
],
)
def test_schema_qualified_type_name(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(testdata.types("custom"))
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggest_columns_from_aliased_set_returning_function(completer):
result = get_result(
completer, "select f. from custom.set_returning_func() f", len("select f.")
)
assert completions_to_set(result) == completions_to_set(
testdata.columns("set_returning_func", "custom", "functions")
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize(
"text",
[
"SELECT * FROM custom.set_returning_func()",
"SELECT * FROM Custom.set_returning_func()",
"SELECT * FROM Custom.Set_Returning_Func()",
],
)
def test_wildcard_column_expansion_with_function(completer, text):
position = len("SELECT *")
completions = get_result(completer, text, position)
col_list = "x"
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
def test_wildcard_column_expansion_with_alias_qualifier(completer):
text = "SELECT p.* FROM custom.products p"
position = len("SELECT p.*")
completions = get_result(completer, text, position)
col_list = "id, p.product_name, p.price"
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"""
SELECT count(1) FROM users;
CREATE FUNCTION foo(custom.products _products) returns custom.shipments
LANGUAGE SQL
AS $foo$
SELECT 1 FROM custom.shipments;
INSERT INTO public.orders(*) values(-1, now(), 'preliminary');
SELECT 2 FROM custom.users;
$foo$;
SELECT count(1) FROM custom.shipments;
""",
"INSERT INTO public.orders(*",
"INSERT INTO public.Orders(*",
"INSERT INTO public.orders (*",
"INSERT INTO public.Orders (*",
"INSERT INTO orders(*",
"INSERT INTO Orders(*",
"INSERT INTO orders (*",
"INSERT INTO Orders (*",
"INSERT INTO public.orders(*)",
"INSERT INTO public.Orders(*)",
"INSERT INTO public.orders (*)",
"INSERT INTO public.Orders (*)",
"INSERT INTO orders(*)",
"INSERT INTO Orders(*)",
"INSERT INTO orders (*)",
"INSERT INTO Orders (*)",
],
)
def test_wildcard_column_expansion_with_insert(completer, text):
position = text.index("*") + 1
completions = get_result(completer, text, position)
expected = [wildcard_expansion("ordered_date, status")]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
def test_wildcard_column_expansion_with_table_qualifier(completer):
text = 'SELECT "select".* FROM public."select"'
position = len('SELECT "select".*')
completions = get_result(completer, text, position)
col_list = 'id, "select"."localtime", "select"."ABC"'
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False, qualify=qual))
def test_wildcard_column_expansion_with_two_tables(completer):
text = 'SELECT * FROM public."select" JOIN custom.users ON true'
position = len("SELECT *")
completions = get_result(completer, text, position)
cols = (
'"select".id, "select"."localtime", "select"."ABC", '
"users.id, users.phone_number"
)
expected = [wildcard_expansion(cols)]
assert completions == expected
@parametrize("completer", completers(filtr=True, casing=False))
def test_wildcard_column_expansion_with_two_tables_and_parent(completer):
text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true'
position = len('SELECT "select".*')
completions = get_result(completer, text, position)
col_list = 'id, "select"."localtime", "select"."ABC"'
expected = [wildcard_expansion(col_list)]
assert expected == completions
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
[
"SELECT U. FROM custom.Users U",
"SELECT U. FROM custom.USERS U",
"SELECT U. FROM custom.users U",
'SELECT U. FROM "custom".Users U',
'SELECT U. FROM "custom".USERS U',
'SELECT U. FROM "custom".users U',
],
)
def test_suggest_columns_from_unquoted_table(completer, text):
position = len("SELECT U.")
result = get_result(completer, text, position)
assert completions_to_set(result) == completions_to_set(
testdata.columns("users", "custom")
)
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U']
)
def test_suggest_columns_from_quoted_table(completer, text):
position = len("SELECT U.")
result = get_result(completer, text, position)
assert completions_to_set(result) == completions_to_set(
testdata.columns("Users", "custom")
)
texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "]
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
@parametrize("text", texts)
def test_schema_or_visible_table_completion(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
)
@parametrize("completer", completers(aliasing=True, casing=False, filtr=True))
@parametrize("text", texts)
def test_table_aliases(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
testdata.schemas()
+ [
table("users u"),
table("orders o" if text == "SELECT * FROM " else "orders o2"),
table('"select" s'),
function("func1() f"),
function("func2() f"),
]
)
@parametrize("completer", completers(aliasing=True, casing=True, filtr=True))
@parametrize("text", texts)
def test_aliases_with_casing(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
cased_schemas
+ [
table("users u"),
table("Orders O" if text == "SELECT * FROM " else "Orders O2"),
table('"select" s'),
function("Func1() F"),
function("func2() f"),
]
)
@parametrize("completer", completers(aliasing=False, casing=True, filtr=True))
@parametrize("text", texts)
def test_table_casing(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
cased_schemas
+ [
table("users"),
table("Orders"),
table('"select"'),
function("Func1()"),
function("func2()"),
]
)
@parametrize("completer", completers(aliasing=False, casing=True))
def test_alias_search_without_aliases2(completer):
text = "SELECT * FROM blog.et"
result = get_result(completer, text)
assert result[0] == table("EntryTags", -2)
@parametrize("completer", completers(aliasing=False, casing=True))
def test_alias_search_without_aliases1(completer):
text = "SELECT * FROM blog.e"
result = get_result(completer, text)
assert result[0] == table("Entries", -1)
@parametrize("completer", completers(aliasing=True, casing=True))
def test_alias_search_with_aliases2(completer):
text = "SELECT * FROM blog.et"
result = get_result(completer, text)
assert result[0] == table("EntryTags ET", -2)
@parametrize("completer", completers(aliasing=True, casing=True))
def test_alias_search_with_aliases1(completer):
text = "SELECT * FROM blog.e"
result = get_result(completer, text)
assert result[0] == table("Entries E", -1)
@parametrize("completer", completers(aliasing=True, casing=True))
def test_join_alias_search_with_aliases1(completer):
text = "SELECT * FROM blog.Entries E JOIN blog.e"
result = get_result(completer, text)
assert result[:2] == [
table("Entries E2", -1),
join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1),
]
@parametrize("completer", completers(aliasing=False, casing=True))
def test_join_alias_search_without_aliases1(completer):
text = "SELECT * FROM blog.Entries JOIN blog.e"
result = get_result(completer, text)
assert result[:2] == [
table("Entries", -1),
join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1),
]
@parametrize("completer", completers(aliasing=True, casing=True))
def test_join_alias_search_with_aliases2(completer):
text = "SELECT * FROM blog.Entries E JOIN blog.et"
result = get_result(completer, text)
assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2)
@parametrize("completer", completers(aliasing=False, casing=True))
def test_join_alias_search_without_aliases2(completer):
text = "SELECT * FROM blog.Entries JOIN blog.et"
result = get_result(completer, text)
assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2)
@parametrize("completer", completers())
def test_function_alias_search_without_aliases(completer):
text = "SELECT blog.ees"
result = get_result(completer, text)
first = result[0]
assert first.start_position == -3
assert first.text == "extract_entry_symbols()"
assert first.display_text == "extract_entry_symbols(_entryid)"
@parametrize("completer", completers())
def test_function_alias_search_with_aliases(completer):
text = "SELECT blog.ee"
result = get_result(completer, text)
first = result[0]
assert first.start_position == -2
assert first.text == "enter_entry(_title := , _text := )"
assert first.display_text == "enter_entry(_title, _text)"
@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual))
def test_column_alias_search(completer):
result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et"))
cols = ("EntryText", "EntryTitle", "EntryID")
assert result[:3] == [column(c, -2) for c in cols]
@parametrize("completer", completers(casing=True))
def test_column_alias_search_qualified(completer):
result = get_result(
completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei")
)
cols = ("EntryID", "EntryTitle")
assert result[:3] == [column(c, -2) for c in cols]
@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
def test_schema_object_order(completer):
result = get_result(completer, "SELECT * FROM u")
assert result[:3] == [
table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users")
]
@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
def test_all_schema_objects(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
[table(x) for x in ("orders", '"select"', "custom.shipments")]
+ [function(x + "()") for x in ("func2",)]
)
@parametrize("completer", completers(filtr=False, aliasing=False, casing=True))
def test_all_schema_objects_with_casing(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
[table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")]
+ [function(x + "()") for x in ("func2",)]
)
@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
def test_all_schema_objects_with_aliases(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
[table(x) for x in ("orders o", '"select" s', "custom.shipments s")]
+ [function(x) for x in ("func2() f",)]
)
@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
def test_set_schema(completer):
text = "SET SCHEMA "
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
[schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")]
)

Some files were not shown because too many files have changed in this diff Show more