Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.List[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The list of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether or not `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether or not the schema is empty."""
119        return True
120
121
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
158
159    def find(
160        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
161    ) -> t.Optional[t.Any]:
162        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
163        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
164
165        if value == TrieResult.FAILED:
166            return None
167
168        if value == TrieResult.PREFIX:
169            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
170
171            if len(possibilities) == 1:
172                parts.extend(possibilities[0])
173            else:
174                message = ", ".join(".".join(parts) for parts in possibilities)
175                if raise_on_missing:
176                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
177                return None
178
179        return self.nested_get(parts, raise_on_missing=raise_on_missing)
180
181    def nested_get(
182        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
183    ) -> t.Optional[t.Any]:
184        return nested_get(
185            d or self.mapping,
186            *zip(self.supported_table_args, reversed(parts)),
187            raise_on_missing=raise_on_missing,
188        )
189
190
191class MappingSchema(AbstractMappingSchema, Schema):
192    """
193    Schema based on a nested mapping.
194
195    Args:
196        schema: Mapping in one of the following forms:
197            1. {table: {col: type}}
198            2. {db: {table: {col: type}}}
199            3. {catalog: {db: {table: {col: type}}}}
200            4. None - Tables will be added later
201        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
202            are assumed to be visible. The nesting should mirror that of the schema:
203            1. {table: set(*cols)}}
204            2. {db: {table: set(*cols)}}}
205            3. {catalog: {db: {table: set(*cols)}}}}
206        dialect: The dialect to be used for custom type mappings & parsing string arguments.
207        normalize: Whether to normalize identifier names according to the given dialect or not.
208    """
209
210    def __init__(
211        self,
212        schema: t.Optional[t.Dict] = None,
213        visible: t.Optional[t.Dict] = None,
214        dialect: DialectType = None,
215        normalize: bool = True,
216    ) -> None:
217        self.dialect = dialect
218        self.visible = {} if visible is None else visible
219        self.normalize = normalize
220        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
221        self._depth = 0
222        schema = {} if schema is None else schema
223
224        super().__init__(self._normalize(schema) if self.normalize else schema)
225
226    @classmethod
227    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
228        return MappingSchema(
229            schema=mapping_schema.mapping,
230            visible=mapping_schema.visible,
231            dialect=mapping_schema.dialect,
232            normalize=mapping_schema.normalize,
233        )
234
235    def copy(self, **kwargs) -> MappingSchema:
236        return MappingSchema(
237            **{  # type: ignore
238                "schema": self.mapping.copy(),
239                "visible": self.visible.copy(),
240                "dialect": self.dialect,
241                "normalize": self.normalize,
242                **kwargs,
243            }
244        )
245
246    def add_table(
247        self,
248        table: exp.Table | str,
249        column_mapping: t.Optional[ColumnMapping] = None,
250        dialect: DialectType = None,
251        normalize: t.Optional[bool] = None,
252        match_depth: bool = True,
253    ) -> None:
254        """
255        Register or update a table. Updates are only performed if a new column mapping is provided.
256        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
257
258        Args:
259            table: the `Table` expression instance or string representing the table.
260            column_mapping: a column mapping that describes the structure of the table.
261            dialect: the SQL dialect that will be used to parse `table` if it's a string.
262            normalize: whether to normalize identifiers according to the dialect of interest.
263            match_depth: whether to enforce that the table must match the schema's depth or not.
264        """
265        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
266
267        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
268            raise SchemaError(
269                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
270                f"schema's nesting level: {self.depth()}."
271            )
272
273        normalized_column_mapping = {
274            self._normalize_name(key, dialect=dialect, normalize=normalize): value
275            for key, value in ensure_column_mapping(column_mapping).items()
276        }
277
278        schema = self.find(normalized_table, raise_on_missing=False)
279        if schema and not normalized_column_mapping:
280            return
281
282        parts = self.table_parts(normalized_table)
283
284        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
285        new_trie([parts], self.mapping_trie)
286
287    def column_names(
288        self,
289        table: exp.Table | str,
290        only_visible: bool = False,
291        dialect: DialectType = None,
292        normalize: t.Optional[bool] = None,
293    ) -> t.List[str]:
294        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
295
296        schema = self.find(normalized_table)
297        if schema is None:
298            return []
299
300        if not only_visible or not self.visible:
301            return list(schema)
302
303        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
304        return [col for col in schema if col in visible]
305
306    def get_column_type(
307        self,
308        table: exp.Table | str,
309        column: exp.Column | str,
310        dialect: DialectType = None,
311        normalize: t.Optional[bool] = None,
312    ) -> exp.DataType:
313        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
314
315        normalized_column_name = self._normalize_name(
316            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
317        )
318
319        table_schema = self.find(normalized_table, raise_on_missing=False)
320        if table_schema:
321            column_type = table_schema.get(normalized_column_name)
322
323            if isinstance(column_type, exp.DataType):
324                return column_type
325            elif isinstance(column_type, str):
326                return self._to_data_type(column_type, dialect=dialect)
327
328        return exp.DataType.build("unknown")
329
330    def has_column(
331        self,
332        table: exp.Table | str,
333        column: exp.Column | str,
334        dialect: DialectType = None,
335        normalize: t.Optional[bool] = None,
336    ) -> bool:
337        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
338
339        normalized_column_name = self._normalize_name(
340            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
341        )
342
343        table_schema = self.find(normalized_table, raise_on_missing=False)
344        return normalized_column_name in table_schema if table_schema else False
345
346    def _normalize(self, schema: t.Dict) -> t.Dict:
347        """
348        Normalizes all identifiers in the schema.
349
350        Args:
351            schema: the schema to normalize.
352
353        Returns:
354            The normalized schema mapping.
355        """
356        normalized_mapping: t.Dict = {}
357        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
358
359        for keys in flattened_schema:
360            columns = nested_get(schema, *zip(keys, keys))
361
362            if not isinstance(columns, dict):
363                raise SchemaError(
364                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
365                )
366
367            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
368            for column_name, column_type in columns.items():
369                nested_set(
370                    normalized_mapping,
371                    normalized_keys + [self._normalize_name(column_name)],
372                    column_type,
373                )
374
375        return normalized_mapping
376
377    def _normalize_table(
378        self,
379        table: exp.Table | str,
380        dialect: DialectType = None,
381        normalize: t.Optional[bool] = None,
382    ) -> exp.Table:
383        dialect = dialect or self.dialect
384        normalize = self.normalize if normalize is None else normalize
385
386        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
387
388        if normalize:
389            for arg in exp.TABLE_PARTS:
390                value = normalized_table.args.get(arg)
391                if isinstance(value, exp.Identifier):
392                    normalized_table.set(
393                        arg,
394                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
395                    )
396
397        return normalized_table
398
399    def _normalize_name(
400        self,
401        name: str | exp.Identifier,
402        dialect: DialectType = None,
403        is_table: bool = False,
404        normalize: t.Optional[bool] = None,
405    ) -> str:
406        return normalize_name(
407            name,
408            dialect=dialect or self.dialect,
409            is_table=is_table,
410            normalize=self.normalize if normalize is None else normalize,
411        ).name
412
413    def depth(self) -> int:
414        if not self.empty and not self._depth:
415            # The columns themselves are a mapping, but we don't want to include those
416            self._depth = super().depth() - 1
417        return self._depth
418
419    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
420        """
421        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
422
423        Args:
424            schema_type: the type we want to convert.
425            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
426
427        Returns:
428            The resulting expression type.
429        """
430        if schema_type not in self._type_mapping_cache:
431            dialect = dialect or self.dialect
432            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
433
434            try:
435                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
436                self._type_mapping_cache[schema_type] = expression
437            except AttributeError:
438                in_dialect = f" in dialect {dialect}" if dialect else ""
439                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
440
441        return self._type_mapping_cache[schema_type]
442
443
444def normalize_name(
445    identifier: str | exp.Identifier,
446    dialect: DialectType = None,
447    is_table: bool = False,
448    normalize: t.Optional[bool] = True,
449) -> exp.Identifier:
450    if isinstance(identifier, str):
451        identifier = exp.parse_identifier(identifier, dialect=dialect)
452
453    if not normalize:
454        return identifier
455
456    # this is used for normalize_identifier, bigquery has special rules pertaining tables
457    identifier.meta["is_table"] = is_table
458    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
459
460
461def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
462    if isinstance(schema, Schema):
463        return schema
464
465    return MappingSchema(schema, **kwargs)
466
467
468def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
469    if mapping is None:
470        return {}
471    elif isinstance(mapping, dict):
472        return mapping
473    elif isinstance(mapping, str):
474        col_name_type_strs = [x.strip() for x in mapping.split(",")]
475        return {
476            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
477            for name_type_str in col_name_type_strs
478        }
479    # Check if mapping looks like a DataFrame StructType
480    elif hasattr(mapping, "simpleString"):
481        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
482    elif isinstance(mapping, list):
483        return {x.strip(): None for x in mapping}
484
485    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
486
487
488def flatten_schema(
489    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
490) -> t.List[t.List[str]]:
491    tables = []
492    keys = keys or []
493
494    for k, v in schema.items():
495        if depth >= 2:
496            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
497        elif depth == 1:
498            tables.append(keys + [k])
499
500    return tables
501
502
503def nested_get(
504    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
505) -> t.Optional[t.Any]:
506    """
507    Get a value for a nested dictionary.
508
509    Args:
510        d: the dictionary to search.
511        *path: tuples of (name, key), where:
512            `key` is the key in the dictionary to get.
513            `name` is a string to use in the error if `key` isn't found.
514
515    Returns:
516        The value or None if it doesn't exist.
517    """
518    for name, key in path:
519        d = d.get(key)  # type: ignore
520        if d is None:
521            if raise_on_missing:
522                name = "table" if name == "this" else name
523                raise ValueError(f"Unknown {name}: {key}")
524            return None
525
526    return d
527
528
529def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
530    """
531    In-place set a value for a nested dictionary
532
533    Example:
534        >>> nested_set({}, ["top_key", "second_key"], "value")
535        {'top_key': {'second_key': 'value'}}
536
537        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
538        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
539
540    Args:
541        d: dictionary to update.
542        keys: the keys that makeup the path to `value`.
543        value: the value to set in the dictionary for the given key path.
544
545    Returns:
546        The (possibly) updated dictionary.
547    """
548    if not keys:
549        return d
550
551    if len(keys) == 1:
552        d[keys[0]] = value
553        return d
554
555    subd = d
556    for key in keys[:-1]:
557        if key not in subd:
558            subd = subd.setdefault(key, {})
559        else:
560            subd = subd[key]
561
562    subd[keys[-1]] = value
563    return d
class Schema(abc.ABC):
 20class Schema(abc.ABC):
 21    """Abstract base class for database schemas"""
 22
 23    dialect: DialectType
 24
 25    @abc.abstractmethod
 26    def add_table(
 27        self,
 28        table: exp.Table | str,
 29        column_mapping: t.Optional[ColumnMapping] = None,
 30        dialect: DialectType = None,
 31        normalize: t.Optional[bool] = None,
 32        match_depth: bool = True,
 33    ) -> None:
 34        """
 35        Register or update a table. Some implementing classes may require column information to also be provided.
 36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42            normalize: whether to normalize identifiers according to the dialect of interest.
 43            match_depth: whether to enforce that the table must match the schema's depth or not.
 44        """
 45
 46    @abc.abstractmethod
 47    def column_names(
 48        self,
 49        table: exp.Table | str,
 50        only_visible: bool = False,
 51        dialect: DialectType = None,
 52        normalize: t.Optional[bool] = None,
 53    ) -> t.List[str]:
 54        """
 55        Get the column names for a table.
 56
 57        Args:
 58            table: the `Table` expression instance.
 59            only_visible: whether to include invisible columns.
 60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 61            normalize: whether to normalize identifiers according to the dialect of interest.
 62
 63        Returns:
 64            The list of column names.
 65        """
 66
 67    @abc.abstractmethod
 68    def get_column_type(
 69        self,
 70        table: exp.Table | str,
 71        column: exp.Column | str,
 72        dialect: DialectType = None,
 73        normalize: t.Optional[bool] = None,
 74    ) -> exp.DataType:
 75        """
 76        Get the `sqlglot.exp.DataType` type of a column in the schema.
 77
 78        Args:
 79            table: the source table.
 80            column: the target column.
 81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 82            normalize: whether to normalize identifiers according to the dialect of interest.
 83
 84        Returns:
 85            The resulting column type.
 86        """
 87
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether or not `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)
109
110    @property
111    @abc.abstractmethod
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        """
114        Table arguments this schema support, e.g. `("this", "db", "catalog")`
115        """
116
117    @property
118    def empty(self) -> bool:
119        """Returns whether or not the schema is empty."""
120        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
25    @abc.abstractmethod
26    def add_table(
27        self,
28        table: exp.Table | str,
29        column_mapping: t.Optional[ColumnMapping] = None,
30        dialect: DialectType = None,
31        normalize: t.Optional[bool] = None,
32        match_depth: bool = True,
33    ) -> None:
34        """
35        Register or update a table. Some implementing classes may require column information to also be provided.
36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
37
38        Args:
39            table: the `Table` expression instance or string representing the table.
40            column_mapping: a column mapping that describes the structure of the table.
41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
42            normalize: whether to normalize identifiers according to the dialect of interest.
43            match_depth: whether to enforce that the table must match the schema's depth or not.
44        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
46    @abc.abstractmethod
47    def column_names(
48        self,
49        table: exp.Table | str,
50        only_visible: bool = False,
51        dialect: DialectType = None,
52        normalize: t.Optional[bool] = None,
53    ) -> t.List[str]:
54        """
55        Get the column names for a table.
56
57        Args:
58            table: the `Table` expression instance.
59            only_visible: whether to include invisible columns.
60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
61            normalize: whether to normalize identifiers according to the dialect of interest.
62
63        Returns:
64            The list of column names.
65        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
67    @abc.abstractmethod
68    def get_column_type(
69        self,
70        table: exp.Table | str,
71        column: exp.Column | str,
72        dialect: DialectType = None,
73        normalize: t.Optional[bool] = None,
74    ) -> exp.DataType:
75        """
76        Get the `sqlglot.exp.DataType` type of a column in the schema.
77
78        Args:
79            table: the source table.
80            column: the target column.
81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
82            normalize: whether to normalize identifiers according to the dialect of interest.
83
84        Returns:
85            The resulting column type.
86        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether or not `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]
110    @property
111    @abc.abstractmethod
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        """
114        Table arguments this schema support, e.g. `("this", "db", "catalog")`
115        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
117    @property
118    def empty(self) -> bool:
119        """Returns whether or not the schema is empty."""
120        return True

Returns whether or not the schema is empty.

class AbstractMappingSchema:
123class AbstractMappingSchema:
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
133
134    @property
135    def empty(self) -> bool:
136        return not self.mapping
137
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
140
141    @property
142    def supported_table_args(self) -> t.Tuple[str, ...]:
143        if not self._supported_table_args and self.mapping:
144            depth = self.depth()
145
146            if not depth:  # None
147                self._supported_table_args = tuple()
148            elif 1 <= depth <= 3:
149                self._supported_table_args = exp.TABLE_PARTS[:depth]
150            else:
151                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
152
153        return self._supported_table_args
154
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
159
160    def find(
161        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
162    ) -> t.Optional[t.Any]:
163        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
164        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
165
166        if value == TrieResult.FAILED:
167            return None
168
169        if value == TrieResult.PREFIX:
170            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
171
172            if len(possibilities) == 1:
173                parts.extend(possibilities[0])
174            else:
175                message = ", ".join(".".join(parts) for parts in possibilities)
176                if raise_on_missing:
177                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
178                return None
179
180        return self.nested_get(parts, raise_on_missing=raise_on_missing)
181
182    def nested_get(
183        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
184    ) -> t.Optional[t.Any]:
185        return nested_get(
186            d or self.mapping,
187            *zip(self.supported_table_args, reversed(parts)),
188            raise_on_missing=raise_on_missing,
189        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
134    @property
135    def empty(self) -> bool:
136        return not self.mapping
def depth(self) -> int:
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
141    @property
142    def supported_table_args(self) -> t.Tuple[str, ...]:
143        if not self._supported_table_args and self.mapping:
144            depth = self.depth()
145
146            if not depth:  # None
147                self._supported_table_args = tuple()
148            elif 1 <= depth <= 3:
149                self._supported_table_args = exp.TABLE_PARTS[:depth]
150            else:
151                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
152
153        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[Any]:
160    def find(
161        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
162    ) -> t.Optional[t.Any]:
163        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
164        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
165
166        if value == TrieResult.FAILED:
167            return None
168
169        if value == TrieResult.PREFIX:
170            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
171
172            if len(possibilities) == 1:
173                parts.extend(possibilities[0])
174            else:
175                message = ", ".join(".".join(parts) for parts in possibilities)
176                if raise_on_missing:
177                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
178                return None
179
180        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
182    def nested_get(
183        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
184    ) -> t.Optional[t.Any]:
185        return nested_get(
186            d or self.mapping,
187            *zip(self.supported_table_args, reversed(parts)),
188            raise_on_missing=raise_on_missing,
189        )
class MappingSchema(AbstractMappingSchema, Schema):
192class MappingSchema(AbstractMappingSchema, Schema):
193    """
194    Schema based on a nested mapping.
195
196    Args:
197        schema: Mapping in one of the following forms:
198            1. {table: {col: type}}
199            2. {db: {table: {col: type}}}
200            3. {catalog: {db: {table: {col: type}}}}
201            4. None - Tables will be added later
202        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
203            are assumed to be visible. The nesting should mirror that of the schema:
204            1. {table: set(*cols)}}
205            2. {db: {table: set(*cols)}}}
206            3. {catalog: {db: {table: set(*cols)}}}}
207        dialect: The dialect to be used for custom type mappings & parsing string arguments.
208        normalize: Whether to normalize identifier names according to the given dialect or not.
209    """
210
211    def __init__(
212        self,
213        schema: t.Optional[t.Dict] = None,
214        visible: t.Optional[t.Dict] = None,
215        dialect: DialectType = None,
216        normalize: bool = True,
217    ) -> None:
218        self.dialect = dialect
219        self.visible = {} if visible is None else visible
220        self.normalize = normalize
221        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
222        self._depth = 0
223        schema = {} if schema is None else schema
224
225        super().__init__(self._normalize(schema) if self.normalize else schema)
226
227    @classmethod
228    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
229        return MappingSchema(
230            schema=mapping_schema.mapping,
231            visible=mapping_schema.visible,
232            dialect=mapping_schema.dialect,
233            normalize=mapping_schema.normalize,
234        )
235
236    def copy(self, **kwargs) -> MappingSchema:
237        return MappingSchema(
238            **{  # type: ignore
239                "schema": self.mapping.copy(),
240                "visible": self.visible.copy(),
241                "dialect": self.dialect,
242                "normalize": self.normalize,
243                **kwargs,
244            }
245        )
246
247    def add_table(
248        self,
249        table: exp.Table | str,
250        column_mapping: t.Optional[ColumnMapping] = None,
251        dialect: DialectType = None,
252        normalize: t.Optional[bool] = None,
253        match_depth: bool = True,
254    ) -> None:
255        """
256        Register or update a table. Updates are only performed if a new column mapping is provided.
257        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
258
259        Args:
260            table: the `Table` expression instance or string representing the table.
261            column_mapping: a column mapping that describes the structure of the table.
262            dialect: the SQL dialect that will be used to parse `table` if it's a string.
263            normalize: whether to normalize identifiers according to the dialect of interest.
264            match_depth: whether to enforce that the table must match the schema's depth or not.
265        """
266        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
267
268        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
269            raise SchemaError(
270                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
271                f"schema's nesting level: {self.depth()}."
272            )
273
274        normalized_column_mapping = {
275            self._normalize_name(key, dialect=dialect, normalize=normalize): value
276            for key, value in ensure_column_mapping(column_mapping).items()
277        }
278
279        schema = self.find(normalized_table, raise_on_missing=False)
280        if schema and not normalized_column_mapping:
281            return
282
283        parts = self.table_parts(normalized_table)
284
285        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
286        new_trie([parts], self.mapping_trie)
287
288    def column_names(
289        self,
290        table: exp.Table | str,
291        only_visible: bool = False,
292        dialect: DialectType = None,
293        normalize: t.Optional[bool] = None,
294    ) -> t.List[str]:
295        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
296
297        schema = self.find(normalized_table)
298        if schema is None:
299            return []
300
301        if not only_visible or not self.visible:
302            return list(schema)
303
304        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
305        return [col for col in schema if col in visible]
306
307    def get_column_type(
308        self,
309        table: exp.Table | str,
310        column: exp.Column | str,
311        dialect: DialectType = None,
312        normalize: t.Optional[bool] = None,
313    ) -> exp.DataType:
314        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
315
316        normalized_column_name = self._normalize_name(
317            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
318        )
319
320        table_schema = self.find(normalized_table, raise_on_missing=False)
321        if table_schema:
322            column_type = table_schema.get(normalized_column_name)
323
324            if isinstance(column_type, exp.DataType):
325                return column_type
326            elif isinstance(column_type, str):
327                return self._to_data_type(column_type, dialect=dialect)
328
329        return exp.DataType.build("unknown")
330
331    def has_column(
332        self,
333        table: exp.Table | str,
334        column: exp.Column | str,
335        dialect: DialectType = None,
336        normalize: t.Optional[bool] = None,
337    ) -> bool:
338        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
339
340        normalized_column_name = self._normalize_name(
341            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
342        )
343
344        table_schema = self.find(normalized_table, raise_on_missing=False)
345        return normalized_column_name in table_schema if table_schema else False
346
347    def _normalize(self, schema: t.Dict) -> t.Dict:
348        """
349        Normalizes all identifiers in the schema.
350
351        Args:
352            schema: the schema to normalize.
353
354        Returns:
355            The normalized schema mapping.
356        """
357        normalized_mapping: t.Dict = {}
358        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
359
360        for keys in flattened_schema:
361            columns = nested_get(schema, *zip(keys, keys))
362
363            if not isinstance(columns, dict):
364                raise SchemaError(
365                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
366                )
367
368            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
369            for column_name, column_type in columns.items():
370                nested_set(
371                    normalized_mapping,
372                    normalized_keys + [self._normalize_name(column_name)],
373                    column_type,
374                )
375
376        return normalized_mapping
377
378    def _normalize_table(
379        self,
380        table: exp.Table | str,
381        dialect: DialectType = None,
382        normalize: t.Optional[bool] = None,
383    ) -> exp.Table:
384        dialect = dialect or self.dialect
385        normalize = self.normalize if normalize is None else normalize
386
387        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
388
389        if normalize:
390            for arg in exp.TABLE_PARTS:
391                value = normalized_table.args.get(arg)
392                if isinstance(value, exp.Identifier):
393                    normalized_table.set(
394                        arg,
395                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
396                    )
397
398        return normalized_table
399
400    def _normalize_name(
401        self,
402        name: str | exp.Identifier,
403        dialect: DialectType = None,
404        is_table: bool = False,
405        normalize: t.Optional[bool] = None,
406    ) -> str:
407        return normalize_name(
408            name,
409            dialect=dialect or self.dialect,
410            is_table=is_table,
411            normalize=self.normalize if normalize is None else normalize,
412        ).name
413
414    def depth(self) -> int:
415        if not self.empty and not self._depth:
416            # The columns themselves are a mapping, but we don't want to include those
417            self._depth = super().depth() - 1
418        return self._depth
419
420    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
421        """
422        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
423
424        Args:
425            schema_type: the type we want to convert.
426            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
427
428        Returns:
429            The resulting expression type.
430        """
431        if schema_type not in self._type_mapping_cache:
432            dialect = dialect or self.dialect
433            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
434
435            try:
436                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
437                self._type_mapping_cache[schema_type] = expression
438            except AttributeError:
439                in_dialect = f" in dialect {dialect}" if dialect else ""
440                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
441
442        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
211    def __init__(
212        self,
213        schema: t.Optional[t.Dict] = None,
214        visible: t.Optional[t.Dict] = None,
215        dialect: DialectType = None,
216        normalize: bool = True,
217    ) -> None:
218        self.dialect = dialect
219        self.visible = {} if visible is None else visible
220        self.normalize = normalize
221        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
222        self._depth = 0
223        schema = {} if schema is None else schema
224
225        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
227    @classmethod
228    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
229        return MappingSchema(
230            schema=mapping_schema.mapping,
231            visible=mapping_schema.visible,
232            dialect=mapping_schema.dialect,
233            normalize=mapping_schema.normalize,
234        )
def copy(self, **kwargs) -> MappingSchema:
236    def copy(self, **kwargs) -> MappingSchema:
237        return MappingSchema(
238            **{  # type: ignore
239                "schema": self.mapping.copy(),
240                "visible": self.visible.copy(),
241                "dialect": self.dialect,
242                "normalize": self.normalize,
243                **kwargs,
244            }
245        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
247    def add_table(
248        self,
249        table: exp.Table | str,
250        column_mapping: t.Optional[ColumnMapping] = None,
251        dialect: DialectType = None,
252        normalize: t.Optional[bool] = None,
253        match_depth: bool = True,
254    ) -> None:
255        """
256        Register or update a table. Updates are only performed if a new column mapping is provided.
257        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
258
259        Args:
260            table: the `Table` expression instance or string representing the table.
261            column_mapping: a column mapping that describes the structure of the table.
262            dialect: the SQL dialect that will be used to parse `table` if it's a string.
263            normalize: whether to normalize identifiers according to the dialect of interest.
264            match_depth: whether to enforce that the table must match the schema's depth or not.
265        """
266        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
267
268        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
269            raise SchemaError(
270                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
271                f"schema's nesting level: {self.depth()}."
272            )
273
274        normalized_column_mapping = {
275            self._normalize_name(key, dialect=dialect, normalize=normalize): value
276            for key, value in ensure_column_mapping(column_mapping).items()
277        }
278
279        schema = self.find(normalized_table, raise_on_missing=False)
280        if schema and not normalized_column_mapping:
281            return
282
283        parts = self.table_parts(normalized_table)
284
285        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
286        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
288    def column_names(
289        self,
290        table: exp.Table | str,
291        only_visible: bool = False,
292        dialect: DialectType = None,
293        normalize: t.Optional[bool] = None,
294    ) -> t.List[str]:
295        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
296
297        schema = self.find(normalized_table)
298        if schema is None:
299            return []
300
301        if not only_visible or not self.visible:
302            return list(schema)
303
304        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
305        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
307    def get_column_type(
308        self,
309        table: exp.Table | str,
310        column: exp.Column | str,
311        dialect: DialectType = None,
312        normalize: t.Optional[bool] = None,
313    ) -> exp.DataType:
314        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
315
316        normalized_column_name = self._normalize_name(
317            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
318        )
319
320        table_schema = self.find(normalized_table, raise_on_missing=False)
321        if table_schema:
322            column_type = table_schema.get(normalized_column_name)
323
324            if isinstance(column_type, exp.DataType):
325                return column_type
326            elif isinstance(column_type, str):
327                return self._to_data_type(column_type, dialect=dialect)
328
329        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
331    def has_column(
332        self,
333        table: exp.Table | str,
334        column: exp.Column | str,
335        dialect: DialectType = None,
336        normalize: t.Optional[bool] = None,
337    ) -> bool:
338        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
339
340        normalized_column_name = self._normalize_name(
341            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
342        )
343
344        table_schema = self.find(normalized_table, raise_on_missing=False)
345        return normalized_column_name in table_schema if table_schema else False

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
414    def depth(self) -> int:
415        if not self.empty and not self._depth:
416            # The columns themselves are a mapping, but we don't want to include those
417            self._depth = super().depth() - 1
418        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
445def normalize_name(
446    identifier: str | exp.Identifier,
447    dialect: DialectType = None,
448    is_table: bool = False,
449    normalize: t.Optional[bool] = True,
450) -> exp.Identifier:
451    if isinstance(identifier, str):
452        identifier = exp.parse_identifier(identifier, dialect=dialect)
453
454    if not normalize:
455        return identifier
456
457    # this is used for normalize_identifier, bigquery has special rules pertaining tables
458    identifier.meta["is_table"] = is_table
459    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
462def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
463    if isinstance(schema, Schema):
464        return schema
465
466    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
469def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
470    if mapping is None:
471        return {}
472    elif isinstance(mapping, dict):
473        return mapping
474    elif isinstance(mapping, str):
475        col_name_type_strs = [x.strip() for x in mapping.split(",")]
476        return {
477            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
478            for name_type_str in col_name_type_strs
479        }
480    # Check if mapping looks like a DataFrame StructType
481    elif hasattr(mapping, "simpleString"):
482        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
483    elif isinstance(mapping, list):
484        return {x.strip(): None for x in mapping}
485
486    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
489def flatten_schema(
490    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
491) -> t.List[t.List[str]]:
492    tables = []
493    keys = keys or []
494
495    for k, v in schema.items():
496        if depth >= 2:
497            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
498        elif depth == 1:
499            tables.append(keys + [k])
500
501    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
504def nested_get(
505    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
506) -> t.Optional[t.Any]:
507    """
508    Get a value for a nested dictionary.
509
510    Args:
511        d: the dictionary to search.
512        *path: tuples of (name, key), where:
513            `key` is the key in the dictionary to get.
514            `name` is a string to use in the error if `key` isn't found.
515
516    Returns:
517        The value or None if it doesn't exist.
518    """
519    for name, key in path:
520        d = d.get(key)  # type: ignore
521        if d is None:
522            if raise_on_missing:
523                name = "table" if name == "this" else name
524                raise ValueError(f"Unknown {name}: {key}")
525            return None
526
527    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
530def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
531    """
532    In-place set a value for a nested dictionary
533
534    Example:
535        >>> nested_set({}, ["top_key", "second_key"], "value")
536        {'top_key': {'second_key': 'value'}}
537
538        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
539        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
540
541    Args:
542        d: dictionary to update.
543        keys: the keys that makeup the path to `value`.
544        value: the value to set in the dictionary for the given key path.
545
546    Returns:
547        The (possibly) updated dictionary.
548    """
549    if not keys:
550        return d
551
552    if len(keys) == 1:
553        d[keys[0]] = value
554        return d
555
556    subd = d
557    for key in keys[:-1]:
558        if key not in subd:
559            subd = subd.setdefault(key, {})
560        else:
561            subd = subd[key]
562
563    subd[keys[-1]] = value
564    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.