sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18 19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether the schema is empty.""" 119 return True 120 121 122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 if isinstance(table.this, exp.ReadCSV): 156 return [table.this.name] 157 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 158 159 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 160 """ 161 Returns the schema of a given table. 162 163 Args: 164 table: the target table. 165 raise_on_missing: whether to raise in case the schema is not found. 166 167 Returns: 168 The schema of the target table. 169 """ 170 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 171 value, trie = in_trie(self.mapping_trie, parts) 172 173 if value == TrieResult.FAILED: 174 return None 175 176 if value == TrieResult.PREFIX: 177 possibilities = flatten_schema(trie) 178 179 if len(possibilities) == 1: 180 parts.extend(possibilities[0]) 181 else: 182 message = ", ".join(".".join(parts) for parts in possibilities) 183 if raise_on_missing: 184 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 185 return None 186 187 return self.nested_get(parts, raise_on_missing=raise_on_missing) 188 189 def nested_get( 190 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 191 ) -> t.Optional[t.Any]: 192 return nested_get( 193 d or self.mapping, 194 *zip(self.supported_table_args, reversed(parts)), 195 raise_on_missing=raise_on_missing, 196 ) 197 198 199class MappingSchema(AbstractMappingSchema, Schema): 200 """ 201 Schema based on a nested mapping. 202 203 Args: 204 schema: Mapping in one of the following forms: 205 1. {table: {col: type}} 206 2. {db: {table: {col: type}}} 207 3. {catalog: {db: {table: {col: type}}}} 208 4. None - Tables will be added later 209 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 210 are assumed to be visible. The nesting should mirror that of the schema: 211 1. {table: set(*cols)}} 212 2. {db: {table: set(*cols)}}} 213 3. {catalog: {db: {table: set(*cols)}}}} 214 dialect: The dialect to be used for custom type mappings & parsing string arguments. 215 normalize: Whether to normalize identifier names according to the given dialect or not. 216 """ 217 218 def __init__( 219 self, 220 schema: t.Optional[t.Dict] = None, 221 visible: t.Optional[t.Dict] = None, 222 dialect: DialectType = None, 223 normalize: bool = True, 224 ) -> None: 225 self.dialect = dialect 226 self.visible = {} if visible is None else visible 227 self.normalize = normalize 228 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 229 self._depth = 0 230 schema = {} if schema is None else schema 231 232 super().__init__(self._normalize(schema) if self.normalize else schema) 233 234 @classmethod 235 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 236 return MappingSchema( 237 schema=mapping_schema.mapping, 238 visible=mapping_schema.visible, 239 dialect=mapping_schema.dialect, 240 normalize=mapping_schema.normalize, 241 ) 242 243 def copy(self, **kwargs) -> MappingSchema: 244 return MappingSchema( 245 **{ # type: ignore 246 "schema": self.mapping.copy(), 247 "visible": self.visible.copy(), 248 "dialect": self.dialect, 249 "normalize": self.normalize, 250 **kwargs, 251 } 252 ) 253 254 def add_table( 255 self, 256 table: exp.Table | str, 257 column_mapping: t.Optional[ColumnMapping] = None, 258 dialect: DialectType = None, 259 normalize: t.Optional[bool] = None, 260 match_depth: bool = True, 261 ) -> None: 262 """ 263 Register or update a table. Updates are only performed if a new column mapping is provided. 264 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 265 266 Args: 267 table: the `Table` expression instance or string representing the table. 268 column_mapping: a column mapping that describes the structure of the table. 269 dialect: the SQL dialect that will be used to parse `table` if it's a string. 270 normalize: whether to normalize identifiers according to the dialect of interest. 271 match_depth: whether to enforce that the table must match the schema's depth or not. 272 """ 273 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 274 275 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 276 raise SchemaError( 277 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 278 f"schema's nesting level: {self.depth()}." 279 ) 280 281 normalized_column_mapping = { 282 self._normalize_name(key, dialect=dialect, normalize=normalize): value 283 for key, value in ensure_column_mapping(column_mapping).items() 284 } 285 286 schema = self.find(normalized_table, raise_on_missing=False) 287 if schema and not normalized_column_mapping: 288 return 289 290 parts = self.table_parts(normalized_table) 291 292 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 293 new_trie([parts], self.mapping_trie) 294 295 def column_names( 296 self, 297 table: exp.Table | str, 298 only_visible: bool = False, 299 dialect: DialectType = None, 300 normalize: t.Optional[bool] = None, 301 ) -> t.List[str]: 302 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 303 304 schema = self.find(normalized_table) 305 if schema is None: 306 return [] 307 308 if not only_visible or not self.visible: 309 return list(schema) 310 311 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 312 return [col for col in schema if col in visible] 313 314 def get_column_type( 315 self, 316 table: exp.Table | str, 317 column: exp.Column | str, 318 dialect: DialectType = None, 319 normalize: t.Optional[bool] = None, 320 ) -> exp.DataType: 321 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 322 323 normalized_column_name = self._normalize_name( 324 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 325 ) 326 327 table_schema = self.find(normalized_table, raise_on_missing=False) 328 if table_schema: 329 column_type = table_schema.get(normalized_column_name) 330 331 if isinstance(column_type, exp.DataType): 332 return column_type 333 elif isinstance(column_type, str): 334 return self._to_data_type(column_type, dialect=dialect) 335 336 return exp.DataType.build("unknown") 337 338 def has_column( 339 self, 340 table: exp.Table | str, 341 column: exp.Column | str, 342 dialect: DialectType = None, 343 normalize: t.Optional[bool] = None, 344 ) -> bool: 345 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 346 347 normalized_column_name = self._normalize_name( 348 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 349 ) 350 351 table_schema = self.find(normalized_table, raise_on_missing=False) 352 return normalized_column_name in table_schema if table_schema else False 353 354 def _normalize(self, schema: t.Dict) -> t.Dict: 355 """ 356 Normalizes all identifiers in the schema. 357 358 Args: 359 schema: the schema to normalize. 360 361 Returns: 362 The normalized schema mapping. 363 """ 364 normalized_mapping: t.Dict = {} 365 flattened_schema = flatten_schema(schema) 366 error_msg = "Table {} must match the schema's nesting level: {}." 367 368 for keys in flattened_schema: 369 columns = nested_get(schema, *zip(keys, keys)) 370 371 if not isinstance(columns, dict): 372 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 373 if isinstance(first(columns.values()), dict): 374 raise SchemaError( 375 error_msg.format( 376 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 377 ), 378 ) 379 380 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 381 for column_name, column_type in columns.items(): 382 nested_set( 383 normalized_mapping, 384 normalized_keys + [self._normalize_name(column_name)], 385 column_type, 386 ) 387 388 return normalized_mapping 389 390 def _normalize_table( 391 self, 392 table: exp.Table | str, 393 dialect: DialectType = None, 394 normalize: t.Optional[bool] = None, 395 ) -> exp.Table: 396 dialect = dialect or self.dialect 397 normalize = self.normalize if normalize is None else normalize 398 399 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 400 401 if normalize: 402 for arg in exp.TABLE_PARTS: 403 value = normalized_table.args.get(arg) 404 if isinstance(value, exp.Identifier): 405 normalized_table.set( 406 arg, 407 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 408 ) 409 410 return normalized_table 411 412 def _normalize_name( 413 self, 414 name: str | exp.Identifier, 415 dialect: DialectType = None, 416 is_table: bool = False, 417 normalize: t.Optional[bool] = None, 418 ) -> str: 419 return normalize_name( 420 name, 421 dialect=dialect or self.dialect, 422 is_table=is_table, 423 normalize=self.normalize if normalize is None else normalize, 424 ).name 425 426 def depth(self) -> int: 427 if not self.empty and not self._depth: 428 # The columns themselves are a mapping, but we don't want to include those 429 self._depth = super().depth() - 1 430 return self._depth 431 432 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 433 """ 434 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 435 436 Args: 437 schema_type: the type we want to convert. 438 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 439 440 Returns: 441 The resulting expression type. 442 """ 443 if schema_type not in self._type_mapping_cache: 444 dialect = dialect or self.dialect 445 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 446 447 try: 448 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 449 self._type_mapping_cache[schema_type] = expression 450 except AttributeError: 451 in_dialect = f" in dialect {dialect}" if dialect else "" 452 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 453 454 return self._type_mapping_cache[schema_type] 455 456 457def normalize_name( 458 identifier: str | exp.Identifier, 459 dialect: DialectType = None, 460 is_table: bool = False, 461 normalize: t.Optional[bool] = True, 462) -> exp.Identifier: 463 if isinstance(identifier, str): 464 identifier = exp.parse_identifier(identifier, dialect=dialect) 465 466 if not normalize: 467 return identifier 468 469 # this is used for normalize_identifier, bigquery has special rules pertaining tables 470 identifier.meta["is_table"] = is_table 471 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 472 473 474def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 475 if isinstance(schema, Schema): 476 return schema 477 478 return MappingSchema(schema, **kwargs) 479 480 481def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 482 if mapping is None: 483 return {} 484 elif isinstance(mapping, dict): 485 return mapping 486 elif isinstance(mapping, str): 487 col_name_type_strs = [x.strip() for x in mapping.split(",")] 488 return { 489 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 490 for name_type_str in col_name_type_strs 491 } 492 # Check if mapping looks like a DataFrame StructType 493 elif hasattr(mapping, "simpleString"): 494 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 495 elif isinstance(mapping, list): 496 return {x.strip(): None for x in mapping} 497 498 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 499 500 501def flatten_schema( 502 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 503) -> t.List[t.List[str]]: 504 tables = [] 505 keys = keys or [] 506 depth = dict_depth(schema) - 1 if depth is None else depth 507 508 for k, v in schema.items(): 509 if depth == 1 or not isinstance(v, dict): 510 tables.append(keys + [k]) 511 elif depth >= 2: 512 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 513 514 return tables 515 516 517def nested_get( 518 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 519) -> t.Optional[t.Any]: 520 """ 521 Get a value for a nested dictionary. 522 523 Args: 524 d: the dictionary to search. 525 *path: tuples of (name, key), where: 526 `key` is the key in the dictionary to get. 527 `name` is a string to use in the error if `key` isn't found. 528 529 Returns: 530 The value or None if it doesn't exist. 531 """ 532 for name, key in path: 533 d = d.get(key) # type: ignore 534 if d is None: 535 if raise_on_missing: 536 name = "table" if name == "this" else name 537 raise ValueError(f"Unknown {name}: {key}") 538 return None 539 540 return d 541 542 543def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 544 """ 545 In-place set a value for a nested dictionary 546 547 Example: 548 >>> nested_set({}, ["top_key", "second_key"], "value") 549 {'top_key': {'second_key': 'value'}} 550 551 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 552 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 553 554 Args: 555 d: dictionary to update. 556 keys: the keys that makeup the path to `value`. 557 value: the value to set in the dictionary for the given key path. 558 559 Returns: 560 The (possibly) updated dictionary. 561 """ 562 if not keys: 563 return d 564 565 if len(keys) == 1: 566 d[keys[0]] = value 567 return d 568 569 subd = d 570 for key in keys[:-1]: 571 if key not in subd: 572 subd = subd.setdefault(key, {}) 573 else: 574 subd = subd[key] 575 576 subd[keys[-1]] = value 577 return d
20class Schema(abc.ABC): 21 """Abstract base class for database schemas""" 22 23 dialect: DialectType 24 25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """ 45 46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.Sequence[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The sequence of column names. 65 """ 66 67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """ 87 88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize) 109 110 @property 111 @abc.abstractmethod 112 def supported_table_args(self) -> t.Tuple[str, ...]: 113 """ 114 Table arguments this schema support, e.g. `("this", "db", "catalog")` 115 """ 116 117 @property 118 def empty(self) -> bool: 119 """Returns whether the schema is empty.""" 120 return True
Abstract base class for database schemas
25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.Sequence[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The sequence of column names. 65 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
123class AbstractMappingSchema: 124 def __init__( 125 self, 126 mapping: t.Optional[t.Dict] = None, 127 ) -> None: 128 self.mapping = mapping or {} 129 self.mapping_trie = new_trie( 130 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 131 ) 132 self._supported_table_args: t.Tuple[str, ...] = tuple() 133 134 @property 135 def empty(self) -> bool: 136 return not self.mapping 137 138 def depth(self) -> int: 139 return dict_depth(self.mapping) 140 141 @property 142 def supported_table_args(self) -> t.Tuple[str, ...]: 143 if not self._supported_table_args and self.mapping: 144 depth = self.depth() 145 146 if not depth: # None 147 self._supported_table_args = tuple() 148 elif 1 <= depth <= 3: 149 self._supported_table_args = exp.TABLE_PARTS[:depth] 150 else: 151 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 152 153 return self._supported_table_args 154 155 def table_parts(self, table: exp.Table) -> t.List[str]: 156 if isinstance(table.this, exp.ReadCSV): 157 return [table.this.name] 158 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 159 160 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 161 """ 162 Returns the schema of a given table. 163 164 Args: 165 table: the target table. 166 raise_on_missing: whether to raise in case the schema is not found. 167 168 Returns: 169 The schema of the target table. 170 """ 171 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 172 value, trie = in_trie(self.mapping_trie, parts) 173 174 if value == TrieResult.FAILED: 175 return None 176 177 if value == TrieResult.PREFIX: 178 possibilities = flatten_schema(trie) 179 180 if len(possibilities) == 1: 181 parts.extend(possibilities[0]) 182 else: 183 message = ", ".join(".".join(parts) for parts in possibilities) 184 if raise_on_missing: 185 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 186 return None 187 188 return self.nested_get(parts, raise_on_missing=raise_on_missing) 189 190 def nested_get( 191 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 192 ) -> t.Optional[t.Any]: 193 return nested_get( 194 d or self.mapping, 195 *zip(self.supported_table_args, reversed(parts)), 196 raise_on_missing=raise_on_missing, 197 )
141 @property 142 def supported_table_args(self) -> t.Tuple[str, ...]: 143 if not self._supported_table_args and self.mapping: 144 depth = self.depth() 145 146 if not depth: # None 147 self._supported_table_args = tuple() 148 elif 1 <= depth <= 3: 149 self._supported_table_args = exp.TABLE_PARTS[:depth] 150 else: 151 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 152 153 return self._supported_table_args
160 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 161 """ 162 Returns the schema of a given table. 163 164 Args: 165 table: the target table. 166 raise_on_missing: whether to raise in case the schema is not found. 167 168 Returns: 169 The schema of the target table. 170 """ 171 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 172 value, trie = in_trie(self.mapping_trie, parts) 173 174 if value == TrieResult.FAILED: 175 return None 176 177 if value == TrieResult.PREFIX: 178 possibilities = flatten_schema(trie) 179 180 if len(possibilities) == 1: 181 parts.extend(possibilities[0]) 182 else: 183 message = ", ".join(".".join(parts) for parts in possibilities) 184 if raise_on_missing: 185 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 186 return None 187 188 return self.nested_get(parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
Returns:
The schema of the target table.
200class MappingSchema(AbstractMappingSchema, Schema): 201 """ 202 Schema based on a nested mapping. 203 204 Args: 205 schema: Mapping in one of the following forms: 206 1. {table: {col: type}} 207 2. {db: {table: {col: type}}} 208 3. {catalog: {db: {table: {col: type}}}} 209 4. None - Tables will be added later 210 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 211 are assumed to be visible. The nesting should mirror that of the schema: 212 1. {table: set(*cols)}} 213 2. {db: {table: set(*cols)}}} 214 3. {catalog: {db: {table: set(*cols)}}}} 215 dialect: The dialect to be used for custom type mappings & parsing string arguments. 216 normalize: Whether to normalize identifier names according to the given dialect or not. 217 """ 218 219 def __init__( 220 self, 221 schema: t.Optional[t.Dict] = None, 222 visible: t.Optional[t.Dict] = None, 223 dialect: DialectType = None, 224 normalize: bool = True, 225 ) -> None: 226 self.dialect = dialect 227 self.visible = {} if visible is None else visible 228 self.normalize = normalize 229 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 230 self._depth = 0 231 schema = {} if schema is None else schema 232 233 super().__init__(self._normalize(schema) if self.normalize else schema) 234 235 @classmethod 236 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 237 return MappingSchema( 238 schema=mapping_schema.mapping, 239 visible=mapping_schema.visible, 240 dialect=mapping_schema.dialect, 241 normalize=mapping_schema.normalize, 242 ) 243 244 def copy(self, **kwargs) -> MappingSchema: 245 return MappingSchema( 246 **{ # type: ignore 247 "schema": self.mapping.copy(), 248 "visible": self.visible.copy(), 249 "dialect": self.dialect, 250 "normalize": self.normalize, 251 **kwargs, 252 } 253 ) 254 255 def add_table( 256 self, 257 table: exp.Table | str, 258 column_mapping: t.Optional[ColumnMapping] = None, 259 dialect: DialectType = None, 260 normalize: t.Optional[bool] = None, 261 match_depth: bool = True, 262 ) -> None: 263 """ 264 Register or update a table. Updates are only performed if a new column mapping is provided. 265 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 266 267 Args: 268 table: the `Table` expression instance or string representing the table. 269 column_mapping: a column mapping that describes the structure of the table. 270 dialect: the SQL dialect that will be used to parse `table` if it's a string. 271 normalize: whether to normalize identifiers according to the dialect of interest. 272 match_depth: whether to enforce that the table must match the schema's depth or not. 273 """ 274 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 275 276 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 277 raise SchemaError( 278 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 279 f"schema's nesting level: {self.depth()}." 280 ) 281 282 normalized_column_mapping = { 283 self._normalize_name(key, dialect=dialect, normalize=normalize): value 284 for key, value in ensure_column_mapping(column_mapping).items() 285 } 286 287 schema = self.find(normalized_table, raise_on_missing=False) 288 if schema and not normalized_column_mapping: 289 return 290 291 parts = self.table_parts(normalized_table) 292 293 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 294 new_trie([parts], self.mapping_trie) 295 296 def column_names( 297 self, 298 table: exp.Table | str, 299 only_visible: bool = False, 300 dialect: DialectType = None, 301 normalize: t.Optional[bool] = None, 302 ) -> t.List[str]: 303 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 304 305 schema = self.find(normalized_table) 306 if schema is None: 307 return [] 308 309 if not only_visible or not self.visible: 310 return list(schema) 311 312 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 313 return [col for col in schema if col in visible] 314 315 def get_column_type( 316 self, 317 table: exp.Table | str, 318 column: exp.Column | str, 319 dialect: DialectType = None, 320 normalize: t.Optional[bool] = None, 321 ) -> exp.DataType: 322 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 323 324 normalized_column_name = self._normalize_name( 325 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 326 ) 327 328 table_schema = self.find(normalized_table, raise_on_missing=False) 329 if table_schema: 330 column_type = table_schema.get(normalized_column_name) 331 332 if isinstance(column_type, exp.DataType): 333 return column_type 334 elif isinstance(column_type, str): 335 return self._to_data_type(column_type, dialect=dialect) 336 337 return exp.DataType.build("unknown") 338 339 def has_column( 340 self, 341 table: exp.Table | str, 342 column: exp.Column | str, 343 dialect: DialectType = None, 344 normalize: t.Optional[bool] = None, 345 ) -> bool: 346 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 347 348 normalized_column_name = self._normalize_name( 349 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 350 ) 351 352 table_schema = self.find(normalized_table, raise_on_missing=False) 353 return normalized_column_name in table_schema if table_schema else False 354 355 def _normalize(self, schema: t.Dict) -> t.Dict: 356 """ 357 Normalizes all identifiers in the schema. 358 359 Args: 360 schema: the schema to normalize. 361 362 Returns: 363 The normalized schema mapping. 364 """ 365 normalized_mapping: t.Dict = {} 366 flattened_schema = flatten_schema(schema) 367 error_msg = "Table {} must match the schema's nesting level: {}." 368 369 for keys in flattened_schema: 370 columns = nested_get(schema, *zip(keys, keys)) 371 372 if not isinstance(columns, dict): 373 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 374 if isinstance(first(columns.values()), dict): 375 raise SchemaError( 376 error_msg.format( 377 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 378 ), 379 ) 380 381 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 382 for column_name, column_type in columns.items(): 383 nested_set( 384 normalized_mapping, 385 normalized_keys + [self._normalize_name(column_name)], 386 column_type, 387 ) 388 389 return normalized_mapping 390 391 def _normalize_table( 392 self, 393 table: exp.Table | str, 394 dialect: DialectType = None, 395 normalize: t.Optional[bool] = None, 396 ) -> exp.Table: 397 dialect = dialect or self.dialect 398 normalize = self.normalize if normalize is None else normalize 399 400 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 401 402 if normalize: 403 for arg in exp.TABLE_PARTS: 404 value = normalized_table.args.get(arg) 405 if isinstance(value, exp.Identifier): 406 normalized_table.set( 407 arg, 408 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 409 ) 410 411 return normalized_table 412 413 def _normalize_name( 414 self, 415 name: str | exp.Identifier, 416 dialect: DialectType = None, 417 is_table: bool = False, 418 normalize: t.Optional[bool] = None, 419 ) -> str: 420 return normalize_name( 421 name, 422 dialect=dialect or self.dialect, 423 is_table=is_table, 424 normalize=self.normalize if normalize is None else normalize, 425 ).name 426 427 def depth(self) -> int: 428 if not self.empty and not self._depth: 429 # The columns themselves are a mapping, but we don't want to include those 430 self._depth = super().depth() - 1 431 return self._depth 432 433 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 434 """ 435 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 436 437 Args: 438 schema_type: the type we want to convert. 439 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 440 441 Returns: 442 The resulting expression type. 443 """ 444 if schema_type not in self._type_mapping_cache: 445 dialect = dialect or self.dialect 446 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 447 448 try: 449 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 450 self._type_mapping_cache[schema_type] = expression 451 except AttributeError: 452 in_dialect = f" in dialect {dialect}" if dialect else "" 453 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 454 455 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
219 def __init__( 220 self, 221 schema: t.Optional[t.Dict] = None, 222 visible: t.Optional[t.Dict] = None, 223 dialect: DialectType = None, 224 normalize: bool = True, 225 ) -> None: 226 self.dialect = dialect 227 self.visible = {} if visible is None else visible 228 self.normalize = normalize 229 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 230 self._depth = 0 231 schema = {} if schema is None else schema 232 233 super().__init__(self._normalize(schema) if self.normalize else schema)
255 def add_table( 256 self, 257 table: exp.Table | str, 258 column_mapping: t.Optional[ColumnMapping] = None, 259 dialect: DialectType = None, 260 normalize: t.Optional[bool] = None, 261 match_depth: bool = True, 262 ) -> None: 263 """ 264 Register or update a table. Updates are only performed if a new column mapping is provided. 265 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 266 267 Args: 268 table: the `Table` expression instance or string representing the table. 269 column_mapping: a column mapping that describes the structure of the table. 270 dialect: the SQL dialect that will be used to parse `table` if it's a string. 271 normalize: whether to normalize identifiers according to the dialect of interest. 272 match_depth: whether to enforce that the table must match the schema's depth or not. 273 """ 274 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 275 276 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 277 raise SchemaError( 278 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 279 f"schema's nesting level: {self.depth()}." 280 ) 281 282 normalized_column_mapping = { 283 self._normalize_name(key, dialect=dialect, normalize=normalize): value 284 for key, value in ensure_column_mapping(column_mapping).items() 285 } 286 287 schema = self.find(normalized_table, raise_on_missing=False) 288 if schema and not normalized_column_mapping: 289 return 290 291 parts = self.table_parts(normalized_table) 292 293 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 294 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
296 def column_names( 297 self, 298 table: exp.Table | str, 299 only_visible: bool = False, 300 dialect: DialectType = None, 301 normalize: t.Optional[bool] = None, 302 ) -> t.List[str]: 303 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 304 305 schema = self.find(normalized_table) 306 if schema is None: 307 return [] 308 309 if not only_visible or not self.visible: 310 return list(schema) 311 312 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 313 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
315 def get_column_type( 316 self, 317 table: exp.Table | str, 318 column: exp.Column | str, 319 dialect: DialectType = None, 320 normalize: t.Optional[bool] = None, 321 ) -> exp.DataType: 322 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 323 324 normalized_column_name = self._normalize_name( 325 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 326 ) 327 328 table_schema = self.find(normalized_table, raise_on_missing=False) 329 if table_schema: 330 column_type = table_schema.get(normalized_column_name) 331 332 if isinstance(column_type, exp.DataType): 333 return column_type 334 elif isinstance(column_type, str): 335 return self._to_data_type(column_type, dialect=dialect) 336 337 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
339 def has_column( 340 self, 341 table: exp.Table | str, 342 column: exp.Column | str, 343 dialect: DialectType = None, 344 normalize: t.Optional[bool] = None, 345 ) -> bool: 346 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 347 348 normalized_column_name = self._normalize_name( 349 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 350 ) 351 352 table_schema = self.find(normalized_table, raise_on_missing=False) 353 return normalized_column_name in table_schema if table_schema else False
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
458def normalize_name( 459 identifier: str | exp.Identifier, 460 dialect: DialectType = None, 461 is_table: bool = False, 462 normalize: t.Optional[bool] = True, 463) -> exp.Identifier: 464 if isinstance(identifier, str): 465 identifier = exp.parse_identifier(identifier, dialect=dialect) 466 467 if not normalize: 468 return identifier 469 470 # this is used for normalize_identifier, bigquery has special rules pertaining tables 471 identifier.meta["is_table"] = is_table 472 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
482def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 483 if mapping is None: 484 return {} 485 elif isinstance(mapping, dict): 486 return mapping 487 elif isinstance(mapping, str): 488 col_name_type_strs = [x.strip() for x in mapping.split(",")] 489 return { 490 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 491 for name_type_str in col_name_type_strs 492 } 493 # Check if mapping looks like a DataFrame StructType 494 elif hasattr(mapping, "simpleString"): 495 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 496 elif isinstance(mapping, list): 497 return {x.strip(): None for x in mapping} 498 499 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
502def flatten_schema( 503 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 504) -> t.List[t.List[str]]: 505 tables = [] 506 keys = keys or [] 507 depth = dict_depth(schema) - 1 if depth is None else depth 508 509 for k, v in schema.items(): 510 if depth == 1 or not isinstance(v, dict): 511 tables.append(keys + [k]) 512 elif depth >= 2: 513 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 514 515 return tables
518def nested_get( 519 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 520) -> t.Optional[t.Any]: 521 """ 522 Get a value for a nested dictionary. 523 524 Args: 525 d: the dictionary to search. 526 *path: tuples of (name, key), where: 527 `key` is the key in the dictionary to get. 528 `name` is a string to use in the error if `key` isn't found. 529 530 Returns: 531 The value or None if it doesn't exist. 532 """ 533 for name, key in path: 534 d = d.get(key) # type: ignore 535 if d is None: 536 if raise_on_missing: 537 name = "table" if name == "this" else name 538 raise ValueError(f"Unknown {name}: {key}") 539 return None 540 541 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
544def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 545 """ 546 In-place set a value for a nested dictionary 547 548 Example: 549 >>> nested_set({}, ["top_key", "second_key"], "value") 550 {'top_key': {'second_key': 'value'}} 551 552 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 553 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 554 555 Args: 556 d: dictionary to update. 557 keys: the keys that makeup the path to `value`. 558 value: the value to set in the dictionary for the given key path. 559 560 Returns: 561 The (possibly) updated dictionary. 562 """ 563 if not keys: 564 return d 565 566 if len(keys) == 1: 567 d[keys[0]] = value 568 return d 569 570 subd = d 571 for key in keys[:-1]: 572 if key not in subd: 573 subd = subd.setdefault(key, {}) 574 else: 575 subd = subd[key] 576 577 subd[keys[-1]] = value 578 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.