1
0
Fork 0

Merging upstream version 3.5.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 20:01:39 +01:00
parent 7a56138e00
commit 6bbbbdf0c7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
43 changed files with 1272 additions and 430 deletions

View file

@ -1,155 +1,45 @@
import logging
import select
import traceback
from collections import namedtuple
import pgspecial as special
import psycopg2
import psycopg2.errorcodes
import psycopg2.extensions as ext
import psycopg2.extras
import psycopg
import psycopg.sql
from psycopg.conninfo import make_conninfo
import sqlparse
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
# Cast all database input to unicode automatically.
# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info.
# pg3: These should be automatic: unicode is the default
ext.register_type(ext.UNICODE)
ext.register_type(ext.UNICODEARRAY)
ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE))
# See https://github.com/dbcli/pgcli/issues/426 for more details.
# This registers a unicode type caster for datatype 'RECORD'.
ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
# Cast bytea fields to text. By default, this will render as hex strings with
# Postgres 9+ and as escaped binary in earlier versions.
ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
_wait_callback_is_set = False
ViewDef = namedtuple(
"ViewDef", "nspname relname relkind viewdef reloptions checkoption"
)
# pg3: it is already "green" but Ctrl-C breaks the query
# pg3: This should be fixed upstream: https://github.com/psycopg/psycopg/issues/231
def _wait_select(conn):
"""
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
try:
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except OSError as e:
errno = e.args[0]
if errno != 4:
raise
except psycopg2.OperationalError:
pass
def _set_wait_callback(is_virtual_database):
global _wait_callback_is_set
if _wait_callback_is_set:
return
_wait_callback_is_set = True
if is_virtual_database:
return
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
# pg3: You can do something like:
# pg3: cnn.adapters.register_loader("date", psycopg.types.string.TextLoader)
def register_date_typecasters(connection):
"""
Casts date and timestamp values to string, resolves issues with out of
range dates (e.g. BC) which psycopg2 can't handle
"""
def cast_date(value, cursor):
return value
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
if cursor.description is None:
return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp with time zone")
timestamptz_oid = cursor.description[0][1]
oids = (date_oid, timestamp_oid, timestamptz_oid)
new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date)
psycopg2.extensions.register_type(new_type)
def register_json_typecasters(conn, loads_fn):
"""Set the function for converting JSON data for a connection.
Use the supplied function to decode JSON data returned from the database
via the given connection. The function should accept a single argument of
the data as a string encoded in the database's character encoding.
psycopg2's default handler for JSON data is json.loads.
http://initd.org/psycopg/docs/extras.html#json-adaptation
This function attempts to register the typecaster for both JSON and JSONB
types.
Returns a set that is a subset of {'json', 'jsonb'} indicating which types
(if any) were successfully registered.
"""
available = set()
for name in ["json", "jsonb"]:
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
# pg3: Probably you don't need this because by default unknown -> unicode
def register_hstore_typecaster(conn):
"""
Instead of using register_hstore() which converts hstore into a python
dict, we query the 'oid' of hstore which will be different for each
database and register a type caster that converts it to unicode.
http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore
"""
with conn.cursor() as cur:
try:
cur.execute(
"select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined"
)
oid = cur.fetchone()[0]
ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
except Exception:
pass
def register_typecasters(connection):
"""Casts date and timestamp values to string, resolves issues with out-of-range
dates (e.g. BC) which psycopg can't handle"""
for forced_text_type in [
"date",
"time",
"timestamp",
"timestamptz",
"bytea",
"json",
"jsonb",
]:
connection.adapters.register_loader(
forced_text_type, psycopg.types.string.TextLoader
)
# pg3: I don't know what is this
class ProtocolSafeCursor(psycopg2.extensions.cursor):
class ProtocolSafeCursor(psycopg.Cursor):
"""This class wraps and suppresses Protocol Errors with pgbouncer database.
See https://github.com/dbcli/pgcli/pull/1097.
Pgbouncer database is a virtual database with its own set of commands."""
def __init__(self, *args, **kwargs):
self.protocol_error = False
self.protocol_message = ""
@ -170,14 +60,18 @@ class ProtocolSafeCursor(psycopg2.extensions.cursor):
return (self.protocol_message,)
return super().fetchone()
def execute(self, sql, args=None):
# def mogrify(self, query, params):
# args = [Literal(v).as_string(self.connection) for v in params]
# return query % tuple(args)
#
def execute(self, *args, **kwargs):
try:
psycopg2.extensions.cursor.execute(self, sql, args)
super().execute(*args, **kwargs)
self.protocol_error = False
self.protocol_message = ""
except psycopg2.errors.ProtocolViolation as ex:
except psycopg.errors.ProtocolViolation as ex:
self.protocol_error = True
self.protocol_message = ex.pgerror
self.protocol_message = str(ex)
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
@ -290,7 +184,7 @@ class PGExecute:
conn_params = self._conn_params.copy()
new_params = {
"database": database,
"dbname": database,
"user": user,
"password": password,
"host": host,
@ -303,15 +197,15 @@ class PGExecute:
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
if new_params["password"]:
new_params["dsn"] = make_dsn(
new_params["dsn"] = make_conninfo(
new_params["dsn"], password=new_params.pop("password")
)
conn_params.update({k: v for k, v in new_params.items() if v})
conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
conn.set_client_encoding("utf8")
conn_info = make_conninfo(**conn_params)
conn = psycopg.connect(conn_info)
conn.cursor_factory = ProtocolSafeCursor
self._conn_params = conn_params
if self.conn:
@ -322,19 +216,7 @@ class PGExecute:
# When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it.
# Note: moved this after setting autocommit because of #664.
libpq_version = psycopg2.__libpq_version__
dsn_parameters = {}
if libpq_version >= 93000:
# use actual connection info from psycopg2.extensions.Connection.info
# as libpq_version > 9.3 is available and required dependency
dsn_parameters = conn.info.dsn_parameters
else:
try:
dsn_parameters = conn.get_dsn_parameters()
except Exception as x:
# https://github.com/dbcli/pgcli/issues/1110
# PQconninfo not available in libpq < 9.3
_logger.info("Exception in get_dsn_parameters: %r", x)
dsn_parameters = conn.info.get_parameters()
if dsn_parameters:
self.dbname = dsn_parameters.get("dbname")
@ -357,16 +239,14 @@ class PGExecute:
else self.get_socket_directory()
)
self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
self.server_version = conn.get_parameter_status("server_version") or ""
self.pid = conn.info.backend_pid
self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1")
self.server_version = conn.info.parameter_status("server_version") or ""
_set_wait_callback(self.is_virtual_database())
# _set_wait_callback(self.is_virtual_database())
if not self.is_virtual_database():
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
register_typecasters(conn)
@property
def short_host(self):
@ -387,31 +267,23 @@ class PGExecute:
cur.execute(sql)
return cur.fetchone()
def _json_typecaster(self, json_data):
"""Interpret incoming JSON data as a string.
The raw data is decoded using the connection's encoding, which defaults
to the database's encoding.
See http://initd.org/psycopg/docs/connection.html#connection.encoding
"""
return json_data
def failed_transaction(self):
# pg3: self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
status = self.conn.get_transaction_status()
return status == ext.TRANSACTION_STATUS_INERROR
return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
def valid_transaction(self):
status = self.conn.get_transaction_status()
status = self.conn.info.transaction_status
return (
status == ext.TRANSACTION_STATUS_ACTIVE
or status == ext.TRANSACTION_STATUS_INTRANS
status == psycopg.pq.TransactionStatus.ACTIVE
or status == psycopg.pq.TransactionStatus.INTRANS
)
def run(
self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False
self,
statement,
pgspecial=None,
exception_formatter=None,
on_error_resume=False,
explain_mode=False,
):
"""Execute the sql in the database and return the results.
@ -432,17 +304,38 @@ class PGExecute:
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
yield (None, None, None, None, statement, False, False)
yield None, None, None, None, statement, False, False
# Split the sql into separate queries and run each one.
for sql in sqlparse.split(statement):
# sql parse doesn't split on a comment first + special
# so we're going to do it
sqltemp = []
sqlarr = []
if statement.startswith("--"):
sqltemp = statement.split("\n")
sqlarr.append(sqltemp[0])
for i in sqlparse.split(sqltemp[1]):
sqlarr.append(i)
elif statement.startswith("/*"):
sqltemp = statement.split("*/")
sqltemp[0] = sqltemp[0] + "*/"
for i in sqlparse.split(sqltemp[1]):
sqlarr.append(i)
else:
sqlarr = sqlparse.split(statement)
# run each sql query
for sql in sqlarr:
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(";")
sql = sqlparse.format(sql, strip_comments=False).strip()
if not sql:
continue
try:
if pgspecial:
if explain_mode:
sql = self.explain_prefix() + sql
elif pgspecial:
# \G is treated specially since we have to set the expanded output.
if sql.endswith("\\G"):
if not pgspecial.expanded_output:
@ -454,7 +347,7 @@ class PGExecute:
_logger.debug("Trying a pgspecial command. sql: %r", sql)
try:
cur = self.conn.cursor()
except psycopg2.InterfaceError:
except psycopg.InterfaceError:
# edge case when connection is already closed, but we
# don't need cursor for special_cmd.arg_type == NO_QUERY.
# See https://github.com/dbcli/pgcli/issues/1014.
@ -478,7 +371,7 @@ class PGExecute:
# Not a special command, so execute as normal sql
yield self.execute_normal_sql(sql) + (sql, True, False)
except psycopg2.DatabaseError as e:
except psycopg.DatabaseError as e:
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
@ -498,7 +391,7 @@ class PGExecute:
"""Return true if e is an error that should not be caught in ``run``.
An uncaught error will prompt the user to reconnect; as long as we
detect that the connection is stil open, we catch the error, as
detect that the connection is still open, we catch the error, as
reconnecting won't solve that problem.
:param e: DatabaseError. An exception raised while executing a query.
@ -511,14 +404,24 @@ class PGExecute:
def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
_logger.debug("Regular sql statement. sql: %r", split_sql)
title = ""
def handle_notices(n):
nonlocal title
title = f"{n.message_primary}\n{n.message_detail}\n{title}"
self.conn.add_notice_handler(handle_notices)
if self.is_virtual_database() and "show help" in split_sql.lower():
# see https://github.com/psycopg/psycopg/issues/303
# special case "show help" in pgbouncer
res = self.conn.pgconn.exec_(split_sql.encode())
return title, None, None, res.command_status.decode()
cur = self.conn.cursor()
cur.execute(split_sql)
# conn.notices persist between queies, we use pop to clear out the list
title = ""
while len(self.conn.notices) > 0:
title = self.conn.notices.pop() + title
# cur.description will be None for operations that do not return
# rows.
if cur.description:
@ -539,7 +442,7 @@ class PGExecute:
_logger.debug("Search path query. sql: %r", self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
except psycopg2.ProgrammingError:
except psycopg.ProgrammingError:
fallback = "SELECT * FROM current_schemas(true)"
with self.conn.cursor() as cur:
_logger.debug("Search path query. sql: %r", fallback)
@ -549,9 +452,6 @@ class PGExecute:
def view_definition(self, spec):
"""Returns the SQL defining views described by `spec`"""
# pg3: you may want to use `psycopg.sql` for client-side composition
# pg3: (also available in psycopg2 by the way)
template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}"
# 2: relkind, v or m (materialized)
# 4: reloptions, null
# 5: checkoption: local or cascaded
@ -560,11 +460,21 @@ class PGExecute:
_logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
except psycopg.ProgrammingError:
raise RuntimeError(f"View {spec} does not exist.")
result = cur.fetchone()
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
result = ViewDef(*cur.fetchone())
if result.relkind == "m":
template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}"
else:
template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}"
return (
psycopg.sql.SQL(template)
.format(
name=psycopg.sql.Identifier(f"{result.nspname}.{result.relname}"),
stmt=psycopg.sql.SQL(result.viewdef),
)
.as_string(self.conn)
)
def function_definition(self, spec):
"""Returns the SQL defining functions described by `spec`"""
@ -576,7 +486,7 @@ class PGExecute:
cur.execute(sql, (spec,))
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
except psycopg.ProgrammingError:
raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
@ -600,9 +510,9 @@ class PGExecute:
"""
with self.conn.cursor() as cur:
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
# sql = cur.mogrify(self.tables_query, kinds)
# _logger.debug("Tables Query. sql: %r", sql)
cur.execute(self.tables_query, [kinds])
yield from cur
def tables(self):
@ -628,7 +538,7 @@ class PGExecute:
:return: list of (schema_name, relation_name, column_name, column_type) tuples
"""
if self.conn.server_version >= 80400:
if self.conn.info.server_version >= 80400:
columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
@ -669,9 +579,9 @@ class PGExecute:
ORDER BY 1, 2, att.attnum"""
with self.conn.cursor() as cur:
sql = cur.mogrify(columns_query, [kinds])
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
# sql = cur.mogrify(columns_query, kinds)
# _logger.debug("Columns Query. sql: %r", sql)
cur.execute(columns_query, [kinds])
yield from cur
def table_columns(self):
@ -712,7 +622,7 @@ class PGExecute:
def foreignkeys(self):
"""Yields ForeignKey named tuples"""
if self.conn.server_version < 90000:
if self.conn.info.server_version < 90000:
return
with self.conn.cursor() as cur:
@ -752,7 +662,7 @@ class PGExecute:
def functions(self):
"""Yields FunctionMetadata named tuples"""
if self.conn.server_version >= 110000:
if self.conn.info.server_version >= 110000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@ -772,7 +682,7 @@ class PGExecute:
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
elif self.conn.server_version > 90000:
elif self.conn.info.server_version > 90000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@ -792,7 +702,7 @@ class PGExecute:
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
elif self.conn.server_version >= 80400:
elif self.conn.info.server_version >= 80400:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@ -843,7 +753,7 @@ class PGExecute:
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
if self.conn.server_version > 90000:
if self.conn.info.server_version > 90000:
query = """
SELECT n.nspname schema_name,
t.typname type_name
@ -931,3 +841,6 @@ class PGExecute:
cur.execute(query)
for row in cur:
yield row[0]
def explain_prefix(self):
return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) "