Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.Sequence[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The sequence of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True
120
121
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
158
159    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
160        """
161        Returns the schema of a given table.
162
163        Args:
164            table: the target table.
165            raise_on_missing: whether to raise in case the schema is not found.
166
167        Returns:
168            The schema of the target table.
169        """
170        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
171        value, trie = in_trie(self.mapping_trie, parts)
172
173        if value == TrieResult.FAILED:
174            return None
175
176        if value == TrieResult.PREFIX:
177            possibilities = flatten_schema(trie)
178
179            if len(possibilities) == 1:
180                parts.extend(possibilities[0])
181            else:
182                message = ", ".join(".".join(parts) for parts in possibilities)
183                if raise_on_missing:
184                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
185                return None
186
187        return self.nested_get(parts, raise_on_missing=raise_on_missing)
188
189    def nested_get(
190        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
191    ) -> t.Optional[t.Any]:
192        return nested_get(
193            d or self.mapping,
194            *zip(self.supported_table_args, reversed(parts)),
195            raise_on_missing=raise_on_missing,
196        )
197
198
199class MappingSchema(AbstractMappingSchema, Schema):
200    """
201    Schema based on a nested mapping.
202
203    Args:
204        schema: Mapping in one of the following forms:
205            1. {table: {col: type}}
206            2. {db: {table: {col: type}}}
207            3. {catalog: {db: {table: {col: type}}}}
208            4. None - Tables will be added later
209        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
210            are assumed to be visible. The nesting should mirror that of the schema:
211            1. {table: set(*cols)}}
212            2. {db: {table: set(*cols)}}}
213            3. {catalog: {db: {table: set(*cols)}}}}
214        dialect: The dialect to be used for custom type mappings & parsing string arguments.
215        normalize: Whether to normalize identifier names according to the given dialect or not.
216    """
217
218    def __init__(
219        self,
220        schema: t.Optional[t.Dict] = None,
221        visible: t.Optional[t.Dict] = None,
222        dialect: DialectType = None,
223        normalize: bool = True,
224    ) -> None:
225        self.dialect = dialect
226        self.visible = {} if visible is None else visible
227        self.normalize = normalize
228        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
229        self._depth = 0
230        schema = {} if schema is None else schema
231
232        super().__init__(self._normalize(schema) if self.normalize else schema)
233
234    @classmethod
235    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
236        return MappingSchema(
237            schema=mapping_schema.mapping,
238            visible=mapping_schema.visible,
239            dialect=mapping_schema.dialect,
240            normalize=mapping_schema.normalize,
241        )
242
243    def copy(self, **kwargs) -> MappingSchema:
244        return MappingSchema(
245            **{  # type: ignore
246                "schema": self.mapping.copy(),
247                "visible": self.visible.copy(),
248                "dialect": self.dialect,
249                "normalize": self.normalize,
250                **kwargs,
251            }
252        )
253
254    def add_table(
255        self,
256        table: exp.Table | str,
257        column_mapping: t.Optional[ColumnMapping] = None,
258        dialect: DialectType = None,
259        normalize: t.Optional[bool] = None,
260        match_depth: bool = True,
261    ) -> None:
262        """
263        Register or update a table. Updates are only performed if a new column mapping is provided.
264        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
265
266        Args:
267            table: the `Table` expression instance or string representing the table.
268            column_mapping: a column mapping that describes the structure of the table.
269            dialect: the SQL dialect that will be used to parse `table` if it's a string.
270            normalize: whether to normalize identifiers according to the dialect of interest.
271            match_depth: whether to enforce that the table must match the schema's depth or not.
272        """
273        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
274
275        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
276            raise SchemaError(
277                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
278                f"schema's nesting level: {self.depth()}."
279            )
280
281        normalized_column_mapping = {
282            self._normalize_name(key, dialect=dialect, normalize=normalize): value
283            for key, value in ensure_column_mapping(column_mapping).items()
284        }
285
286        schema = self.find(normalized_table, raise_on_missing=False)
287        if schema and not normalized_column_mapping:
288            return
289
290        parts = self.table_parts(normalized_table)
291
292        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
293        new_trie([parts], self.mapping_trie)
294
295    def column_names(
296        self,
297        table: exp.Table | str,
298        only_visible: bool = False,
299        dialect: DialectType = None,
300        normalize: t.Optional[bool] = None,
301    ) -> t.List[str]:
302        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
303
304        schema = self.find(normalized_table)
305        if schema is None:
306            return []
307
308        if not only_visible or not self.visible:
309            return list(schema)
310
311        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
312        return [col for col in schema if col in visible]
313
314    def get_column_type(
315        self,
316        table: exp.Table | str,
317        column: exp.Column | str,
318        dialect: DialectType = None,
319        normalize: t.Optional[bool] = None,
320    ) -> exp.DataType:
321        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
322
323        normalized_column_name = self._normalize_name(
324            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
325        )
326
327        table_schema = self.find(normalized_table, raise_on_missing=False)
328        if table_schema:
329            column_type = table_schema.get(normalized_column_name)
330
331            if isinstance(column_type, exp.DataType):
332                return column_type
333            elif isinstance(column_type, str):
334                return self._to_data_type(column_type, dialect=dialect)
335
336        return exp.DataType.build("unknown")
337
338    def has_column(
339        self,
340        table: exp.Table | str,
341        column: exp.Column | str,
342        dialect: DialectType = None,
343        normalize: t.Optional[bool] = None,
344    ) -> bool:
345        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
346
347        normalized_column_name = self._normalize_name(
348            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
349        )
350
351        table_schema = self.find(normalized_table, raise_on_missing=False)
352        return normalized_column_name in table_schema if table_schema else False
353
354    def _normalize(self, schema: t.Dict) -> t.Dict:
355        """
356        Normalizes all identifiers in the schema.
357
358        Args:
359            schema: the schema to normalize.
360
361        Returns:
362            The normalized schema mapping.
363        """
364        normalized_mapping: t.Dict = {}
365        flattened_schema = flatten_schema(schema)
366        error_msg = "Table {} must match the schema's nesting level: {}."
367
368        for keys in flattened_schema:
369            columns = nested_get(schema, *zip(keys, keys))
370
371            if not isinstance(columns, dict):
372                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
373            if isinstance(first(columns.values()), dict):
374                raise SchemaError(
375                    error_msg.format(
376                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
377                    ),
378                )
379
380            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
381            for column_name, column_type in columns.items():
382                nested_set(
383                    normalized_mapping,
384                    normalized_keys + [self._normalize_name(column_name)],
385                    column_type,
386                )
387
388        return normalized_mapping
389
390    def _normalize_table(
391        self,
392        table: exp.Table | str,
393        dialect: DialectType = None,
394        normalize: t.Optional[bool] = None,
395    ) -> exp.Table:
396        dialect = dialect or self.dialect
397        normalize = self.normalize if normalize is None else normalize
398
399        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
400
401        if normalize:
402            for arg in exp.TABLE_PARTS:
403                value = normalized_table.args.get(arg)
404                if isinstance(value, exp.Identifier):
405                    normalized_table.set(
406                        arg,
407                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
408                    )
409
410        return normalized_table
411
412    def _normalize_name(
413        self,
414        name: str | exp.Identifier,
415        dialect: DialectType = None,
416        is_table: bool = False,
417        normalize: t.Optional[bool] = None,
418    ) -> str:
419        return normalize_name(
420            name,
421            dialect=dialect or self.dialect,
422            is_table=is_table,
423            normalize=self.normalize if normalize is None else normalize,
424        ).name
425
426    def depth(self) -> int:
427        if not self.empty and not self._depth:
428            # The columns themselves are a mapping, but we don't want to include those
429            self._depth = super().depth() - 1
430        return self._depth
431
432    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
433        """
434        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
435
436        Args:
437            schema_type: the type we want to convert.
438            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
439
440        Returns:
441            The resulting expression type.
442        """
443        if schema_type not in self._type_mapping_cache:
444            dialect = dialect or self.dialect
445            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
446
447            try:
448                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
449                self._type_mapping_cache[schema_type] = expression
450            except AttributeError:
451                in_dialect = f" in dialect {dialect}" if dialect else ""
452                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
453
454        return self._type_mapping_cache[schema_type]
455
456
457def normalize_name(
458    identifier: str | exp.Identifier,
459    dialect: DialectType = None,
460    is_table: bool = False,
461    normalize: t.Optional[bool] = True,
462) -> exp.Identifier:
463    if isinstance(identifier, str):
464        identifier = exp.parse_identifier(identifier, dialect=dialect)
465
466    if not normalize:
467        return identifier
468
469    # this is used for normalize_identifier, bigquery has special rules pertaining tables
470    identifier.meta["is_table"] = is_table
471    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
472
473
474def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
475    if isinstance(schema, Schema):
476        return schema
477
478    return MappingSchema(schema, **kwargs)
479
480
481def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
482    if mapping is None:
483        return {}
484    elif isinstance(mapping, dict):
485        return mapping
486    elif isinstance(mapping, str):
487        col_name_type_strs = [x.strip() for x in mapping.split(",")]
488        return {
489            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
490            for name_type_str in col_name_type_strs
491        }
492    # Check if mapping looks like a DataFrame StructType
493    elif hasattr(mapping, "simpleString"):
494        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
495    elif isinstance(mapping, list):
496        return {x.strip(): None for x in mapping}
497
498    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
499
500
501def flatten_schema(
502    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
503) -> t.List[t.List[str]]:
504    tables = []
505    keys = keys or []
506    depth = dict_depth(schema) - 1 if depth is None else depth
507
508    for k, v in schema.items():
509        if depth == 1 or not isinstance(v, dict):
510            tables.append(keys + [k])
511        elif depth >= 2:
512            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
513
514    return tables
515
516
517def nested_get(
518    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
519) -> t.Optional[t.Any]:
520    """
521    Get a value for a nested dictionary.
522
523    Args:
524        d: the dictionary to search.
525        *path: tuples of (name, key), where:
526            `key` is the key in the dictionary to get.
527            `name` is a string to use in the error if `key` isn't found.
528
529    Returns:
530        The value or None if it doesn't exist.
531    """
532    for name, key in path:
533        d = d.get(key)  # type: ignore
534        if d is None:
535            if raise_on_missing:
536                name = "table" if name == "this" else name
537                raise ValueError(f"Unknown {name}: {key}")
538            return None
539
540    return d
541
542
543def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
544    """
545    In-place set a value for a nested dictionary
546
547    Example:
548        >>> nested_set({}, ["top_key", "second_key"], "value")
549        {'top_key': {'second_key': 'value'}}
550
551        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
552        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
553
554    Args:
555        d: dictionary to update.
556        keys: the keys that makeup the path to `value`.
557        value: the value to set in the dictionary for the given key path.
558
559    Returns:
560        The (possibly) updated dictionary.
561    """
562    if not keys:
563        return d
564
565    if len(keys) == 1:
566        d[keys[0]] = value
567        return d
568
569    subd = d
570    for key in keys[:-1]:
571        if key not in subd:
572            subd = subd.setdefault(key, {})
573        else:
574            subd = subd[key]
575
576    subd[keys[-1]] = value
577    return d
class Schema(abc.ABC):
 20class Schema(abc.ABC):
 21    """Abstract base class for database schemas"""
 22
 23    dialect: DialectType
 24
 25    @abc.abstractmethod
 26    def add_table(
 27        self,
 28        table: exp.Table | str,
 29        column_mapping: t.Optional[ColumnMapping] = None,
 30        dialect: DialectType = None,
 31        normalize: t.Optional[bool] = None,
 32        match_depth: bool = True,
 33    ) -> None:
 34        """
 35        Register or update a table. Some implementing classes may require column information to also be provided.
 36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42            normalize: whether to normalize identifiers according to the dialect of interest.
 43            match_depth: whether to enforce that the table must match the schema's depth or not.
 44        """
 45
 46    @abc.abstractmethod
 47    def column_names(
 48        self,
 49        table: exp.Table | str,
 50        only_visible: bool = False,
 51        dialect: DialectType = None,
 52        normalize: t.Optional[bool] = None,
 53    ) -> t.Sequence[str]:
 54        """
 55        Get the column names for a table.
 56
 57        Args:
 58            table: the `Table` expression instance.
 59            only_visible: whether to include invisible columns.
 60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 61            normalize: whether to normalize identifiers according to the dialect of interest.
 62
 63        Returns:
 64            The sequence of column names.
 65        """
 66
 67    @abc.abstractmethod
 68    def get_column_type(
 69        self,
 70        table: exp.Table | str,
 71        column: exp.Column | str,
 72        dialect: DialectType = None,
 73        normalize: t.Optional[bool] = None,
 74    ) -> exp.DataType:
 75        """
 76        Get the `sqlglot.exp.DataType` type of a column in the schema.
 77
 78        Args:
 79            table: the source table.
 80            column: the target column.
 81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 82            normalize: whether to normalize identifiers according to the dialect of interest.
 83
 84        Returns:
 85            The resulting column type.
 86        """
 87
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)
109
110    @property
111    @abc.abstractmethod
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        """
114        Table arguments this schema support, e.g. `("this", "db", "catalog")`
115        """
116
117    @property
118    def empty(self) -> bool:
119        """Returns whether the schema is empty."""
120        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
25    @abc.abstractmethod
26    def add_table(
27        self,
28        table: exp.Table | str,
29        column_mapping: t.Optional[ColumnMapping] = None,
30        dialect: DialectType = None,
31        normalize: t.Optional[bool] = None,
32        match_depth: bool = True,
33    ) -> None:
34        """
35        Register or update a table. Some implementing classes may require column information to also be provided.
36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
37
38        Args:
39            table: the `Table` expression instance or string representing the table.
40            column_mapping: a column mapping that describes the structure of the table.
41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
42            normalize: whether to normalize identifiers according to the dialect of interest.
43            match_depth: whether to enforce that the table must match the schema's depth or not.
44        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
46    @abc.abstractmethod
47    def column_names(
48        self,
49        table: exp.Table | str,
50        only_visible: bool = False,
51        dialect: DialectType = None,
52        normalize: t.Optional[bool] = None,
53    ) -> t.Sequence[str]:
54        """
55        Get the column names for a table.
56
57        Args:
58            table: the `Table` expression instance.
59            only_visible: whether to include invisible columns.
60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
61            normalize: whether to normalize identifiers according to the dialect of interest.
62
63        Returns:
64            The sequence of column names.
65        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
67    @abc.abstractmethod
68    def get_column_type(
69        self,
70        table: exp.Table | str,
71        column: exp.Column | str,
72        dialect: DialectType = None,
73        normalize: t.Optional[bool] = None,
74    ) -> exp.DataType:
75        """
76        Get the `sqlglot.exp.DataType` type of a column in the schema.
77
78        Args:
79            table: the source table.
80            column: the target column.
81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
82            normalize: whether to normalize identifiers according to the dialect of interest.
83
84        Returns:
85            The resulting column type.
86        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]
110    @property
111    @abc.abstractmethod
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        """
114        Table arguments this schema support, e.g. `("this", "db", "catalog")`
115        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
117    @property
118    def empty(self) -> bool:
119        """Returns whether the schema is empty."""
120        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
123class AbstractMappingSchema:
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
133
134    @property
135    def empty(self) -> bool:
136        return not self.mapping
137
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
140
141    @property
142    def supported_table_args(self) -> t.Tuple[str, ...]:
143        if not self._supported_table_args and self.mapping:
144            depth = self.depth()
145
146            if not depth:  # None
147                self._supported_table_args = tuple()
148            elif 1 <= depth <= 3:
149                self._supported_table_args = exp.TABLE_PARTS[:depth]
150            else:
151                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
152
153        return self._supported_table_args
154
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
159
160    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
161        """
162        Returns the schema of a given table.
163
164        Args:
165            table: the target table.
166            raise_on_missing: whether to raise in case the schema is not found.
167
168        Returns:
169            The schema of the target table.
170        """
171        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
172        value, trie = in_trie(self.mapping_trie, parts)
173
174        if value == TrieResult.FAILED:
175            return None
176
177        if value == TrieResult.PREFIX:
178            possibilities = flatten_schema(trie)
179
180            if len(possibilities) == 1:
181                parts.extend(possibilities[0])
182            else:
183                message = ", ".join(".".join(parts) for parts in possibilities)
184                if raise_on_missing:
185                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
186                return None
187
188        return self.nested_get(parts, raise_on_missing=raise_on_missing)
189
190    def nested_get(
191        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
192    ) -> t.Optional[t.Any]:
193        return nested_get(
194            d or self.mapping,
195            *zip(self.supported_table_args, reversed(parts)),
196            raise_on_missing=raise_on_missing,
197        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
134    @property
135    def empty(self) -> bool:
136        return not self.mapping
def depth(self) -> int:
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
141    @property
142    def supported_table_args(self) -> t.Tuple[str, ...]:
143        if not self._supported_table_args and self.mapping:
144            depth = self.depth()
145
146            if not depth:  # None
147                self._supported_table_args = tuple()
148            elif 1 <= depth <= 3:
149                self._supported_table_args = exp.TABLE_PARTS[:depth]
150            else:
151                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
152
153        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True) -> Optional[Any]:
160    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
161        """
162        Returns the schema of a given table.
163
164        Args:
165            table: the target table.
166            raise_on_missing: whether to raise in case the schema is not found.
167
168        Returns:
169            The schema of the target table.
170        """
171        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
172        value, trie = in_trie(self.mapping_trie, parts)
173
174        if value == TrieResult.FAILED:
175            return None
176
177        if value == TrieResult.PREFIX:
178            possibilities = flatten_schema(trie)
179
180            if len(possibilities) == 1:
181                parts.extend(possibilities[0])
182            else:
183                message = ", ".join(".".join(parts) for parts in possibilities)
184                if raise_on_missing:
185                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
186                return None
187
188        return self.nested_get(parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
Returns:

The schema of the target table.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
190    def nested_get(
191        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
192    ) -> t.Optional[t.Any]:
193        return nested_get(
194            d or self.mapping,
195            *zip(self.supported_table_args, reversed(parts)),
196            raise_on_missing=raise_on_missing,
197        )
class MappingSchema(AbstractMappingSchema, Schema):
200class MappingSchema(AbstractMappingSchema, Schema):
201    """
202    Schema based on a nested mapping.
203
204    Args:
205        schema: Mapping in one of the following forms:
206            1. {table: {col: type}}
207            2. {db: {table: {col: type}}}
208            3. {catalog: {db: {table: {col: type}}}}
209            4. None - Tables will be added later
210        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
211            are assumed to be visible. The nesting should mirror that of the schema:
212            1. {table: set(*cols)}}
213            2. {db: {table: set(*cols)}}}
214            3. {catalog: {db: {table: set(*cols)}}}}
215        dialect: The dialect to be used for custom type mappings & parsing string arguments.
216        normalize: Whether to normalize identifier names according to the given dialect or not.
217    """
218
219    def __init__(
220        self,
221        schema: t.Optional[t.Dict] = None,
222        visible: t.Optional[t.Dict] = None,
223        dialect: DialectType = None,
224        normalize: bool = True,
225    ) -> None:
226        self.dialect = dialect
227        self.visible = {} if visible is None else visible
228        self.normalize = normalize
229        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
230        self._depth = 0
231        schema = {} if schema is None else schema
232
233        super().__init__(self._normalize(schema) if self.normalize else schema)
234
235    @classmethod
236    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
237        return MappingSchema(
238            schema=mapping_schema.mapping,
239            visible=mapping_schema.visible,
240            dialect=mapping_schema.dialect,
241            normalize=mapping_schema.normalize,
242        )
243
244    def copy(self, **kwargs) -> MappingSchema:
245        return MappingSchema(
246            **{  # type: ignore
247                "schema": self.mapping.copy(),
248                "visible": self.visible.copy(),
249                "dialect": self.dialect,
250                "normalize": self.normalize,
251                **kwargs,
252            }
253        )
254
255    def add_table(
256        self,
257        table: exp.Table | str,
258        column_mapping: t.Optional[ColumnMapping] = None,
259        dialect: DialectType = None,
260        normalize: t.Optional[bool] = None,
261        match_depth: bool = True,
262    ) -> None:
263        """
264        Register or update a table. Updates are only performed if a new column mapping is provided.
265        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
266
267        Args:
268            table: the `Table` expression instance or string representing the table.
269            column_mapping: a column mapping that describes the structure of the table.
270            dialect: the SQL dialect that will be used to parse `table` if it's a string.
271            normalize: whether to normalize identifiers according to the dialect of interest.
272            match_depth: whether to enforce that the table must match the schema's depth or not.
273        """
274        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
275
276        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
277            raise SchemaError(
278                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
279                f"schema's nesting level: {self.depth()}."
280            )
281
282        normalized_column_mapping = {
283            self._normalize_name(key, dialect=dialect, normalize=normalize): value
284            for key, value in ensure_column_mapping(column_mapping).items()
285        }
286
287        schema = self.find(normalized_table, raise_on_missing=False)
288        if schema and not normalized_column_mapping:
289            return
290
291        parts = self.table_parts(normalized_table)
292
293        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
294        new_trie([parts], self.mapping_trie)
295
296    def column_names(
297        self,
298        table: exp.Table | str,
299        only_visible: bool = False,
300        dialect: DialectType = None,
301        normalize: t.Optional[bool] = None,
302    ) -> t.List[str]:
303        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
304
305        schema = self.find(normalized_table)
306        if schema is None:
307            return []
308
309        if not only_visible or not self.visible:
310            return list(schema)
311
312        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
313        return [col for col in schema if col in visible]
314
315    def get_column_type(
316        self,
317        table: exp.Table | str,
318        column: exp.Column | str,
319        dialect: DialectType = None,
320        normalize: t.Optional[bool] = None,
321    ) -> exp.DataType:
322        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
323
324        normalized_column_name = self._normalize_name(
325            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
326        )
327
328        table_schema = self.find(normalized_table, raise_on_missing=False)
329        if table_schema:
330            column_type = table_schema.get(normalized_column_name)
331
332            if isinstance(column_type, exp.DataType):
333                return column_type
334            elif isinstance(column_type, str):
335                return self._to_data_type(column_type, dialect=dialect)
336
337        return exp.DataType.build("unknown")
338
339    def has_column(
340        self,
341        table: exp.Table | str,
342        column: exp.Column | str,
343        dialect: DialectType = None,
344        normalize: t.Optional[bool] = None,
345    ) -> bool:
346        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
347
348        normalized_column_name = self._normalize_name(
349            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
350        )
351
352        table_schema = self.find(normalized_table, raise_on_missing=False)
353        return normalized_column_name in table_schema if table_schema else False
354
355    def _normalize(self, schema: t.Dict) -> t.Dict:
356        """
357        Normalizes all identifiers in the schema.
358
359        Args:
360            schema: the schema to normalize.
361
362        Returns:
363            The normalized schema mapping.
364        """
365        normalized_mapping: t.Dict = {}
366        flattened_schema = flatten_schema(schema)
367        error_msg = "Table {} must match the schema's nesting level: {}."
368
369        for keys in flattened_schema:
370            columns = nested_get(schema, *zip(keys, keys))
371
372            if not isinstance(columns, dict):
373                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
374            if isinstance(first(columns.values()), dict):
375                raise SchemaError(
376                    error_msg.format(
377                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
378                    ),
379                )
380
381            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
382            for column_name, column_type in columns.items():
383                nested_set(
384                    normalized_mapping,
385                    normalized_keys + [self._normalize_name(column_name)],
386                    column_type,
387                )
388
389        return normalized_mapping
390
391    def _normalize_table(
392        self,
393        table: exp.Table | str,
394        dialect: DialectType = None,
395        normalize: t.Optional[bool] = None,
396    ) -> exp.Table:
397        dialect = dialect or self.dialect
398        normalize = self.normalize if normalize is None else normalize
399
400        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
401
402        if normalize:
403            for arg in exp.TABLE_PARTS:
404                value = normalized_table.args.get(arg)
405                if isinstance(value, exp.Identifier):
406                    normalized_table.set(
407                        arg,
408                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
409                    )
410
411        return normalized_table
412
413    def _normalize_name(
414        self,
415        name: str | exp.Identifier,
416        dialect: DialectType = None,
417        is_table: bool = False,
418        normalize: t.Optional[bool] = None,
419    ) -> str:
420        return normalize_name(
421            name,
422            dialect=dialect or self.dialect,
423            is_table=is_table,
424            normalize=self.normalize if normalize is None else normalize,
425        ).name
426
427    def depth(self) -> int:
428        if not self.empty and not self._depth:
429            # The columns themselves are a mapping, but we don't want to include those
430            self._depth = super().depth() - 1
431        return self._depth
432
433    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
434        """
435        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
436
437        Args:
438            schema_type: the type we want to convert.
439            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
440
441        Returns:
442            The resulting expression type.
443        """
444        if schema_type not in self._type_mapping_cache:
445            dialect = dialect or self.dialect
446            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
447
448            try:
449                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
450                self._type_mapping_cache[schema_type] = expression
451            except AttributeError:
452                in_dialect = f" in dialect {dialect}" if dialect else ""
453                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
454
455        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
219    def __init__(
220        self,
221        schema: t.Optional[t.Dict] = None,
222        visible: t.Optional[t.Dict] = None,
223        dialect: DialectType = None,
224        normalize: bool = True,
225    ) -> None:
226        self.dialect = dialect
227        self.visible = {} if visible is None else visible
228        self.normalize = normalize
229        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
230        self._depth = 0
231        schema = {} if schema is None else schema
232
233        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
235    @classmethod
236    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
237        return MappingSchema(
238            schema=mapping_schema.mapping,
239            visible=mapping_schema.visible,
240            dialect=mapping_schema.dialect,
241            normalize=mapping_schema.normalize,
242        )
def copy(self, **kwargs) -> MappingSchema:
244    def copy(self, **kwargs) -> MappingSchema:
245        return MappingSchema(
246            **{  # type: ignore
247                "schema": self.mapping.copy(),
248                "visible": self.visible.copy(),
249                "dialect": self.dialect,
250                "normalize": self.normalize,
251                **kwargs,
252            }
253        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
255    def add_table(
256        self,
257        table: exp.Table | str,
258        column_mapping: t.Optional[ColumnMapping] = None,
259        dialect: DialectType = None,
260        normalize: t.Optional[bool] = None,
261        match_depth: bool = True,
262    ) -> None:
263        """
264        Register or update a table. Updates are only performed if a new column mapping is provided.
265        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
266
267        Args:
268            table: the `Table` expression instance or string representing the table.
269            column_mapping: a column mapping that describes the structure of the table.
270            dialect: the SQL dialect that will be used to parse `table` if it's a string.
271            normalize: whether to normalize identifiers according to the dialect of interest.
272            match_depth: whether to enforce that the table must match the schema's depth or not.
273        """
274        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
275
276        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
277            raise SchemaError(
278                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
279                f"schema's nesting level: {self.depth()}."
280            )
281
282        normalized_column_mapping = {
283            self._normalize_name(key, dialect=dialect, normalize=normalize): value
284            for key, value in ensure_column_mapping(column_mapping).items()
285        }
286
287        schema = self.find(normalized_table, raise_on_missing=False)
288        if schema and not normalized_column_mapping:
289            return
290
291        parts = self.table_parts(normalized_table)
292
293        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
294        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
296    def column_names(
297        self,
298        table: exp.Table | str,
299        only_visible: bool = False,
300        dialect: DialectType = None,
301        normalize: t.Optional[bool] = None,
302    ) -> t.List[str]:
303        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
304
305        schema = self.find(normalized_table)
306        if schema is None:
307            return []
308
309        if not only_visible or not self.visible:
310            return list(schema)
311
312        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
313        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
315    def get_column_type(
316        self,
317        table: exp.Table | str,
318        column: exp.Column | str,
319        dialect: DialectType = None,
320        normalize: t.Optional[bool] = None,
321    ) -> exp.DataType:
322        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
323
324        normalized_column_name = self._normalize_name(
325            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
326        )
327
328        table_schema = self.find(normalized_table, raise_on_missing=False)
329        if table_schema:
330            column_type = table_schema.get(normalized_column_name)
331
332            if isinstance(column_type, exp.DataType):
333                return column_type
334            elif isinstance(column_type, str):
335                return self._to_data_type(column_type, dialect=dialect)
336
337        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
339    def has_column(
340        self,
341        table: exp.Table | str,
342        column: exp.Column | str,
343        dialect: DialectType = None,
344        normalize: t.Optional[bool] = None,
345    ) -> bool:
346        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
347
348        normalized_column_name = self._normalize_name(
349            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
350        )
351
352        table_schema = self.find(normalized_table, raise_on_missing=False)
353        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
427    def depth(self) -> int:
428        if not self.empty and not self._depth:
429            # The columns themselves are a mapping, but we don't want to include those
430            self._depth = super().depth() - 1
431        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
458def normalize_name(
459    identifier: str | exp.Identifier,
460    dialect: DialectType = None,
461    is_table: bool = False,
462    normalize: t.Optional[bool] = True,
463) -> exp.Identifier:
464    if isinstance(identifier, str):
465        identifier = exp.parse_identifier(identifier, dialect=dialect)
466
467    if not normalize:
468        return identifier
469
470    # this is used for normalize_identifier, bigquery has special rules pertaining tables
471    identifier.meta["is_table"] = is_table
472    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
475def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
476    if isinstance(schema, Schema):
477        return schema
478
479    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
482def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
483    if mapping is None:
484        return {}
485    elif isinstance(mapping, dict):
486        return mapping
487    elif isinstance(mapping, str):
488        col_name_type_strs = [x.strip() for x in mapping.split(",")]
489        return {
490            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
491            for name_type_str in col_name_type_strs
492        }
493    # Check if mapping looks like a DataFrame StructType
494    elif hasattr(mapping, "simpleString"):
495        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
496    elif isinstance(mapping, list):
497        return {x.strip(): None for x in mapping}
498
499    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: Optional[int] = None, keys: Optional[List[str]] = None) -> List[List[str]]:
502def flatten_schema(
503    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
504) -> t.List[t.List[str]]:
505    tables = []
506    keys = keys or []
507    depth = dict_depth(schema) - 1 if depth is None else depth
508
509    for k, v in schema.items():
510        if depth == 1 or not isinstance(v, dict):
511            tables.append(keys + [k])
512        elif depth >= 2:
513            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
514
515    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
518def nested_get(
519    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
520) -> t.Optional[t.Any]:
521    """
522    Get a value for a nested dictionary.
523
524    Args:
525        d: the dictionary to search.
526        *path: tuples of (name, key), where:
527            `key` is the key in the dictionary to get.
528            `name` is a string to use in the error if `key` isn't found.
529
530    Returns:
531        The value or None if it doesn't exist.
532    """
533    for name, key in path:
534        d = d.get(key)  # type: ignore
535        if d is None:
536            if raise_on_missing:
537                name = "table" if name == "this" else name
538                raise ValueError(f"Unknown {name}: {key}")
539            return None
540
541    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
544def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
545    """
546    In-place set a value for a nested dictionary
547
548    Example:
549        >>> nested_set({}, ["top_key", "second_key"], "value")
550        {'top_key': {'second_key': 'value'}}
551
552        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
553        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
554
555    Args:
556        d: dictionary to update.
557        keys: the keys that makeup the path to `value`.
558        value: the value to set in the dictionary for the given key path.
559
560    Returns:
561        The (possibly) updated dictionary.
562    """
563    if not keys:
564        return d
565
566    if len(keys) == 1:
567        d[keys[0]] = value
568        return d
569
570    subd = d
571    for key in keys[:-1]:
572        if key not in subd:
573            subd = subd.setdefault(key, {})
574        else:
575            subd = subd[key]
576
577    subd[keys[-1]] = value
578    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.