1
0
Fork 0

Merging upstream version 3.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 19:58:08 +01:00
parent a868bb3d29
commit 39b7cc8559
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
50 changed files with 952 additions and 634 deletions

View file

@ -1,13 +1,15 @@
import traceback
import logging
import select
import traceback
import pgspecial as special
import psycopg2
import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
import psycopg2.extras
import sqlparse
import pgspecial as special
import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
_wait_callback_is_set = False
def _wait_select(conn):
@ -34,31 +37,41 @@ def _wait_select(conn):
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except select.error as e:
errno = e.args[0]
if errno != 4:
raise
try:
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except OSError as e:
errno = e.args[0]
if errno != 4:
raise
except psycopg2.OperationalError:
pass
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
def _set_wait_callback(is_virtual_database):
global _wait_callback_is_set
if _wait_callback_is_set:
return
_wait_callback_is_set = True
if is_virtual_database:
return
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
@ -72,6 +85,8 @@ def register_date_typecasters(connection):
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
if cursor.description is None:
return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
except psycopg2.ProgrammingError:
except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
@ -127,7 +142,39 @@ def register_hstore_typecaster(conn):
pass
class PGExecute(object):
class ProtocolSafeCursor(psycopg2.extensions.cursor):
def __init__(self, *args, **kwargs):
self.protocol_error = False
self.protocol_message = ""
super().__init__(*args, **kwargs)
def __iter__(self):
if self.protocol_error:
raise StopIteration
return super().__iter__()
def fetchall(self):
if self.protocol_error:
return [(self.protocol_message,)]
return super().fetchall()
def fetchone(self):
if self.protocol_error:
return (self.protocol_message,)
return super().fetchone()
def execute(self, sql, args=None):
try:
psycopg2.extensions.cursor.execute(self, sql, args)
self.protocol_error = False
self.protocol_message = ""
except psycopg2.errors.ProtocolViolation as ex:
self.protocol_error = True
self.protocol_message = ex.pgerror
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
class PGExecute:
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
@ -190,8 +237,6 @@ class PGExecute(object):
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
version_query = "SELECT version();"
def __init__(
self,
database=None,
@ -203,6 +248,7 @@ class PGExecute(object):
**kwargs,
):
self._conn_params = {}
self._is_virtual_database = None
self.conn = None
self.dbname = None
self.user = None
@ -214,6 +260,11 @@ class PGExecute(object):
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
def is_virtual_database(self):
if self._is_virtual_database is None:
self._is_virtual_database = self.is_protocol_error()
return self._is_virtual_database
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
@ -250,9 +301,9 @@ class PGExecute(object):
)
conn_params.update({k: v for k, v in new_params.items() if v})
conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
@ -293,16 +344,22 @@ class PGExecute(object):
self.extra_args = kwargs
if not self.host:
self.host = self.get_socket_directory()
self.host = (
"pgbouncer"
if self.is_virtual_database()
else self.get_socket_directory()
)
pid = self._select_one(cursor, "select pg_backend_pid()")[0]
self.pid = pid
self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
self.server_version = conn.get_parameter_status("server_version")
self.server_version = conn.get_parameter_status("server_version") or ""
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
_set_wait_callback(self.is_virtual_database())
if not self.is_virtual_database():
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
@property
def short_host(self):
@ -395,7 +452,13 @@ class PGExecute(object):
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
for result in pgspecial.execute(cur, sql):
response = pgspecial.execute(cur, sql)
if cur and cur.protocol_error:
yield None, None, None, cur.protocol_message, statement, False, False
# this would close connection. We should reconnect.
self.connect()
continue
for result in response:
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
@ -453,6 +516,9 @@ class PGExecute(object):
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
elif cur.protocol_error:
_logger.debug("Protocol error, unsupported command.")
return title, None, None, cur.protocol_message
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
@ -485,7 +551,7 @@ class PGExecute(object):
try:
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
raise RuntimeError("View {} does not exist.".format(spec))
raise RuntimeError(f"View {spec} does not exist.")
result = cur.fetchone()
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
@ -501,7 +567,7 @@ class PGExecute(object):
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
raise RuntimeError("Function {} does not exist.".format(spec))
raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
"""Returns a list of schema names in the database"""
@ -527,21 +593,18 @@ class PGExecute(object):
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
yield from cur
def tables(self):
"""Yields (schema_name, table_name) tuples"""
for row in self._relations(kinds=["r", "p", "f"]):
yield row
yield from self._relations(kinds=["r", "p", "f"])
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
for row in self._relations(kinds=["v", "m"]):
yield row
yield from self._relations(kinds=["v", "m"])
def _columns(self, kinds=("r", "p", "f", "v", "m")):
"""Get column metadata for tables and views
@ -599,16 +662,13 @@ class PGExecute(object):
sql = cur.mogrify(columns_query, [kinds])
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
yield from cur
def table_columns(self):
for row in self._columns(kinds=["r", "p", "f"]):
yield row
yield from self._columns(kinds=["r", "p", "f"])
def view_columns(self):
for row in self._columns(kinds=["v", "m"]):
yield row
yield from self._columns(kinds=["v", "m"])
def databases(self):
with self.conn.cursor() as cur:
@ -623,6 +683,13 @@ class PGExecute(object):
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
def is_protocol_error(self):
query = "SELECT 1"
with self.conn.cursor() as cur:
_logger.debug("Simple Query. sql: %r", query)
cur.execute(query)
return bool(cur.protocol_error)
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(
@ -804,8 +871,7 @@ class PGExecute(object):
"""
_logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row
yield from cur
def casing(self):
"""Yields the most common casing for names used in db functions"""