127 lines
3.3 KiB
Python
127 lines
3.3 KiB
Python
import abc
|
|
|
|
from sqlglot import exp
|
|
from sqlglot.errors import OptimizeError
|
|
from sqlglot.helper import csv_reader
|
|
|
|
|
|
class Schema(abc.ABC):
|
|
"""Abstract base class for database schemas"""
|
|
|
|
@abc.abstractmethod
|
|
def column_names(self, table):
|
|
"""
|
|
Get the column names for a table.
|
|
|
|
Args:
|
|
table (sqlglot.expressions.Table): Table expression instance
|
|
Returns:
|
|
list[str]: list of column names
|
|
"""
|
|
|
|
|
|
class MappingSchema(Schema):
|
|
"""
|
|
Schema based on a nested mapping.
|
|
|
|
Args:
|
|
schema (dict): Mapping in one of the following forms:
|
|
1. {table: {col: type}}
|
|
2. {db: {table: {col: type}}}
|
|
3. {catalog: {db: {table: {col: type}}}}
|
|
"""
|
|
|
|
def __init__(self, schema):
|
|
self.schema = schema
|
|
|
|
depth = _dict_depth(schema)
|
|
|
|
if not depth: # {}
|
|
self.supported_table_args = []
|
|
elif depth == 2: # {table: {col: type}}
|
|
self.supported_table_args = ("this",)
|
|
elif depth == 3: # {db: {table: {col: type}}}
|
|
self.supported_table_args = ("db", "this")
|
|
elif depth == 4: # {catalog: {db: {table: {col: type}}}}
|
|
self.supported_table_args = ("catalog", "db", "this")
|
|
else:
|
|
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
|
|
|
|
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
|
|
|
def column_names(self, table):
|
|
if not isinstance(table.this, exp.Identifier):
|
|
return fs_get(table)
|
|
|
|
args = tuple(table.text(p) for p in self.supported_table_args)
|
|
|
|
for forbidden in self.forbidden_args:
|
|
if table.text(forbidden):
|
|
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
|
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
|
|
|
|
|
def ensure_schema(schema):
|
|
if isinstance(schema, Schema):
|
|
return schema
|
|
|
|
return MappingSchema(schema)
|
|
|
|
|
|
def fs_get(table):
|
|
name = table.this.name.upper()
|
|
|
|
if name.upper() == "READ_CSV":
|
|
with csv_reader(table) as reader:
|
|
return next(reader)
|
|
|
|
raise ValueError(f"Cannot read schema for {table}")
|
|
|
|
|
|
def _nested_get(d, *path):
|
|
"""
|
|
Get a value for a nested dictionary.
|
|
|
|
Args:
|
|
d (dict): dictionary
|
|
*path (tuple[str, str]): tuples of (name, key)
|
|
`key` is the key in the dictionary to get.
|
|
`name` is a string to use in the error if `key` isn't found.
|
|
"""
|
|
for name, key in path:
|
|
d = d.get(key)
|
|
if d is None:
|
|
name = "table" if name == "this" else name
|
|
raise ValueError(f"Unknown {name}")
|
|
return d
|
|
|
|
|
|
def _dict_depth(d):
|
|
"""
|
|
Get the nesting depth of a dictionary.
|
|
|
|
For example:
|
|
>>> _dict_depth(None)
|
|
0
|
|
>>> _dict_depth({})
|
|
1
|
|
>>> _dict_depth({"a": "b"})
|
|
1
|
|
>>> _dict_depth({"a": {}})
|
|
2
|
|
>>> _dict_depth({"a": {"b": {}}})
|
|
3
|
|
|
|
Args:
|
|
d (dict): dictionary
|
|
Returns:
|
|
int: depth
|
|
"""
|
|
try:
|
|
return 1 + _dict_depth(next(iter(d.values())))
|
|
except AttributeError:
|
|
# d doesn't have attribute "values"
|
|
return 0
|
|
except StopIteration:
|
|
# d.values() returns an empty sequence
|
|
return 1
|