sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18 19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.List[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The list of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether or not `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @abc.abstractmethod 110 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 111 """ 112 Returns the schema of a given table. 113 114 Args: 115 table: the target table. 116 raise_on_missing: whether or not to raise in case the schema is not found. 117 118 Returns: 119 The schema of the target table. 120 """ 121 122 @property 123 @abc.abstractmethod 124 def supported_table_args(self) -> t.Tuple[str, ...]: 125 """ 126 Table arguments this schema support, e.g. `("this", "db", "catalog")` 127 """ 128 129 @property 130 def empty(self) -> bool: 131 """Returns whether or not the schema is empty.""" 132 return True 133 134 135class AbstractMappingSchema: 136 def __init__( 137 self, 138 mapping: t.Optional[t.Dict] = None, 139 ) -> None: 140 self.mapping = mapping or {} 141 self.mapping_trie = new_trie( 142 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 143 ) 144 self._supported_table_args: t.Tuple[str, ...] = tuple() 145 146 @property 147 def empty(self) -> bool: 148 return not self.mapping 149 150 def depth(self) -> int: 151 return dict_depth(self.mapping) 152 153 @property 154 def supported_table_args(self) -> t.Tuple[str, ...]: 155 if not self._supported_table_args and self.mapping: 156 depth = self.depth() 157 158 if not depth: # None 159 self._supported_table_args = tuple() 160 elif 1 <= depth <= 3: 161 self._supported_table_args = exp.TABLE_PARTS[:depth] 162 else: 163 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 164 165 return self._supported_table_args 166 167 def table_parts(self, table: exp.Table) -> t.List[str]: 168 if isinstance(table.this, exp.ReadCSV): 169 return [table.this.name] 170 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 171 172 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 173 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 174 value, trie = in_trie(self.mapping_trie, parts) 175 176 if value == TrieResult.FAILED: 177 return None 178 179 if value == TrieResult.PREFIX: 180 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 181 182 if len(possibilities) == 1: 183 parts.extend(possibilities[0]) 184 else: 185 message = ", ".join(".".join(parts) for parts in possibilities) 186 if raise_on_missing: 187 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 188 return None 189 190 return self.nested_get(parts, raise_on_missing=raise_on_missing) 191 192 def nested_get( 193 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 194 ) -> t.Optional[t.Any]: 195 return nested_get( 196 d or self.mapping, 197 *zip(self.supported_table_args, reversed(parts)), 198 raise_on_missing=raise_on_missing, 199 ) 200 201 202class MappingSchema(AbstractMappingSchema, Schema): 203 """ 204 Schema based on a nested mapping. 205 206 Args: 207 schema: Mapping in one of the following forms: 208 1. {table: {col: type}} 209 2. {db: {table: {col: type}}} 210 3. {catalog: {db: {table: {col: type}}}} 211 4. None - Tables will be added later 212 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 213 are assumed to be visible. The nesting should mirror that of the schema: 214 1. {table: set(*cols)}} 215 2. {db: {table: set(*cols)}}} 216 3. {catalog: {db: {table: set(*cols)}}}} 217 dialect: The dialect to be used for custom type mappings & parsing string arguments. 218 normalize: Whether to normalize identifier names according to the given dialect or not. 219 """ 220 221 def __init__( 222 self, 223 schema: t.Optional[t.Dict] = None, 224 visible: t.Optional[t.Dict] = None, 225 dialect: DialectType = None, 226 normalize: bool = True, 227 ) -> None: 228 self.dialect = dialect 229 self.visible = {} if visible is None else visible 230 self.normalize = normalize 231 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 232 self._depth = 0 233 schema = {} if schema is None else schema 234 235 super().__init__(self._normalize(schema) if self.normalize else schema) 236 237 @classmethod 238 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 239 return MappingSchema( 240 schema=mapping_schema.mapping, 241 visible=mapping_schema.visible, 242 dialect=mapping_schema.dialect, 243 normalize=mapping_schema.normalize, 244 ) 245 246 def copy(self, **kwargs) -> MappingSchema: 247 return MappingSchema( 248 **{ # type: ignore 249 "schema": self.mapping.copy(), 250 "visible": self.visible.copy(), 251 "dialect": self.dialect, 252 "normalize": self.normalize, 253 **kwargs, 254 } 255 ) 256 257 def add_table( 258 self, 259 table: exp.Table | str, 260 column_mapping: t.Optional[ColumnMapping] = None, 261 dialect: DialectType = None, 262 normalize: t.Optional[bool] = None, 263 match_depth: bool = True, 264 ) -> None: 265 """ 266 Register or update a table. Updates are only performed if a new column mapping is provided. 267 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 268 269 Args: 270 table: the `Table` expression instance or string representing the table. 271 column_mapping: a column mapping that describes the structure of the table. 272 dialect: the SQL dialect that will be used to parse `table` if it's a string. 273 normalize: whether to normalize identifiers according to the dialect of interest. 274 match_depth: whether to enforce that the table must match the schema's depth or not. 275 """ 276 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 277 278 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 279 raise SchemaError( 280 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 281 f"schema's nesting level: {self.depth()}." 282 ) 283 284 normalized_column_mapping = { 285 self._normalize_name(key, dialect=dialect, normalize=normalize): value 286 for key, value in ensure_column_mapping(column_mapping).items() 287 } 288 289 schema = self.find(normalized_table, raise_on_missing=False) 290 if schema and not normalized_column_mapping: 291 return 292 293 parts = self.table_parts(normalized_table) 294 295 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 296 new_trie([parts], self.mapping_trie) 297 298 def column_names( 299 self, 300 table: exp.Table | str, 301 only_visible: bool = False, 302 dialect: DialectType = None, 303 normalize: t.Optional[bool] = None, 304 ) -> t.List[str]: 305 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 306 307 schema = self.find(normalized_table) 308 if schema is None: 309 return [] 310 311 if not only_visible or not self.visible: 312 return list(schema) 313 314 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 315 return [col for col in schema if col in visible] 316 317 def get_column_type( 318 self, 319 table: exp.Table | str, 320 column: exp.Column | str, 321 dialect: DialectType = None, 322 normalize: t.Optional[bool] = None, 323 ) -> exp.DataType: 324 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 325 326 normalized_column_name = self._normalize_name( 327 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 328 ) 329 330 table_schema = self.find(normalized_table, raise_on_missing=False) 331 if table_schema: 332 column_type = table_schema.get(normalized_column_name) 333 334 if isinstance(column_type, exp.DataType): 335 return column_type 336 elif isinstance(column_type, str): 337 return self._to_data_type(column_type, dialect=dialect) 338 339 return exp.DataType.build("unknown") 340 341 def has_column( 342 self, 343 table: exp.Table | str, 344 column: exp.Column | str, 345 dialect: DialectType = None, 346 normalize: t.Optional[bool] = None, 347 ) -> bool: 348 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 349 350 normalized_column_name = self._normalize_name( 351 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 352 ) 353 354 table_schema = self.find(normalized_table, raise_on_missing=False) 355 return normalized_column_name in table_schema if table_schema else False 356 357 def _normalize(self, schema: t.Dict) -> t.Dict: 358 """ 359 Normalizes all identifiers in the schema. 360 361 Args: 362 schema: the schema to normalize. 363 364 Returns: 365 The normalized schema mapping. 366 """ 367 normalized_mapping: t.Dict = {} 368 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 369 370 for keys in flattened_schema: 371 columns = nested_get(schema, *zip(keys, keys)) 372 373 if not isinstance(columns, dict): 374 raise SchemaError( 375 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 376 ) 377 378 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 379 for column_name, column_type in columns.items(): 380 nested_set( 381 normalized_mapping, 382 normalized_keys + [self._normalize_name(column_name)], 383 column_type, 384 ) 385 386 return normalized_mapping 387 388 def _normalize_table( 389 self, 390 table: exp.Table | str, 391 dialect: DialectType = None, 392 normalize: t.Optional[bool] = None, 393 ) -> exp.Table: 394 dialect = dialect or self.dialect 395 normalize = self.normalize if normalize is None else normalize 396 397 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 398 399 if normalize: 400 for arg in exp.TABLE_PARTS: 401 value = normalized_table.args.get(arg) 402 if isinstance(value, exp.Identifier): 403 normalized_table.set( 404 arg, 405 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 406 ) 407 408 return normalized_table 409 410 def _normalize_name( 411 self, 412 name: str | exp.Identifier, 413 dialect: DialectType = None, 414 is_table: bool = False, 415 normalize: t.Optional[bool] = None, 416 ) -> str: 417 return normalize_name( 418 name, 419 dialect=dialect or self.dialect, 420 is_table=is_table, 421 normalize=self.normalize if normalize is None else normalize, 422 ).name 423 424 def depth(self) -> int: 425 if not self.empty and not self._depth: 426 # The columns themselves are a mapping, but we don't want to include those 427 self._depth = super().depth() - 1 428 return self._depth 429 430 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 431 """ 432 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 433 434 Args: 435 schema_type: the type we want to convert. 436 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 437 438 Returns: 439 The resulting expression type. 440 """ 441 if schema_type not in self._type_mapping_cache: 442 dialect = dialect or self.dialect 443 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 444 445 try: 446 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 447 self._type_mapping_cache[schema_type] = expression 448 except AttributeError: 449 in_dialect = f" in dialect {dialect}" if dialect else "" 450 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 451 452 return self._type_mapping_cache[schema_type] 453 454 455def normalize_name( 456 identifier: str | exp.Identifier, 457 dialect: DialectType = None, 458 is_table: bool = False, 459 normalize: t.Optional[bool] = True, 460) -> exp.Identifier: 461 if isinstance(identifier, str): 462 identifier = exp.parse_identifier(identifier, dialect=dialect) 463 464 if not normalize: 465 return identifier 466 467 # this is used for normalize_identifier, bigquery has special rules pertaining tables 468 identifier.meta["is_table"] = is_table 469 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 470 471 472def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 473 if isinstance(schema, Schema): 474 return schema 475 476 return MappingSchema(schema, **kwargs) 477 478 479def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 480 if mapping is None: 481 return {} 482 elif isinstance(mapping, dict): 483 return mapping 484 elif isinstance(mapping, str): 485 col_name_type_strs = [x.strip() for x in mapping.split(",")] 486 return { 487 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 488 for name_type_str in col_name_type_strs 489 } 490 # Check if mapping looks like a DataFrame StructType 491 elif hasattr(mapping, "simpleString"): 492 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 493 elif isinstance(mapping, list): 494 return {x.strip(): None for x in mapping} 495 496 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 497 498 499def flatten_schema( 500 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 501) -> t.List[t.List[str]]: 502 tables = [] 503 keys = keys or [] 504 505 for k, v in schema.items(): 506 if depth >= 2: 507 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 508 elif depth == 1: 509 tables.append(keys + [k]) 510 511 return tables 512 513 514def nested_get( 515 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 516) -> t.Optional[t.Any]: 517 """ 518 Get a value for a nested dictionary. 519 520 Args: 521 d: the dictionary to search. 522 *path: tuples of (name, key), where: 523 `key` is the key in the dictionary to get. 524 `name` is a string to use in the error if `key` isn't found. 525 526 Returns: 527 The value or None if it doesn't exist. 528 """ 529 for name, key in path: 530 d = d.get(key) # type: ignore 531 if d is None: 532 if raise_on_missing: 533 name = "table" if name == "this" else name 534 raise ValueError(f"Unknown {name}: {key}") 535 return None 536 537 return d 538 539 540def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 541 """ 542 In-place set a value for a nested dictionary 543 544 Example: 545 >>> nested_set({}, ["top_key", "second_key"], "value") 546 {'top_key': {'second_key': 'value'}} 547 548 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 549 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 550 551 Args: 552 d: dictionary to update. 553 keys: the keys that makeup the path to `value`. 554 value: the value to set in the dictionary for the given key path. 555 556 Returns: 557 The (possibly) updated dictionary. 558 """ 559 if not keys: 560 return d 561 562 if len(keys) == 1: 563 d[keys[0]] = value 564 return d 565 566 subd = d 567 for key in keys[:-1]: 568 if key not in subd: 569 subd = subd.setdefault(key, {}) 570 else: 571 subd = subd[key] 572 573 subd[keys[-1]] = value 574 return d
20class Schema(abc.ABC): 21 """Abstract base class for database schemas""" 22 23 dialect: DialectType 24 25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """ 45 46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """ 66 67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """ 87 88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize) 109 110 @abc.abstractmethod 111 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 112 """ 113 Returns the schema of a given table. 114 115 Args: 116 table: the target table. 117 raise_on_missing: whether or not to raise in case the schema is not found. 118 119 Returns: 120 The schema of the target table. 121 """ 122 123 @property 124 @abc.abstractmethod 125 def supported_table_args(self) -> t.Tuple[str, ...]: 126 """ 127 Table arguments this schema support, e.g. `("this", "db", "catalog")` 128 """ 129 130 @property 131 def empty(self) -> bool: 132 """Returns whether or not the schema is empty.""" 133 return True
Abstract base class for database schemas
25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
110 @abc.abstractmethod 111 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 112 """ 113 Returns the schema of a given table. 114 115 Args: 116 table: the target table. 117 raise_on_missing: whether or not to raise in case the schema is not found. 118 119 Returns: 120 The schema of the target table. 121 """
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether or not to raise in case the schema is not found.
Returns:
The schema of the target table.
136class AbstractMappingSchema: 137 def __init__( 138 self, 139 mapping: t.Optional[t.Dict] = None, 140 ) -> None: 141 self.mapping = mapping or {} 142 self.mapping_trie = new_trie( 143 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 144 ) 145 self._supported_table_args: t.Tuple[str, ...] = tuple() 146 147 @property 148 def empty(self) -> bool: 149 return not self.mapping 150 151 def depth(self) -> int: 152 return dict_depth(self.mapping) 153 154 @property 155 def supported_table_args(self) -> t.Tuple[str, ...]: 156 if not self._supported_table_args and self.mapping: 157 depth = self.depth() 158 159 if not depth: # None 160 self._supported_table_args = tuple() 161 elif 1 <= depth <= 3: 162 self._supported_table_args = exp.TABLE_PARTS[:depth] 163 else: 164 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 165 166 return self._supported_table_args 167 168 def table_parts(self, table: exp.Table) -> t.List[str]: 169 if isinstance(table.this, exp.ReadCSV): 170 return [table.this.name] 171 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 172 173 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 174 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 175 value, trie = in_trie(self.mapping_trie, parts) 176 177 if value == TrieResult.FAILED: 178 return None 179 180 if value == TrieResult.PREFIX: 181 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 182 183 if len(possibilities) == 1: 184 parts.extend(possibilities[0]) 185 else: 186 message = ", ".join(".".join(parts) for parts in possibilities) 187 if raise_on_missing: 188 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 189 return None 190 191 return self.nested_get(parts, raise_on_missing=raise_on_missing) 192 193 def nested_get( 194 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 195 ) -> t.Optional[t.Any]: 196 return nested_get( 197 d or self.mapping, 198 *zip(self.supported_table_args, reversed(parts)), 199 raise_on_missing=raise_on_missing, 200 )
154 @property 155 def supported_table_args(self) -> t.Tuple[str, ...]: 156 if not self._supported_table_args and self.mapping: 157 depth = self.depth() 158 159 if not depth: # None 160 self._supported_table_args = tuple() 161 elif 1 <= depth <= 3: 162 self._supported_table_args = exp.TABLE_PARTS[:depth] 163 else: 164 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 165 166 return self._supported_table_args
173 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 174 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 175 value, trie = in_trie(self.mapping_trie, parts) 176 177 if value == TrieResult.FAILED: 178 return None 179 180 if value == TrieResult.PREFIX: 181 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 182 183 if len(possibilities) == 1: 184 parts.extend(possibilities[0]) 185 else: 186 message = ", ".join(".".join(parts) for parts in possibilities) 187 if raise_on_missing: 188 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 189 return None 190 191 return self.nested_get(parts, raise_on_missing=raise_on_missing)
203class MappingSchema(AbstractMappingSchema, Schema): 204 """ 205 Schema based on a nested mapping. 206 207 Args: 208 schema: Mapping in one of the following forms: 209 1. {table: {col: type}} 210 2. {db: {table: {col: type}}} 211 3. {catalog: {db: {table: {col: type}}}} 212 4. None - Tables will be added later 213 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 214 are assumed to be visible. The nesting should mirror that of the schema: 215 1. {table: set(*cols)}} 216 2. {db: {table: set(*cols)}}} 217 3. {catalog: {db: {table: set(*cols)}}}} 218 dialect: The dialect to be used for custom type mappings & parsing string arguments. 219 normalize: Whether to normalize identifier names according to the given dialect or not. 220 """ 221 222 def __init__( 223 self, 224 schema: t.Optional[t.Dict] = None, 225 visible: t.Optional[t.Dict] = None, 226 dialect: DialectType = None, 227 normalize: bool = True, 228 ) -> None: 229 self.dialect = dialect 230 self.visible = {} if visible is None else visible 231 self.normalize = normalize 232 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 233 self._depth = 0 234 schema = {} if schema is None else schema 235 236 super().__init__(self._normalize(schema) if self.normalize else schema) 237 238 @classmethod 239 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 240 return MappingSchema( 241 schema=mapping_schema.mapping, 242 visible=mapping_schema.visible, 243 dialect=mapping_schema.dialect, 244 normalize=mapping_schema.normalize, 245 ) 246 247 def copy(self, **kwargs) -> MappingSchema: 248 return MappingSchema( 249 **{ # type: ignore 250 "schema": self.mapping.copy(), 251 "visible": self.visible.copy(), 252 "dialect": self.dialect, 253 "normalize": self.normalize, 254 **kwargs, 255 } 256 ) 257 258 def add_table( 259 self, 260 table: exp.Table | str, 261 column_mapping: t.Optional[ColumnMapping] = None, 262 dialect: DialectType = None, 263 normalize: t.Optional[bool] = None, 264 match_depth: bool = True, 265 ) -> None: 266 """ 267 Register or update a table. Updates are only performed if a new column mapping is provided. 268 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 269 270 Args: 271 table: the `Table` expression instance or string representing the table. 272 column_mapping: a column mapping that describes the structure of the table. 273 dialect: the SQL dialect that will be used to parse `table` if it's a string. 274 normalize: whether to normalize identifiers according to the dialect of interest. 275 match_depth: whether to enforce that the table must match the schema's depth or not. 276 """ 277 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 278 279 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 280 raise SchemaError( 281 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 282 f"schema's nesting level: {self.depth()}." 283 ) 284 285 normalized_column_mapping = { 286 self._normalize_name(key, dialect=dialect, normalize=normalize): value 287 for key, value in ensure_column_mapping(column_mapping).items() 288 } 289 290 schema = self.find(normalized_table, raise_on_missing=False) 291 if schema and not normalized_column_mapping: 292 return 293 294 parts = self.table_parts(normalized_table) 295 296 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 297 new_trie([parts], self.mapping_trie) 298 299 def column_names( 300 self, 301 table: exp.Table | str, 302 only_visible: bool = False, 303 dialect: DialectType = None, 304 normalize: t.Optional[bool] = None, 305 ) -> t.List[str]: 306 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 307 308 schema = self.find(normalized_table) 309 if schema is None: 310 return [] 311 312 if not only_visible or not self.visible: 313 return list(schema) 314 315 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 316 return [col for col in schema if col in visible] 317 318 def get_column_type( 319 self, 320 table: exp.Table | str, 321 column: exp.Column | str, 322 dialect: DialectType = None, 323 normalize: t.Optional[bool] = None, 324 ) -> exp.DataType: 325 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 326 327 normalized_column_name = self._normalize_name( 328 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 329 ) 330 331 table_schema = self.find(normalized_table, raise_on_missing=False) 332 if table_schema: 333 column_type = table_schema.get(normalized_column_name) 334 335 if isinstance(column_type, exp.DataType): 336 return column_type 337 elif isinstance(column_type, str): 338 return self._to_data_type(column_type, dialect=dialect) 339 340 return exp.DataType.build("unknown") 341 342 def has_column( 343 self, 344 table: exp.Table | str, 345 column: exp.Column | str, 346 dialect: DialectType = None, 347 normalize: t.Optional[bool] = None, 348 ) -> bool: 349 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 350 351 normalized_column_name = self._normalize_name( 352 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 353 ) 354 355 table_schema = self.find(normalized_table, raise_on_missing=False) 356 return normalized_column_name in table_schema if table_schema else False 357 358 def _normalize(self, schema: t.Dict) -> t.Dict: 359 """ 360 Normalizes all identifiers in the schema. 361 362 Args: 363 schema: the schema to normalize. 364 365 Returns: 366 The normalized schema mapping. 367 """ 368 normalized_mapping: t.Dict = {} 369 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 370 371 for keys in flattened_schema: 372 columns = nested_get(schema, *zip(keys, keys)) 373 374 if not isinstance(columns, dict): 375 raise SchemaError( 376 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 377 ) 378 379 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 380 for column_name, column_type in columns.items(): 381 nested_set( 382 normalized_mapping, 383 normalized_keys + [self._normalize_name(column_name)], 384 column_type, 385 ) 386 387 return normalized_mapping 388 389 def _normalize_table( 390 self, 391 table: exp.Table | str, 392 dialect: DialectType = None, 393 normalize: t.Optional[bool] = None, 394 ) -> exp.Table: 395 dialect = dialect or self.dialect 396 normalize = self.normalize if normalize is None else normalize 397 398 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 399 400 if normalize: 401 for arg in exp.TABLE_PARTS: 402 value = normalized_table.args.get(arg) 403 if isinstance(value, exp.Identifier): 404 normalized_table.set( 405 arg, 406 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 407 ) 408 409 return normalized_table 410 411 def _normalize_name( 412 self, 413 name: str | exp.Identifier, 414 dialect: DialectType = None, 415 is_table: bool = False, 416 normalize: t.Optional[bool] = None, 417 ) -> str: 418 return normalize_name( 419 name, 420 dialect=dialect or self.dialect, 421 is_table=is_table, 422 normalize=self.normalize if normalize is None else normalize, 423 ).name 424 425 def depth(self) -> int: 426 if not self.empty and not self._depth: 427 # The columns themselves are a mapping, but we don't want to include those 428 self._depth = super().depth() - 1 429 return self._depth 430 431 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 432 """ 433 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 434 435 Args: 436 schema_type: the type we want to convert. 437 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 438 439 Returns: 440 The resulting expression type. 441 """ 442 if schema_type not in self._type_mapping_cache: 443 dialect = dialect or self.dialect 444 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 445 446 try: 447 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 448 self._type_mapping_cache[schema_type] = expression 449 except AttributeError: 450 in_dialect = f" in dialect {dialect}" if dialect else "" 451 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 452 453 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
222 def __init__( 223 self, 224 schema: t.Optional[t.Dict] = None, 225 visible: t.Optional[t.Dict] = None, 226 dialect: DialectType = None, 227 normalize: bool = True, 228 ) -> None: 229 self.dialect = dialect 230 self.visible = {} if visible is None else visible 231 self.normalize = normalize 232 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 233 self._depth = 0 234 schema = {} if schema is None else schema 235 236 super().__init__(self._normalize(schema) if self.normalize else schema)
258 def add_table( 259 self, 260 table: exp.Table | str, 261 column_mapping: t.Optional[ColumnMapping] = None, 262 dialect: DialectType = None, 263 normalize: t.Optional[bool] = None, 264 match_depth: bool = True, 265 ) -> None: 266 """ 267 Register or update a table. Updates are only performed if a new column mapping is provided. 268 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 269 270 Args: 271 table: the `Table` expression instance or string representing the table. 272 column_mapping: a column mapping that describes the structure of the table. 273 dialect: the SQL dialect that will be used to parse `table` if it's a string. 274 normalize: whether to normalize identifiers according to the dialect of interest. 275 match_depth: whether to enforce that the table must match the schema's depth or not. 276 """ 277 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 278 279 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 280 raise SchemaError( 281 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 282 f"schema's nesting level: {self.depth()}." 283 ) 284 285 normalized_column_mapping = { 286 self._normalize_name(key, dialect=dialect, normalize=normalize): value 287 for key, value in ensure_column_mapping(column_mapping).items() 288 } 289 290 schema = self.find(normalized_table, raise_on_missing=False) 291 if schema and not normalized_column_mapping: 292 return 293 294 parts = self.table_parts(normalized_table) 295 296 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 297 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
299 def column_names( 300 self, 301 table: exp.Table | str, 302 only_visible: bool = False, 303 dialect: DialectType = None, 304 normalize: t.Optional[bool] = None, 305 ) -> t.List[str]: 306 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 307 308 schema = self.find(normalized_table) 309 if schema is None: 310 return [] 311 312 if not only_visible or not self.visible: 313 return list(schema) 314 315 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 316 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
318 def get_column_type( 319 self, 320 table: exp.Table | str, 321 column: exp.Column | str, 322 dialect: DialectType = None, 323 normalize: t.Optional[bool] = None, 324 ) -> exp.DataType: 325 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 326 327 normalized_column_name = self._normalize_name( 328 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 329 ) 330 331 table_schema = self.find(normalized_table, raise_on_missing=False) 332 if table_schema: 333 column_type = table_schema.get(normalized_column_name) 334 335 if isinstance(column_type, exp.DataType): 336 return column_type 337 elif isinstance(column_type, str): 338 return self._to_data_type(column_type, dialect=dialect) 339 340 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
342 def has_column( 343 self, 344 table: exp.Table | str, 345 column: exp.Column | str, 346 dialect: DialectType = None, 347 normalize: t.Optional[bool] = None, 348 ) -> bool: 349 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 350 351 normalized_column_name = self._normalize_name( 352 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 353 ) 354 355 table_schema = self.find(normalized_table, raise_on_missing=False) 356 return normalized_column_name in table_schema if table_schema else False
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
456def normalize_name( 457 identifier: str | exp.Identifier, 458 dialect: DialectType = None, 459 is_table: bool = False, 460 normalize: t.Optional[bool] = True, 461) -> exp.Identifier: 462 if isinstance(identifier, str): 463 identifier = exp.parse_identifier(identifier, dialect=dialect) 464 465 if not normalize: 466 return identifier 467 468 # this is used for normalize_identifier, bigquery has special rules pertaining tables 469 identifier.meta["is_table"] = is_table 470 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
480def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 481 if mapping is None: 482 return {} 483 elif isinstance(mapping, dict): 484 return mapping 485 elif isinstance(mapping, str): 486 col_name_type_strs = [x.strip() for x in mapping.split(",")] 487 return { 488 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 489 for name_type_str in col_name_type_strs 490 } 491 # Check if mapping looks like a DataFrame StructType 492 elif hasattr(mapping, "simpleString"): 493 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 494 elif isinstance(mapping, list): 495 return {x.strip(): None for x in mapping} 496 497 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
500def flatten_schema( 501 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 502) -> t.List[t.List[str]]: 503 tables = [] 504 keys = keys or [] 505 506 for k, v in schema.items(): 507 if depth >= 2: 508 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 509 elif depth == 1: 510 tables.append(keys + [k]) 511 512 return tables
515def nested_get( 516 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 517) -> t.Optional[t.Any]: 518 """ 519 Get a value for a nested dictionary. 520 521 Args: 522 d: the dictionary to search. 523 *path: tuples of (name, key), where: 524 `key` is the key in the dictionary to get. 525 `name` is a string to use in the error if `key` isn't found. 526 527 Returns: 528 The value or None if it doesn't exist. 529 """ 530 for name, key in path: 531 d = d.get(key) # type: ignore 532 if d is None: 533 if raise_on_missing: 534 name = "table" if name == "this" else name 535 raise ValueError(f"Unknown {name}: {key}") 536 return None 537 538 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
541def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 542 """ 543 In-place set a value for a nested dictionary 544 545 Example: 546 >>> nested_set({}, ["top_key", "second_key"], "value") 547 {'top_key': {'second_key': 'value'}} 548 549 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 550 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 551 552 Args: 553 d: dictionary to update. 554 keys: the keys that makeup the path to `value`. 555 value: the value to set in the dictionary for the given key path. 556 557 Returns: 558 The (possibly) updated dictionary. 559 """ 560 if not keys: 561 return d 562 563 if len(keys) == 1: 564 d[keys[0]] = value 565 return d 566 567 subd = d 568 for key in keys[:-1]: 569 if key not in subd: 570 subd = subd.setdefault(key, {}) 571 else: 572 subd = subd[key] 573 574 subd[keys[-1]] = value 575 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.