Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.List[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The list of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether or not `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @abc.abstractmethod
110    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
111        """
112        Returns the schema of a given table.
113
114        Args:
115            table: the target table.
116            raise_on_missing: whether or not to raise in case the schema is not found.
117
118        Returns:
119            The schema of the target table.
120        """
121
122    @property
123    @abc.abstractmethod
124    def supported_table_args(self) -> t.Tuple[str, ...]:
125        """
126        Table arguments this schema support, e.g. `("this", "db", "catalog")`
127        """
128
129    @property
130    def empty(self) -> bool:
131        """Returns whether or not the schema is empty."""
132        return True
133
134
135class AbstractMappingSchema:
136    def __init__(
137        self,
138        mapping: t.Optional[t.Dict] = None,
139    ) -> None:
140        self.mapping = mapping or {}
141        self.mapping_trie = new_trie(
142            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
143        )
144        self._supported_table_args: t.Tuple[str, ...] = tuple()
145
146    @property
147    def empty(self) -> bool:
148        return not self.mapping
149
150    def depth(self) -> int:
151        return dict_depth(self.mapping)
152
153    @property
154    def supported_table_args(self) -> t.Tuple[str, ...]:
155        if not self._supported_table_args and self.mapping:
156            depth = self.depth()
157
158            if not depth:  # None
159                self._supported_table_args = tuple()
160            elif 1 <= depth <= 3:
161                self._supported_table_args = exp.TABLE_PARTS[:depth]
162            else:
163                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
164
165        return self._supported_table_args
166
167    def table_parts(self, table: exp.Table) -> t.List[str]:
168        if isinstance(table.this, exp.ReadCSV):
169            return [table.this.name]
170        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
171
172    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
173        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
174        value, trie = in_trie(self.mapping_trie, parts)
175
176        if value == TrieResult.FAILED:
177            return None
178
179        if value == TrieResult.PREFIX:
180            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
181
182            if len(possibilities) == 1:
183                parts.extend(possibilities[0])
184            else:
185                message = ", ".join(".".join(parts) for parts in possibilities)
186                if raise_on_missing:
187                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
188                return None
189
190        return self.nested_get(parts, raise_on_missing=raise_on_missing)
191
192    def nested_get(
193        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
194    ) -> t.Optional[t.Any]:
195        return nested_get(
196            d or self.mapping,
197            *zip(self.supported_table_args, reversed(parts)),
198            raise_on_missing=raise_on_missing,
199        )
200
201
202class MappingSchema(AbstractMappingSchema, Schema):
203    """
204    Schema based on a nested mapping.
205
206    Args:
207        schema: Mapping in one of the following forms:
208            1. {table: {col: type}}
209            2. {db: {table: {col: type}}}
210            3. {catalog: {db: {table: {col: type}}}}
211            4. None - Tables will be added later
212        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
213            are assumed to be visible. The nesting should mirror that of the schema:
214            1. {table: set(*cols)}}
215            2. {db: {table: set(*cols)}}}
216            3. {catalog: {db: {table: set(*cols)}}}}
217        dialect: The dialect to be used for custom type mappings & parsing string arguments.
218        normalize: Whether to normalize identifier names according to the given dialect or not.
219    """
220
221    def __init__(
222        self,
223        schema: t.Optional[t.Dict] = None,
224        visible: t.Optional[t.Dict] = None,
225        dialect: DialectType = None,
226        normalize: bool = True,
227    ) -> None:
228        self.dialect = dialect
229        self.visible = {} if visible is None else visible
230        self.normalize = normalize
231        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
232        self._depth = 0
233        schema = {} if schema is None else schema
234
235        super().__init__(self._normalize(schema) if self.normalize else schema)
236
237    @classmethod
238    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
239        return MappingSchema(
240            schema=mapping_schema.mapping,
241            visible=mapping_schema.visible,
242            dialect=mapping_schema.dialect,
243            normalize=mapping_schema.normalize,
244        )
245
246    def copy(self, **kwargs) -> MappingSchema:
247        return MappingSchema(
248            **{  # type: ignore
249                "schema": self.mapping.copy(),
250                "visible": self.visible.copy(),
251                "dialect": self.dialect,
252                "normalize": self.normalize,
253                **kwargs,
254            }
255        )
256
257    def add_table(
258        self,
259        table: exp.Table | str,
260        column_mapping: t.Optional[ColumnMapping] = None,
261        dialect: DialectType = None,
262        normalize: t.Optional[bool] = None,
263        match_depth: bool = True,
264    ) -> None:
265        """
266        Register or update a table. Updates are only performed if a new column mapping is provided.
267        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
268
269        Args:
270            table: the `Table` expression instance or string representing the table.
271            column_mapping: a column mapping that describes the structure of the table.
272            dialect: the SQL dialect that will be used to parse `table` if it's a string.
273            normalize: whether to normalize identifiers according to the dialect of interest.
274            match_depth: whether to enforce that the table must match the schema's depth or not.
275        """
276        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
277
278        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
279            raise SchemaError(
280                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
281                f"schema's nesting level: {self.depth()}."
282            )
283
284        normalized_column_mapping = {
285            self._normalize_name(key, dialect=dialect, normalize=normalize): value
286            for key, value in ensure_column_mapping(column_mapping).items()
287        }
288
289        schema = self.find(normalized_table, raise_on_missing=False)
290        if schema and not normalized_column_mapping:
291            return
292
293        parts = self.table_parts(normalized_table)
294
295        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
296        new_trie([parts], self.mapping_trie)
297
298    def column_names(
299        self,
300        table: exp.Table | str,
301        only_visible: bool = False,
302        dialect: DialectType = None,
303        normalize: t.Optional[bool] = None,
304    ) -> t.List[str]:
305        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
306
307        schema = self.find(normalized_table)
308        if schema is None:
309            return []
310
311        if not only_visible or not self.visible:
312            return list(schema)
313
314        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
315        return [col for col in schema if col in visible]
316
317    def get_column_type(
318        self,
319        table: exp.Table | str,
320        column: exp.Column | str,
321        dialect: DialectType = None,
322        normalize: t.Optional[bool] = None,
323    ) -> exp.DataType:
324        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
325
326        normalized_column_name = self._normalize_name(
327            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
328        )
329
330        table_schema = self.find(normalized_table, raise_on_missing=False)
331        if table_schema:
332            column_type = table_schema.get(normalized_column_name)
333
334            if isinstance(column_type, exp.DataType):
335                return column_type
336            elif isinstance(column_type, str):
337                return self._to_data_type(column_type, dialect=dialect)
338
339        return exp.DataType.build("unknown")
340
341    def has_column(
342        self,
343        table: exp.Table | str,
344        column: exp.Column | str,
345        dialect: DialectType = None,
346        normalize: t.Optional[bool] = None,
347    ) -> bool:
348        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
349
350        normalized_column_name = self._normalize_name(
351            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
352        )
353
354        table_schema = self.find(normalized_table, raise_on_missing=False)
355        return normalized_column_name in table_schema if table_schema else False
356
357    def _normalize(self, schema: t.Dict) -> t.Dict:
358        """
359        Normalizes all identifiers in the schema.
360
361        Args:
362            schema: the schema to normalize.
363
364        Returns:
365            The normalized schema mapping.
366        """
367        normalized_mapping: t.Dict = {}
368        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
369
370        for keys in flattened_schema:
371            columns = nested_get(schema, *zip(keys, keys))
372
373            if not isinstance(columns, dict):
374                raise SchemaError(
375                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
376                )
377
378            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
379            for column_name, column_type in columns.items():
380                nested_set(
381                    normalized_mapping,
382                    normalized_keys + [self._normalize_name(column_name)],
383                    column_type,
384                )
385
386        return normalized_mapping
387
388    def _normalize_table(
389        self,
390        table: exp.Table | str,
391        dialect: DialectType = None,
392        normalize: t.Optional[bool] = None,
393    ) -> exp.Table:
394        dialect = dialect or self.dialect
395        normalize = self.normalize if normalize is None else normalize
396
397        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
398
399        if normalize:
400            for arg in exp.TABLE_PARTS:
401                value = normalized_table.args.get(arg)
402                if isinstance(value, exp.Identifier):
403                    normalized_table.set(
404                        arg,
405                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
406                    )
407
408        return normalized_table
409
410    def _normalize_name(
411        self,
412        name: str | exp.Identifier,
413        dialect: DialectType = None,
414        is_table: bool = False,
415        normalize: t.Optional[bool] = None,
416    ) -> str:
417        return normalize_name(
418            name,
419            dialect=dialect or self.dialect,
420            is_table=is_table,
421            normalize=self.normalize if normalize is None else normalize,
422        ).name
423
424    def depth(self) -> int:
425        if not self.empty and not self._depth:
426            # The columns themselves are a mapping, but we don't want to include those
427            self._depth = super().depth() - 1
428        return self._depth
429
430    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
431        """
432        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
433
434        Args:
435            schema_type: the type we want to convert.
436            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
437
438        Returns:
439            The resulting expression type.
440        """
441        if schema_type not in self._type_mapping_cache:
442            dialect = dialect or self.dialect
443            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
444
445            try:
446                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
447                self._type_mapping_cache[schema_type] = expression
448            except AttributeError:
449                in_dialect = f" in dialect {dialect}" if dialect else ""
450                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
451
452        return self._type_mapping_cache[schema_type]
453
454
455def normalize_name(
456    identifier: str | exp.Identifier,
457    dialect: DialectType = None,
458    is_table: bool = False,
459    normalize: t.Optional[bool] = True,
460) -> exp.Identifier:
461    if isinstance(identifier, str):
462        identifier = exp.parse_identifier(identifier, dialect=dialect)
463
464    if not normalize:
465        return identifier
466
467    # this is used for normalize_identifier, bigquery has special rules pertaining tables
468    identifier.meta["is_table"] = is_table
469    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
470
471
472def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
473    if isinstance(schema, Schema):
474        return schema
475
476    return MappingSchema(schema, **kwargs)
477
478
479def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
480    if mapping is None:
481        return {}
482    elif isinstance(mapping, dict):
483        return mapping
484    elif isinstance(mapping, str):
485        col_name_type_strs = [x.strip() for x in mapping.split(",")]
486        return {
487            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
488            for name_type_str in col_name_type_strs
489        }
490    # Check if mapping looks like a DataFrame StructType
491    elif hasattr(mapping, "simpleString"):
492        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
493    elif isinstance(mapping, list):
494        return {x.strip(): None for x in mapping}
495
496    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
497
498
499def flatten_schema(
500    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
501) -> t.List[t.List[str]]:
502    tables = []
503    keys = keys or []
504
505    for k, v in schema.items():
506        if depth >= 2:
507            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
508        elif depth == 1:
509            tables.append(keys + [k])
510
511    return tables
512
513
514def nested_get(
515    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
516) -> t.Optional[t.Any]:
517    """
518    Get a value for a nested dictionary.
519
520    Args:
521        d: the dictionary to search.
522        *path: tuples of (name, key), where:
523            `key` is the key in the dictionary to get.
524            `name` is a string to use in the error if `key` isn't found.
525
526    Returns:
527        The value or None if it doesn't exist.
528    """
529    for name, key in path:
530        d = d.get(key)  # type: ignore
531        if d is None:
532            if raise_on_missing:
533                name = "table" if name == "this" else name
534                raise ValueError(f"Unknown {name}: {key}")
535            return None
536
537    return d
538
539
540def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
541    """
542    In-place set a value for a nested dictionary
543
544    Example:
545        >>> nested_set({}, ["top_key", "second_key"], "value")
546        {'top_key': {'second_key': 'value'}}
547
548        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
549        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
550
551    Args:
552        d: dictionary to update.
553        keys: the keys that makeup the path to `value`.
554        value: the value to set in the dictionary for the given key path.
555
556    Returns:
557        The (possibly) updated dictionary.
558    """
559    if not keys:
560        return d
561
562    if len(keys) == 1:
563        d[keys[0]] = value
564        return d
565
566    subd = d
567    for key in keys[:-1]:
568        if key not in subd:
569            subd = subd.setdefault(key, {})
570        else:
571            subd = subd[key]
572
573    subd[keys[-1]] = value
574    return d
class Schema(abc.ABC):
 20class Schema(abc.ABC):
 21    """Abstract base class for database schemas"""
 22
 23    dialect: DialectType
 24
 25    @abc.abstractmethod
 26    def add_table(
 27        self,
 28        table: exp.Table | str,
 29        column_mapping: t.Optional[ColumnMapping] = None,
 30        dialect: DialectType = None,
 31        normalize: t.Optional[bool] = None,
 32        match_depth: bool = True,
 33    ) -> None:
 34        """
 35        Register or update a table. Some implementing classes may require column information to also be provided.
 36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42            normalize: whether to normalize identifiers according to the dialect of interest.
 43            match_depth: whether to enforce that the table must match the schema's depth or not.
 44        """
 45
 46    @abc.abstractmethod
 47    def column_names(
 48        self,
 49        table: exp.Table | str,
 50        only_visible: bool = False,
 51        dialect: DialectType = None,
 52        normalize: t.Optional[bool] = None,
 53    ) -> t.List[str]:
 54        """
 55        Get the column names for a table.
 56
 57        Args:
 58            table: the `Table` expression instance.
 59            only_visible: whether to include invisible columns.
 60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 61            normalize: whether to normalize identifiers according to the dialect of interest.
 62
 63        Returns:
 64            The list of column names.
 65        """
 66
 67    @abc.abstractmethod
 68    def get_column_type(
 69        self,
 70        table: exp.Table | str,
 71        column: exp.Column | str,
 72        dialect: DialectType = None,
 73        normalize: t.Optional[bool] = None,
 74    ) -> exp.DataType:
 75        """
 76        Get the `sqlglot.exp.DataType` type of a column in the schema.
 77
 78        Args:
 79            table: the source table.
 80            column: the target column.
 81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 82            normalize: whether to normalize identifiers according to the dialect of interest.
 83
 84        Returns:
 85            The resulting column type.
 86        """
 87
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether or not `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)
109
110    @abc.abstractmethod
111    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
112        """
113        Returns the schema of a given table.
114
115        Args:
116            table: the target table.
117            raise_on_missing: whether or not to raise in case the schema is not found.
118
119        Returns:
120            The schema of the target table.
121        """
122
123    @property
124    @abc.abstractmethod
125    def supported_table_args(self) -> t.Tuple[str, ...]:
126        """
127        Table arguments this schema support, e.g. `("this", "db", "catalog")`
128        """
129
130    @property
131    def empty(self) -> bool:
132        """Returns whether or not the schema is empty."""
133        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
25    @abc.abstractmethod
26    def add_table(
27        self,
28        table: exp.Table | str,
29        column_mapping: t.Optional[ColumnMapping] = None,
30        dialect: DialectType = None,
31        normalize: t.Optional[bool] = None,
32        match_depth: bool = True,
33    ) -> None:
34        """
35        Register or update a table. Some implementing classes may require column information to also be provided.
36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
37
38        Args:
39            table: the `Table` expression instance or string representing the table.
40            column_mapping: a column mapping that describes the structure of the table.
41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
42            normalize: whether to normalize identifiers according to the dialect of interest.
43            match_depth: whether to enforce that the table must match the schema's depth or not.
44        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
46    @abc.abstractmethod
47    def column_names(
48        self,
49        table: exp.Table | str,
50        only_visible: bool = False,
51        dialect: DialectType = None,
52        normalize: t.Optional[bool] = None,
53    ) -> t.List[str]:
54        """
55        Get the column names for a table.
56
57        Args:
58            table: the `Table` expression instance.
59            only_visible: whether to include invisible columns.
60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
61            normalize: whether to normalize identifiers according to the dialect of interest.
62
63        Returns:
64            The list of column names.
65        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
67    @abc.abstractmethod
68    def get_column_type(
69        self,
70        table: exp.Table | str,
71        column: exp.Column | str,
72        dialect: DialectType = None,
73        normalize: t.Optional[bool] = None,
74    ) -> exp.DataType:
75        """
76        Get the `sqlglot.exp.DataType` type of a column in the schema.
77
78        Args:
79            table: the source table.
80            column: the target column.
81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
82            normalize: whether to normalize identifiers according to the dialect of interest.
83
84        Returns:
85            The resulting column type.
86        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether or not `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

@abc.abstractmethod
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True) -> Optional[Any]:
110    @abc.abstractmethod
111    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
112        """
113        Returns the schema of a given table.
114
115        Args:
116            table: the target table.
117            raise_on_missing: whether or not to raise in case the schema is not found.
118
119        Returns:
120            The schema of the target table.
121        """

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether or not to raise in case the schema is not found.
Returns:

The schema of the target table.

supported_table_args: Tuple[str, ...]
123    @property
124    @abc.abstractmethod
125    def supported_table_args(self) -> t.Tuple[str, ...]:
126        """
127        Table arguments this schema support, e.g. `("this", "db", "catalog")`
128        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
130    @property
131    def empty(self) -> bool:
132        """Returns whether or not the schema is empty."""
133        return True

Returns whether or not the schema is empty.

class AbstractMappingSchema:
136class AbstractMappingSchema:
137    def __init__(
138        self,
139        mapping: t.Optional[t.Dict] = None,
140    ) -> None:
141        self.mapping = mapping or {}
142        self.mapping_trie = new_trie(
143            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
144        )
145        self._supported_table_args: t.Tuple[str, ...] = tuple()
146
147    @property
148    def empty(self) -> bool:
149        return not self.mapping
150
151    def depth(self) -> int:
152        return dict_depth(self.mapping)
153
154    @property
155    def supported_table_args(self) -> t.Tuple[str, ...]:
156        if not self._supported_table_args and self.mapping:
157            depth = self.depth()
158
159            if not depth:  # None
160                self._supported_table_args = tuple()
161            elif 1 <= depth <= 3:
162                self._supported_table_args = exp.TABLE_PARTS[:depth]
163            else:
164                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
165
166        return self._supported_table_args
167
168    def table_parts(self, table: exp.Table) -> t.List[str]:
169        if isinstance(table.this, exp.ReadCSV):
170            return [table.this.name]
171        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
172
173    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
174        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
175        value, trie = in_trie(self.mapping_trie, parts)
176
177        if value == TrieResult.FAILED:
178            return None
179
180        if value == TrieResult.PREFIX:
181            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
182
183            if len(possibilities) == 1:
184                parts.extend(possibilities[0])
185            else:
186                message = ", ".join(".".join(parts) for parts in possibilities)
187                if raise_on_missing:
188                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
189                return None
190
191        return self.nested_get(parts, raise_on_missing=raise_on_missing)
192
193    def nested_get(
194        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
195    ) -> t.Optional[t.Any]:
196        return nested_get(
197            d or self.mapping,
198            *zip(self.supported_table_args, reversed(parts)),
199            raise_on_missing=raise_on_missing,
200        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
137    def __init__(
138        self,
139        mapping: t.Optional[t.Dict] = None,
140    ) -> None:
141        self.mapping = mapping or {}
142        self.mapping_trie = new_trie(
143            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
144        )
145        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
147    @property
148    def empty(self) -> bool:
149        return not self.mapping
def depth(self) -> int:
151    def depth(self) -> int:
152        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
154    @property
155    def supported_table_args(self) -> t.Tuple[str, ...]:
156        if not self._supported_table_args and self.mapping:
157            depth = self.depth()
158
159            if not depth:  # None
160                self._supported_table_args = tuple()
161            elif 1 <= depth <= 3:
162                self._supported_table_args = exp.TABLE_PARTS[:depth]
163            else:
164                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
165
166        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
168    def table_parts(self, table: exp.Table) -> t.List[str]:
169        if isinstance(table.this, exp.ReadCSV):
170            return [table.this.name]
171        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True) -> Optional[Any]:
173    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
174        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
175        value, trie = in_trie(self.mapping_trie, parts)
176
177        if value == TrieResult.FAILED:
178            return None
179
180        if value == TrieResult.PREFIX:
181            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
182
183            if len(possibilities) == 1:
184                parts.extend(possibilities[0])
185            else:
186                message = ", ".join(".".join(parts) for parts in possibilities)
187                if raise_on_missing:
188                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
189                return None
190
191        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
193    def nested_get(
194        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
195    ) -> t.Optional[t.Any]:
196        return nested_get(
197            d or self.mapping,
198            *zip(self.supported_table_args, reversed(parts)),
199            raise_on_missing=raise_on_missing,
200        )
class MappingSchema(AbstractMappingSchema, Schema):
203class MappingSchema(AbstractMappingSchema, Schema):
204    """
205    Schema based on a nested mapping.
206
207    Args:
208        schema: Mapping in one of the following forms:
209            1. {table: {col: type}}
210            2. {db: {table: {col: type}}}
211            3. {catalog: {db: {table: {col: type}}}}
212            4. None - Tables will be added later
213        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
214            are assumed to be visible. The nesting should mirror that of the schema:
215            1. {table: set(*cols)}}
216            2. {db: {table: set(*cols)}}}
217            3. {catalog: {db: {table: set(*cols)}}}}
218        dialect: The dialect to be used for custom type mappings & parsing string arguments.
219        normalize: Whether to normalize identifier names according to the given dialect or not.
220    """
221
222    def __init__(
223        self,
224        schema: t.Optional[t.Dict] = None,
225        visible: t.Optional[t.Dict] = None,
226        dialect: DialectType = None,
227        normalize: bool = True,
228    ) -> None:
229        self.dialect = dialect
230        self.visible = {} if visible is None else visible
231        self.normalize = normalize
232        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
233        self._depth = 0
234        schema = {} if schema is None else schema
235
236        super().__init__(self._normalize(schema) if self.normalize else schema)
237
238    @classmethod
239    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
240        return MappingSchema(
241            schema=mapping_schema.mapping,
242            visible=mapping_schema.visible,
243            dialect=mapping_schema.dialect,
244            normalize=mapping_schema.normalize,
245        )
246
247    def copy(self, **kwargs) -> MappingSchema:
248        return MappingSchema(
249            **{  # type: ignore
250                "schema": self.mapping.copy(),
251                "visible": self.visible.copy(),
252                "dialect": self.dialect,
253                "normalize": self.normalize,
254                **kwargs,
255            }
256        )
257
258    def add_table(
259        self,
260        table: exp.Table | str,
261        column_mapping: t.Optional[ColumnMapping] = None,
262        dialect: DialectType = None,
263        normalize: t.Optional[bool] = None,
264        match_depth: bool = True,
265    ) -> None:
266        """
267        Register or update a table. Updates are only performed if a new column mapping is provided.
268        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
269
270        Args:
271            table: the `Table` expression instance or string representing the table.
272            column_mapping: a column mapping that describes the structure of the table.
273            dialect: the SQL dialect that will be used to parse `table` if it's a string.
274            normalize: whether to normalize identifiers according to the dialect of interest.
275            match_depth: whether to enforce that the table must match the schema's depth or not.
276        """
277        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
278
279        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
280            raise SchemaError(
281                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
282                f"schema's nesting level: {self.depth()}."
283            )
284
285        normalized_column_mapping = {
286            self._normalize_name(key, dialect=dialect, normalize=normalize): value
287            for key, value in ensure_column_mapping(column_mapping).items()
288        }
289
290        schema = self.find(normalized_table, raise_on_missing=False)
291        if schema and not normalized_column_mapping:
292            return
293
294        parts = self.table_parts(normalized_table)
295
296        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
297        new_trie([parts], self.mapping_trie)
298
299    def column_names(
300        self,
301        table: exp.Table | str,
302        only_visible: bool = False,
303        dialect: DialectType = None,
304        normalize: t.Optional[bool] = None,
305    ) -> t.List[str]:
306        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
307
308        schema = self.find(normalized_table)
309        if schema is None:
310            return []
311
312        if not only_visible or not self.visible:
313            return list(schema)
314
315        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
316        return [col for col in schema if col in visible]
317
318    def get_column_type(
319        self,
320        table: exp.Table | str,
321        column: exp.Column | str,
322        dialect: DialectType = None,
323        normalize: t.Optional[bool] = None,
324    ) -> exp.DataType:
325        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
326
327        normalized_column_name = self._normalize_name(
328            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
329        )
330
331        table_schema = self.find(normalized_table, raise_on_missing=False)
332        if table_schema:
333            column_type = table_schema.get(normalized_column_name)
334
335            if isinstance(column_type, exp.DataType):
336                return column_type
337            elif isinstance(column_type, str):
338                return self._to_data_type(column_type, dialect=dialect)
339
340        return exp.DataType.build("unknown")
341
342    def has_column(
343        self,
344        table: exp.Table | str,
345        column: exp.Column | str,
346        dialect: DialectType = None,
347        normalize: t.Optional[bool] = None,
348    ) -> bool:
349        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
350
351        normalized_column_name = self._normalize_name(
352            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
353        )
354
355        table_schema = self.find(normalized_table, raise_on_missing=False)
356        return normalized_column_name in table_schema if table_schema else False
357
358    def _normalize(self, schema: t.Dict) -> t.Dict:
359        """
360        Normalizes all identifiers in the schema.
361
362        Args:
363            schema: the schema to normalize.
364
365        Returns:
366            The normalized schema mapping.
367        """
368        normalized_mapping: t.Dict = {}
369        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
370
371        for keys in flattened_schema:
372            columns = nested_get(schema, *zip(keys, keys))
373
374            if not isinstance(columns, dict):
375                raise SchemaError(
376                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
377                )
378
379            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
380            for column_name, column_type in columns.items():
381                nested_set(
382                    normalized_mapping,
383                    normalized_keys + [self._normalize_name(column_name)],
384                    column_type,
385                )
386
387        return normalized_mapping
388
389    def _normalize_table(
390        self,
391        table: exp.Table | str,
392        dialect: DialectType = None,
393        normalize: t.Optional[bool] = None,
394    ) -> exp.Table:
395        dialect = dialect or self.dialect
396        normalize = self.normalize if normalize is None else normalize
397
398        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
399
400        if normalize:
401            for arg in exp.TABLE_PARTS:
402                value = normalized_table.args.get(arg)
403                if isinstance(value, exp.Identifier):
404                    normalized_table.set(
405                        arg,
406                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
407                    )
408
409        return normalized_table
410
411    def _normalize_name(
412        self,
413        name: str | exp.Identifier,
414        dialect: DialectType = None,
415        is_table: bool = False,
416        normalize: t.Optional[bool] = None,
417    ) -> str:
418        return normalize_name(
419            name,
420            dialect=dialect or self.dialect,
421            is_table=is_table,
422            normalize=self.normalize if normalize is None else normalize,
423        ).name
424
425    def depth(self) -> int:
426        if not self.empty and not self._depth:
427            # The columns themselves are a mapping, but we don't want to include those
428            self._depth = super().depth() - 1
429        return self._depth
430
431    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
432        """
433        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
434
435        Args:
436            schema_type: the type we want to convert.
437            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
438
439        Returns:
440            The resulting expression type.
441        """
442        if schema_type not in self._type_mapping_cache:
443            dialect = dialect or self.dialect
444            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
445
446            try:
447                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
448                self._type_mapping_cache[schema_type] = expression
449            except AttributeError:
450                in_dialect = f" in dialect {dialect}" if dialect else ""
451                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
452
453        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
222    def __init__(
223        self,
224        schema: t.Optional[t.Dict] = None,
225        visible: t.Optional[t.Dict] = None,
226        dialect: DialectType = None,
227        normalize: bool = True,
228    ) -> None:
229        self.dialect = dialect
230        self.visible = {} if visible is None else visible
231        self.normalize = normalize
232        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
233        self._depth = 0
234        schema = {} if schema is None else schema
235
236        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
238    @classmethod
239    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
240        return MappingSchema(
241            schema=mapping_schema.mapping,
242            visible=mapping_schema.visible,
243            dialect=mapping_schema.dialect,
244            normalize=mapping_schema.normalize,
245        )
def copy(self, **kwargs) -> MappingSchema:
247    def copy(self, **kwargs) -> MappingSchema:
248        return MappingSchema(
249            **{  # type: ignore
250                "schema": self.mapping.copy(),
251                "visible": self.visible.copy(),
252                "dialect": self.dialect,
253                "normalize": self.normalize,
254                **kwargs,
255            }
256        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
258    def add_table(
259        self,
260        table: exp.Table | str,
261        column_mapping: t.Optional[ColumnMapping] = None,
262        dialect: DialectType = None,
263        normalize: t.Optional[bool] = None,
264        match_depth: bool = True,
265    ) -> None:
266        """
267        Register or update a table. Updates are only performed if a new column mapping is provided.
268        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
269
270        Args:
271            table: the `Table` expression instance or string representing the table.
272            column_mapping: a column mapping that describes the structure of the table.
273            dialect: the SQL dialect that will be used to parse `table` if it's a string.
274            normalize: whether to normalize identifiers according to the dialect of interest.
275            match_depth: whether to enforce that the table must match the schema's depth or not.
276        """
277        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
278
279        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
280            raise SchemaError(
281                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
282                f"schema's nesting level: {self.depth()}."
283            )
284
285        normalized_column_mapping = {
286            self._normalize_name(key, dialect=dialect, normalize=normalize): value
287            for key, value in ensure_column_mapping(column_mapping).items()
288        }
289
290        schema = self.find(normalized_table, raise_on_missing=False)
291        if schema and not normalized_column_mapping:
292            return
293
294        parts = self.table_parts(normalized_table)
295
296        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
297        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
299    def column_names(
300        self,
301        table: exp.Table | str,
302        only_visible: bool = False,
303        dialect: DialectType = None,
304        normalize: t.Optional[bool] = None,
305    ) -> t.List[str]:
306        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
307
308        schema = self.find(normalized_table)
309        if schema is None:
310            return []
311
312        if not only_visible or not self.visible:
313            return list(schema)
314
315        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
316        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
318    def get_column_type(
319        self,
320        table: exp.Table | str,
321        column: exp.Column | str,
322        dialect: DialectType = None,
323        normalize: t.Optional[bool] = None,
324    ) -> exp.DataType:
325        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
326
327        normalized_column_name = self._normalize_name(
328            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
329        )
330
331        table_schema = self.find(normalized_table, raise_on_missing=False)
332        if table_schema:
333            column_type = table_schema.get(normalized_column_name)
334
335            if isinstance(column_type, exp.DataType):
336                return column_type
337            elif isinstance(column_type, str):
338                return self._to_data_type(column_type, dialect=dialect)
339
340        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
342    def has_column(
343        self,
344        table: exp.Table | str,
345        column: exp.Column | str,
346        dialect: DialectType = None,
347        normalize: t.Optional[bool] = None,
348    ) -> bool:
349        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
350
351        normalized_column_name = self._normalize_name(
352            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
353        )
354
355        table_schema = self.find(normalized_table, raise_on_missing=False)
356        return normalized_column_name in table_schema if table_schema else False

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
425    def depth(self) -> int:
426        if not self.empty and not self._depth:
427            # The columns themselves are a mapping, but we don't want to include those
428            self._depth = super().depth() - 1
429        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
456def normalize_name(
457    identifier: str | exp.Identifier,
458    dialect: DialectType = None,
459    is_table: bool = False,
460    normalize: t.Optional[bool] = True,
461) -> exp.Identifier:
462    if isinstance(identifier, str):
463        identifier = exp.parse_identifier(identifier, dialect=dialect)
464
465    if not normalize:
466        return identifier
467
468    # this is used for normalize_identifier, bigquery has special rules pertaining tables
469    identifier.meta["is_table"] = is_table
470    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
473def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
474    if isinstance(schema, Schema):
475        return schema
476
477    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
480def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
481    if mapping is None:
482        return {}
483    elif isinstance(mapping, dict):
484        return mapping
485    elif isinstance(mapping, str):
486        col_name_type_strs = [x.strip() for x in mapping.split(",")]
487        return {
488            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
489            for name_type_str in col_name_type_strs
490        }
491    # Check if mapping looks like a DataFrame StructType
492    elif hasattr(mapping, "simpleString"):
493        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
494    elif isinstance(mapping, list):
495        return {x.strip(): None for x in mapping}
496
497    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
500def flatten_schema(
501    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
502) -> t.List[t.List[str]]:
503    tables = []
504    keys = keys or []
505
506    for k, v in schema.items():
507        if depth >= 2:
508            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
509        elif depth == 1:
510            tables.append(keys + [k])
511
512    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
515def nested_get(
516    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
517) -> t.Optional[t.Any]:
518    """
519    Get a value for a nested dictionary.
520
521    Args:
522        d: the dictionary to search.
523        *path: tuples of (name, key), where:
524            `key` is the key in the dictionary to get.
525            `name` is a string to use in the error if `key` isn't found.
526
527    Returns:
528        The value or None if it doesn't exist.
529    """
530    for name, key in path:
531        d = d.get(key)  # type: ignore
532        if d is None:
533            if raise_on_missing:
534                name = "table" if name == "this" else name
535                raise ValueError(f"Unknown {name}: {key}")
536            return None
537
538    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
541def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
542    """
543    In-place set a value for a nested dictionary
544
545    Example:
546        >>> nested_set({}, ["top_key", "second_key"], "value")
547        {'top_key': {'second_key': 'value'}}
548
549        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
550        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
551
552    Args:
553        d: dictionary to update.
554        keys: the keys that makeup the path to `value`.
555        value: the value to set in the dictionary for the given key path.
556
557    Returns:
558        The (possibly) updated dictionary.
559    """
560    if not keys:
561        return d
562
563    if len(keys) == 1:
564        d[keys[0]] = value
565        return d
566
567    subd = d
568    for key in keys[:-1]:
569        if key not in subd:
570            subd = subd.setdefault(key, {})
571        else:
572            subd = subd[key]
573
574    subd[keys[-1]] = value
575    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.