Merging upstream version 3.4.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
89e6c74b99
commit
3826c78c85
12 changed files with 328 additions and 40 deletions
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
|
@ -38,7 +38,7 @@ jobs:
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
pip install -U pip setuptools
|
pip install -U pip setuptools
|
||||||
pip install --no-cache-dir .
|
pip install --no-cache-dir ".[sshtunnel]"
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
pip install keyrings.alt>=3.1
|
pip install keyrings.alt>=3.1
|
||||||
|
|
||||||
|
|
3
AUTHORS
3
AUTHORS
|
@ -116,8 +116,9 @@ Contributors:
|
||||||
* Kevin Marsh (kevinmarsh)
|
* Kevin Marsh (kevinmarsh)
|
||||||
* Eero Ruohola (ruohola)
|
* Eero Ruohola (ruohola)
|
||||||
* Miroslav Šedivý (eumiro)
|
* Miroslav Šedivý (eumiro)
|
||||||
* Eric R Young (ERYoung11)
|
* Eric R Young (ERYoung11)
|
||||||
* Paweł Sacawa (psacawa)
|
* Paweł Sacawa (psacawa)
|
||||||
|
* Bruno Inec (sweenu)
|
||||||
|
|
||||||
Creator:
|
Creator:
|
||||||
--------
|
--------
|
||||||
|
|
17
RELEASES.md
Normal file
17
RELEASES.md
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
Releasing pgcli
|
||||||
|
---------------
|
||||||
|
|
||||||
|
We have a script called `release.py` to automate the process.
|
||||||
|
|
||||||
|
The script can be run with `-c` to confirm or skip steps. There's also a `--dry-run` option that only prints out the steps.
|
||||||
|
|
||||||
|
```
|
||||||
|
> python release.py --help
|
||||||
|
Usage: release.py [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
-c, --confirm-steps Confirm every step. If the step is not confirmed, it
|
||||||
|
will be skipped.
|
||||||
|
-d, --dry-run Print out, but not actually run any steps.
|
||||||
|
```
|
|
@ -1,3 +1,16 @@
|
||||||
|
TBD
|
||||||
|
===
|
||||||
|
|
||||||
|
* [List new changes here].
|
||||||
|
|
||||||
|
3.4.0 (2022/02/21)
|
||||||
|
==================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
---------
|
||||||
|
|
||||||
|
* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)).
|
||||||
|
|
||||||
3.3.1 (2022/01/18)
|
3.3.1 (2022/01/18)
|
||||||
==================
|
==================
|
||||||
|
|
||||||
|
@ -7,7 +20,6 @@ Bug fixes:
|
||||||
* Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307.
|
* Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307.
|
||||||
* Upgrade cli_helpers to 2.2.1
|
* Upgrade cli_helpers to 2.2.1
|
||||||
|
|
||||||
|
|
||||||
3.3.0 (2022/01/11)
|
3.3.0 (2022/01/11)
|
||||||
==================
|
==================
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = "3.3.1"
|
__version__ = "3.4.0"
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import platform
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
from os.path import expanduser
|
|
||||||
|
|
||||||
from configobj import ConfigObj, ParseError
|
from configobj import ConfigObj, ParseError
|
||||||
from pgspecial.namedqueries import NamedQueries
|
from pgspecial.namedqueries import NamedQueries
|
||||||
|
@ -8,6 +7,7 @@ from .config import skip_initial_comment
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
|
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
|
||||||
|
|
||||||
|
import atexit
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
@ -21,6 +21,8 @@ import datetime as dt
|
||||||
import itertools
|
import itertools
|
||||||
import platform
|
import platform
|
||||||
from time import time, sleep
|
from time import time, sleep
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
keyring = None # keyring will be loaded later
|
keyring = None # keyring will be loaded later
|
||||||
|
|
||||||
|
@ -78,12 +80,21 @@ except ImportError:
|
||||||
|
|
||||||
from getpass import getuser
|
from getpass import getuser
|
||||||
from psycopg2 import OperationalError, InterfaceError
|
from psycopg2 import OperationalError, InterfaceError
|
||||||
|
from psycopg2.extensions import make_dsn, parse_dsn
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sshtunnel
|
||||||
|
|
||||||
|
SSH_TUNNEL_SUPPORT = True
|
||||||
|
except ImportError:
|
||||||
|
SSH_TUNNEL_SUPPORT = False
|
||||||
|
|
||||||
|
|
||||||
# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
|
# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
|
||||||
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
|
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
|
||||||
DEFAULT_MAX_FIELD_WIDTH = 500
|
DEFAULT_MAX_FIELD_WIDTH = 500
|
||||||
|
@ -168,8 +179,8 @@ class PGCli:
|
||||||
prompt_dsn=None,
|
prompt_dsn=None,
|
||||||
auto_vertical_output=False,
|
auto_vertical_output=False,
|
||||||
warn=None,
|
warn=None,
|
||||||
|
ssh_tunnel_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.force_passwd_prompt = force_passwd_prompt
|
self.force_passwd_prompt = force_passwd_prompt
|
||||||
self.never_passwd_prompt = never_passwd_prompt
|
self.never_passwd_prompt = never_passwd_prompt
|
||||||
self.pgexecute = pgexecute
|
self.pgexecute = pgexecute
|
||||||
|
@ -275,6 +286,10 @@ class PGCli:
|
||||||
|
|
||||||
self.prompt_app = None
|
self.prompt_app = None
|
||||||
|
|
||||||
|
self.ssh_tunnel_config = c.get("ssh tunnels")
|
||||||
|
self.ssh_tunnel_url = ssh_tunnel_url
|
||||||
|
self.ssh_tunnel = None
|
||||||
|
|
||||||
def quit(self):
|
def quit(self):
|
||||||
raise PgCliQuitError
|
raise PgCliQuitError
|
||||||
|
|
||||||
|
@ -585,6 +600,56 @@ class PGCli:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if dsn:
|
||||||
|
parsed_dsn = parse_dsn(dsn)
|
||||||
|
if "host" in parsed_dsn:
|
||||||
|
host = parsed_dsn["host"]
|
||||||
|
if "port" in parsed_dsn:
|
||||||
|
port = parsed_dsn["port"]
|
||||||
|
|
||||||
|
if self.ssh_tunnel_config and not self.ssh_tunnel_url:
|
||||||
|
for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
|
||||||
|
if re.search(db_host_regex, host):
|
||||||
|
self.ssh_tunnel_url = tunnel_url
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.ssh_tunnel_url:
|
||||||
|
# We add the protocol as urlparse doesn't find it by itself
|
||||||
|
if "://" not in self.ssh_tunnel_url:
|
||||||
|
self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
|
||||||
|
|
||||||
|
tunnel_info = urlparse(self.ssh_tunnel_url)
|
||||||
|
params = {
|
||||||
|
"local_bind_address": ("127.0.0.1",),
|
||||||
|
"remote_bind_address": (host, int(port or 5432)),
|
||||||
|
"ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
|
||||||
|
"logger": self.logger,
|
||||||
|
}
|
||||||
|
if tunnel_info.username:
|
||||||
|
params["ssh_username"] = tunnel_info.username
|
||||||
|
if tunnel_info.password:
|
||||||
|
params["ssh_password"] = tunnel_info.password
|
||||||
|
|
||||||
|
# Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
|
||||||
|
# We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
|
||||||
|
logger_handlers = self.logger.handlers.copy()
|
||||||
|
try:
|
||||||
|
self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
|
||||||
|
self.ssh_tunnel.start()
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.handlers = logger_handlers
|
||||||
|
self.logger.error("traceback: %r", traceback.format_exc())
|
||||||
|
click.secho(str(e), err=True, fg="red")
|
||||||
|
exit(1)
|
||||||
|
self.logger.handlers = logger_handlers
|
||||||
|
|
||||||
|
atexit.register(self.ssh_tunnel.stop)
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port = self.ssh_tunnel.local_bind_ports[0]
|
||||||
|
|
||||||
|
if dsn:
|
||||||
|
dsn = make_dsn(dsn, host=host, port=port)
|
||||||
|
|
||||||
# Attempt to connect to the database.
|
# Attempt to connect to the database.
|
||||||
# Note that passwd may be empty on the first attempt. If connection
|
# Note that passwd may be empty on the first attempt. If connection
|
||||||
# fails because of a missing or incorrect password, but we're allowed to
|
# fails because of a missing or incorrect password, but we're allowed to
|
||||||
|
@ -1222,7 +1287,7 @@ class PGCli:
|
||||||
"--list",
|
"--list",
|
||||||
"list_databases",
|
"list_databases",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="list " "available databases, then exit.",
|
help="list available databases, then exit.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--auto-vertical-output",
|
"--auto-vertical-output",
|
||||||
|
@ -1235,6 +1300,11 @@ class PGCli:
|
||||||
type=click.Choice(["all", "moderate", "off"]),
|
type=click.Choice(["all", "moderate", "off"]),
|
||||||
help="Warn before running a destructive query.",
|
help="Warn before running a destructive query.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ssh-tunnel",
|
||||||
|
default=None,
|
||||||
|
help="Open an SSH tunnel to the given address and connect to the database from it.",
|
||||||
|
)
|
||||||
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
|
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
|
||||||
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
|
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
|
||||||
def cli(
|
def cli(
|
||||||
|
@ -1258,6 +1328,7 @@ def cli(
|
||||||
auto_vertical_output,
|
auto_vertical_output,
|
||||||
list_dsn,
|
list_dsn,
|
||||||
warn,
|
warn,
|
||||||
|
ssh_tunnel: str,
|
||||||
):
|
):
|
||||||
if version:
|
if version:
|
||||||
print("Version:", __version__)
|
print("Version:", __version__)
|
||||||
|
@ -1294,6 +1365,15 @@ def cli(
|
||||||
)
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
if ssh_tunnel and not SSH_TUNNEL_SUPPORT:
|
||||||
|
click.secho(
|
||||||
|
'Cannot open SSH tunnel, "sshtunnel" package was not found. '
|
||||||
|
"Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.",
|
||||||
|
err=True,
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
pgcli = PGCli(
|
pgcli = PGCli(
|
||||||
prompt_passwd,
|
prompt_passwd,
|
||||||
never_prompt,
|
never_prompt,
|
||||||
|
@ -1305,6 +1385,7 @@ def cli(
|
||||||
prompt_dsn=prompt_dsn,
|
prompt_dsn=prompt_dsn,
|
||||||
auto_vertical_output=auto_vertical_output,
|
auto_vertical_output=auto_vertical_output,
|
||||||
warn=warn,
|
warn=warn,
|
||||||
|
ssh_tunnel_url=ssh_tunnel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Choose which ever one has a valid value.
|
# Choose which ever one has a valid value.
|
||||||
|
@ -1548,7 +1629,7 @@ def parse_service_info(service):
|
||||||
elif os.getenv("PGSYSCONFDIR"):
|
elif os.getenv("PGSYSCONFDIR"):
|
||||||
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
|
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
|
||||||
else:
|
else:
|
||||||
service_file = expanduser("~/.pg_service.conf")
|
service_file = os.path.expanduser("~/.pg_service.conf")
|
||||||
if not service or not os.path.exists(service_file):
|
if not service or not os.path.exists(service_file):
|
||||||
# nothing to do
|
# nothing to do
|
||||||
return None, service_file
|
return None, service_file
|
||||||
|
|
|
@ -63,17 +63,13 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
||||||
yield item
|
yield item
|
||||||
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
||||||
item_val = item.value.upper()
|
item_val = item.value.upper()
|
||||||
if (
|
if item_val in (
|
||||||
item_val
|
"COPY",
|
||||||
in (
|
"FROM",
|
||||||
"COPY",
|
"INTO",
|
||||||
"FROM",
|
"UPDATE",
|
||||||
"INTO",
|
"TABLE",
|
||||||
"UPDATE",
|
) or item_val.endswith("JOIN"):
|
||||||
"TABLE",
|
|
||||||
)
|
|
||||||
or item_val.endswith("JOIN")
|
|
||||||
):
|
|
||||||
tbl_prefix_seen = True
|
tbl_prefix_seen = True
|
||||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||||
# So this check here is necessary.
|
# So this check here is necessary.
|
||||||
|
|
|
@ -491,11 +491,14 @@ class PGCompleter(Completer):
|
||||||
|
|
||||||
def get_column_matches(self, suggestion, word_before_cursor):
|
def get_column_matches(self, suggestion, word_before_cursor):
|
||||||
tables = suggestion.table_refs
|
tables = suggestion.table_refs
|
||||||
do_qualify = suggestion.qualifiable and {
|
do_qualify = (
|
||||||
"always": True,
|
suggestion.qualifiable
|
||||||
"never": False,
|
and {
|
||||||
"if_more_than_one_table": len(tables) > 1,
|
"always": True,
|
||||||
}[self.qualify_columns]
|
"never": False,
|
||||||
|
"if_more_than_one_table": len(tables) > 1,
|
||||||
|
}[self.qualify_columns]
|
||||||
|
)
|
||||||
qualify = lambda col, tbl: (
|
qualify = lambda col, tbl: (
|
||||||
(tbl + "." + self.case(col)) if do_qualify else self.case(col)
|
(tbl + "." + self.case(col)) if do_qualify else self.case(col)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
# vi: ft=vimwiki
|
|
||||||
|
|
||||||
* Bump the version number in pgcli/__init__.py
|
|
||||||
* Commit with message: 'Releasing version X.X.X.'
|
|
||||||
* Create a tag: git tag vX.X.X
|
|
||||||
* Fix the image url in PyPI to point to github raw content. https://raw.githubusercontent.com/dbcli/pgcli/master/screenshots/image01.png
|
|
||||||
* Create source dist tar ball: python setup.py sdist
|
|
||||||
* Test this by installing it in a fresh new virtualenv. Run SanityChecks [./sanity_checks.txt].
|
|
||||||
* Upload the source dist to PyPI: https://pypi.python.org/pypi/pgcli
|
|
||||||
* pip install pgcli
|
|
||||||
* Run SanityChecks.
|
|
||||||
* Push the version back to github: git push --tags origin master
|
|
||||||
* Done!
|
|
5
setup.py
5
setup.py
|
@ -39,7 +39,10 @@ setup(
|
||||||
description=description,
|
description=description,
|
||||||
long_description=open("README.rst").read(),
|
long_description=open("README.rst").read(),
|
||||||
install_requires=install_requirements,
|
install_requires=install_requirements,
|
||||||
extras_require={"keyring": ["keyring >= 12.2.0"]},
|
extras_require={
|
||||||
|
"keyring": ["keyring >= 12.2.0"],
|
||||||
|
"sshtunnel": ["sshtunnel >= 0.4.0"],
|
||||||
|
},
|
||||||
python_requires=">=3.6",
|
python_requires=">=3.6",
|
||||||
entry_points="""
|
entry_points="""
|
||||||
[console_scripts]
|
[console_scripts]
|
||||||
|
|
|
@ -97,9 +97,9 @@ def step_see_error_message(context):
|
||||||
@when("we send source command")
|
@when("we send source command")
|
||||||
def step_send_source_command(context):
|
def step_send_source_command(context):
|
||||||
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
|
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
|
||||||
context.tmpfile_sql_help.write(br"\?")
|
context.tmpfile_sql_help.write(rb"\?")
|
||||||
context.tmpfile_sql_help.flush()
|
context.tmpfile_sql_help.flush()
|
||||||
context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
|
context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}")
|
||||||
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
|
188
tests/test_ssh_tunnel.py
Normal file
188
tests/test_ssh_tunnel.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch, MagicMock, ANY
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from configobj import ConfigObj
|
||||||
|
from click.testing import CliRunner
|
||||||
|
from sshtunnel import SSHTunnelForwarder
|
||||||
|
|
||||||
|
from pgcli.main import cli, PGCli
|
||||||
|
from pgcli.pgexecute import PGExecute
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ssh_tunnel_forwarder() -> MagicMock:
|
||||||
|
mock_ssh_tunnel_forwarder = MagicMock(
|
||||||
|
SSHTunnelForwarder, local_bind_ports=[1111], autospec=True
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"pgcli.main.sshtunnel.SSHTunnelForwarder",
|
||||||
|
return_value=mock_ssh_tunnel_forwarder,
|
||||||
|
) as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pgexecute() -> MagicMock:
|
||||||
|
with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute:
|
||||||
|
yield mock_pgexecute
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssh_tunnel(
|
||||||
|
mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
|
||||||
|
) -> None:
|
||||||
|
# Test with just a host
|
||||||
|
tunnel_url = "some.host"
|
||||||
|
db_params = {
|
||||||
|
"database": "dbname",
|
||||||
|
"host": "db.host",
|
||||||
|
"user": "db_user",
|
||||||
|
"passwd": "db_passwd",
|
||||||
|
}
|
||||||
|
expected_tunnel_params = {
|
||||||
|
"local_bind_address": ("127.0.0.1",),
|
||||||
|
"remote_bind_address": (db_params["host"], 5432),
|
||||||
|
"ssh_address_or_host": (tunnel_url, 22),
|
||||||
|
"logger": ANY,
|
||||||
|
}
|
||||||
|
|
||||||
|
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
|
||||||
|
pgcli.connect(**db_params)
|
||||||
|
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
|
||||||
|
mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
|
||||||
|
mock_pgexecute.assert_called_once()
|
||||||
|
|
||||||
|
call_args, call_kwargs = mock_pgexecute.call_args
|
||||||
|
assert call_args == (
|
||||||
|
db_params["database"],
|
||||||
|
db_params["user"],
|
||||||
|
db_params["passwd"],
|
||||||
|
"127.0.0.1",
|
||||||
|
pgcli.ssh_tunnel.local_bind_ports[0],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
mock_ssh_tunnel_forwarder.reset_mock()
|
||||||
|
mock_pgexecute.reset_mock()
|
||||||
|
|
||||||
|
# Test with a full url and with a specific db port
|
||||||
|
tunnel_user = "tunnel_user"
|
||||||
|
tunnel_passwd = "tunnel_pass"
|
||||||
|
tunnel_host = "some.other.host"
|
||||||
|
tunnel_port = 1022
|
||||||
|
tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
|
||||||
|
db_params["port"] = 1234
|
||||||
|
|
||||||
|
expected_tunnel_params["remote_bind_address"] = (
|
||||||
|
db_params["host"],
|
||||||
|
db_params["port"],
|
||||||
|
)
|
||||||
|
expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port)
|
||||||
|
expected_tunnel_params["ssh_username"] = tunnel_user
|
||||||
|
expected_tunnel_params["ssh_password"] = tunnel_passwd
|
||||||
|
|
||||||
|
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
|
||||||
|
pgcli.connect(**db_params)
|
||||||
|
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
|
||||||
|
mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
|
||||||
|
mock_pgexecute.assert_called_once()
|
||||||
|
|
||||||
|
call_args, call_kwargs = mock_pgexecute.call_args
|
||||||
|
assert call_args == (
|
||||||
|
db_params["database"],
|
||||||
|
db_params["user"],
|
||||||
|
db_params["passwd"],
|
||||||
|
"127.0.0.1",
|
||||||
|
pgcli.ssh_tunnel.local_bind_ports[0],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
mock_ssh_tunnel_forwarder.reset_mock()
|
||||||
|
mock_pgexecute.reset_mock()
|
||||||
|
|
||||||
|
# Test with DSN
|
||||||
|
dsn = (
|
||||||
|
f"user={db_params['user']} password={db_params['passwd']} "
|
||||||
|
f"host={db_params['host']} port={db_params['port']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
|
||||||
|
pgcli.connect(dsn=dsn)
|
||||||
|
|
||||||
|
expected_dsn = (
|
||||||
|
f"user={db_params['user']} password={db_params['passwd']} "
|
||||||
|
f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
|
||||||
|
mock_pgexecute.assert_called_once()
|
||||||
|
|
||||||
|
call_args, call_kwargs = mock_pgexecute.call_args
|
||||||
|
assert expected_dsn in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_with_tunnel() -> None:
|
||||||
|
runner = CliRunner()
|
||||||
|
tunnel_url = "mytunnel"
|
||||||
|
with patch.object(
|
||||||
|
PGCli, "__init__", autospec=True, return_value=None
|
||||||
|
) as mock_pgcli:
|
||||||
|
runner.invoke(cli, ["--ssh-tunnel", tunnel_url])
|
||||||
|
mock_pgcli.assert_called_once()
|
||||||
|
call_args, call_kwargs = mock_pgcli.call_args
|
||||||
|
assert call_kwargs["ssh_tunnel_url"] == tunnel_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_config(
|
||||||
|
tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
|
||||||
|
) -> None:
|
||||||
|
pgclirc = str(tmpdir.join("rcfile"))
|
||||||
|
|
||||||
|
tunnel_user = "tunnel_user"
|
||||||
|
tunnel_passwd = "tunnel_pass"
|
||||||
|
tunnel_host = "tunnel.host"
|
||||||
|
tunnel_port = 1022
|
||||||
|
tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
|
||||||
|
|
||||||
|
tunnel2_url = "tunnel2.host"
|
||||||
|
|
||||||
|
config = ConfigObj()
|
||||||
|
config.filename = pgclirc
|
||||||
|
config["ssh tunnels"] = {}
|
||||||
|
config["ssh tunnels"][r"\.com$"] = tunnel_url
|
||||||
|
config["ssh tunnels"][r"^hello-"] = tunnel2_url
|
||||||
|
config.write()
|
||||||
|
|
||||||
|
# Unmatched host
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="unmatched.host")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_not_called()
|
||||||
|
|
||||||
|
# Host matching first tunnel
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="matched.host.com")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once()
|
||||||
|
call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
|
||||||
|
assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
|
||||||
|
assert call_kwargs["ssh_username"] == tunnel_user
|
||||||
|
assert call_kwargs["ssh_password"] == tunnel_passwd
|
||||||
|
mock_ssh_tunnel_forwarder.reset_mock()
|
||||||
|
|
||||||
|
# Host matching second tunnel
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="hello-i-am-matched")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once()
|
||||||
|
|
||||||
|
call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
|
||||||
|
assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22)
|
||||||
|
mock_ssh_tunnel_forwarder.reset_mock()
|
||||||
|
|
||||||
|
# Host matching both tunnels (will use the first one matched)
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="hello-i-am-matched.com")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once()
|
||||||
|
|
||||||
|
call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
|
||||||
|
assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
|
||||||
|
assert call_kwargs["ssh_username"] == tunnel_user
|
||||||
|
assert call_kwargs["ssh_password"] == tunnel_passwd
|
Loading…
Add table
Reference in a new issue