2025-02-24 20:26:40 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-03-04 08:18:53 +01:00
|
|
|
from contextlib import suppress
|
|
|
|
from typing import TYPE_CHECKING, Any, Sequence
|
2025-02-24 20:26:40 +01:00
|
|
|
|
2025-03-04 08:18:53 +01:00
|
|
|
import pyodbc
|
2025-02-24 20:26:40 +01:00
|
|
|
from harlequin import (
|
|
|
|
HarlequinAdapter,
|
|
|
|
HarlequinConnection,
|
|
|
|
HarlequinCursor,
|
|
|
|
)
|
|
|
|
from harlequin.autocomplete.completion import HarlequinCompletion
|
|
|
|
from harlequin.catalog import Catalog, CatalogItem
|
|
|
|
from harlequin.exception import (
|
|
|
|
HarlequinConfigError,
|
|
|
|
HarlequinConnectionError,
|
|
|
|
HarlequinQueryError,
|
|
|
|
)
|
|
|
|
from textual_fastdatatable.backend import AutoBackendType
|
|
|
|
|
2025-03-04 08:18:53 +01:00
|
|
|
from harlequin_odbc.catalog import (
|
|
|
|
DatabaseCatalogItem,
|
|
|
|
RelationCatalogItem,
|
|
|
|
SchemaCatalogItem,
|
|
|
|
)
|
2025-02-24 20:26:40 +01:00
|
|
|
from harlequin_odbc.cli_options import ODBC_OPTIONS
|
|
|
|
|
2025-03-04 08:18:53 +01:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
pass
|
|
|
|
|
2025-02-24 20:26:40 +01:00
|
|
|
|
|
|
|
class HarlequinOdbcCursor(HarlequinCursor):
|
|
|
|
def __init__(self, cur: pyodbc.Cursor) -> None:
|
|
|
|
self.cur = cur
|
|
|
|
self._limit: int | None = None
|
|
|
|
|
|
|
|
def columns(self) -> list[tuple[str, str]]:
|
|
|
|
# todo: use getTypeInfo
|
|
|
|
type_mapping = {
|
|
|
|
"bool": "t/f",
|
|
|
|
"int": "##",
|
|
|
|
"float": "#.#",
|
|
|
|
"Decimal": "#.#",
|
|
|
|
"str": "s",
|
|
|
|
"bytes": "0b",
|
|
|
|
"date": "d",
|
|
|
|
"time": "t",
|
|
|
|
"datetime": "dt",
|
|
|
|
"UUID": "uid",
|
|
|
|
}
|
|
|
|
return [
|
|
|
|
(
|
|
|
|
col_name if col_name else "(No column name)",
|
|
|
|
type_mapping.get(col_type.__name__, "?"),
|
|
|
|
)
|
|
|
|
for col_name, col_type, *_ in self.cur.description
|
|
|
|
]
|
|
|
|
|
|
|
|
def set_limit(self, limit: int) -> HarlequinOdbcCursor:
|
|
|
|
self._limit = limit
|
|
|
|
return self
|
|
|
|
|
|
|
|
def fetchall(self) -> AutoBackendType:
|
|
|
|
try:
|
|
|
|
if self._limit is None:
|
|
|
|
return self.cur.fetchall()
|
|
|
|
else:
|
|
|
|
return self.cur.fetchmany(self._limit)
|
|
|
|
except Exception as e:
|
|
|
|
raise HarlequinQueryError(
|
|
|
|
msg=str(e),
|
|
|
|
title="Harlequin encountered an error while executing your query.",
|
|
|
|
) from e
|
|
|
|
|
|
|
|
|
|
|
|
class HarlequinOdbcConnection(HarlequinConnection):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
conn_str: Sequence[str],
|
|
|
|
init_message: str = "",
|
|
|
|
) -> None:
|
|
|
|
assert len(conn_str) == 1
|
|
|
|
self.init_message = init_message
|
|
|
|
try:
|
|
|
|
self.conn = pyodbc.connect(conn_str[0], autocommit=True)
|
|
|
|
self.aux_conn = pyodbc.connect(conn_str[0], autocommit=True)
|
|
|
|
except Exception as e:
|
|
|
|
raise HarlequinConnectionError(
|
|
|
|
msg=str(e), title="Harlequin could not connect to your database."
|
|
|
|
) from e
|
|
|
|
|
|
|
|
def execute(self, query: str) -> HarlequinOdbcCursor | None:
|
|
|
|
try:
|
|
|
|
cur = self.conn.cursor()
|
|
|
|
cur.execute(query)
|
|
|
|
except Exception as e:
|
|
|
|
raise HarlequinQueryError(
|
2025-03-04 08:18:53 +01:00
|
|
|
msg=f"{e.__class__.__name__}: {e}",
|
2025-02-24 20:26:40 +01:00
|
|
|
title="Harlequin encountered an error while executing your query.",
|
|
|
|
) from e
|
|
|
|
else:
|
|
|
|
if cur.description is not None:
|
|
|
|
return HarlequinOdbcCursor(cur)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
def get_catalog(self) -> Catalog:
|
|
|
|
raw_catalog = self._list_tables()
|
|
|
|
db_items: list[CatalogItem] = []
|
|
|
|
for db, schemas in raw_catalog.items():
|
|
|
|
schema_items: list[CatalogItem] = []
|
|
|
|
for schema, relations in schemas.items():
|
|
|
|
rel_items: list[CatalogItem] = []
|
|
|
|
for rel, rel_type in relations:
|
|
|
|
rel_items.append(
|
2025-03-04 08:18:53 +01:00
|
|
|
RelationCatalogItem.from_label(
|
2025-02-24 20:26:40 +01:00
|
|
|
label=rel,
|
2025-03-04 08:18:53 +01:00
|
|
|
schema_label=schema,
|
|
|
|
db_label=db,
|
|
|
|
rel_type=rel_type,
|
|
|
|
connection=self,
|
2025-02-24 20:26:40 +01:00
|
|
|
)
|
|
|
|
)
|
|
|
|
schema_items.append(
|
2025-03-04 08:18:53 +01:00
|
|
|
SchemaCatalogItem.from_label(
|
2025-02-24 20:26:40 +01:00
|
|
|
label=schema,
|
2025-03-04 08:18:53 +01:00
|
|
|
db_label=db,
|
|
|
|
connection=self,
|
2025-02-24 20:26:40 +01:00
|
|
|
children=rel_items,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
db_items.append(
|
2025-03-04 08:18:53 +01:00
|
|
|
DatabaseCatalogItem.from_label(
|
2025-02-24 20:26:40 +01:00
|
|
|
label=db,
|
2025-03-04 08:18:53 +01:00
|
|
|
connection=self,
|
2025-02-24 20:26:40 +01:00
|
|
|
children=schema_items,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return Catalog(items=db_items)
|
|
|
|
|
2025-03-04 08:18:53 +01:00
|
|
|
def close(self) -> None:
|
|
|
|
with suppress(Exception):
|
|
|
|
self.conn.close()
|
|
|
|
with suppress(Exception):
|
|
|
|
self.aux_conn.close()
|
|
|
|
|
2025-02-24 20:26:40 +01:00
|
|
|
def _list_tables(self) -> dict[str, dict[str, list[tuple[str, str]]]]:
|
|
|
|
cur = self.aux_conn.cursor()
|
|
|
|
catalog: dict[str, dict[str, list[tuple[str, str]]]] = {}
|
2025-03-04 08:18:53 +01:00
|
|
|
for db_name, schema_name, rel_name, rel_type, *_ in cur.tables(catalog="%"):
|
2025-02-24 20:26:40 +01:00
|
|
|
if db_name not in catalog:
|
2025-03-04 08:18:53 +01:00
|
|
|
catalog[db_name] = {schema_name: [(rel_name, rel_type)]}
|
2025-02-24 20:26:40 +01:00
|
|
|
elif schema_name not in catalog[db_name]:
|
2025-03-04 08:18:53 +01:00
|
|
|
catalog[db_name][schema_name] = [(rel_name, rel_type)]
|
2025-02-24 20:26:40 +01:00
|
|
|
else:
|
2025-03-04 08:18:53 +01:00
|
|
|
catalog[db_name][schema_name].append((rel_name, rel_type))
|
2025-02-24 20:26:40 +01:00
|
|
|
return catalog
|
|
|
|
|
|
|
|
def _list_columns_in_relation(
|
|
|
|
self, catalog_name: str, schema_name: str, rel_name: str
|
|
|
|
) -> list[tuple[str, str]]:
|
|
|
|
cur = self.aux_conn.cursor()
|
|
|
|
raw_cols = cur.columns(table=rel_name, catalog=catalog_name, schema=schema_name)
|
|
|
|
return [(col[3], col[5]) for col in raw_cols]
|
|
|
|
|
|
|
|
def get_completions(self) -> list[HarlequinCompletion]:
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
class HarlequinOdbcAdapter(HarlequinAdapter):
|
|
|
|
ADAPTER_OPTIONS = ODBC_OPTIONS
|
|
|
|
|
|
|
|
def __init__(self, conn_str: Sequence[str], **_: Any) -> None:
|
|
|
|
self.conn_str = conn_str
|
|
|
|
if len(conn_str) != 1:
|
|
|
|
raise HarlequinConfigError(
|
|
|
|
title="Harlequin could not initialize the ODBC adapter.",
|
|
|
|
msg=(
|
|
|
|
"The ODBC adapter expects exactly one connection string. "
|
|
|
|
f"It received:\n{conn_str}"
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
def connect(self) -> HarlequinOdbcConnection:
|
|
|
|
conn = HarlequinOdbcConnection(self.conn_str)
|
|
|
|
return conn
|