428 lines
14 KiB
Python
428 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from contextlib import suppress
|
|
from typing import Any, Sequence
|
|
|
|
from harlequin import (
|
|
HarlequinAdapter,
|
|
HarlequinConnection,
|
|
HarlequinCursor,
|
|
)
|
|
from harlequin.autocomplete.completion import HarlequinCompletion
|
|
from harlequin.catalog import Catalog, CatalogItem
|
|
from harlequin.exception import (
|
|
HarlequinConfigError,
|
|
HarlequinConnectionError,
|
|
HarlequinQueryError,
|
|
)
|
|
from mysql.connector import FieldType
|
|
from mysql.connector.cursor import MySQLCursor
|
|
from mysql.connector.errors import InternalError, PoolError
|
|
from mysql.connector.pooling import (
|
|
MySQLConnectionPool,
|
|
PooledMySQLConnection,
|
|
)
|
|
from textual_fastdatatable.backend import AutoBackendType
|
|
|
|
from harlequin_mysql.catalog import DatabaseCatalogItem
|
|
from harlequin_mysql.cli_options import MYSQLADAPTER_OPTIONS
|
|
from harlequin_mysql.completions import load_completions
|
|
|
|
USE_DATABASE_PROG = re.compile(
|
|
r"\s*use\s+([^\\/?%*:|\"<>.]{1,64})", flags=re.IGNORECASE
|
|
)
|
|
QUERY_INTERRUPT_MSG = "1317 (70100): Query execution was interrupted"
|
|
|
|
|
|
class HarlequinMySQLCursor(HarlequinCursor):
|
|
def __init__(
|
|
self,
|
|
cur: MySQLCursor,
|
|
conn: PooledMySQLConnection,
|
|
harlequin_conn: HarlequinMySQLConnection,
|
|
*_: Any,
|
|
**__: Any,
|
|
) -> None:
|
|
self.cur = cur
|
|
|
|
# copy description in case the cursor is closed before columns() is called
|
|
assert cur.description is not None
|
|
self.description = cur.description.copy()
|
|
|
|
self.conn = conn
|
|
self.harlequin_conn = harlequin_conn
|
|
self.connection_id = conn._cnx.connection_id
|
|
self._limit: int | None = None
|
|
|
|
def columns(self) -> list[tuple[str, str]]:
|
|
return [(col[0], self._get_short_type(col[1])) for col in self.description]
|
|
|
|
def set_limit(self, limit: int) -> "HarlequinMySQLCursor":
|
|
self._limit = limit
|
|
return self
|
|
|
|
def fetchall(self) -> AutoBackendType:
|
|
try:
|
|
if self._limit is None:
|
|
results = self.cur.fetchall()
|
|
else:
|
|
results = self.cur.fetchmany(self._limit)
|
|
return results
|
|
except Exception as e:
|
|
if str(e) == QUERY_INTERRUPT_MSG:
|
|
return []
|
|
else:
|
|
raise HarlequinQueryError(
|
|
msg=str(e),
|
|
title="Harlequin encountered an error while executing your query.",
|
|
) from e
|
|
finally:
|
|
self.conn.consume_results()
|
|
self.cur.close()
|
|
self.conn.close()
|
|
if self.connection_id:
|
|
self.harlequin_conn._in_use_connections.discard(self.connection_id)
|
|
|
|
@staticmethod
|
|
def _get_short_type(type_id: int) -> str:
|
|
mapping = {
|
|
FieldType.BIT: "010",
|
|
FieldType.BLOB: "0b",
|
|
FieldType.DATE: "d",
|
|
FieldType.DATETIME: "dt",
|
|
FieldType.DECIMAL: "#.#",
|
|
FieldType.DOUBLE: "#.#",
|
|
FieldType.ENUM: "enum",
|
|
FieldType.FLOAT: "#.#",
|
|
FieldType.GEOMETRY: "▽□",
|
|
FieldType.INT24: "###",
|
|
FieldType.JSON: "{}",
|
|
FieldType.LONG: "##",
|
|
FieldType.LONGLONG: "##",
|
|
FieldType.LONG_BLOB: "00b",
|
|
FieldType.MEDIUM_BLOB: "00b",
|
|
FieldType.NEWDATE: "d",
|
|
FieldType.NEWDECIMAL: "#.#",
|
|
FieldType.NULL: "∅",
|
|
FieldType.SET: "set",
|
|
FieldType.SHORT: "#",
|
|
FieldType.STRING: "s",
|
|
FieldType.TIME: "t",
|
|
FieldType.TIMESTAMP: "#ts",
|
|
FieldType.TINY: "#",
|
|
FieldType.TINY_BLOB: "b",
|
|
FieldType.VARCHAR: "s",
|
|
FieldType.VAR_STRING: "s",
|
|
FieldType.YEAR: "y",
|
|
}
|
|
return mapping.get(type_id, "?")
|
|
|
|
|
|
class HarlequinMySQLConnection(HarlequinConnection):
|
|
def __init__(
|
|
self,
|
|
conn_str: Sequence[str],
|
|
*_: Any,
|
|
init_message: str = "",
|
|
options: dict[str, Any],
|
|
) -> None:
|
|
self.init_message = init_message
|
|
self._in_use_connections: set[int] = set()
|
|
try:
|
|
self._pool: MySQLConnectionPool = MySQLConnectionPool(
|
|
pool_name="harlequin",
|
|
pool_reset_session=False,
|
|
autocommit=True,
|
|
**options,
|
|
)
|
|
except Exception as e:
|
|
raise HarlequinConnectionError(
|
|
msg=str(e), title="Harlequin could not connect to your database."
|
|
) from e
|
|
|
|
def safe_get_mysql_cursor(
|
|
self, buffered: bool = False
|
|
) -> tuple[PooledMySQLConnection | None, MySQLCursor | None]:
|
|
"""
|
|
Return None if the connection pool is exhausted, to avoid getting
|
|
in an unrecoverable state.
|
|
"""
|
|
try:
|
|
conn = self._pool.get_connection()
|
|
except (InternalError, PoolError):
|
|
# if we're out of connections, we can't raise a query error,
|
|
# or we get in a state where we have cursors without fetched
|
|
# results, which requires a restart of Harlequin. Instead,
|
|
# just return None and silently fail (there isn't a sensible
|
|
# way to show an error to the user without aborting processing
|
|
# all the other cursors).
|
|
return None, None
|
|
|
|
try:
|
|
cur: MySQLCursor = conn.cursor(buffered=buffered)
|
|
except InternalError:
|
|
# cursor has an unread result. Try to consume the results,
|
|
# and try again.
|
|
conn.consume_results()
|
|
cur = conn.cursor(buffered=buffered)
|
|
|
|
return conn, cur
|
|
|
|
def set_pool_config(self, **config: Any) -> None:
|
|
"""
|
|
Updates the config of the MySQL connection pool.
|
|
"""
|
|
self._pool.set_config(**config)
|
|
|
|
def execute(self, query: str) -> HarlequinCursor | None:
|
|
retval: HarlequinCursor | None = None
|
|
|
|
conn, cur = self.safe_get_mysql_cursor()
|
|
if conn is None or cur is None:
|
|
return None
|
|
else:
|
|
connection_id = conn._cnx.connection_id
|
|
if connection_id:
|
|
self._in_use_connections.add(connection_id)
|
|
|
|
try:
|
|
cur.execute(query)
|
|
except Exception as e:
|
|
cur.close()
|
|
conn.close()
|
|
if connection_id:
|
|
self._in_use_connections.discard(connection_id)
|
|
if str(e) == QUERY_INTERRUPT_MSG:
|
|
return None
|
|
else:
|
|
raise HarlequinQueryError(
|
|
msg=str(e),
|
|
title="Harlequin encountered an error while executing your query.",
|
|
) from e
|
|
else:
|
|
if cur.description is not None:
|
|
retval = HarlequinMySQLCursor(cur, conn=conn, harlequin_conn=self)
|
|
else:
|
|
cur.close()
|
|
conn.close()
|
|
if connection_id:
|
|
self._in_use_connections.discard(connection_id)
|
|
|
|
# this is a hack to update all connections in the pool if the user
|
|
# changes the database for the active connection.
|
|
# it is impossible to check the database or other config
|
|
# of a connection with an open cursor, and we can't use a dedicated
|
|
# connection for user queries, since mysql only supports a single
|
|
# (unfetched) cursor per connection.
|
|
if match := USE_DATABASE_PROG.match(query):
|
|
new_db = match.group(1)
|
|
self.set_pool_config(database=new_db)
|
|
return retval
|
|
|
|
def cancel(self) -> None:
|
|
# get a new cursor to execute the KILL statements
|
|
conn, cur = self.safe_get_mysql_cursor()
|
|
if conn is None or cur is None:
|
|
return None
|
|
|
|
# loop through in-use connections and kill each of them
|
|
for connection_id in self._in_use_connections:
|
|
try:
|
|
cur.execute("KILL QUERY %s", (connection_id,))
|
|
except BaseException:
|
|
continue
|
|
|
|
cur.close()
|
|
conn.close()
|
|
self._in_use_connections = set()
|
|
|
|
def close(self) -> None:
|
|
with suppress(PoolError):
|
|
self._pool._remove_connections()
|
|
|
|
def get_catalog(self) -> Catalog:
|
|
databases = self._get_databases()
|
|
db_items: list[CatalogItem] = [
|
|
DatabaseCatalogItem.from_label(label=db, connection=self)
|
|
for (db,) in databases
|
|
]
|
|
return Catalog(items=db_items)
|
|
|
|
def get_completions(self) -> list[HarlequinCompletion]:
|
|
return load_completions()
|
|
|
|
def _get_databases(self) -> list[tuple[str]]:
|
|
conn, cur = self.safe_get_mysql_cursor(buffered=True)
|
|
if conn is None or cur is None:
|
|
raise HarlequinConnectionError(
|
|
title="Connection pool exhausted",
|
|
msg=(
|
|
"Connection pool exhausted. Try restarting Harlequin "
|
|
"with a larger pool or running fewer queries at once."
|
|
),
|
|
)
|
|
cur.execute(
|
|
"""
|
|
show databases
|
|
where `Database` not in (
|
|
'sys', 'information_schema', 'performance_schema', 'mysql'
|
|
)
|
|
"""
|
|
)
|
|
results: list[tuple[str]] = cur.fetchall() # type: ignore
|
|
cur.close()
|
|
conn.close()
|
|
return results
|
|
|
|
def _get_relations(self, db_name: str) -> list[tuple[str, str]]:
|
|
conn, cur = self.safe_get_mysql_cursor(buffered=True)
|
|
if conn is None or cur is None:
|
|
raise HarlequinConnectionError(
|
|
title="Connection pool exhausted",
|
|
msg=(
|
|
"Connection pool exhausted. Try restarting Harlequin "
|
|
"with a larger pool or running fewer queries at once."
|
|
),
|
|
)
|
|
cur.execute(
|
|
f"""
|
|
select
|
|
table_name,
|
|
table_type
|
|
from information_schema.tables
|
|
where table_schema = '{db_name}'
|
|
and table_type != 'SYSTEM VIEW'
|
|
order by table_name asc
|
|
;"""
|
|
)
|
|
results: list[tuple[str, str]] = cur.fetchall() # type: ignore
|
|
cur.close()
|
|
conn.close()
|
|
return results
|
|
|
|
def _get_columns(self, db_name: str, rel_name: str) -> list[tuple[str, str]]:
|
|
conn, cur = self.safe_get_mysql_cursor(buffered=True)
|
|
if conn is None or cur is None:
|
|
raise HarlequinConnectionError(
|
|
title="Connection pool exhausted",
|
|
msg=(
|
|
"Connection pool exhausted. Try restarting Harlequin "
|
|
"with a larger pool or running fewer queries at once."
|
|
),
|
|
)
|
|
cur.execute(
|
|
f"""
|
|
select column_name, data_type
|
|
from information_schema.columns
|
|
where
|
|
table_schema = '{db_name}'
|
|
and table_name = '{rel_name}'
|
|
and extra not like '%INVISIBLE%'
|
|
order by ordinal_position asc
|
|
;"""
|
|
)
|
|
results: list[tuple[str, str]] = cur.fetchall() # type: ignore
|
|
cur.close()
|
|
conn.close()
|
|
return results
|
|
|
|
@staticmethod
|
|
def _short_column_type(info_schema_type: str) -> str:
|
|
mapping = {
|
|
"bigint": "###",
|
|
"binary": "010",
|
|
"blob": "0b",
|
|
"char": "c",
|
|
"datetime": "dt",
|
|
"decimal": "#.#",
|
|
"double": "#.#",
|
|
"enum": "enum",
|
|
"float": "#.#",
|
|
"int": "##",
|
|
"json": "{}",
|
|
"longblob": "00b",
|
|
"longtext": "ss",
|
|
"mediumblob": "00b",
|
|
"mediumint": "##",
|
|
"mediumtext": "s",
|
|
"set": "set",
|
|
"smallint": "#",
|
|
"text": "s",
|
|
"time": "t",
|
|
"timestamp": "ts",
|
|
"tinyint": "#",
|
|
"varbinary": "010",
|
|
"varchar": "s",
|
|
}
|
|
return mapping.get(info_schema_type, "?")
|
|
|
|
|
|
class HarlequinMySQLAdapter(HarlequinAdapter):
|
|
ADAPTER_OPTIONS = MYSQLADAPTER_OPTIONS
|
|
IMPLEMENTS_CANCEL = True
|
|
|
|
def __init__(
|
|
self,
|
|
conn_str: Sequence[str],
|
|
host: str | None = None,
|
|
port: str | int | None = 3306,
|
|
unix_socket: str | None = None,
|
|
database: str | None = None,
|
|
user: str | None = None,
|
|
password: str | None = None,
|
|
password2: str | None = None,
|
|
password3: str | None = None,
|
|
connection_timeout: str | int | None = None,
|
|
ssl_ca: str | None = None,
|
|
ssl_cert: str | None = None,
|
|
ssl_disabled: str | bool | None = False,
|
|
ssl_key: str | None = None,
|
|
openid_token_file: str | None = None,
|
|
pool_size: str | int | None = 5,
|
|
**_: Any,
|
|
) -> None:
|
|
if conn_str:
|
|
raise HarlequinConnectionError(
|
|
f"Cannot provide a DSN to the MySQL adapter. Got:\n{conn_str}"
|
|
)
|
|
try:
|
|
self.options = {
|
|
"host": host,
|
|
"port": int(port) if port is not None else 3306,
|
|
"unix_socket": unix_socket,
|
|
"database": database,
|
|
"user": user,
|
|
"password": password,
|
|
"password2": password2,
|
|
"password3": password3,
|
|
"connection_timeout": int(connection_timeout)
|
|
if connection_timeout is not None
|
|
else None,
|
|
"ssl_ca": ssl_ca,
|
|
"ssl_cert": ssl_cert,
|
|
"ssl_disabled": ssl_disabled if ssl_disabled is not None else False,
|
|
"ssl_key": ssl_key,
|
|
"openid_token_file": openid_token_file,
|
|
"pool_size": int(pool_size) if pool_size is not None else 5,
|
|
}
|
|
except (ValueError, TypeError) as e:
|
|
raise HarlequinConfigError(
|
|
msg=f"MySQL adapter received bad config value: {e}",
|
|
title="Harlequin could not initialize the selected adapter.",
|
|
) from e
|
|
|
|
@property
|
|
def connection_id(self) -> str | None:
|
|
host = self.options.get("host", "") or ""
|
|
sock = self.options.get("unix_socket", "") or ""
|
|
host = host if host or sock else "127.0.0.1"
|
|
|
|
port = self.options.get("port", 3306)
|
|
database = self.options.get("database", "") or ""
|
|
|
|
return f"{host}{sock}:{port}/{database}"
|
|
|
|
def connect(self) -> HarlequinMySQLConnection:
|
|
conn = HarlequinMySQLConnection(conn_str=tuple(), options=self.options)
|
|
return conn
|