sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18 19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.List[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The list of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether or not `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether or not the schema is empty.""" 119 return True 120 121 122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 if isinstance(table.this, exp.ReadCSV): 156 return [table.this.name] 157 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 158 159 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 160 """ 161 Returns the schema of a given table. 162 163 Args: 164 table: the target table. 165 raise_on_missing: whether or not to raise in case the schema is not found. 166 167 Returns: 168 The schema of the target table. 169 """ 170 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 171 value, trie = in_trie(self.mapping_trie, parts) 172 173 if value == TrieResult.FAILED: 174 return None 175 176 if value == TrieResult.PREFIX: 177 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 178 179 if len(possibilities) == 1: 180 parts.extend(possibilities[0]) 181 else: 182 message = ", ".join(".".join(parts) for parts in possibilities) 183 if raise_on_missing: 184 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 185 return None 186 187 return self.nested_get(parts, raise_on_missing=raise_on_missing) 188 189 def nested_get( 190 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 191 ) -> t.Optional[t.Any]: 192 return nested_get( 193 d or self.mapping, 194 *zip(self.supported_table_args, reversed(parts)), 195 raise_on_missing=raise_on_missing, 196 ) 197 198 199class MappingSchema(AbstractMappingSchema, Schema): 200 """ 201 Schema based on a nested mapping. 202 203 Args: 204 schema: Mapping in one of the following forms: 205 1. {table: {col: type}} 206 2. {db: {table: {col: type}}} 207 3. {catalog: {db: {table: {col: type}}}} 208 4. None - Tables will be added later 209 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 210 are assumed to be visible. The nesting should mirror that of the schema: 211 1. {table: set(*cols)}} 212 2. {db: {table: set(*cols)}}} 213 3. {catalog: {db: {table: set(*cols)}}}} 214 dialect: The dialect to be used for custom type mappings & parsing string arguments. 215 normalize: Whether to normalize identifier names according to the given dialect or not. 216 """ 217 218 def __init__( 219 self, 220 schema: t.Optional[t.Dict] = None, 221 visible: t.Optional[t.Dict] = None, 222 dialect: DialectType = None, 223 normalize: bool = True, 224 ) -> None: 225 self.dialect = dialect 226 self.visible = {} if visible is None else visible 227 self.normalize = normalize 228 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 229 self._depth = 0 230 schema = {} if schema is None else schema 231 232 super().__init__(self._normalize(schema) if self.normalize else schema) 233 234 @classmethod 235 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 236 return MappingSchema( 237 schema=mapping_schema.mapping, 238 visible=mapping_schema.visible, 239 dialect=mapping_schema.dialect, 240 normalize=mapping_schema.normalize, 241 ) 242 243 def copy(self, **kwargs) -> MappingSchema: 244 return MappingSchema( 245 **{ # type: ignore 246 "schema": self.mapping.copy(), 247 "visible": self.visible.copy(), 248 "dialect": self.dialect, 249 "normalize": self.normalize, 250 **kwargs, 251 } 252 ) 253 254 def add_table( 255 self, 256 table: exp.Table | str, 257 column_mapping: t.Optional[ColumnMapping] = None, 258 dialect: DialectType = None, 259 normalize: t.Optional[bool] = None, 260 match_depth: bool = True, 261 ) -> None: 262 """ 263 Register or update a table. Updates are only performed if a new column mapping is provided. 264 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 265 266 Args: 267 table: the `Table` expression instance or string representing the table. 268 column_mapping: a column mapping that describes the structure of the table. 269 dialect: the SQL dialect that will be used to parse `table` if it's a string. 270 normalize: whether to normalize identifiers according to the dialect of interest. 271 match_depth: whether to enforce that the table must match the schema's depth or not. 272 """ 273 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 274 275 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 276 raise SchemaError( 277 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 278 f"schema's nesting level: {self.depth()}." 279 ) 280 281 normalized_column_mapping = { 282 self._normalize_name(key, dialect=dialect, normalize=normalize): value 283 for key, value in ensure_column_mapping(column_mapping).items() 284 } 285 286 schema = self.find(normalized_table, raise_on_missing=False) 287 if schema and not normalized_column_mapping: 288 return 289 290 parts = self.table_parts(normalized_table) 291 292 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 293 new_trie([parts], self.mapping_trie) 294 295 def column_names( 296 self, 297 table: exp.Table | str, 298 only_visible: bool = False, 299 dialect: DialectType = None, 300 normalize: t.Optional[bool] = None, 301 ) -> t.List[str]: 302 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 303 304 schema = self.find(normalized_table) 305 if schema is None: 306 return [] 307 308 if not only_visible or not self.visible: 309 return list(schema) 310 311 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 312 return [col for col in schema if col in visible] 313 314 def get_column_type( 315 self, 316 table: exp.Table | str, 317 column: exp.Column | str, 318 dialect: DialectType = None, 319 normalize: t.Optional[bool] = None, 320 ) -> exp.DataType: 321 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 322 323 normalized_column_name = self._normalize_name( 324 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 325 ) 326 327 table_schema = self.find(normalized_table, raise_on_missing=False) 328 if table_schema: 329 column_type = table_schema.get(normalized_column_name) 330 331 if isinstance(column_type, exp.DataType): 332 return column_type 333 elif isinstance(column_type, str): 334 return self._to_data_type(column_type, dialect=dialect) 335 336 return exp.DataType.build("unknown") 337 338 def has_column( 339 self, 340 table: exp.Table | str, 341 column: exp.Column | str, 342 dialect: DialectType = None, 343 normalize: t.Optional[bool] = None, 344 ) -> bool: 345 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 346 347 normalized_column_name = self._normalize_name( 348 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 349 ) 350 351 table_schema = self.find(normalized_table, raise_on_missing=False) 352 return normalized_column_name in table_schema if table_schema else False 353 354 def _normalize(self, schema: t.Dict) -> t.Dict: 355 """ 356 Normalizes all identifiers in the schema. 357 358 Args: 359 schema: the schema to normalize. 360 361 Returns: 362 The normalized schema mapping. 363 """ 364 normalized_mapping: t.Dict = {} 365 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 366 367 for keys in flattened_schema: 368 columns = nested_get(schema, *zip(keys, keys)) 369 370 if not isinstance(columns, dict): 371 raise SchemaError( 372 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 373 ) 374 375 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 376 for column_name, column_type in columns.items(): 377 nested_set( 378 normalized_mapping, 379 normalized_keys + [self._normalize_name(column_name)], 380 column_type, 381 ) 382 383 return normalized_mapping 384 385 def _normalize_table( 386 self, 387 table: exp.Table | str, 388 dialect: DialectType = None, 389 normalize: t.Optional[bool] = None, 390 ) -> exp.Table: 391 dialect = dialect or self.dialect 392 normalize = self.normalize if normalize is None else normalize 393 394 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 395 396 if normalize: 397 for arg in exp.TABLE_PARTS: 398 value = normalized_table.args.get(arg) 399 if isinstance(value, exp.Identifier): 400 normalized_table.set( 401 arg, 402 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 403 ) 404 405 return normalized_table 406 407 def _normalize_name( 408 self, 409 name: str | exp.Identifier, 410 dialect: DialectType = None, 411 is_table: bool = False, 412 normalize: t.Optional[bool] = None, 413 ) -> str: 414 return normalize_name( 415 name, 416 dialect=dialect or self.dialect, 417 is_table=is_table, 418 normalize=self.normalize if normalize is None else normalize, 419 ).name 420 421 def depth(self) -> int: 422 if not self.empty and not self._depth: 423 # The columns themselves are a mapping, but we don't want to include those 424 self._depth = super().depth() - 1 425 return self._depth 426 427 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 428 """ 429 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 430 431 Args: 432 schema_type: the type we want to convert. 433 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 434 435 Returns: 436 The resulting expression type. 437 """ 438 if schema_type not in self._type_mapping_cache: 439 dialect = dialect or self.dialect 440 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 441 442 try: 443 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 444 self._type_mapping_cache[schema_type] = expression 445 except AttributeError: 446 in_dialect = f" in dialect {dialect}" if dialect else "" 447 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 448 449 return self._type_mapping_cache[schema_type] 450 451 452def normalize_name( 453 identifier: str | exp.Identifier, 454 dialect: DialectType = None, 455 is_table: bool = False, 456 normalize: t.Optional[bool] = True, 457) -> exp.Identifier: 458 if isinstance(identifier, str): 459 identifier = exp.parse_identifier(identifier, dialect=dialect) 460 461 if not normalize: 462 return identifier 463 464 # this is used for normalize_identifier, bigquery has special rules pertaining tables 465 identifier.meta["is_table"] = is_table 466 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 467 468 469def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 470 if isinstance(schema, Schema): 471 return schema 472 473 return MappingSchema(schema, **kwargs) 474 475 476def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 477 if mapping is None: 478 return {} 479 elif isinstance(mapping, dict): 480 return mapping 481 elif isinstance(mapping, str): 482 col_name_type_strs = [x.strip() for x in mapping.split(",")] 483 return { 484 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 485 for name_type_str in col_name_type_strs 486 } 487 # Check if mapping looks like a DataFrame StructType 488 elif hasattr(mapping, "simpleString"): 489 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 490 elif isinstance(mapping, list): 491 return {x.strip(): None for x in mapping} 492 493 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 494 495 496def flatten_schema( 497 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 498) -> t.List[t.List[str]]: 499 tables = [] 500 keys = keys or [] 501 502 for k, v in schema.items(): 503 if depth >= 2: 504 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 505 elif depth == 1: 506 tables.append(keys + [k]) 507 508 return tables 509 510 511def nested_get( 512 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 513) -> t.Optional[t.Any]: 514 """ 515 Get a value for a nested dictionary. 516 517 Args: 518 d: the dictionary to search. 519 *path: tuples of (name, key), where: 520 `key` is the key in the dictionary to get. 521 `name` is a string to use in the error if `key` isn't found. 522 523 Returns: 524 The value or None if it doesn't exist. 525 """ 526 for name, key in path: 527 d = d.get(key) # type: ignore 528 if d is None: 529 if raise_on_missing: 530 name = "table" if name == "this" else name 531 raise ValueError(f"Unknown {name}: {key}") 532 return None 533 534 return d 535 536 537def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 538 """ 539 In-place set a value for a nested dictionary 540 541 Example: 542 >>> nested_set({}, ["top_key", "second_key"], "value") 543 {'top_key': {'second_key': 'value'}} 544 545 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 546 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 547 548 Args: 549 d: dictionary to update. 550 keys: the keys that makeup the path to `value`. 551 value: the value to set in the dictionary for the given key path. 552 553 Returns: 554 The (possibly) updated dictionary. 555 """ 556 if not keys: 557 return d 558 559 if len(keys) == 1: 560 d[keys[0]] = value 561 return d 562 563 subd = d 564 for key in keys[:-1]: 565 if key not in subd: 566 subd = subd.setdefault(key, {}) 567 else: 568 subd = subd[key] 569 570 subd[keys[-1]] = value 571 return d
20class Schema(abc.ABC): 21 """Abstract base class for database schemas""" 22 23 dialect: DialectType 24 25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """ 45 46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """ 66 67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """ 87 88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize) 109 110 @property 111 @abc.abstractmethod 112 def supported_table_args(self) -> t.Tuple[str, ...]: 113 """ 114 Table arguments this schema support, e.g. `("this", "db", "catalog")` 115 """ 116 117 @property 118 def empty(self) -> bool: 119 """Returns whether or not the schema is empty.""" 120 return True
Abstract base class for database schemas
25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
123class AbstractMappingSchema: 124 def __init__( 125 self, 126 mapping: t.Optional[t.Dict] = None, 127 ) -> None: 128 self.mapping = mapping or {} 129 self.mapping_trie = new_trie( 130 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 131 ) 132 self._supported_table_args: t.Tuple[str, ...] = tuple() 133 134 @property 135 def empty(self) -> bool: 136 return not self.mapping 137 138 def depth(self) -> int: 139 return dict_depth(self.mapping) 140 141 @property 142 def supported_table_args(self) -> t.Tuple[str, ...]: 143 if not self._supported_table_args and self.mapping: 144 depth = self.depth() 145 146 if not depth: # None 147 self._supported_table_args = tuple() 148 elif 1 <= depth <= 3: 149 self._supported_table_args = exp.TABLE_PARTS[:depth] 150 else: 151 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 152 153 return self._supported_table_args 154 155 def table_parts(self, table: exp.Table) -> t.List[str]: 156 if isinstance(table.this, exp.ReadCSV): 157 return [table.this.name] 158 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 159 160 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 161 """ 162 Returns the schema of a given table. 163 164 Args: 165 table: the target table. 166 raise_on_missing: whether or not to raise in case the schema is not found. 167 168 Returns: 169 The schema of the target table. 170 """ 171 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 172 value, trie = in_trie(self.mapping_trie, parts) 173 174 if value == TrieResult.FAILED: 175 return None 176 177 if value == TrieResult.PREFIX: 178 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 179 180 if len(possibilities) == 1: 181 parts.extend(possibilities[0]) 182 else: 183 message = ", ".join(".".join(parts) for parts in possibilities) 184 if raise_on_missing: 185 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 186 return None 187 188 return self.nested_get(parts, raise_on_missing=raise_on_missing) 189 190 def nested_get( 191 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 192 ) -> t.Optional[t.Any]: 193 return nested_get( 194 d or self.mapping, 195 *zip(self.supported_table_args, reversed(parts)), 196 raise_on_missing=raise_on_missing, 197 )
141 @property 142 def supported_table_args(self) -> t.Tuple[str, ...]: 143 if not self._supported_table_args and self.mapping: 144 depth = self.depth() 145 146 if not depth: # None 147 self._supported_table_args = tuple() 148 elif 1 <= depth <= 3: 149 self._supported_table_args = exp.TABLE_PARTS[:depth] 150 else: 151 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 152 153 return self._supported_table_args
160 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 161 """ 162 Returns the schema of a given table. 163 164 Args: 165 table: the target table. 166 raise_on_missing: whether or not to raise in case the schema is not found. 167 168 Returns: 169 The schema of the target table. 170 """ 171 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 172 value, trie = in_trie(self.mapping_trie, parts) 173 174 if value == TrieResult.FAILED: 175 return None 176 177 if value == TrieResult.PREFIX: 178 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 179 180 if len(possibilities) == 1: 181 parts.extend(possibilities[0]) 182 else: 183 message = ", ".join(".".join(parts) for parts in possibilities) 184 if raise_on_missing: 185 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 186 return None 187 188 return self.nested_get(parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether or not to raise in case the schema is not found.
Returns:
The schema of the target table.
200class MappingSchema(AbstractMappingSchema, Schema): 201 """ 202 Schema based on a nested mapping. 203 204 Args: 205 schema: Mapping in one of the following forms: 206 1. {table: {col: type}} 207 2. {db: {table: {col: type}}} 208 3. {catalog: {db: {table: {col: type}}}} 209 4. None - Tables will be added later 210 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 211 are assumed to be visible. The nesting should mirror that of the schema: 212 1. {table: set(*cols)}} 213 2. {db: {table: set(*cols)}}} 214 3. {catalog: {db: {table: set(*cols)}}}} 215 dialect: The dialect to be used for custom type mappings & parsing string arguments. 216 normalize: Whether to normalize identifier names according to the given dialect or not. 217 """ 218 219 def __init__( 220 self, 221 schema: t.Optional[t.Dict] = None, 222 visible: t.Optional[t.Dict] = None, 223 dialect: DialectType = None, 224 normalize: bool = True, 225 ) -> None: 226 self.dialect = dialect 227 self.visible = {} if visible is None else visible 228 self.normalize = normalize 229 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 230 self._depth = 0 231 schema = {} if schema is None else schema 232 233 super().__init__(self._normalize(schema) if self.normalize else schema) 234 235 @classmethod 236 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 237 return MappingSchema( 238 schema=mapping_schema.mapping, 239 visible=mapping_schema.visible, 240 dialect=mapping_schema.dialect, 241 normalize=mapping_schema.normalize, 242 ) 243 244 def copy(self, **kwargs) -> MappingSchema: 245 return MappingSchema( 246 **{ # type: ignore 247 "schema": self.mapping.copy(), 248 "visible": self.visible.copy(), 249 "dialect": self.dialect, 250 "normalize": self.normalize, 251 **kwargs, 252 } 253 ) 254 255 def add_table( 256 self, 257 table: exp.Table | str, 258 column_mapping: t.Optional[ColumnMapping] = None, 259 dialect: DialectType = None, 260 normalize: t.Optional[bool] = None, 261 match_depth: bool = True, 262 ) -> None: 263 """ 264 Register or update a table. Updates are only performed if a new column mapping is provided. 265 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 266 267 Args: 268 table: the `Table` expression instance or string representing the table. 269 column_mapping: a column mapping that describes the structure of the table. 270 dialect: the SQL dialect that will be used to parse `table` if it's a string. 271 normalize: whether to normalize identifiers according to the dialect of interest. 272 match_depth: whether to enforce that the table must match the schema's depth or not. 273 """ 274 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 275 276 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 277 raise SchemaError( 278 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 279 f"schema's nesting level: {self.depth()}." 280 ) 281 282 normalized_column_mapping = { 283 self._normalize_name(key, dialect=dialect, normalize=normalize): value 284 for key, value in ensure_column_mapping(column_mapping).items() 285 } 286 287 schema = self.find(normalized_table, raise_on_missing=False) 288 if schema and not normalized_column_mapping: 289 return 290 291 parts = self.table_parts(normalized_table) 292 293 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 294 new_trie([parts], self.mapping_trie) 295 296 def column_names( 297 self, 298 table: exp.Table | str, 299 only_visible: bool = False, 300 dialect: DialectType = None, 301 normalize: t.Optional[bool] = None, 302 ) -> t.List[str]: 303 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 304 305 schema = self.find(normalized_table) 306 if schema is None: 307 return [] 308 309 if not only_visible or not self.visible: 310 return list(schema) 311 312 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 313 return [col for col in schema if col in visible] 314 315 def get_column_type( 316 self, 317 table: exp.Table | str, 318 column: exp.Column | str, 319 dialect: DialectType = None, 320 normalize: t.Optional[bool] = None, 321 ) -> exp.DataType: 322 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 323 324 normalized_column_name = self._normalize_name( 325 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 326 ) 327 328 table_schema = self.find(normalized_table, raise_on_missing=False) 329 if table_schema: 330 column_type = table_schema.get(normalized_column_name) 331 332 if isinstance(column_type, exp.DataType): 333 return column_type 334 elif isinstance(column_type, str): 335 return self._to_data_type(column_type, dialect=dialect) 336 337 return exp.DataType.build("unknown") 338 339 def has_column( 340 self, 341 table: exp.Table | str, 342 column: exp.Column | str, 343 dialect: DialectType = None, 344 normalize: t.Optional[bool] = None, 345 ) -> bool: 346 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 347 348 normalized_column_name = self._normalize_name( 349 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 350 ) 351 352 table_schema = self.find(normalized_table, raise_on_missing=False) 353 return normalized_column_name in table_schema if table_schema else False 354 355 def _normalize(self, schema: t.Dict) -> t.Dict: 356 """ 357 Normalizes all identifiers in the schema. 358 359 Args: 360 schema: the schema to normalize. 361 362 Returns: 363 The normalized schema mapping. 364 """ 365 normalized_mapping: t.Dict = {} 366 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 367 368 for keys in flattened_schema: 369 columns = nested_get(schema, *zip(keys, keys)) 370 371 if not isinstance(columns, dict): 372 raise SchemaError( 373 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 374 ) 375 376 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 377 for column_name, column_type in columns.items(): 378 nested_set( 379 normalized_mapping, 380 normalized_keys + [self._normalize_name(column_name)], 381 column_type, 382 ) 383 384 return normalized_mapping 385 386 def _normalize_table( 387 self, 388 table: exp.Table | str, 389 dialect: DialectType = None, 390 normalize: t.Optional[bool] = None, 391 ) -> exp.Table: 392 dialect = dialect or self.dialect 393 normalize = self.normalize if normalize is None else normalize 394 395 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 396 397 if normalize: 398 for arg in exp.TABLE_PARTS: 399 value = normalized_table.args.get(arg) 400 if isinstance(value, exp.Identifier): 401 normalized_table.set( 402 arg, 403 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 404 ) 405 406 return normalized_table 407 408 def _normalize_name( 409 self, 410 name: str | exp.Identifier, 411 dialect: DialectType = None, 412 is_table: bool = False, 413 normalize: t.Optional[bool] = None, 414 ) -> str: 415 return normalize_name( 416 name, 417 dialect=dialect or self.dialect, 418 is_table=is_table, 419 normalize=self.normalize if normalize is None else normalize, 420 ).name 421 422 def depth(self) -> int: 423 if not self.empty and not self._depth: 424 # The columns themselves are a mapping, but we don't want to include those 425 self._depth = super().depth() - 1 426 return self._depth 427 428 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 429 """ 430 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 431 432 Args: 433 schema_type: the type we want to convert. 434 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 435 436 Returns: 437 The resulting expression type. 438 """ 439 if schema_type not in self._type_mapping_cache: 440 dialect = dialect or self.dialect 441 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 442 443 try: 444 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 445 self._type_mapping_cache[schema_type] = expression 446 except AttributeError: 447 in_dialect = f" in dialect {dialect}" if dialect else "" 448 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 449 450 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
219 def __init__( 220 self, 221 schema: t.Optional[t.Dict] = None, 222 visible: t.Optional[t.Dict] = None, 223 dialect: DialectType = None, 224 normalize: bool = True, 225 ) -> None: 226 self.dialect = dialect 227 self.visible = {} if visible is None else visible 228 self.normalize = normalize 229 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 230 self._depth = 0 231 schema = {} if schema is None else schema 232 233 super().__init__(self._normalize(schema) if self.normalize else schema)
255 def add_table( 256 self, 257 table: exp.Table | str, 258 column_mapping: t.Optional[ColumnMapping] = None, 259 dialect: DialectType = None, 260 normalize: t.Optional[bool] = None, 261 match_depth: bool = True, 262 ) -> None: 263 """ 264 Register or update a table. Updates are only performed if a new column mapping is provided. 265 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 266 267 Args: 268 table: the `Table` expression instance or string representing the table. 269 column_mapping: a column mapping that describes the structure of the table. 270 dialect: the SQL dialect that will be used to parse `table` if it's a string. 271 normalize: whether to normalize identifiers according to the dialect of interest. 272 match_depth: whether to enforce that the table must match the schema's depth or not. 273 """ 274 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 275 276 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 277 raise SchemaError( 278 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 279 f"schema's nesting level: {self.depth()}." 280 ) 281 282 normalized_column_mapping = { 283 self._normalize_name(key, dialect=dialect, normalize=normalize): value 284 for key, value in ensure_column_mapping(column_mapping).items() 285 } 286 287 schema = self.find(normalized_table, raise_on_missing=False) 288 if schema and not normalized_column_mapping: 289 return 290 291 parts = self.table_parts(normalized_table) 292 293 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 294 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
296 def column_names( 297 self, 298 table: exp.Table | str, 299 only_visible: bool = False, 300 dialect: DialectType = None, 301 normalize: t.Optional[bool] = None, 302 ) -> t.List[str]: 303 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 304 305 schema = self.find(normalized_table) 306 if schema is None: 307 return [] 308 309 if not only_visible or not self.visible: 310 return list(schema) 311 312 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 313 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
315 def get_column_type( 316 self, 317 table: exp.Table | str, 318 column: exp.Column | str, 319 dialect: DialectType = None, 320 normalize: t.Optional[bool] = None, 321 ) -> exp.DataType: 322 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 323 324 normalized_column_name = self._normalize_name( 325 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 326 ) 327 328 table_schema = self.find(normalized_table, raise_on_missing=False) 329 if table_schema: 330 column_type = table_schema.get(normalized_column_name) 331 332 if isinstance(column_type, exp.DataType): 333 return column_type 334 elif isinstance(column_type, str): 335 return self._to_data_type(column_type, dialect=dialect) 336 337 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
339 def has_column( 340 self, 341 table: exp.Table | str, 342 column: exp.Column | str, 343 dialect: DialectType = None, 344 normalize: t.Optional[bool] = None, 345 ) -> bool: 346 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 347 348 normalized_column_name = self._normalize_name( 349 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 350 ) 351 352 table_schema = self.find(normalized_table, raise_on_missing=False) 353 return normalized_column_name in table_schema if table_schema else False
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
453def normalize_name( 454 identifier: str | exp.Identifier, 455 dialect: DialectType = None, 456 is_table: bool = False, 457 normalize: t.Optional[bool] = True, 458) -> exp.Identifier: 459 if isinstance(identifier, str): 460 identifier = exp.parse_identifier(identifier, dialect=dialect) 461 462 if not normalize: 463 return identifier 464 465 # this is used for normalize_identifier, bigquery has special rules pertaining tables 466 identifier.meta["is_table"] = is_table 467 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
477def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 478 if mapping is None: 479 return {} 480 elif isinstance(mapping, dict): 481 return mapping 482 elif isinstance(mapping, str): 483 col_name_type_strs = [x.strip() for x in mapping.split(",")] 484 return { 485 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 486 for name_type_str in col_name_type_strs 487 } 488 # Check if mapping looks like a DataFrame StructType 489 elif hasattr(mapping, "simpleString"): 490 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 491 elif isinstance(mapping, list): 492 return {x.strip(): None for x in mapping} 493 494 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
497def flatten_schema( 498 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 499) -> t.List[t.List[str]]: 500 tables = [] 501 keys = keys or [] 502 503 for k, v in schema.items(): 504 if depth >= 2: 505 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 506 elif depth == 1: 507 tables.append(keys + [k]) 508 509 return tables
512def nested_get( 513 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 514) -> t.Optional[t.Any]: 515 """ 516 Get a value for a nested dictionary. 517 518 Args: 519 d: the dictionary to search. 520 *path: tuples of (name, key), where: 521 `key` is the key in the dictionary to get. 522 `name` is a string to use in the error if `key` isn't found. 523 524 Returns: 525 The value or None if it doesn't exist. 526 """ 527 for name, key in path: 528 d = d.get(key) # type: ignore 529 if d is None: 530 if raise_on_missing: 531 name = "table" if name == "this" else name 532 raise ValueError(f"Unknown {name}: {key}") 533 return None 534 535 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
538def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 539 """ 540 In-place set a value for a nested dictionary 541 542 Example: 543 >>> nested_set({}, ["top_key", "second_key"], "value") 544 {'top_key': {'second_key': 'value'}} 545 546 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 547 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 548 549 Args: 550 d: dictionary to update. 551 keys: the keys that makeup the path to `value`. 552 value: the value to set in the dictionary for the given key path. 553 554 Returns: 555 The (possibly) updated dictionary. 556 """ 557 if not keys: 558 return d 559 560 if len(keys) == 1: 561 d[keys[0]] = value 562 return d 563 564 subd = d 565 for key in keys[:-1]: 566 if key not in subd: 567 subd = subd.setdefault(key, {}) 568 else: 569 subd = subd[key] 570 571 subd[keys[-1]] = value 572 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.