1
0
Fork 0
harlequin-mysql/src/harlequin_mysql/adapter.py
Daniel Baumann 7269eaf22a
Adding upstream version 1.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-24 20:14:26 +01:00

428 lines
14 KiB
Python

from __future__ import annotations
import re
from contextlib import suppress
from typing import Any, Sequence
from harlequin import (
HarlequinAdapter,
HarlequinConnection,
HarlequinCursor,
)
from harlequin.autocomplete.completion import HarlequinCompletion
from harlequin.catalog import Catalog, CatalogItem
from harlequin.exception import (
HarlequinConfigError,
HarlequinConnectionError,
HarlequinQueryError,
)
from mysql.connector import FieldType
from mysql.connector.cursor import MySQLCursor
from mysql.connector.errors import InternalError, PoolError
from mysql.connector.pooling import (
MySQLConnectionPool,
PooledMySQLConnection,
)
from textual_fastdatatable.backend import AutoBackendType
from harlequin_mysql.catalog import DatabaseCatalogItem
from harlequin_mysql.cli_options import MYSQLADAPTER_OPTIONS
from harlequin_mysql.completions import load_completions
USE_DATABASE_PROG = re.compile(
r"\s*use\s+([^\\/?%*:|\"<>.]{1,64})", flags=re.IGNORECASE
)
QUERY_INTERRUPT_MSG = "1317 (70100): Query execution was interrupted"
class HarlequinMySQLCursor(HarlequinCursor):
def __init__(
self,
cur: MySQLCursor,
conn: PooledMySQLConnection,
harlequin_conn: HarlequinMySQLConnection,
*_: Any,
**__: Any,
) -> None:
self.cur = cur
# copy description in case the cursor is closed before columns() is called
assert cur.description is not None
self.description = cur.description.copy()
self.conn = conn
self.harlequin_conn = harlequin_conn
self.connection_id = conn._cnx.connection_id
self._limit: int | None = None
def columns(self) -> list[tuple[str, str]]:
return [(col[0], self._get_short_type(col[1])) for col in self.description]
def set_limit(self, limit: int) -> "HarlequinMySQLCursor":
self._limit = limit
return self
def fetchall(self) -> AutoBackendType:
try:
if self._limit is None:
results = self.cur.fetchall()
else:
results = self.cur.fetchmany(self._limit)
return results
except Exception as e:
if str(e) == QUERY_INTERRUPT_MSG:
return []
else:
raise HarlequinQueryError(
msg=str(e),
title="Harlequin encountered an error while executing your query.",
) from e
finally:
self.conn.consume_results()
self.cur.close()
self.conn.close()
if self.connection_id:
self.harlequin_conn._in_use_connections.discard(self.connection_id)
@staticmethod
def _get_short_type(type_id: int) -> str:
mapping = {
FieldType.BIT: "010",
FieldType.BLOB: "0b",
FieldType.DATE: "d",
FieldType.DATETIME: "dt",
FieldType.DECIMAL: "#.#",
FieldType.DOUBLE: "#.#",
FieldType.ENUM: "enum",
FieldType.FLOAT: "#.#",
FieldType.GEOMETRY: "▽□",
FieldType.INT24: "###",
FieldType.JSON: "{}",
FieldType.LONG: "##",
FieldType.LONGLONG: "##",
FieldType.LONG_BLOB: "00b",
FieldType.MEDIUM_BLOB: "00b",
FieldType.NEWDATE: "d",
FieldType.NEWDECIMAL: "#.#",
FieldType.NULL: "",
FieldType.SET: "set",
FieldType.SHORT: "#",
FieldType.STRING: "s",
FieldType.TIME: "t",
FieldType.TIMESTAMP: "#ts",
FieldType.TINY: "#",
FieldType.TINY_BLOB: "b",
FieldType.VARCHAR: "s",
FieldType.VAR_STRING: "s",
FieldType.YEAR: "y",
}
return mapping.get(type_id, "?")
class HarlequinMySQLConnection(HarlequinConnection):
def __init__(
self,
conn_str: Sequence[str],
*_: Any,
init_message: str = "",
options: dict[str, Any],
) -> None:
self.init_message = init_message
self._in_use_connections: set[int] = set()
try:
self._pool: MySQLConnectionPool = MySQLConnectionPool(
pool_name="harlequin",
pool_reset_session=False,
autocommit=True,
**options,
)
except Exception as e:
raise HarlequinConnectionError(
msg=str(e), title="Harlequin could not connect to your database."
) from e
def safe_get_mysql_cursor(
self, buffered: bool = False
) -> tuple[PooledMySQLConnection | None, MySQLCursor | None]:
"""
Return None if the connection pool is exhausted, to avoid getting
in an unrecoverable state.
"""
try:
conn = self._pool.get_connection()
except (InternalError, PoolError):
# if we're out of connections, we can't raise a query error,
# or we get in a state where we have cursors without fetched
# results, which requires a restart of Harlequin. Instead,
# just return None and silently fail (there isn't a sensible
# way to show an error to the user without aborting processing
# all the other cursors).
return None, None
try:
cur: MySQLCursor = conn.cursor(buffered=buffered)
except InternalError:
# cursor has an unread result. Try to consume the results,
# and try again.
conn.consume_results()
cur = conn.cursor(buffered=buffered)
return conn, cur
def set_pool_config(self, **config: Any) -> None:
"""
Updates the config of the MySQL connection pool.
"""
self._pool.set_config(**config)
def execute(self, query: str) -> HarlequinCursor | None:
retval: HarlequinCursor | None = None
conn, cur = self.safe_get_mysql_cursor()
if conn is None or cur is None:
return None
else:
connection_id = conn._cnx.connection_id
if connection_id:
self._in_use_connections.add(connection_id)
try:
cur.execute(query)
except Exception as e:
cur.close()
conn.close()
if connection_id:
self._in_use_connections.discard(connection_id)
if str(e) == QUERY_INTERRUPT_MSG:
return None
else:
raise HarlequinQueryError(
msg=str(e),
title="Harlequin encountered an error while executing your query.",
) from e
else:
if cur.description is not None:
retval = HarlequinMySQLCursor(cur, conn=conn, harlequin_conn=self)
else:
cur.close()
conn.close()
if connection_id:
self._in_use_connections.discard(connection_id)
# this is a hack to update all connections in the pool if the user
# changes the database for the active connection.
# it is impossible to check the database or other config
# of a connection with an open cursor, and we can't use a dedicated
# connection for user queries, since mysql only supports a single
# (unfetched) cursor per connection.
if match := USE_DATABASE_PROG.match(query):
new_db = match.group(1)
self.set_pool_config(database=new_db)
return retval
def cancel(self) -> None:
# get a new cursor to execute the KILL statements
conn, cur = self.safe_get_mysql_cursor()
if conn is None or cur is None:
return None
# loop through in-use connections and kill each of them
for connection_id in self._in_use_connections:
try:
cur.execute("KILL QUERY %s", (connection_id,))
except BaseException:
continue
cur.close()
conn.close()
self._in_use_connections = set()
def close(self) -> None:
with suppress(PoolError):
self._pool._remove_connections()
def get_catalog(self) -> Catalog:
databases = self._get_databases()
db_items: list[CatalogItem] = [
DatabaseCatalogItem.from_label(label=db, connection=self)
for (db,) in databases
]
return Catalog(items=db_items)
def get_completions(self) -> list[HarlequinCompletion]:
return load_completions()
def _get_databases(self) -> list[tuple[str]]:
conn, cur = self.safe_get_mysql_cursor(buffered=True)
if conn is None or cur is None:
raise HarlequinConnectionError(
title="Connection pool exhausted",
msg=(
"Connection pool exhausted. Try restarting Harlequin "
"with a larger pool or running fewer queries at once."
),
)
cur.execute(
"""
show databases
where `Database` not in (
'sys', 'information_schema', 'performance_schema', 'mysql'
)
"""
)
results: list[tuple[str]] = cur.fetchall() # type: ignore
cur.close()
conn.close()
return results
def _get_relations(self, db_name: str) -> list[tuple[str, str]]:
conn, cur = self.safe_get_mysql_cursor(buffered=True)
if conn is None or cur is None:
raise HarlequinConnectionError(
title="Connection pool exhausted",
msg=(
"Connection pool exhausted. Try restarting Harlequin "
"with a larger pool or running fewer queries at once."
),
)
cur.execute(
f"""
select
table_name,
table_type
from information_schema.tables
where table_schema = '{db_name}'
and table_type != 'SYSTEM VIEW'
order by table_name asc
;"""
)
results: list[tuple[str, str]] = cur.fetchall() # type: ignore
cur.close()
conn.close()
return results
def _get_columns(self, db_name: str, rel_name: str) -> list[tuple[str, str]]:
conn, cur = self.safe_get_mysql_cursor(buffered=True)
if conn is None or cur is None:
raise HarlequinConnectionError(
title="Connection pool exhausted",
msg=(
"Connection pool exhausted. Try restarting Harlequin "
"with a larger pool or running fewer queries at once."
),
)
cur.execute(
f"""
select column_name, data_type
from information_schema.columns
where
table_schema = '{db_name}'
and table_name = '{rel_name}'
and extra not like '%INVISIBLE%'
order by ordinal_position asc
;"""
)
results: list[tuple[str, str]] = cur.fetchall() # type: ignore
cur.close()
conn.close()
return results
@staticmethod
def _short_column_type(info_schema_type: str) -> str:
mapping = {
"bigint": "###",
"binary": "010",
"blob": "0b",
"char": "c",
"datetime": "dt",
"decimal": "#.#",
"double": "#.#",
"enum": "enum",
"float": "#.#",
"int": "##",
"json": "{}",
"longblob": "00b",
"longtext": "ss",
"mediumblob": "00b",
"mediumint": "##",
"mediumtext": "s",
"set": "set",
"smallint": "#",
"text": "s",
"time": "t",
"timestamp": "ts",
"tinyint": "#",
"varbinary": "010",
"varchar": "s",
}
return mapping.get(info_schema_type, "?")
class HarlequinMySQLAdapter(HarlequinAdapter):
ADAPTER_OPTIONS = MYSQLADAPTER_OPTIONS
IMPLEMENTS_CANCEL = True
def __init__(
self,
conn_str: Sequence[str],
host: str | None = None,
port: str | int | None = 3306,
unix_socket: str | None = None,
database: str | None = None,
user: str | None = None,
password: str | None = None,
password2: str | None = None,
password3: str | None = None,
connection_timeout: str | int | None = None,
ssl_ca: str | None = None,
ssl_cert: str | None = None,
ssl_disabled: str | bool | None = False,
ssl_key: str | None = None,
openid_token_file: str | None = None,
pool_size: str | int | None = 5,
**_: Any,
) -> None:
if conn_str:
raise HarlequinConnectionError(
f"Cannot provide a DSN to the MySQL adapter. Got:\n{conn_str}"
)
try:
self.options = {
"host": host,
"port": int(port) if port is not None else 3306,
"unix_socket": unix_socket,
"database": database,
"user": user,
"password": password,
"password2": password2,
"password3": password3,
"connection_timeout": int(connection_timeout)
if connection_timeout is not None
else None,
"ssl_ca": ssl_ca,
"ssl_cert": ssl_cert,
"ssl_disabled": ssl_disabled if ssl_disabled is not None else False,
"ssl_key": ssl_key,
"openid_token_file": openid_token_file,
"pool_size": int(pool_size) if pool_size is not None else 5,
}
except (ValueError, TypeError) as e:
raise HarlequinConfigError(
msg=f"MySQL adapter received bad config value: {e}",
title="Harlequin could not initialize the selected adapter.",
) from e
@property
def connection_id(self) -> str | None:
host = self.options.get("host", "") or ""
sock = self.options.get("unix_socket", "") or ""
host = host if host or sock else "127.0.0.1"
port = self.options.get("port", 3306)
database = self.options.get("database", "") or ""
return f"{host}{sock}:{port}/{database}"
def connect(self) -> HarlequinMySQLConnection:
conn = HarlequinMySQLConnection(conn_str=tuple(), options=self.options)
return conn