1
0
Fork 0
sqlglot/sqlglot/schema.py

408 lines
13 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import abc
import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
TABLE_ARGS = ("this", "db", "catalog")
T = t.TypeVar("T")
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided.
Args:
table: table expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
"""
@abc.abstractmethod
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
"""
Get the column names for a table.
Args:
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
Returns:
The list of column names.
"""
@abc.abstractmethod
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
"""
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
Args:
table: the source table.
column: the target column.
Returns:
The resulting column type.
"""
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
"""
Table arguments this schema support, e.g. `("this", "db", "catalog")`
"""
raise NotImplementedError
class AbstractMappingSchema(t.Generic[T]):
def __init__(
self,
mapping: dict | None = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = self._build_trie(self.mapping)
self._supported_table_args: t.Tuple[str, ...] = tuple()
def _build_trie(self, schema: t.Dict) -> t.Dict:
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
def _depth(self) -> int:
return dict_depth(self.mapping)
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
if not self._supported_table_args and self.mapping:
depth = self._depth()
if not depth: # None
self._supported_table_args = tuple()
elif 1 <= depth <= 3:
self._supported_table_args = TABLE_ARGS[:depth]
else:
raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
return self._supported_table_args
def table_parts(self, table: exp.Table) -> t.List[str]:
if isinstance(table.this, exp.ReadCSV):
return [table.this.name]
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[T]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
if value == 0:
return None
elif value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def _nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
d or self.mapping,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
)
class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"""
Schema based on a nested mapping.
Args:
schema (dict): Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
4. None - Tables will be added later
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect (str): The dialect to be used for custom type mappings.
"""
def __init__(
self,
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
self.dialect = dialect
self.visible = visible or {}
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema(
schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
)
def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore
"schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
}
)
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Converts all identifiers in the schema into lowercase, unless they're quoted.
Args:
schema: the schema to normalize.
Returns:
The normalized schema mapping.
"""
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
normalized_mapping: t.Dict = {}
for keys in flattened_schema:
columns = _nested_get(schema, *zip(keys, keys))
assert columns is not None
normalized_keys = [self._normalize_name(key) for key in keys]
for column_name, column_type in columns.items():
_nested_set(
normalized_mapping,
normalized_keys + [self._normalize_name(column_name)],
column_type,
)
return normalized_mapping
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
"""
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
schema = self.find(table_, raise_on_missing=False)
if schema and not column_mapping:
return
_nested_set(
self.mapping,
list(reversed(self.table_parts(table_))),
column_mapping,
)
self.mapping_trie = self._build_trie(self.mapping)
def _normalize_name(self, name: str) -> str:
try:
identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
name, read=self.dialect, into=exp.Identifier
)
except:
identifier = exp.to_identifier(name)
assert isinstance(identifier, exp.Identifier)
if identifier.quoted:
return identifier.name
return identifier.name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)
if not table_:
raise SchemaError(f"Not a valid table '{table}'")
return table_
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table)
schema = self.find(table_)
if schema is None:
return []
if not only_visible or not self.visible:
return list(schema)
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
table_schema = self.find(table_, raise_on_missing=False)
if table_schema:
column_type = table_schema.get(column_name)
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
return self._to_data_type(column_type.upper())
raise SchemaError(f"Unknown column type '{column_type}'")
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
raise SchemaError(f"Could not convert table '{table}'")
def _to_data_type(self, schema_type: str) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
Args:
schema_type: the type we want to convert.
Returns:
The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
try:
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
if expression is None:
raise ValueError(f"Could not parse {schema_type}")
self._type_mapping_cache[schema_type] = expression # type: ignore
except AttributeError:
raise SchemaError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type]
def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
if isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
col_name_type_strs = [x.strip() for x in mapping.split(",")]
return {
name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
for name_type_str in col_name_type_strs
}
# Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"):
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
elif isinstance(mapping, list):
return {x.strip(): None for x in mapping}
elif mapping is None:
return {}
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema(
schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
) -> t.List[t.List[str]]:
tables = []
keys = keys or []
for k, v in schema.items():
if depth >= 2:
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
elif depth == 1:
tables.append(keys + [k])
return tables
def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
"""
Get a value for a nested dictionary.
Args:
d: the dictionary to search.
*path: tuples of (name, key), where:
`key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found.
Returns:
The value or None if it doesn't exist.
"""
for name, key in path:
d = d.get(key) # type: ignore
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}: {key}")
return None
return d
def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
Example:
>>> _nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Args:
d: dictionary to update.
keys: the keys that makeup the path to `value`.
value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.
"""
if not keys:
return d
if len(keys) == 1:
d[keys[0]] = value
return d
subd = d
for key in keys[:-1]:
if key not in subd:
subd = subd.setdefault(key, {})
else:
subd = subd[key]
subd[keys[-1]] = value
return d