1663 lines
58 KiB
Python
1663 lines
58 KiB
Python
import platform
|
|
import warnings
|
|
|
|
from configobj import ConfigObj, ParseError
|
|
from pgspecial.namedqueries import NamedQueries
|
|
from .config import skip_initial_comment
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
|
|
|
|
import atexit
|
|
import os
|
|
import re
|
|
import sys
|
|
import traceback
|
|
import logging
|
|
import threading
|
|
import shutil
|
|
import functools
|
|
import pendulum
|
|
import datetime as dt
|
|
import itertools
|
|
import platform
|
|
from time import time, sleep
|
|
from typing import Optional
|
|
from urllib.parse import urlparse
|
|
|
|
keyring = None # keyring will be loaded later
|
|
|
|
from cli_helpers.tabular_output import TabularOutputFormatter
|
|
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
|
|
from cli_helpers.utils import strip_ansi
|
|
import click
|
|
|
|
try:
|
|
import setproctitle
|
|
except ImportError:
|
|
setproctitle = None
|
|
from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter
|
|
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
|
|
from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
|
|
from prompt_toolkit.document import Document
|
|
from prompt_toolkit.filters import HasFocus, IsDone
|
|
from prompt_toolkit.formatted_text import ANSI
|
|
from prompt_toolkit.lexers import PygmentsLexer
|
|
from prompt_toolkit.layout.processors import (
|
|
ConditionalProcessor,
|
|
HighlightMatchingBracketProcessor,
|
|
TabsProcessor,
|
|
)
|
|
from prompt_toolkit.history import FileHistory
|
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
|
from pygments.lexers.sql import PostgresLexer
|
|
|
|
from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
|
|
import pgspecial as special
|
|
|
|
from .pgcompleter import PGCompleter
|
|
from .pgtoolbar import create_toolbar_tokens_func
|
|
from .pgstyle import style_factory, style_factory_output
|
|
from .pgexecute import PGExecute
|
|
from .completion_refresher import CompletionRefresher
|
|
from .config import (
|
|
get_casing_file,
|
|
load_config,
|
|
config_location,
|
|
ensure_dir_exists,
|
|
get_config,
|
|
get_config_filename,
|
|
)
|
|
from .key_bindings import pgcli_bindings
|
|
from .packages.prompt_utils import confirm_destructive_query
|
|
from .__init__ import __version__
|
|
|
|
click.disable_unicode_literals_warning = True
|
|
|
|
try:
|
|
from urlparse import urlparse, unquote, parse_qs
|
|
except ImportError:
|
|
from urllib.parse import urlparse, unquote, parse_qs
|
|
|
|
from getpass import getuser
|
|
from psycopg2 import OperationalError, InterfaceError
|
|
|
|
# pg3: https://www.psycopg.org/psycopg3/docs/api/conninfo.html
|
|
from psycopg2.extensions import make_dsn, parse_dsn
|
|
import psycopg2
|
|
|
|
from collections import namedtuple
|
|
|
|
from textwrap import dedent
|
|
|
|
try:
|
|
import sshtunnel
|
|
|
|
SSH_TUNNEL_SUPPORT = True
|
|
except ImportError:
|
|
SSH_TUNNEL_SUPPORT = False
|
|
|
|
|
|
# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
|
|
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
|
|
DEFAULT_MAX_FIELD_WIDTH = 500
|
|
|
|
# Query tuples are used for maintaining history
|
|
MetaQuery = namedtuple(
|
|
"Query",
|
|
[
|
|
"query", # The entire text of the command
|
|
"successful", # True If all subqueries were successful
|
|
"total_time", # Time elapsed executing the query and formatting results
|
|
"execution_time", # Time elapsed executing the query
|
|
"meta_changed", # True if any subquery executed create/alter/drop
|
|
"db_changed", # True if any subquery changed the database
|
|
"path_changed", # True if any subquery changed the search path
|
|
"mutated", # True if any subquery executed insert/update/delete
|
|
"is_special", # True if the query is a special command
|
|
],
|
|
)
|
|
MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False)
|
|
|
|
OutputSettings = namedtuple(
|
|
"OutputSettings",
|
|
"table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output max_field_width",
|
|
)
|
|
OutputSettings.__new__.__defaults__ = (
|
|
None,
|
|
None,
|
|
None,
|
|
"<null>",
|
|
False,
|
|
None,
|
|
lambda x: x,
|
|
None,
|
|
DEFAULT_MAX_FIELD_WIDTH,
|
|
)
|
|
|
|
|
|
class PgCliQuitError(Exception):
|
|
pass
|
|
|
|
|
|
class PGCli:
|
|
default_prompt = "\\u@\\h:\\d> "
|
|
max_len_prompt = 30
|
|
|
|
def set_default_pager(self, config):
|
|
configured_pager = config["main"].get("pager")
|
|
os_environ_pager = os.environ.get("PAGER")
|
|
|
|
if configured_pager:
|
|
self.logger.info(
|
|
'Default pager found in config file: "%s"', configured_pager
|
|
)
|
|
os.environ["PAGER"] = configured_pager
|
|
elif os_environ_pager:
|
|
self.logger.info(
|
|
'Default pager found in PAGER environment variable: "%s"',
|
|
os_environ_pager,
|
|
)
|
|
os.environ["PAGER"] = os_environ_pager
|
|
else:
|
|
self.logger.info(
|
|
"No default pager found in environment. Using os default pager"
|
|
)
|
|
|
|
# Set default set of less recommended options, if they are not already set.
|
|
# They are ignored if pager is different than less.
|
|
if not os.environ.get("LESS"):
|
|
os.environ["LESS"] = "-SRXF"
|
|
|
|
def __init__(
|
|
self,
|
|
force_passwd_prompt=False,
|
|
never_passwd_prompt=False,
|
|
pgexecute=None,
|
|
pgclirc_file=None,
|
|
row_limit=None,
|
|
single_connection=False,
|
|
less_chatty=None,
|
|
prompt=None,
|
|
prompt_dsn=None,
|
|
auto_vertical_output=False,
|
|
warn=None,
|
|
ssh_tunnel_url: Optional[str] = None,
|
|
):
|
|
self.force_passwd_prompt = force_passwd_prompt
|
|
self.never_passwd_prompt = never_passwd_prompt
|
|
self.pgexecute = pgexecute
|
|
self.dsn_alias = None
|
|
self.watch_command = None
|
|
|
|
# Load config.
|
|
c = self.config = get_config(pgclirc_file)
|
|
|
|
# at this point, config should be written to pgclirc_file if it did not exist. Read it.
|
|
self.config_writer = load_config(get_config_filename(pgclirc_file))
|
|
|
|
# make sure to use self.config_writer, not self.config
|
|
NamedQueries.instance = NamedQueries.from_config(self.config_writer)
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
self.initialize_logging()
|
|
|
|
self.set_default_pager(c)
|
|
self.output_file = None
|
|
self.pgspecial = PGSpecial()
|
|
|
|
self.multi_line = c["main"].as_bool("multi_line")
|
|
self.multiline_mode = c["main"].get("multi_line_mode", "psql")
|
|
self.vi_mode = c["main"].as_bool("vi")
|
|
self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
|
|
self.expanded_output = c["main"].as_bool("expand")
|
|
self.pgspecial.timing_enabled = c["main"].as_bool("timing")
|
|
if row_limit is not None:
|
|
self.row_limit = row_limit
|
|
else:
|
|
self.row_limit = c["main"].as_int("row_limit")
|
|
|
|
# if not specified, set to DEFAULT_MAX_FIELD_WIDTH
|
|
# if specified but empty, set to None to disable truncation
|
|
# ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0
|
|
max_field_width = c["main"].get("max_field_width", DEFAULT_MAX_FIELD_WIDTH)
|
|
if max_field_width and max_field_width.lower() != "none":
|
|
max_field_width = max(3, abs(int(max_field_width)))
|
|
else:
|
|
max_field_width = None
|
|
self.max_field_width = max_field_width
|
|
|
|
self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
|
|
self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
|
|
self.table_format = c["main"]["table_format"]
|
|
self.syntax_style = c["main"]["syntax_style"]
|
|
self.cli_style = c["colors"]
|
|
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
|
|
self.destructive_warning = warn or c["main"]["destructive_warning"]
|
|
# also handle boolean format of destructive warning
|
|
self.destructive_warning = {"true": "all", "false": "off"}.get(
|
|
self.destructive_warning.lower(), self.destructive_warning
|
|
)
|
|
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
|
|
self.null_string = c["main"].get("null_string", "<null>")
|
|
self.prompt_format = (
|
|
prompt
|
|
if prompt is not None
|
|
else c["main"].get("prompt", self.default_prompt)
|
|
)
|
|
self.prompt_dsn_format = prompt_dsn
|
|
self.on_error = c["main"]["on_error"].upper()
|
|
self.decimal_format = c["data_formats"]["decimal"]
|
|
self.float_format = c["data_formats"]["float"]
|
|
self.initialize_keyring()
|
|
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
|
|
|
|
self.pgspecial.pset_pager(
|
|
self.config["main"].as_bool("enable_pager") and "on" or "off"
|
|
)
|
|
|
|
self.style_output = style_factory_output(self.syntax_style, c["colors"])
|
|
|
|
self.now = dt.datetime.today()
|
|
|
|
self.completion_refresher = CompletionRefresher()
|
|
|
|
self.query_history = []
|
|
|
|
# Initialize completer
|
|
smart_completion = c["main"].as_bool("smart_completion")
|
|
keyword_casing = c["main"]["keyword_casing"]
|
|
self.settings = {
|
|
"casing_file": get_casing_file(c),
|
|
"generate_casing_file": c["main"].as_bool("generate_casing_file"),
|
|
"generate_aliases": c["main"].as_bool("generate_aliases"),
|
|
"asterisk_column_order": c["main"]["asterisk_column_order"],
|
|
"qualify_columns": c["main"]["qualify_columns"],
|
|
"case_column_headers": c["main"].as_bool("case_column_headers"),
|
|
"search_path_filter": c["main"].as_bool("search_path_filter"),
|
|
"single_connection": single_connection,
|
|
"less_chatty": less_chatty,
|
|
"keyword_casing": keyword_casing,
|
|
}
|
|
|
|
completer = PGCompleter(
|
|
smart_completion, pgspecial=self.pgspecial, settings=self.settings
|
|
)
|
|
self.completer = completer
|
|
self._completer_lock = threading.Lock()
|
|
self.register_special_commands()
|
|
|
|
self.prompt_app = None
|
|
|
|
self.ssh_tunnel_config = c.get("ssh tunnels")
|
|
self.ssh_tunnel_url = ssh_tunnel_url
|
|
self.ssh_tunnel = None
|
|
|
|
def quit(self):
|
|
raise PgCliQuitError
|
|
|
|
def register_special_commands(self):
|
|
|
|
self.pgspecial.register(
|
|
self.change_db,
|
|
"\\c",
|
|
"\\c[onnect] database_name",
|
|
"Change to a new database.",
|
|
aliases=("use", "\\connect", "USE"),
|
|
)
|
|
|
|
refresh_callback = lambda: self.refresh_completions(persist_priorities="all")
|
|
|
|
self.pgspecial.register(
|
|
self.quit,
|
|
"\\q",
|
|
"\\q",
|
|
"Quit pgcli.",
|
|
arg_type=NO_QUERY,
|
|
case_sensitive=True,
|
|
aliases=(":q",),
|
|
)
|
|
self.pgspecial.register(
|
|
self.quit,
|
|
"quit",
|
|
"quit",
|
|
"Quit pgcli.",
|
|
arg_type=NO_QUERY,
|
|
case_sensitive=False,
|
|
aliases=("exit",),
|
|
)
|
|
self.pgspecial.register(
|
|
refresh_callback,
|
|
"\\#",
|
|
"\\#",
|
|
"Refresh auto-completions.",
|
|
arg_type=NO_QUERY,
|
|
)
|
|
self.pgspecial.register(
|
|
refresh_callback,
|
|
"\\refresh",
|
|
"\\refresh",
|
|
"Refresh auto-completions.",
|
|
arg_type=NO_QUERY,
|
|
)
|
|
self.pgspecial.register(
|
|
self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
|
|
)
|
|
self.pgspecial.register(
|
|
self.write_to_file,
|
|
"\\o",
|
|
"\\o [filename]",
|
|
"Send all query results to file.",
|
|
)
|
|
self.pgspecial.register(
|
|
self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
|
|
)
|
|
self.pgspecial.register(
|
|
self.change_table_format,
|
|
"\\T",
|
|
"\\T [format]",
|
|
"Change the table format used to output results",
|
|
)
|
|
|
|
def change_table_format(self, pattern, **_):
|
|
try:
|
|
if pattern not in TabularOutputFormatter().supported_formats:
|
|
raise ValueError()
|
|
self.table_format = pattern
|
|
yield (None, None, None, f"Changed table format to {pattern}")
|
|
except ValueError:
|
|
msg = f"Table format {pattern} not recognized. Allowed formats:"
|
|
for table_type in TabularOutputFormatter().supported_formats:
|
|
msg += f"\n\t{table_type}"
|
|
msg += "\nCurrently set to: %s" % self.table_format
|
|
yield (None, None, None, msg)
|
|
|
|
def info_connection(self, **_):
|
|
if self.pgexecute.host.startswith("/"):
|
|
host = 'socket "%s"' % self.pgexecute.host
|
|
else:
|
|
host = 'host "%s"' % self.pgexecute.host
|
|
|
|
yield (
|
|
None,
|
|
None,
|
|
None,
|
|
'You are connected to database "%s" as user '
|
|
'"%s" on %s at port "%s".'
|
|
% (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
|
|
)
|
|
|
|
def change_db(self, pattern, **_):
|
|
if pattern:
|
|
# Get all the parameters in pattern, handling double quotes if any.
|
|
infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
|
|
# Now removing quotes.
|
|
list(map(lambda s: s.strip('"'), infos))
|
|
|
|
infos.extend([None] * (4 - len(infos)))
|
|
db, user, host, port = infos
|
|
try:
|
|
self.pgexecute.connect(
|
|
database=db,
|
|
user=user,
|
|
host=host,
|
|
port=port,
|
|
**self.pgexecute.extra_args,
|
|
)
|
|
except OperationalError as e:
|
|
click.secho(str(e), err=True, fg="red")
|
|
click.echo("Previous connection kept")
|
|
else:
|
|
self.pgexecute.connect()
|
|
|
|
yield (
|
|
None,
|
|
None,
|
|
None,
|
|
'You are now connected to database "%s" as '
|
|
'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
|
|
)
|
|
|
|
def execute_from_file(self, pattern, **_):
|
|
if not pattern:
|
|
message = "\\i: missing required argument"
|
|
return [(None, None, None, message, "", False, True)]
|
|
try:
|
|
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
|
|
query = f.read()
|
|
except OSError as e:
|
|
return [(None, None, None, str(e), "", False, True)]
|
|
|
|
if (
|
|
self.destructive_warning != "off"
|
|
and confirm_destructive_query(query, self.destructive_warning) is False
|
|
):
|
|
message = "Wise choice. Command execution stopped."
|
|
return [(None, None, None, message)]
|
|
|
|
on_error_resume = self.on_error == "RESUME"
|
|
return self.pgexecute.run(
|
|
query, self.pgspecial, on_error_resume=on_error_resume
|
|
)
|
|
|
|
def write_to_file(self, pattern, **_):
|
|
if not pattern:
|
|
self.output_file = None
|
|
message = "File output disabled"
|
|
return [(None, None, None, message, "", True, True)]
|
|
filename = os.path.abspath(os.path.expanduser(pattern))
|
|
if not os.path.isfile(filename):
|
|
try:
|
|
open(filename, "w").close()
|
|
except OSError as e:
|
|
self.output_file = None
|
|
message = str(e) + "\nFile output disabled"
|
|
return [(None, None, None, message, "", False, True)]
|
|
self.output_file = filename
|
|
message = 'Writing to file "%s"' % self.output_file
|
|
return [(None, None, None, message, "", True, True)]
|
|
|
|
def initialize_logging(self):
|
|
|
|
log_file = self.config["main"]["log_file"]
|
|
if log_file == "default":
|
|
log_file = config_location() + "log"
|
|
ensure_dir_exists(log_file)
|
|
log_level = self.config["main"]["log_level"]
|
|
|
|
# Disable logging if value is NONE by switching to a no-op handler.
|
|
# Set log level to a high value so it doesn't even waste cycles getting called.
|
|
if log_level.upper() == "NONE":
|
|
handler = logging.NullHandler()
|
|
else:
|
|
handler = logging.FileHandler(os.path.expanduser(log_file))
|
|
|
|
level_map = {
|
|
"CRITICAL": logging.CRITICAL,
|
|
"ERROR": logging.ERROR,
|
|
"WARNING": logging.WARNING,
|
|
"INFO": logging.INFO,
|
|
"DEBUG": logging.DEBUG,
|
|
"NONE": logging.CRITICAL,
|
|
}
|
|
|
|
log_level = level_map[log_level.upper()]
|
|
|
|
formatter = logging.Formatter(
|
|
"%(asctime)s (%(process)d/%(threadName)s) "
|
|
"%(name)s %(levelname)s - %(message)s"
|
|
)
|
|
|
|
handler.setFormatter(formatter)
|
|
|
|
root_logger = logging.getLogger("pgcli")
|
|
root_logger.addHandler(handler)
|
|
root_logger.setLevel(log_level)
|
|
|
|
root_logger.debug("Initializing pgcli logging.")
|
|
root_logger.debug("Log file %r.", log_file)
|
|
|
|
pgspecial_logger = logging.getLogger("pgspecial")
|
|
pgspecial_logger.addHandler(handler)
|
|
pgspecial_logger.setLevel(log_level)
|
|
|
|
def initialize_keyring(self):
|
|
global keyring
|
|
|
|
keyring_enabled = self.config["main"].as_bool("keyring")
|
|
if keyring_enabled:
|
|
# Try best to load keyring (issue #1041).
|
|
import importlib
|
|
|
|
try:
|
|
keyring = importlib.import_module("keyring")
|
|
except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
|
|
self.logger.warning("import keyring failed: %r.", e)
|
|
|
|
def connect_dsn(self, dsn, **kwargs):
|
|
self.connect(dsn=dsn, **kwargs)
|
|
|
|
def connect_service(self, service, user):
|
|
service_config, file = parse_service_info(service)
|
|
if service_config is None:
|
|
click.secho(
|
|
f"service '{service}' was not found in {file}", err=True, fg="red"
|
|
)
|
|
exit(1)
|
|
self.connect(
|
|
database=service_config.get("dbname"),
|
|
host=service_config.get("host"),
|
|
user=user or service_config.get("user"),
|
|
port=service_config.get("port"),
|
|
passwd=service_config.get("password"),
|
|
)
|
|
|
|
def connect_uri(self, uri):
|
|
kwargs = psycopg2.extensions.parse_dsn(uri)
|
|
remap = {"dbname": "database", "password": "passwd"}
|
|
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
|
|
self.connect(**kwargs)
|
|
|
|
def connect(
|
|
self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
|
|
):
|
|
# Connect to the database.
|
|
|
|
if not user:
|
|
user = getuser()
|
|
|
|
if not database:
|
|
database = user
|
|
|
|
kwargs.setdefault("application_name", "pgcli")
|
|
|
|
# If password prompt is not forced but no password is provided, try
|
|
# getting it from environment variable.
|
|
if not self.force_passwd_prompt and not passwd:
|
|
passwd = os.environ.get("PGPASSWORD", "")
|
|
|
|
# Find password from store
|
|
key = f"{user}@{host}"
|
|
keyring_error_message = dedent(
|
|
"""\
|
|
{}
|
|
{}
|
|
To remove this message do one of the following:
|
|
- prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
|
|
- uninstall keyring: pip uninstall keyring
|
|
- disable keyring in our configuration: add keyring = False to [main]"""
|
|
)
|
|
|
|
# Prompt for a password immediately if requested via the -W flag. This
|
|
# avoids wasting time trying to connect to the database and catching a
|
|
# no-password exception.
|
|
# If we successfully parsed a password from a URI, there's no need to
|
|
# prompt for it, even with the -W flag
|
|
if self.force_passwd_prompt and not passwd:
|
|
passwd = click.prompt(
|
|
"Password for %s" % user, hide_input=True, show_default=False, type=str
|
|
)
|
|
|
|
if not passwd and keyring:
|
|
|
|
try:
|
|
passwd = keyring.get_password("pgcli", key)
|
|
except (RuntimeError, keyring.errors.InitError) as e:
|
|
click.secho(
|
|
keyring_error_message.format(
|
|
"Load your password from keyring returned:", str(e)
|
|
),
|
|
err=True,
|
|
fg="red",
|
|
)
|
|
|
|
def should_ask_for_password(exc):
|
|
# Prompt for a password after 1st attempt to connect
|
|
# fails. Don't prompt if the -w flag is supplied
|
|
if self.never_passwd_prompt:
|
|
return False
|
|
error_msg = exc.args[0]
|
|
if "no password supplied" in error_msg:
|
|
return True
|
|
if "password authentication failed" in error_msg:
|
|
return True
|
|
return False
|
|
|
|
if dsn:
|
|
parsed_dsn = parse_dsn(dsn)
|
|
if "host" in parsed_dsn:
|
|
host = parsed_dsn["host"]
|
|
if "port" in parsed_dsn:
|
|
port = parsed_dsn["port"]
|
|
|
|
if self.ssh_tunnel_config and not self.ssh_tunnel_url:
|
|
for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
|
|
if re.search(db_host_regex, host):
|
|
self.ssh_tunnel_url = tunnel_url
|
|
break
|
|
|
|
if self.ssh_tunnel_url:
|
|
# We add the protocol as urlparse doesn't find it by itself
|
|
if "://" not in self.ssh_tunnel_url:
|
|
self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
|
|
|
|
tunnel_info = urlparse(self.ssh_tunnel_url)
|
|
params = {
|
|
"local_bind_address": ("127.0.0.1",),
|
|
"remote_bind_address": (host, int(port or 5432)),
|
|
"ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
|
|
"logger": self.logger,
|
|
}
|
|
if tunnel_info.username:
|
|
params["ssh_username"] = tunnel_info.username
|
|
if tunnel_info.password:
|
|
params["ssh_password"] = tunnel_info.password
|
|
|
|
# Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
|
|
# We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
|
|
logger_handlers = self.logger.handlers.copy()
|
|
try:
|
|
self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
|
|
self.ssh_tunnel.start()
|
|
except Exception as e:
|
|
self.logger.handlers = logger_handlers
|
|
self.logger.error("traceback: %r", traceback.format_exc())
|
|
click.secho(str(e), err=True, fg="red")
|
|
exit(1)
|
|
self.logger.handlers = logger_handlers
|
|
|
|
atexit.register(self.ssh_tunnel.stop)
|
|
host = "127.0.0.1"
|
|
port = self.ssh_tunnel.local_bind_ports[0]
|
|
|
|
if dsn:
|
|
dsn = make_dsn(dsn, host=host, port=port)
|
|
|
|
# Attempt to connect to the database.
|
|
# Note that passwd may be empty on the first attempt. If connection
|
|
# fails because of a missing or incorrect password, but we're allowed to
|
|
# prompt for a password (no -w flag), prompt for a passwd and try again.
|
|
try:
|
|
try:
|
|
pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
|
|
except (OperationalError, InterfaceError) as e:
|
|
if should_ask_for_password(e):
|
|
passwd = click.prompt(
|
|
"Password for %s" % user,
|
|
hide_input=True,
|
|
show_default=False,
|
|
type=str,
|
|
)
|
|
pgexecute = PGExecute(
|
|
database, user, passwd, host, port, dsn, **kwargs
|
|
)
|
|
else:
|
|
raise e
|
|
if passwd and keyring:
|
|
try:
|
|
keyring.set_password("pgcli", key, passwd)
|
|
except (RuntimeError, keyring.errors.KeyringError) as e:
|
|
click.secho(
|
|
keyring_error_message.format(
|
|
"Set password in keyring returned:", str(e)
|
|
),
|
|
err=True,
|
|
fg="red",
|
|
)
|
|
|
|
except Exception as e: # Connecting to a database could fail.
|
|
self.logger.debug("Database connection failed: %r.", e)
|
|
self.logger.error("traceback: %r", traceback.format_exc())
|
|
click.secho(str(e), err=True, fg="red")
|
|
exit(1)
|
|
|
|
self.pgexecute = pgexecute
|
|
|
|
def handle_editor_command(self, text):
|
|
r"""
|
|
Editor command is any query that is prefixed or suffixed
|
|
by a '\e'. The reason for a while loop is because a user
|
|
might edit a query multiple times.
|
|
For eg:
|
|
"select * from \e"<enter> to edit it in vim, then come
|
|
back to the prompt with the edited query "select * from
|
|
blah where q = 'abc'\e" to edit it again.
|
|
:param text: Document
|
|
:return: Document
|
|
"""
|
|
editor_command = special.editor_command(text)
|
|
while editor_command:
|
|
if editor_command == "\\e":
|
|
filename = special.get_filename(text)
|
|
query = special.get_editor_query(text) or self.get_last_query()
|
|
else: # \ev or \ef
|
|
filename = None
|
|
spec = text.split()[1]
|
|
if editor_command == "\\ev":
|
|
query = self.pgexecute.view_definition(spec)
|
|
elif editor_command == "\\ef":
|
|
query = self.pgexecute.function_definition(spec)
|
|
sql, message = special.open_external_editor(filename, sql=query)
|
|
if message:
|
|
# Something went wrong. Raise an exception and bail.
|
|
raise RuntimeError(message)
|
|
while True:
|
|
try:
|
|
text = self.prompt_app.prompt(default=sql)
|
|
break
|
|
except KeyboardInterrupt:
|
|
sql = ""
|
|
|
|
editor_command = special.editor_command(text)
|
|
return text
|
|
|
|
def execute_command(self, text):
|
|
logger = self.logger
|
|
|
|
query = MetaQuery(query=text, successful=False)
|
|
|
|
try:
|
|
if self.destructive_warning != "off":
|
|
destroy = confirm = confirm_destructive_query(
|
|
text, self.destructive_warning
|
|
)
|
|
if destroy is False:
|
|
click.secho("Wise choice!")
|
|
raise KeyboardInterrupt
|
|
elif destroy:
|
|
click.secho("Your call!")
|
|
output, query = self._evaluate_command(text)
|
|
except KeyboardInterrupt:
|
|
# Restart connection to the database
|
|
self.pgexecute.connect()
|
|
logger.debug("cancelled query, sql: %r", text)
|
|
click.secho("cancelled query", err=True, fg="red")
|
|
except NotImplementedError:
|
|
click.secho("Not Yet Implemented.", fg="yellow")
|
|
except OperationalError as e:
|
|
logger.error("sql: %r, error: %r", text, e)
|
|
logger.error("traceback: %r", traceback.format_exc())
|
|
self._handle_server_closed_connection(text)
|
|
except (PgCliQuitError, EOFError) as e:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("sql: %r, error: %r", text, e)
|
|
logger.error("traceback: %r", traceback.format_exc())
|
|
click.secho(str(e), err=True, fg="red")
|
|
else:
|
|
try:
|
|
if self.output_file and not text.startswith(("\\o ", "\\? ")):
|
|
try:
|
|
with open(self.output_file, "a", encoding="utf-8") as f:
|
|
click.echo(text, file=f)
|
|
click.echo("\n".join(output), file=f)
|
|
click.echo("", file=f) # extra newline
|
|
except OSError as e:
|
|
click.secho(str(e), err=True, fg="red")
|
|
else:
|
|
if output:
|
|
self.echo_via_pager("\n".join(output))
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
if self.pgspecial.timing_enabled:
|
|
# Only add humanized time display if > 1 second
|
|
if query.total_time > 1:
|
|
print(
|
|
"Time: %0.03fs (%s), executed in: %0.03fs (%s)"
|
|
% (
|
|
query.total_time,
|
|
pendulum.Duration(seconds=query.total_time).in_words(),
|
|
query.execution_time,
|
|
pendulum.Duration(seconds=query.execution_time).in_words(),
|
|
)
|
|
)
|
|
else:
|
|
print("Time: %0.03fs" % query.total_time)
|
|
|
|
# Check if we need to update completions, in order of most
|
|
# to least drastic changes
|
|
if query.db_changed:
|
|
with self._completer_lock:
|
|
self.completer.reset_completions()
|
|
self.refresh_completions(persist_priorities="keywords")
|
|
elif query.meta_changed:
|
|
self.refresh_completions(persist_priorities="all")
|
|
elif query.path_changed:
|
|
logger.debug("Refreshing search path")
|
|
with self._completer_lock:
|
|
self.completer.set_search_path(self.pgexecute.search_path())
|
|
logger.debug("Search path: %r", self.completer.search_path)
|
|
return query
|
|
|
|
def run_cli(self):
|
|
logger = self.logger
|
|
|
|
history_file = self.config["main"]["history_file"]
|
|
if history_file == "default":
|
|
history_file = config_location() + "history"
|
|
history = FileHistory(os.path.expanduser(history_file))
|
|
self.refresh_completions(history=history, persist_priorities="none")
|
|
|
|
self.prompt_app = self._build_cli(history)
|
|
|
|
if not self.less_chatty:
|
|
print("Server: PostgreSQL", self.pgexecute.server_version)
|
|
print("Version:", __version__)
|
|
print("Home: http://pgcli.com")
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
text = self.prompt_app.prompt()
|
|
except KeyboardInterrupt:
|
|
continue
|
|
|
|
try:
|
|
text = self.handle_editor_command(text)
|
|
except RuntimeError as e:
|
|
logger.error("sql: %r, error: %r", text, e)
|
|
logger.error("traceback: %r", traceback.format_exc())
|
|
click.secho(str(e), err=True, fg="red")
|
|
continue
|
|
|
|
self.handle_watch_command(text)
|
|
|
|
self.now = dt.datetime.today()
|
|
|
|
# Allow PGCompleter to learn user's preferred keywords, etc.
|
|
with self._completer_lock:
|
|
self.completer.extend_query_history(text)
|
|
|
|
except (PgCliQuitError, EOFError):
|
|
if not self.less_chatty:
|
|
print("Goodbye!")
|
|
|
|
def handle_watch_command(self, text):
|
|
# Initialize default metaquery in case execution fails
|
|
self.watch_command, timing = special.get_watch_command(text)
|
|
|
|
# If we run \watch without a command, apply it to the last query run.
|
|
if self.watch_command is not None and not self.watch_command.strip():
|
|
try:
|
|
self.watch_command = self.query_history[-1].query
|
|
except IndexError:
|
|
click.secho(
|
|
"\\watch cannot be used with an empty query", err=True, fg="red"
|
|
)
|
|
self.watch_command = None
|
|
|
|
# If there's a command to \watch, run it in a loop.
|
|
if self.watch_command:
|
|
while self.watch_command:
|
|
try:
|
|
query = self.execute_command(self.watch_command)
|
|
click.echo(f"Waiting for {timing} seconds before repeating")
|
|
sleep(timing)
|
|
except KeyboardInterrupt:
|
|
self.watch_command = None
|
|
|
|
# Otherwise, execute it as a regular command.
|
|
else:
|
|
query = self.execute_command(text)
|
|
|
|
self.query_history.append(query)
|
|
|
|
def _build_cli(self, history):
|
|
key_bindings = pgcli_bindings(self)
|
|
|
|
def get_message():
|
|
if self.dsn_alias and self.prompt_dsn_format is not None:
|
|
prompt_format = self.prompt_dsn_format
|
|
else:
|
|
prompt_format = self.prompt_format
|
|
|
|
prompt = self.get_prompt(prompt_format)
|
|
|
|
if (
|
|
prompt_format == self.default_prompt
|
|
and len(prompt) > self.max_len_prompt
|
|
):
|
|
prompt = self.get_prompt("\\d> ")
|
|
|
|
prompt = prompt.replace("\\x1b", "\x1b")
|
|
return ANSI(prompt)
|
|
|
|
def get_continuation(width, line_number, is_soft_wrap):
|
|
continuation = self.multiline_continuation_char * (width - 1) + " "
|
|
return [("class:continuation", continuation)]
|
|
|
|
get_toolbar_tokens = create_toolbar_tokens_func(self)
|
|
|
|
if self.wider_completion_menu:
|
|
complete_style = CompleteStyle.MULTI_COLUMN
|
|
else:
|
|
complete_style = CompleteStyle.COLUMN
|
|
|
|
with self._completer_lock:
|
|
prompt_app = PromptSession(
|
|
lexer=PygmentsLexer(PostgresLexer),
|
|
reserve_space_for_menu=self.min_num_menu_lines,
|
|
message=get_message,
|
|
prompt_continuation=get_continuation,
|
|
bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
|
|
complete_style=complete_style,
|
|
input_processors=[
|
|
# Highlight matching brackets while editing.
|
|
ConditionalProcessor(
|
|
processor=HighlightMatchingBracketProcessor(chars="[](){}"),
|
|
filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
|
|
),
|
|
# Render \t as 4 spaces instead of "^I"
|
|
TabsProcessor(char1=" ", char2=" "),
|
|
],
|
|
auto_suggest=AutoSuggestFromHistory(),
|
|
tempfile_suffix=".sql",
|
|
# N.b. pgcli's multi-line mode controls submit-on-Enter (which
|
|
# overrides the default behaviour of prompt_toolkit) and is
|
|
# distinct from prompt_toolkit's multiline mode here, which
|
|
# controls layout/display of the prompt/buffer
|
|
multiline=True,
|
|
history=history,
|
|
completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)),
|
|
complete_while_typing=True,
|
|
style=style_factory(self.syntax_style, self.cli_style),
|
|
include_default_pygments_style=False,
|
|
key_bindings=key_bindings,
|
|
enable_open_in_editor=True,
|
|
enable_system_prompt=True,
|
|
enable_suspend=True,
|
|
editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
|
|
search_ignore_case=True,
|
|
)
|
|
|
|
return prompt_app
|
|
|
|
def _should_limit_output(self, sql, cur):
|
|
"""returns True if the output should be truncated, False otherwise."""
|
|
if not is_select(sql):
|
|
return False
|
|
|
|
return (
|
|
not self._has_limit(sql)
|
|
and self.row_limit != 0
|
|
and cur
|
|
and cur.rowcount > self.row_limit
|
|
)
|
|
|
|
def _has_limit(self, sql):
|
|
if not sql:
|
|
return False
|
|
return "limit " in sql.lower()
|
|
|
|
def _limit_output(self, cur):
|
|
limit = min(self.row_limit, cur.rowcount)
|
|
new_cur = itertools.islice(cur, limit)
|
|
new_status = "SELECT " + str(limit)
|
|
click.secho("The result was limited to %s rows" % limit, fg="red")
|
|
|
|
return new_cur, new_status
|
|
|
|
def _evaluate_command(self, text):
|
|
"""Used to run a command entered by the user during CLI operation
|
|
(Puts the E in REPL)
|
|
|
|
returns (results, MetaQuery)
|
|
"""
|
|
logger = self.logger
|
|
logger.debug("sql: %r", text)
|
|
|
|
all_success = True
|
|
meta_changed = False # CREATE, ALTER, DROP, etc
|
|
mutated = False # INSERT, DELETE, etc
|
|
db_changed = False
|
|
path_changed = False
|
|
output = []
|
|
total = 0
|
|
execution = 0
|
|
|
|
# Run the query.
|
|
start = time()
|
|
on_error_resume = self.on_error == "RESUME"
|
|
res = self.pgexecute.run(
|
|
text, self.pgspecial, exception_formatter, on_error_resume
|
|
)
|
|
|
|
is_special = None
|
|
|
|
for title, cur, headers, status, sql, success, is_special in res:
|
|
logger.debug("headers: %r", headers)
|
|
logger.debug("rows: %r", cur)
|
|
logger.debug("status: %r", status)
|
|
|
|
if self._should_limit_output(sql, cur):
|
|
cur, status = self._limit_output(cur)
|
|
|
|
if self.pgspecial.auto_expand or self.auto_expand:
|
|
max_width = self.prompt_app.output.get_size().columns
|
|
else:
|
|
max_width = None
|
|
|
|
expanded = self.pgspecial.expanded_output or self.expanded_output
|
|
settings = OutputSettings(
|
|
table_format=self.table_format,
|
|
dcmlfmt=self.decimal_format,
|
|
floatfmt=self.float_format,
|
|
missingval=self.null_string,
|
|
expanded=expanded,
|
|
max_width=max_width,
|
|
case_function=(
|
|
self.completer.case
|
|
if self.settings["case_column_headers"]
|
|
else lambda x: x
|
|
),
|
|
style_output=self.style_output,
|
|
max_field_width=self.max_field_width,
|
|
)
|
|
execution = time() - start
|
|
formatted = format_output(title, cur, headers, status, settings)
|
|
|
|
output.extend(formatted)
|
|
total = time() - start
|
|
|
|
# Keep track of whether any of the queries are mutating or changing
|
|
# the database
|
|
if success:
|
|
mutated = mutated or is_mutating(status)
|
|
db_changed = db_changed or has_change_db_cmd(sql)
|
|
meta_changed = meta_changed or has_meta_cmd(sql)
|
|
path_changed = path_changed or has_change_path_cmd(sql)
|
|
else:
|
|
all_success = False
|
|
|
|
meta_query = MetaQuery(
|
|
text,
|
|
all_success,
|
|
total,
|
|
execution,
|
|
meta_changed,
|
|
db_changed,
|
|
path_changed,
|
|
mutated,
|
|
is_special,
|
|
)
|
|
|
|
return output, meta_query
|
|
|
|
def _handle_server_closed_connection(self, text):
|
|
"""Used during CLI execution."""
|
|
try:
|
|
click.secho("Reconnecting...", fg="green")
|
|
self.pgexecute.connect()
|
|
click.secho("Reconnected!", fg="green")
|
|
self.execute_command(text)
|
|
except OperationalError as e:
|
|
click.secho("Reconnect Failed", fg="red")
|
|
click.secho(str(e), err=True, fg="red")
|
|
|
|
def refresh_completions(self, history=None, persist_priorities="all"):
|
|
"""Refresh outdated completions
|
|
|
|
:param history: A prompt_toolkit.history.FileHistory object. Used to
|
|
load keyword and identifier preferences
|
|
|
|
:param persist_priorities: 'all' or 'keywords'
|
|
"""
|
|
|
|
callback = functools.partial(
|
|
self._on_completions_refreshed, persist_priorities=persist_priorities
|
|
)
|
|
return self.completion_refresher.refresh(
|
|
self.pgexecute,
|
|
self.pgspecial,
|
|
callback,
|
|
history=history,
|
|
settings=self.settings,
|
|
)
|
|
|
|
def _on_completions_refreshed(self, new_completer, persist_priorities):
|
|
self._swap_completer_objects(new_completer, persist_priorities)
|
|
|
|
if self.prompt_app:
|
|
# After refreshing, redraw the CLI to clear the statusbar
|
|
# "Refreshing completions..." indicator
|
|
self.prompt_app.app.invalidate()
|
|
|
|
def _swap_completer_objects(self, new_completer, persist_priorities):
|
|
"""Swap the completer object with the newly created completer.
|
|
|
|
persist_priorities is a string specifying how the old completer's
|
|
learned prioritizer should be transferred to the new completer.
|
|
|
|
'none' - The new prioritizer is left in a new/clean state
|
|
|
|
'all' - The new prioritizer is updated to exactly reflect
|
|
the old one
|
|
|
|
'keywords' - The new prioritizer is updated with old keyword
|
|
priorities, but not any other.
|
|
|
|
"""
|
|
with self._completer_lock:
|
|
old_completer = self.completer
|
|
self.completer = new_completer
|
|
|
|
if persist_priorities == "all":
|
|
# Just swap over the entire prioritizer
|
|
new_completer.prioritizer = old_completer.prioritizer
|
|
elif persist_priorities == "keywords":
|
|
# Swap over the entire prioritizer, but clear name priorities,
|
|
# leaving learned keyword priorities alone
|
|
new_completer.prioritizer = old_completer.prioritizer
|
|
new_completer.prioritizer.clear_names()
|
|
elif persist_priorities == "none":
|
|
# Leave the new prioritizer as is
|
|
pass
|
|
self.completer = new_completer
|
|
|
|
def get_completions(self, text, cursor_positition):
|
|
with self._completer_lock:
|
|
return self.completer.get_completions(
|
|
Document(text=text, cursor_position=cursor_positition), None
|
|
)
|
|
|
|
def get_prompt(self, string):
|
|
# should be before replacing \\d
|
|
string = string.replace("\\dsn_alias", self.dsn_alias or "")
|
|
string = string.replace("\\t", self.now.strftime("%x %X"))
|
|
string = string.replace("\\u", self.pgexecute.user or "(none)")
|
|
string = string.replace("\\H", self.pgexecute.host or "(none)")
|
|
string = string.replace("\\h", self.pgexecute.short_host or "(none)")
|
|
string = string.replace("\\d", self.pgexecute.dbname or "(none)")
|
|
string = string.replace(
|
|
"\\p",
|
|
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
|
|
)
|
|
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
|
|
string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
|
|
string = string.replace("\\n", "\n")
|
|
return string
|
|
|
|
def get_last_query(self):
|
|
"""Get the last query executed or None."""
|
|
return self.query_history[-1][0] if self.query_history else None
|
|
|
|
def is_too_wide(self, line):
|
|
"""Will this line be too wide to fit into terminal?"""
|
|
if not self.prompt_app:
|
|
return False
|
|
return (
|
|
len(COLOR_CODE_REGEX.sub("", line))
|
|
> self.prompt_app.output.get_size().columns
|
|
)
|
|
|
|
def is_too_tall(self, lines):
|
|
"""Are there too many lines to fit into terminal?"""
|
|
if not self.prompt_app:
|
|
return False
|
|
return len(lines) >= (self.prompt_app.output.get_size().rows - 4)
|
|
|
|
def echo_via_pager(self, text, color=None):
|
|
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
|
|
click.echo(text, color=color)
|
|
elif (
|
|
self.pgspecial.pager_config == PAGER_LONG_OUTPUT
|
|
and self.table_format != "csv"
|
|
):
|
|
lines = text.split("\n")
|
|
|
|
# The last 4 lines are reserved for the pgcli menu and padding
|
|
if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
|
|
click.echo_via_pager(text, color=color)
|
|
else:
|
|
click.echo(text, color=color)
|
|
else:
|
|
click.echo_via_pager(text, color)
|
|
|
|
|
|
@click.command()
|
|
# Default host is '' so psycopg2 can default to either localhost or unix socket
|
|
@click.option(
|
|
"-h",
|
|
"--host",
|
|
default="",
|
|
envvar="PGHOST",
|
|
help="Host address of the postgres database.",
|
|
)
|
|
@click.option(
|
|
"-p",
|
|
"--port",
|
|
default=5432,
|
|
help="Port number at which the " "postgres instance is listening.",
|
|
envvar="PGPORT",
|
|
type=click.INT,
|
|
)
|
|
@click.option(
|
|
"-U",
|
|
"--username",
|
|
"username_opt",
|
|
help="Username to connect to the postgres database.",
|
|
)
|
|
@click.option(
|
|
"-u", "--user", "username_opt", help="Username to connect to the postgres database."
|
|
)
|
|
@click.option(
|
|
"-W",
|
|
"--password",
|
|
"prompt_passwd",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Force password prompt.",
|
|
)
|
|
@click.option(
|
|
"-w",
|
|
"--no-password",
|
|
"never_prompt",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Never prompt for password.",
|
|
)
|
|
@click.option(
|
|
"--single-connection",
|
|
"single_connection",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Do not use a separate connection for completions.",
|
|
)
|
|
@click.option("-v", "--version", is_flag=True, help="Version of pgcli.")
|
|
@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.")
|
|
@click.option(
|
|
"--pgclirc",
|
|
default=config_location() + "config",
|
|
envvar="PGCLIRC",
|
|
help="Location of pgclirc file.",
|
|
type=click.Path(dir_okay=False),
|
|
)
|
|
@click.option(
|
|
"-D",
|
|
"--dsn",
|
|
default="",
|
|
envvar="DSN",
|
|
help="Use DSN configured into the [alias_dsn] section of pgclirc file.",
|
|
)
|
|
@click.option(
|
|
"--list-dsn",
|
|
"list_dsn",
|
|
is_flag=True,
|
|
help="list of DSN configured into the [alias_dsn] section of pgclirc file.",
|
|
)
|
|
@click.option(
|
|
"--row-limit",
|
|
default=None,
|
|
envvar="PGROWLIMIT",
|
|
type=click.INT,
|
|
help="Set threshold for row limit prompt. Use 0 to disable prompt.",
|
|
)
|
|
@click.option(
|
|
"--less-chatty",
|
|
"less_chatty",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Skip intro on startup and goodbye on exit.",
|
|
)
|
|
@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").')
|
|
@click.option(
|
|
"--prompt-dsn",
|
|
help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").',
|
|
)
|
|
@click.option(
|
|
"-l",
|
|
"--list",
|
|
"list_databases",
|
|
is_flag=True,
|
|
help="list available databases, then exit.",
|
|
)
|
|
@click.option(
|
|
"--auto-vertical-output",
|
|
is_flag=True,
|
|
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
|
|
)
|
|
@click.option(
|
|
"--warn",
|
|
default=None,
|
|
type=click.Choice(["all", "moderate", "off"]),
|
|
help="Warn before running a destructive query.",
|
|
)
|
|
@click.option(
|
|
"--ssh-tunnel",
|
|
default=None,
|
|
help="Open an SSH tunnel to the given address and connect to the database from it.",
|
|
)
|
|
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
|
|
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
|
|
def cli(
|
|
dbname,
|
|
username_opt,
|
|
host,
|
|
port,
|
|
prompt_passwd,
|
|
never_prompt,
|
|
single_connection,
|
|
dbname_opt,
|
|
username,
|
|
version,
|
|
pgclirc,
|
|
dsn,
|
|
row_limit,
|
|
less_chatty,
|
|
prompt,
|
|
prompt_dsn,
|
|
list_databases,
|
|
auto_vertical_output,
|
|
list_dsn,
|
|
warn,
|
|
ssh_tunnel: str,
|
|
):
|
|
if version:
|
|
print("Version:", __version__)
|
|
sys.exit(0)
|
|
|
|
config_dir = os.path.dirname(config_location())
|
|
if not os.path.exists(config_dir):
|
|
os.makedirs(config_dir)
|
|
|
|
# Migrate the config file from old location.
|
|
config_full_path = config_location() + "config"
|
|
if os.path.exists(os.path.expanduser("~/.pgclirc")):
|
|
if not os.path.exists(config_full_path):
|
|
shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path)
|
|
print("Config file (~/.pgclirc) moved to new location", config_full_path)
|
|
else:
|
|
print("Config file is now located at", config_full_path)
|
|
print(
|
|
"Please move the existing config file ~/.pgclirc to",
|
|
config_full_path,
|
|
)
|
|
if list_dsn:
|
|
try:
|
|
cfg = load_config(pgclirc, config_full_path)
|
|
for alias in cfg["alias_dsn"]:
|
|
click.secho(alias + " : " + cfg["alias_dsn"][alias])
|
|
sys.exit(0)
|
|
except Exception as err:
|
|
click.secho(
|
|
"Invalid DSNs found in the config file. "
|
|
'Please check the "[alias_dsn]" section in pgclirc.',
|
|
err=True,
|
|
fg="red",
|
|
)
|
|
exit(1)
|
|
|
|
if ssh_tunnel and not SSH_TUNNEL_SUPPORT:
|
|
click.secho(
|
|
'Cannot open SSH tunnel, "sshtunnel" package was not found. '
|
|
"Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.",
|
|
err=True,
|
|
fg="red",
|
|
)
|
|
exit(1)
|
|
|
|
pgcli = PGCli(
|
|
prompt_passwd,
|
|
never_prompt,
|
|
pgclirc_file=pgclirc,
|
|
row_limit=row_limit,
|
|
single_connection=single_connection,
|
|
less_chatty=less_chatty,
|
|
prompt=prompt,
|
|
prompt_dsn=prompt_dsn,
|
|
auto_vertical_output=auto_vertical_output,
|
|
warn=warn,
|
|
ssh_tunnel_url=ssh_tunnel,
|
|
)
|
|
|
|
# Choose which ever one has a valid value.
|
|
if dbname_opt and dbname:
|
|
# work as psql: when database is given as option and argument use the argument as user
|
|
username = dbname
|
|
database = dbname_opt or dbname or ""
|
|
user = username_opt or username
|
|
service = None
|
|
if database.startswith("service="):
|
|
service = database[8:]
|
|
elif os.getenv("PGSERVICE") is not None:
|
|
service = os.getenv("PGSERVICE")
|
|
# because option --list or -l are not supposed to have a db name
|
|
if list_databases:
|
|
database = "postgres"
|
|
|
|
if dsn != "":
|
|
try:
|
|
cfg = load_config(pgclirc, config_full_path)
|
|
dsn_config = cfg["alias_dsn"][dsn]
|
|
except KeyError:
|
|
click.secho(
|
|
f"Could not find a DSN with alias {dsn}. "
|
|
'Please check the "[alias_dsn]" section in pgclirc.',
|
|
err=True,
|
|
fg="red",
|
|
)
|
|
exit(1)
|
|
except Exception:
|
|
click.secho(
|
|
"Invalid DSNs found in the config file. "
|
|
'Please check the "[alias_dsn]" section in pgclirc.',
|
|
err=True,
|
|
fg="red",
|
|
)
|
|
exit(1)
|
|
pgcli.connect_uri(dsn_config)
|
|
pgcli.dsn_alias = dsn
|
|
elif "://" in database:
|
|
pgcli.connect_uri(database)
|
|
elif "=" in database and service is None:
|
|
pgcli.connect_dsn(database, user=user)
|
|
elif service is not None:
|
|
pgcli.connect_service(service, user)
|
|
else:
|
|
pgcli.connect(database, host, user, port)
|
|
|
|
if list_databases:
|
|
cur, headers, status = pgcli.pgexecute.full_databases()
|
|
|
|
title = "List of databases"
|
|
settings = OutputSettings(table_format="ascii", missingval="<null>")
|
|
formatted = format_output(title, cur, headers, status, settings)
|
|
pgcli.echo_via_pager("\n".join(formatted))
|
|
|
|
sys.exit(0)
|
|
|
|
pgcli.logger.debug(
|
|
"Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r",
|
|
database,
|
|
user,
|
|
host,
|
|
port,
|
|
)
|
|
|
|
if setproctitle:
|
|
obfuscate_process_password()
|
|
|
|
pgcli.run_cli()
|
|
|
|
|
|
def obfuscate_process_password():
|
|
process_title = setproctitle.getproctitle()
|
|
if "://" in process_title:
|
|
process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title)
|
|
elif "=" in process_title:
|
|
process_title = re.sub(
|
|
r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title
|
|
)
|
|
|
|
setproctitle.setproctitle(process_title)
|
|
|
|
|
|
def has_meta_cmd(query):
|
|
"""Determines if the completion needs a refresh by checking if the sql
|
|
statement is an alter, create, drop, commit or rollback."""
|
|
try:
|
|
first_token = query.split()[0]
|
|
if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
def has_change_db_cmd(query):
|
|
"""Determines if the statement is a database switch such as 'use' or '\\c'"""
|
|
try:
|
|
first_token = query.split()[0]
|
|
if first_token.lower() in ("use", "\\c", "\\connect"):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
def has_change_path_cmd(sql):
|
|
"""Determines if the search_path should be refreshed by checking if the
|
|
sql has 'set search_path'."""
|
|
return "set search_path" in sql.lower()
|
|
|
|
|
|
def is_mutating(status):
|
|
"""Determines if the statement is mutating based on the status."""
|
|
if not status:
|
|
return False
|
|
|
|
mutating = {"insert", "update", "delete"}
|
|
return status.split(None, 1)[0].lower() in mutating
|
|
|
|
|
|
def is_select(status):
|
|
"""Returns true if the first word in status is 'select'."""
|
|
if not status:
|
|
return False
|
|
return status.split(None, 1)[0].lower() == "select"
|
|
|
|
|
|
def exception_formatter(e):
|
|
return click.style(str(e), fg="red")
|
|
|
|
|
|
def format_output(title, cur, headers, status, settings):
|
|
output = []
|
|
expanded = settings.expanded or settings.table_format == "vertical"
|
|
table_format = "vertical" if settings.expanded else settings.table_format
|
|
max_width = settings.max_width
|
|
case_function = settings.case_function
|
|
formatter = TabularOutputFormatter(format_name=table_format)
|
|
|
|
def format_array(val):
|
|
if val is None:
|
|
return settings.missingval
|
|
if not isinstance(val, list):
|
|
return val
|
|
return "{" + ",".join(str(format_array(e)) for e in val) + "}"
|
|
|
|
def format_arrays(data, headers, **_):
|
|
data = list(data)
|
|
for row in data:
|
|
row[:] = [
|
|
format_array(val) if isinstance(val, list) else val for val in row
|
|
]
|
|
|
|
return data, headers
|
|
|
|
def format_status(cur, status):
|
|
# redshift does not return rowcount as part of status.
|
|
# See https://github.com/dbcli/pgcli/issues/1320
|
|
if cur and hasattr(cur, "rowcount") and cur.rowcount is not None:
|
|
if status and not status.endswith(str(cur.rowcount)):
|
|
status += " %s" % cur.rowcount
|
|
return status
|
|
|
|
output_kwargs = {
|
|
"sep_title": "RECORD {n}",
|
|
"sep_character": "-",
|
|
"sep_length": (1, 25),
|
|
"missing_value": settings.missingval,
|
|
"integer_format": settings.dcmlfmt,
|
|
"float_format": settings.floatfmt,
|
|
"preprocessors": (format_numbers, format_arrays),
|
|
"disable_numparse": True,
|
|
"preserve_whitespace": True,
|
|
"style": settings.style_output,
|
|
"max_field_width": settings.max_field_width,
|
|
}
|
|
if not settings.floatfmt:
|
|
output_kwargs["preprocessors"] = (align_decimals,)
|
|
|
|
if table_format == "csv":
|
|
# The default CSV dialect is "excel" which is not handling newline values correctly
|
|
# Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n'
|
|
# as the line terminator
|
|
# https://github.com/dbcli/pgcli/issues/1102
|
|
dialect = "excel" if platform.system() == "Windows" else "unix"
|
|
output_kwargs["dialect"] = dialect
|
|
|
|
if title: # Only print the title if it's not None.
|
|
output.append(title)
|
|
|
|
if cur:
|
|
headers = [case_function(x) for x in headers]
|
|
if max_width is not None:
|
|
cur = list(cur)
|
|
column_types = None
|
|
if hasattr(cur, "description"):
|
|
column_types = []
|
|
for d in cur.description:
|
|
# pg3: type_name = cur.adapters.types[d.type_code].name
|
|
if (
|
|
# pg3: type_name in ("numeric", "float4", "float8")
|
|
d[1] in psycopg2.extensions.DECIMAL.values
|
|
or d[1] in psycopg2.extensions.FLOAT.values
|
|
):
|
|
column_types.append(float)
|
|
if (
|
|
# pg3: type_name in ("int2", "int4", "int8")
|
|
d[1] == psycopg2.extensions.INTEGER.values
|
|
or d[1] in psycopg2.extensions.LONGINTEGER.values
|
|
):
|
|
column_types.append(int)
|
|
else:
|
|
column_types.append(str)
|
|
|
|
formatted = formatter.format_output(cur, headers, **output_kwargs)
|
|
if isinstance(formatted, str):
|
|
formatted = iter(formatted.splitlines())
|
|
first_line = next(formatted)
|
|
formatted = itertools.chain([first_line], formatted)
|
|
if (
|
|
not expanded
|
|
and max_width
|
|
and len(strip_ansi(first_line)) > max_width
|
|
and headers
|
|
):
|
|
formatted = formatter.format_output(
|
|
cur, headers, format_name="vertical", column_types=None, **output_kwargs
|
|
)
|
|
if isinstance(formatted, str):
|
|
formatted = iter(formatted.splitlines())
|
|
|
|
output = itertools.chain(output, formatted)
|
|
|
|
# Only print the status if it's not None and we are not producing CSV
|
|
if status and table_format != "csv":
|
|
output = itertools.chain(output, [format_status(cur, status)])
|
|
|
|
return output
|
|
|
|
|
|
def parse_service_info(service):
|
|
service = service or os.getenv("PGSERVICE")
|
|
service_file = os.getenv("PGSERVICEFILE")
|
|
if not service_file:
|
|
# try ~/.pg_service.conf (if that exists)
|
|
if platform.system() == "Windows":
|
|
service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf"
|
|
elif os.getenv("PGSYSCONFDIR"):
|
|
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
|
|
else:
|
|
service_file = os.path.expanduser("~/.pg_service.conf")
|
|
if not service or not os.path.exists(service_file):
|
|
# nothing to do
|
|
return None, service_file
|
|
with open(service_file, newline="") as f:
|
|
skipped_lines = skip_initial_comment(f)
|
|
try:
|
|
service_file_config = ConfigObj(f)
|
|
except ParseError as err:
|
|
err.line_number += skipped_lines
|
|
raise err
|
|
if service not in service_file_config:
|
|
return None, service_file
|
|
service_conf = service_file_config.get(service)
|
|
return service_conf, service_file
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|