Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.List[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The list of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether or not `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether or not the schema is empty."""
119        return True
120
121
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
158
159    def find(
160        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
161    ) -> t.Optional[t.Any]:
162        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
163        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
164
165        if value == TrieResult.FAILED:
166            return None
167
168        if value == TrieResult.PREFIX:
169            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
170
171            if len(possibilities) == 1:
172                parts.extend(possibilities[0])
173            else:
174                message = ", ".join(".".join(parts) for parts in possibilities)
175                if raise_on_missing:
176                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
177                return None
178
179        return self.nested_get(parts, raise_on_missing=raise_on_missing)
180
181    def nested_get(
182        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
183    ) -> t.Optional[t.Any]:
184        return nested_get(
185            d or self.mapping,
186            *zip(self.supported_table_args, reversed(parts)),
187            raise_on_missing=raise_on_missing,
188        )
189
190
191class MappingSchema(AbstractMappingSchema, Schema):
192    """
193    Schema based on a nested mapping.
194
195    Args:
196        schema: Mapping in one of the following forms:
197            1. {table: {col: type}}
198            2. {db: {table: {col: type}}}
199            3. {catalog: {db: {table: {col: type}}}}
200            4. None - Tables will be added later
201        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
202            are assumed to be visible. The nesting should mirror that of the schema:
203            1. {table: set(*cols)}}
204            2. {db: {table: set(*cols)}}}
205            3. {catalog: {db: {table: set(*cols)}}}}
206        dialect: The dialect to be used for custom type mappings & parsing string arguments.
207        normalize: Whether to normalize identifier names according to the given dialect or not.
208    """
209
210    def __init__(
211        self,
212        schema: t.Optional[t.Dict] = None,
213        visible: t.Optional[t.Dict] = None,
214        dialect: DialectType = None,
215        normalize: bool = True,
216    ) -> None:
217        self.dialect = dialect
218        self.visible = visible or {}
219        self.normalize = normalize
220        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
221        self._depth = 0
222
223        super().__init__(self._normalize(schema or {}))
224
225    @classmethod
226    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
227        return MappingSchema(
228            schema=mapping_schema.mapping,
229            visible=mapping_schema.visible,
230            dialect=mapping_schema.dialect,
231            normalize=mapping_schema.normalize,
232        )
233
234    def copy(self, **kwargs) -> MappingSchema:
235        return MappingSchema(
236            **{  # type: ignore
237                "schema": self.mapping.copy(),
238                "visible": self.visible.copy(),
239                "dialect": self.dialect,
240                "normalize": self.normalize,
241                **kwargs,
242            }
243        )
244
245    def add_table(
246        self,
247        table: exp.Table | str,
248        column_mapping: t.Optional[ColumnMapping] = None,
249        dialect: DialectType = None,
250        normalize: t.Optional[bool] = None,
251        match_depth: bool = True,
252    ) -> None:
253        """
254        Register or update a table. Updates are only performed if a new column mapping is provided.
255        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
256
257        Args:
258            table: the `Table` expression instance or string representing the table.
259            column_mapping: a column mapping that describes the structure of the table.
260            dialect: the SQL dialect that will be used to parse `table` if it's a string.
261            normalize: whether to normalize identifiers according to the dialect of interest.
262            match_depth: whether to enforce that the table must match the schema's depth or not.
263        """
264        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
265
266        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
267            raise SchemaError(
268                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
269                f"schema's nesting level: {self.depth()}."
270            )
271
272        normalized_column_mapping = {
273            self._normalize_name(key, dialect=dialect, normalize=normalize): value
274            for key, value in ensure_column_mapping(column_mapping).items()
275        }
276
277        schema = self.find(normalized_table, raise_on_missing=False)
278        if schema and not normalized_column_mapping:
279            return
280
281        parts = self.table_parts(normalized_table)
282
283        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
284        new_trie([parts], self.mapping_trie)
285
286    def column_names(
287        self,
288        table: exp.Table | str,
289        only_visible: bool = False,
290        dialect: DialectType = None,
291        normalize: t.Optional[bool] = None,
292    ) -> t.List[str]:
293        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
294
295        schema = self.find(normalized_table)
296        if schema is None:
297            return []
298
299        if not only_visible or not self.visible:
300            return list(schema)
301
302        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
303        return [col for col in schema if col in visible]
304
305    def get_column_type(
306        self,
307        table: exp.Table | str,
308        column: exp.Column | str,
309        dialect: DialectType = None,
310        normalize: t.Optional[bool] = None,
311    ) -> exp.DataType:
312        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
313
314        normalized_column_name = self._normalize_name(
315            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
316        )
317
318        table_schema = self.find(normalized_table, raise_on_missing=False)
319        if table_schema:
320            column_type = table_schema.get(normalized_column_name)
321
322            if isinstance(column_type, exp.DataType):
323                return column_type
324            elif isinstance(column_type, str):
325                return self._to_data_type(column_type, dialect=dialect)
326
327        return exp.DataType.build("unknown")
328
329    def has_column(
330        self,
331        table: exp.Table | str,
332        column: exp.Column | str,
333        dialect: DialectType = None,
334        normalize: t.Optional[bool] = None,
335    ) -> bool:
336        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
337
338        normalized_column_name = self._normalize_name(
339            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
340        )
341
342        table_schema = self.find(normalized_table, raise_on_missing=False)
343        return normalized_column_name in table_schema if table_schema else False
344
345    def _normalize(self, schema: t.Dict) -> t.Dict:
346        """
347        Normalizes all identifiers in the schema.
348
349        Args:
350            schema: the schema to normalize.
351
352        Returns:
353            The normalized schema mapping.
354        """
355        normalized_mapping: t.Dict = {}
356        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
357
358        for keys in flattened_schema:
359            columns = nested_get(schema, *zip(keys, keys))
360
361            if not isinstance(columns, dict):
362                raise SchemaError(
363                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
364                )
365
366            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
367            for column_name, column_type in columns.items():
368                nested_set(
369                    normalized_mapping,
370                    normalized_keys + [self._normalize_name(column_name)],
371                    column_type,
372                )
373
374        return normalized_mapping
375
376    def _normalize_table(
377        self,
378        table: exp.Table | str,
379        dialect: DialectType = None,
380        normalize: t.Optional[bool] = None,
381    ) -> exp.Table:
382        dialect = dialect or self.dialect
383        normalize = self.normalize if normalize is None else normalize
384
385        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
386
387        if normalize:
388            for arg in exp.TABLE_PARTS:
389                value = normalized_table.args.get(arg)
390                if isinstance(value, exp.Identifier):
391                    normalized_table.set(
392                        arg,
393                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
394                    )
395
396        return normalized_table
397
398    def _normalize_name(
399        self,
400        name: str | exp.Identifier,
401        dialect: DialectType = None,
402        is_table: bool = False,
403        normalize: t.Optional[bool] = None,
404    ) -> str:
405        return normalize_name(
406            name,
407            dialect=dialect or self.dialect,
408            is_table=is_table,
409            normalize=self.normalize if normalize is None else normalize,
410        ).name
411
412    def depth(self) -> int:
413        if not self.empty and not self._depth:
414            # The columns themselves are a mapping, but we don't want to include those
415            self._depth = super().depth() - 1
416        return self._depth
417
418    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
419        """
420        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
421
422        Args:
423            schema_type: the type we want to convert.
424            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
425
426        Returns:
427            The resulting expression type.
428        """
429        if schema_type not in self._type_mapping_cache:
430            dialect = dialect or self.dialect
431            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
432
433            try:
434                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
435                self._type_mapping_cache[schema_type] = expression
436            except AttributeError:
437                in_dialect = f" in dialect {dialect}" if dialect else ""
438                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
439
440        return self._type_mapping_cache[schema_type]
441
442
443def normalize_name(
444    identifier: str | exp.Identifier,
445    dialect: DialectType = None,
446    is_table: bool = False,
447    normalize: t.Optional[bool] = True,
448) -> exp.Identifier:
449    if isinstance(identifier, str):
450        identifier = exp.parse_identifier(identifier, dialect=dialect)
451
452    if not normalize:
453        return identifier
454
455    # this is used for normalize_identifier, bigquery has special rules pertaining tables
456    identifier.meta["is_table"] = is_table
457    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
458
459
460def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
461    if isinstance(schema, Schema):
462        return schema
463
464    return MappingSchema(schema, **kwargs)
465
466
467def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
468    if mapping is None:
469        return {}
470    elif isinstance(mapping, dict):
471        return mapping
472    elif isinstance(mapping, str):
473        col_name_type_strs = [x.strip() for x in mapping.split(",")]
474        return {
475            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
476            for name_type_str in col_name_type_strs
477        }
478    # Check if mapping looks like a DataFrame StructType
479    elif hasattr(mapping, "simpleString"):
480        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
481    elif isinstance(mapping, list):
482        return {x.strip(): None for x in mapping}
483
484    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
485
486
487def flatten_schema(
488    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
489) -> t.List[t.List[str]]:
490    tables = []
491    keys = keys or []
492
493    for k, v in schema.items():
494        if depth >= 2:
495            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
496        elif depth == 1:
497            tables.append(keys + [k])
498
499    return tables
500
501
502def nested_get(
503    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
504) -> t.Optional[t.Any]:
505    """
506    Get a value for a nested dictionary.
507
508    Args:
509        d: the dictionary to search.
510        *path: tuples of (name, key), where:
511            `key` is the key in the dictionary to get.
512            `name` is a string to use in the error if `key` isn't found.
513
514    Returns:
515        The value or None if it doesn't exist.
516    """
517    for name, key in path:
518        d = d.get(key)  # type: ignore
519        if d is None:
520            if raise_on_missing:
521                name = "table" if name == "this" else name
522                raise ValueError(f"Unknown {name}: {key}")
523            return None
524
525    return d
526
527
528def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
529    """
530    In-place set a value for a nested dictionary
531
532    Example:
533        >>> nested_set({}, ["top_key", "second_key"], "value")
534        {'top_key': {'second_key': 'value'}}
535
536        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
537        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
538
539    Args:
540        d: dictionary to update.
541        keys: the keys that makeup the path to `value`.
542        value: the value to set in the dictionary for the given key path.
543
544    Returns:
545        The (possibly) updated dictionary.
546    """
547    if not keys:
548        return d
549
550    if len(keys) == 1:
551        d[keys[0]] = value
552        return d
553
554    subd = d
555    for key in keys[:-1]:
556        if key not in subd:
557            subd = subd.setdefault(key, {})
558        else:
559            subd = subd[key]
560
561    subd[keys[-1]] = value
562    return d
class Schema(abc.ABC):
 20class Schema(abc.ABC):
 21    """Abstract base class for database schemas"""
 22
 23    dialect: DialectType
 24
 25    @abc.abstractmethod
 26    def add_table(
 27        self,
 28        table: exp.Table | str,
 29        column_mapping: t.Optional[ColumnMapping] = None,
 30        dialect: DialectType = None,
 31        normalize: t.Optional[bool] = None,
 32        match_depth: bool = True,
 33    ) -> None:
 34        """
 35        Register or update a table. Some implementing classes may require column information to also be provided.
 36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42            normalize: whether to normalize identifiers according to the dialect of interest.
 43            match_depth: whether to enforce that the table must match the schema's depth or not.
 44        """
 45
 46    @abc.abstractmethod
 47    def column_names(
 48        self,
 49        table: exp.Table | str,
 50        only_visible: bool = False,
 51        dialect: DialectType = None,
 52        normalize: t.Optional[bool] = None,
 53    ) -> t.List[str]:
 54        """
 55        Get the column names for a table.
 56
 57        Args:
 58            table: the `Table` expression instance.
 59            only_visible: whether to include invisible columns.
 60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 61            normalize: whether to normalize identifiers according to the dialect of interest.
 62
 63        Returns:
 64            The list of column names.
 65        """
 66
 67    @abc.abstractmethod
 68    def get_column_type(
 69        self,
 70        table: exp.Table | str,
 71        column: exp.Column | str,
 72        dialect: DialectType = None,
 73        normalize: t.Optional[bool] = None,
 74    ) -> exp.DataType:
 75        """
 76        Get the `sqlglot.exp.DataType` type of a column in the schema.
 77
 78        Args:
 79            table: the source table.
 80            column: the target column.
 81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 82            normalize: whether to normalize identifiers according to the dialect of interest.
 83
 84        Returns:
 85            The resulting column type.
 86        """
 87
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether or not `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)
109
110    @property
111    @abc.abstractmethod
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        """
114        Table arguments this schema support, e.g. `("this", "db", "catalog")`
115        """
116
117    @property
118    def empty(self) -> bool:
119        """Returns whether or not the schema is empty."""
120        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
25    @abc.abstractmethod
26    def add_table(
27        self,
28        table: exp.Table | str,
29        column_mapping: t.Optional[ColumnMapping] = None,
30        dialect: DialectType = None,
31        normalize: t.Optional[bool] = None,
32        match_depth: bool = True,
33    ) -> None:
34        """
35        Register or update a table. Some implementing classes may require column information to also be provided.
36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
37
38        Args:
39            table: the `Table` expression instance or string representing the table.
40            column_mapping: a column mapping that describes the structure of the table.
41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
42            normalize: whether to normalize identifiers according to the dialect of interest.
43            match_depth: whether to enforce that the table must match the schema's depth or not.
44        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
46    @abc.abstractmethod
47    def column_names(
48        self,
49        table: exp.Table | str,
50        only_visible: bool = False,
51        dialect: DialectType = None,
52        normalize: t.Optional[bool] = None,
53    ) -> t.List[str]:
54        """
55        Get the column names for a table.
56
57        Args:
58            table: the `Table` expression instance.
59            only_visible: whether to include invisible columns.
60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
61            normalize: whether to normalize identifiers according to the dialect of interest.
62
63        Returns:
64            The list of column names.
65        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
67    @abc.abstractmethod
68    def get_column_type(
69        self,
70        table: exp.Table | str,
71        column: exp.Column | str,
72        dialect: DialectType = None,
73        normalize: t.Optional[bool] = None,
74    ) -> exp.DataType:
75        """
76        Get the `sqlglot.exp.DataType` type of a column in the schema.
77
78        Args:
79            table: the source table.
80            column: the target column.
81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
82            normalize: whether to normalize identifiers according to the dialect of interest.
83
84        Returns:
85            The resulting column type.
86        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether or not `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema:
123class AbstractMappingSchema:
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
133
134    @property
135    def empty(self) -> bool:
136        return not self.mapping
137
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
140
141    @property
142    def supported_table_args(self) -> t.Tuple[str, ...]:
143        if not self._supported_table_args and self.mapping:
144            depth = self.depth()
145
146            if not depth:  # None
147                self._supported_table_args = tuple()
148            elif 1 <= depth <= 3:
149                self._supported_table_args = exp.TABLE_PARTS[:depth]
150            else:
151                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
152
153        return self._supported_table_args
154
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
159
160    def find(
161        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
162    ) -> t.Optional[t.Any]:
163        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
164        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
165
166        if value == TrieResult.FAILED:
167            return None
168
169        if value == TrieResult.PREFIX:
170            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
171
172            if len(possibilities) == 1:
173                parts.extend(possibilities[0])
174            else:
175                message = ", ".join(".".join(parts) for parts in possibilities)
176                if raise_on_missing:
177                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
178                return None
179
180        return self.nested_get(parts, raise_on_missing=raise_on_missing)
181
182    def nested_get(
183        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
184    ) -> t.Optional[t.Any]:
185        return nested_get(
186            d or self.mapping,
187            *zip(self.supported_table_args, reversed(parts)),
188            raise_on_missing=raise_on_missing,
189        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
def depth(self) -> int:
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[Any]:
160    def find(
161        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
162    ) -> t.Optional[t.Any]:
163        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
164        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
165
166        if value == TrieResult.FAILED:
167            return None
168
169        if value == TrieResult.PREFIX:
170            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
171
172            if len(possibilities) == 1:
173                parts.extend(possibilities[0])
174            else:
175                message = ", ".join(".".join(parts) for parts in possibilities)
176                if raise_on_missing:
177                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
178                return None
179
180        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
182    def nested_get(
183        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
184    ) -> t.Optional[t.Any]:
185        return nested_get(
186            d or self.mapping,
187            *zip(self.supported_table_args, reversed(parts)),
188            raise_on_missing=raise_on_missing,
189        )
class MappingSchema(AbstractMappingSchema, Schema):
192class MappingSchema(AbstractMappingSchema, Schema):
193    """
194    Schema based on a nested mapping.
195
196    Args:
197        schema: Mapping in one of the following forms:
198            1. {table: {col: type}}
199            2. {db: {table: {col: type}}}
200            3. {catalog: {db: {table: {col: type}}}}
201            4. None - Tables will be added later
202        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
203            are assumed to be visible. The nesting should mirror that of the schema:
204            1. {table: set(*cols)}}
205            2. {db: {table: set(*cols)}}}
206            3. {catalog: {db: {table: set(*cols)}}}}
207        dialect: The dialect to be used for custom type mappings & parsing string arguments.
208        normalize: Whether to normalize identifier names according to the given dialect or not.
209    """
210
211    def __init__(
212        self,
213        schema: t.Optional[t.Dict] = None,
214        visible: t.Optional[t.Dict] = None,
215        dialect: DialectType = None,
216        normalize: bool = True,
217    ) -> None:
218        self.dialect = dialect
219        self.visible = visible or {}
220        self.normalize = normalize
221        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
222        self._depth = 0
223
224        super().__init__(self._normalize(schema or {}))
225
226    @classmethod
227    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
228        return MappingSchema(
229            schema=mapping_schema.mapping,
230            visible=mapping_schema.visible,
231            dialect=mapping_schema.dialect,
232            normalize=mapping_schema.normalize,
233        )
234
235    def copy(self, **kwargs) -> MappingSchema:
236        return MappingSchema(
237            **{  # type: ignore
238                "schema": self.mapping.copy(),
239                "visible": self.visible.copy(),
240                "dialect": self.dialect,
241                "normalize": self.normalize,
242                **kwargs,
243            }
244        )
245
246    def add_table(
247        self,
248        table: exp.Table | str,
249        column_mapping: t.Optional[ColumnMapping] = None,
250        dialect: DialectType = None,
251        normalize: t.Optional[bool] = None,
252        match_depth: bool = True,
253    ) -> None:
254        """
255        Register or update a table. Updates are only performed if a new column mapping is provided.
256        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
257
258        Args:
259            table: the `Table` expression instance or string representing the table.
260            column_mapping: a column mapping that describes the structure of the table.
261            dialect: the SQL dialect that will be used to parse `table` if it's a string.
262            normalize: whether to normalize identifiers according to the dialect of interest.
263            match_depth: whether to enforce that the table must match the schema's depth or not.
264        """
265        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
266
267        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
268            raise SchemaError(
269                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
270                f"schema's nesting level: {self.depth()}."
271            )
272
273        normalized_column_mapping = {
274            self._normalize_name(key, dialect=dialect, normalize=normalize): value
275            for key, value in ensure_column_mapping(column_mapping).items()
276        }
277
278        schema = self.find(normalized_table, raise_on_missing=False)
279        if schema and not normalized_column_mapping:
280            return
281
282        parts = self.table_parts(normalized_table)
283
284        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
285        new_trie([parts], self.mapping_trie)
286
287    def column_names(
288        self,
289        table: exp.Table | str,
290        only_visible: bool = False,
291        dialect: DialectType = None,
292        normalize: t.Optional[bool] = None,
293    ) -> t.List[str]:
294        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
295
296        schema = self.find(normalized_table)
297        if schema is None:
298            return []
299
300        if not only_visible or not self.visible:
301            return list(schema)
302
303        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
304        return [col for col in schema if col in visible]
305
306    def get_column_type(
307        self,
308        table: exp.Table | str,
309        column: exp.Column | str,
310        dialect: DialectType = None,
311        normalize: t.Optional[bool] = None,
312    ) -> exp.DataType:
313        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
314
315        normalized_column_name = self._normalize_name(
316            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
317        )
318
319        table_schema = self.find(normalized_table, raise_on_missing=False)
320        if table_schema:
321            column_type = table_schema.get(normalized_column_name)
322
323            if isinstance(column_type, exp.DataType):
324                return column_type
325            elif isinstance(column_type, str):
326                return self._to_data_type(column_type, dialect=dialect)
327
328        return exp.DataType.build("unknown")
329
330    def has_column(
331        self,
332        table: exp.Table | str,
333        column: exp.Column | str,
334        dialect: DialectType = None,
335        normalize: t.Optional[bool] = None,
336    ) -> bool:
337        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
338
339        normalized_column_name = self._normalize_name(
340            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
341        )
342
343        table_schema = self.find(normalized_table, raise_on_missing=False)
344        return normalized_column_name in table_schema if table_schema else False
345
346    def _normalize(self, schema: t.Dict) -> t.Dict:
347        """
348        Normalizes all identifiers in the schema.
349
350        Args:
351            schema: the schema to normalize.
352
353        Returns:
354            The normalized schema mapping.
355        """
356        normalized_mapping: t.Dict = {}
357        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
358
359        for keys in flattened_schema:
360            columns = nested_get(schema, *zip(keys, keys))
361
362            if not isinstance(columns, dict):
363                raise SchemaError(
364                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
365                )
366
367            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
368            for column_name, column_type in columns.items():
369                nested_set(
370                    normalized_mapping,
371                    normalized_keys + [self._normalize_name(column_name)],
372                    column_type,
373                )
374
375        return normalized_mapping
376
377    def _normalize_table(
378        self,
379        table: exp.Table | str,
380        dialect: DialectType = None,
381        normalize: t.Optional[bool] = None,
382    ) -> exp.Table:
383        dialect = dialect or self.dialect
384        normalize = self.normalize if normalize is None else normalize
385
386        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
387
388        if normalize:
389            for arg in exp.TABLE_PARTS:
390                value = normalized_table.args.get(arg)
391                if isinstance(value, exp.Identifier):
392                    normalized_table.set(
393                        arg,
394                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
395                    )
396
397        return normalized_table
398
399    def _normalize_name(
400        self,
401        name: str | exp.Identifier,
402        dialect: DialectType = None,
403        is_table: bool = False,
404        normalize: t.Optional[bool] = None,
405    ) -> str:
406        return normalize_name(
407            name,
408            dialect=dialect or self.dialect,
409            is_table=is_table,
410            normalize=self.normalize if normalize is None else normalize,
411        ).name
412
413    def depth(self) -> int:
414        if not self.empty and not self._depth:
415            # The columns themselves are a mapping, but we don't want to include those
416            self._depth = super().depth() - 1
417        return self._depth
418
419    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
420        """
421        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
422
423        Args:
424            schema_type: the type we want to convert.
425            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
426
427        Returns:
428            The resulting expression type.
429        """
430        if schema_type not in self._type_mapping_cache:
431            dialect = dialect or self.dialect
432            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
433
434            try:
435                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
436                self._type_mapping_cache[schema_type] = expression
437            except AttributeError:
438                in_dialect = f" in dialect {dialect}" if dialect else ""
439                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
440
441        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
211    def __init__(
212        self,
213        schema: t.Optional[t.Dict] = None,
214        visible: t.Optional[t.Dict] = None,
215        dialect: DialectType = None,
216        normalize: bool = True,
217    ) -> None:
218        self.dialect = dialect
219        self.visible = visible or {}
220        self.normalize = normalize
221        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
222        self._depth = 0
223
224        super().__init__(self._normalize(schema or {}))
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
226    @classmethod
227    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
228        return MappingSchema(
229            schema=mapping_schema.mapping,
230            visible=mapping_schema.visible,
231            dialect=mapping_schema.dialect,
232            normalize=mapping_schema.normalize,
233        )
def copy(self, **kwargs) -> MappingSchema:
235    def copy(self, **kwargs) -> MappingSchema:
236        return MappingSchema(
237            **{  # type: ignore
238                "schema": self.mapping.copy(),
239                "visible": self.visible.copy(),
240                "dialect": self.dialect,
241                "normalize": self.normalize,
242                **kwargs,
243            }
244        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
246    def add_table(
247        self,
248        table: exp.Table | str,
249        column_mapping: t.Optional[ColumnMapping] = None,
250        dialect: DialectType = None,
251        normalize: t.Optional[bool] = None,
252        match_depth: bool = True,
253    ) -> None:
254        """
255        Register or update a table. Updates are only performed if a new column mapping is provided.
256        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
257
258        Args:
259            table: the `Table` expression instance or string representing the table.
260            column_mapping: a column mapping that describes the structure of the table.
261            dialect: the SQL dialect that will be used to parse `table` if it's a string.
262            normalize: whether to normalize identifiers according to the dialect of interest.
263            match_depth: whether to enforce that the table must match the schema's depth or not.
264        """
265        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
266
267        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
268            raise SchemaError(
269                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
270                f"schema's nesting level: {self.depth()}."
271            )
272
273        normalized_column_mapping = {
274            self._normalize_name(key, dialect=dialect, normalize=normalize): value
275            for key, value in ensure_column_mapping(column_mapping).items()
276        }
277
278        schema = self.find(normalized_table, raise_on_missing=False)
279        if schema and not normalized_column_mapping:
280            return
281
282        parts = self.table_parts(normalized_table)
283
284        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
285        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
287    def column_names(
288        self,
289        table: exp.Table | str,
290        only_visible: bool = False,
291        dialect: DialectType = None,
292        normalize: t.Optional[bool] = None,
293    ) -> t.List[str]:
294        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
295
296        schema = self.find(normalized_table)
297        if schema is None:
298            return []
299
300        if not only_visible or not self.visible:
301            return list(schema)
302
303        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
304        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
306    def get_column_type(
307        self,
308        table: exp.Table | str,
309        column: exp.Column | str,
310        dialect: DialectType = None,
311        normalize: t.Optional[bool] = None,
312    ) -> exp.DataType:
313        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
314
315        normalized_column_name = self._normalize_name(
316            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
317        )
318
319        table_schema = self.find(normalized_table, raise_on_missing=False)
320        if table_schema:
321            column_type = table_schema.get(normalized_column_name)
322
323            if isinstance(column_type, exp.DataType):
324                return column_type
325            elif isinstance(column_type, str):
326                return self._to_data_type(column_type, dialect=dialect)
327
328        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
330    def has_column(
331        self,
332        table: exp.Table | str,
333        column: exp.Column | str,
334        dialect: DialectType = None,
335        normalize: t.Optional[bool] = None,
336    ) -> bool:
337        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
338
339        normalized_column_name = self._normalize_name(
340            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
341        )
342
343        table_schema = self.find(normalized_table, raise_on_missing=False)
344        return normalized_column_name in table_schema if table_schema else False

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
413    def depth(self) -> int:
414        if not self.empty and not self._depth:
415            # The columns themselves are a mapping, but we don't want to include those
416            self._depth = super().depth() - 1
417        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
444def normalize_name(
445    identifier: str | exp.Identifier,
446    dialect: DialectType = None,
447    is_table: bool = False,
448    normalize: t.Optional[bool] = True,
449) -> exp.Identifier:
450    if isinstance(identifier, str):
451        identifier = exp.parse_identifier(identifier, dialect=dialect)
452
453    if not normalize:
454        return identifier
455
456    # this is used for normalize_identifier, bigquery has special rules pertaining tables
457    identifier.meta["is_table"] = is_table
458    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
461def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
462    if isinstance(schema, Schema):
463        return schema
464
465    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
468def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
469    if mapping is None:
470        return {}
471    elif isinstance(mapping, dict):
472        return mapping
473    elif isinstance(mapping, str):
474        col_name_type_strs = [x.strip() for x in mapping.split(",")]
475        return {
476            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
477            for name_type_str in col_name_type_strs
478        }
479    # Check if mapping looks like a DataFrame StructType
480    elif hasattr(mapping, "simpleString"):
481        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
482    elif isinstance(mapping, list):
483        return {x.strip(): None for x in mapping}
484
485    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
488def flatten_schema(
489    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
490) -> t.List[t.List[str]]:
491    tables = []
492    keys = keys or []
493
494    for k, v in schema.items():
495        if depth >= 2:
496            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
497        elif depth == 1:
498            tables.append(keys + [k])
499
500    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
503def nested_get(
504    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
505) -> t.Optional[t.Any]:
506    """
507    Get a value for a nested dictionary.
508
509    Args:
510        d: the dictionary to search.
511        *path: tuples of (name, key), where:
512            `key` is the key in the dictionary to get.
513            `name` is a string to use in the error if `key` isn't found.
514
515    Returns:
516        The value or None if it doesn't exist.
517    """
518    for name, key in path:
519        d = d.get(key)  # type: ignore
520        if d is None:
521            if raise_on_missing:
522                name = "table" if name == "this" else name
523                raise ValueError(f"Unknown {name}: {key}")
524            return None
525
526    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
529def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
530    """
531    In-place set a value for a nested dictionary
532
533    Example:
534        >>> nested_set({}, ["top_key", "second_key"], "value")
535        {'top_key': {'second_key': 'value'}}
536
537        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
538        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
539
540    Args:
541        d: dictionary to update.
542        keys: the keys that makeup the path to `value`.
543        value: the value to set in the dictionary for the given key path.
544
545    Returns:
546        The (possibly) updated dictionary.
547    """
548    if not keys:
549        return d
550
551    if len(keys) == 1:
552        d[keys[0]] = value
553        return d
554
555    subd = d
556    for key in keys[:-1]:
557        if key not in subd:
558            subd = subd.setdefault(key, {})
559        else:
560            subd = subd[key]
561
562    subd[keys[-1]] = value
563    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.