sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18 19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.List[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The list of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether or not `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether or not the schema is empty.""" 119 return True 120 121 122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 if isinstance(table.this, exp.ReadCSV): 156 return [table.this.name] 157 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 158 159 def find( 160 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 161 ) -> t.Optional[t.Any]: 162 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 163 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 164 165 if value == TrieResult.FAILED: 166 return None 167 168 if value == TrieResult.PREFIX: 169 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 170 171 if len(possibilities) == 1: 172 parts.extend(possibilities[0]) 173 else: 174 message = ", ".join(".".join(parts) for parts in possibilities) 175 if raise_on_missing: 176 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 177 return None 178 179 return self.nested_get(parts, raise_on_missing=raise_on_missing) 180 181 def nested_get( 182 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 183 ) -> t.Optional[t.Any]: 184 return nested_get( 185 d or self.mapping, 186 *zip(self.supported_table_args, reversed(parts)), 187 raise_on_missing=raise_on_missing, 188 ) 189 190 191class MappingSchema(AbstractMappingSchema, Schema): 192 """ 193 Schema based on a nested mapping. 194 195 Args: 196 schema: Mapping in one of the following forms: 197 1. {table: {col: type}} 198 2. {db: {table: {col: type}}} 199 3. {catalog: {db: {table: {col: type}}}} 200 4. None - Tables will be added later 201 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 202 are assumed to be visible. The nesting should mirror that of the schema: 203 1. {table: set(*cols)}} 204 2. {db: {table: set(*cols)}}} 205 3. {catalog: {db: {table: set(*cols)}}}} 206 dialect: The dialect to be used for custom type mappings & parsing string arguments. 207 normalize: Whether to normalize identifier names according to the given dialect or not. 208 """ 209 210 def __init__( 211 self, 212 schema: t.Optional[t.Dict] = None, 213 visible: t.Optional[t.Dict] = None, 214 dialect: DialectType = None, 215 normalize: bool = True, 216 ) -> None: 217 self.dialect = dialect 218 self.visible = {} if visible is None else visible 219 self.normalize = normalize 220 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 221 self._depth = 0 222 schema = {} if schema is None else schema 223 224 super().__init__(self._normalize(schema) if self.normalize else schema) 225 226 @classmethod 227 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 228 return MappingSchema( 229 schema=mapping_schema.mapping, 230 visible=mapping_schema.visible, 231 dialect=mapping_schema.dialect, 232 normalize=mapping_schema.normalize, 233 ) 234 235 def copy(self, **kwargs) -> MappingSchema: 236 return MappingSchema( 237 **{ # type: ignore 238 "schema": self.mapping.copy(), 239 "visible": self.visible.copy(), 240 "dialect": self.dialect, 241 "normalize": self.normalize, 242 **kwargs, 243 } 244 ) 245 246 def add_table( 247 self, 248 table: exp.Table | str, 249 column_mapping: t.Optional[ColumnMapping] = None, 250 dialect: DialectType = None, 251 normalize: t.Optional[bool] = None, 252 match_depth: bool = True, 253 ) -> None: 254 """ 255 Register or update a table. Updates are only performed if a new column mapping is provided. 256 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 257 258 Args: 259 table: the `Table` expression instance or string representing the table. 260 column_mapping: a column mapping that describes the structure of the table. 261 dialect: the SQL dialect that will be used to parse `table` if it's a string. 262 normalize: whether to normalize identifiers according to the dialect of interest. 263 match_depth: whether to enforce that the table must match the schema's depth or not. 264 """ 265 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 266 267 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 268 raise SchemaError( 269 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 270 f"schema's nesting level: {self.depth()}." 271 ) 272 273 normalized_column_mapping = { 274 self._normalize_name(key, dialect=dialect, normalize=normalize): value 275 for key, value in ensure_column_mapping(column_mapping).items() 276 } 277 278 schema = self.find(normalized_table, raise_on_missing=False) 279 if schema and not normalized_column_mapping: 280 return 281 282 parts = self.table_parts(normalized_table) 283 284 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 285 new_trie([parts], self.mapping_trie) 286 287 def column_names( 288 self, 289 table: exp.Table | str, 290 only_visible: bool = False, 291 dialect: DialectType = None, 292 normalize: t.Optional[bool] = None, 293 ) -> t.List[str]: 294 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 295 296 schema = self.find(normalized_table) 297 if schema is None: 298 return [] 299 300 if not only_visible or not self.visible: 301 return list(schema) 302 303 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 304 return [col for col in schema if col in visible] 305 306 def get_column_type( 307 self, 308 table: exp.Table | str, 309 column: exp.Column | str, 310 dialect: DialectType = None, 311 normalize: t.Optional[bool] = None, 312 ) -> exp.DataType: 313 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 314 315 normalized_column_name = self._normalize_name( 316 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 317 ) 318 319 table_schema = self.find(normalized_table, raise_on_missing=False) 320 if table_schema: 321 column_type = table_schema.get(normalized_column_name) 322 323 if isinstance(column_type, exp.DataType): 324 return column_type 325 elif isinstance(column_type, str): 326 return self._to_data_type(column_type, dialect=dialect) 327 328 return exp.DataType.build("unknown") 329 330 def has_column( 331 self, 332 table: exp.Table | str, 333 column: exp.Column | str, 334 dialect: DialectType = None, 335 normalize: t.Optional[bool] = None, 336 ) -> bool: 337 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 338 339 normalized_column_name = self._normalize_name( 340 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 341 ) 342 343 table_schema = self.find(normalized_table, raise_on_missing=False) 344 return normalized_column_name in table_schema if table_schema else False 345 346 def _normalize(self, schema: t.Dict) -> t.Dict: 347 """ 348 Normalizes all identifiers in the schema. 349 350 Args: 351 schema: the schema to normalize. 352 353 Returns: 354 The normalized schema mapping. 355 """ 356 normalized_mapping: t.Dict = {} 357 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 358 359 for keys in flattened_schema: 360 columns = nested_get(schema, *zip(keys, keys)) 361 362 if not isinstance(columns, dict): 363 raise SchemaError( 364 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 365 ) 366 367 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 368 for column_name, column_type in columns.items(): 369 nested_set( 370 normalized_mapping, 371 normalized_keys + [self._normalize_name(column_name)], 372 column_type, 373 ) 374 375 return normalized_mapping 376 377 def _normalize_table( 378 self, 379 table: exp.Table | str, 380 dialect: DialectType = None, 381 normalize: t.Optional[bool] = None, 382 ) -> exp.Table: 383 dialect = dialect or self.dialect 384 normalize = self.normalize if normalize is None else normalize 385 386 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 387 388 if normalize: 389 for arg in exp.TABLE_PARTS: 390 value = normalized_table.args.get(arg) 391 if isinstance(value, exp.Identifier): 392 normalized_table.set( 393 arg, 394 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 395 ) 396 397 return normalized_table 398 399 def _normalize_name( 400 self, 401 name: str | exp.Identifier, 402 dialect: DialectType = None, 403 is_table: bool = False, 404 normalize: t.Optional[bool] = None, 405 ) -> str: 406 return normalize_name( 407 name, 408 dialect=dialect or self.dialect, 409 is_table=is_table, 410 normalize=self.normalize if normalize is None else normalize, 411 ).name 412 413 def depth(self) -> int: 414 if not self.empty and not self._depth: 415 # The columns themselves are a mapping, but we don't want to include those 416 self._depth = super().depth() - 1 417 return self._depth 418 419 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 420 """ 421 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 422 423 Args: 424 schema_type: the type we want to convert. 425 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 426 427 Returns: 428 The resulting expression type. 429 """ 430 if schema_type not in self._type_mapping_cache: 431 dialect = dialect or self.dialect 432 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 433 434 try: 435 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 436 self._type_mapping_cache[schema_type] = expression 437 except AttributeError: 438 in_dialect = f" in dialect {dialect}" if dialect else "" 439 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 440 441 return self._type_mapping_cache[schema_type] 442 443 444def normalize_name( 445 identifier: str | exp.Identifier, 446 dialect: DialectType = None, 447 is_table: bool = False, 448 normalize: t.Optional[bool] = True, 449) -> exp.Identifier: 450 if isinstance(identifier, str): 451 identifier = exp.parse_identifier(identifier, dialect=dialect) 452 453 if not normalize: 454 return identifier 455 456 # this is used for normalize_identifier, bigquery has special rules pertaining tables 457 identifier.meta["is_table"] = is_table 458 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 459 460 461def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 462 if isinstance(schema, Schema): 463 return schema 464 465 return MappingSchema(schema, **kwargs) 466 467 468def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 469 if mapping is None: 470 return {} 471 elif isinstance(mapping, dict): 472 return mapping 473 elif isinstance(mapping, str): 474 col_name_type_strs = [x.strip() for x in mapping.split(",")] 475 return { 476 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 477 for name_type_str in col_name_type_strs 478 } 479 # Check if mapping looks like a DataFrame StructType 480 elif hasattr(mapping, "simpleString"): 481 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 482 elif isinstance(mapping, list): 483 return {x.strip(): None for x in mapping} 484 485 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 486 487 488def flatten_schema( 489 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 490) -> t.List[t.List[str]]: 491 tables = [] 492 keys = keys or [] 493 494 for k, v in schema.items(): 495 if depth >= 2: 496 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 497 elif depth == 1: 498 tables.append(keys + [k]) 499 500 return tables 501 502 503def nested_get( 504 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 505) -> t.Optional[t.Any]: 506 """ 507 Get a value for a nested dictionary. 508 509 Args: 510 d: the dictionary to search. 511 *path: tuples of (name, key), where: 512 `key` is the key in the dictionary to get. 513 `name` is a string to use in the error if `key` isn't found. 514 515 Returns: 516 The value or None if it doesn't exist. 517 """ 518 for name, key in path: 519 d = d.get(key) # type: ignore 520 if d is None: 521 if raise_on_missing: 522 name = "table" if name == "this" else name 523 raise ValueError(f"Unknown {name}: {key}") 524 return None 525 526 return d 527 528 529def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 530 """ 531 In-place set a value for a nested dictionary 532 533 Example: 534 >>> nested_set({}, ["top_key", "second_key"], "value") 535 {'top_key': {'second_key': 'value'}} 536 537 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 538 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 539 540 Args: 541 d: dictionary to update. 542 keys: the keys that makeup the path to `value`. 543 value: the value to set in the dictionary for the given key path. 544 545 Returns: 546 The (possibly) updated dictionary. 547 """ 548 if not keys: 549 return d 550 551 if len(keys) == 1: 552 d[keys[0]] = value 553 return d 554 555 subd = d 556 for key in keys[:-1]: 557 if key not in subd: 558 subd = subd.setdefault(key, {}) 559 else: 560 subd = subd[key] 561 562 subd[keys[-1]] = value 563 return d
20class Schema(abc.ABC): 21 """Abstract base class for database schemas""" 22 23 dialect: DialectType 24 25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """ 45 46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """ 66 67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """ 87 88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize) 109 110 @property 111 @abc.abstractmethod 112 def supported_table_args(self) -> t.Tuple[str, ...]: 113 """ 114 Table arguments this schema support, e.g. `("this", "db", "catalog")` 115 """ 116 117 @property 118 def empty(self) -> bool: 119 """Returns whether or not the schema is empty.""" 120 return True
Abstract base class for database schemas
25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
123class AbstractMappingSchema: 124 def __init__( 125 self, 126 mapping: t.Optional[t.Dict] = None, 127 ) -> None: 128 self.mapping = mapping or {} 129 self.mapping_trie = new_trie( 130 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 131 ) 132 self._supported_table_args: t.Tuple[str, ...] = tuple() 133 134 @property 135 def empty(self) -> bool: 136 return not self.mapping 137 138 def depth(self) -> int: 139 return dict_depth(self.mapping) 140 141 @property 142 def supported_table_args(self) -> t.Tuple[str, ...]: 143 if not self._supported_table_args and self.mapping: 144 depth = self.depth() 145 146 if not depth: # None 147 self._supported_table_args = tuple() 148 elif 1 <= depth <= 3: 149 self._supported_table_args = exp.TABLE_PARTS[:depth] 150 else: 151 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 152 153 return self._supported_table_args 154 155 def table_parts(self, table: exp.Table) -> t.List[str]: 156 if isinstance(table.this, exp.ReadCSV): 157 return [table.this.name] 158 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 159 160 def find( 161 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 162 ) -> t.Optional[t.Any]: 163 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 164 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 165 166 if value == TrieResult.FAILED: 167 return None 168 169 if value == TrieResult.PREFIX: 170 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 171 172 if len(possibilities) == 1: 173 parts.extend(possibilities[0]) 174 else: 175 message = ", ".join(".".join(parts) for parts in possibilities) 176 if raise_on_missing: 177 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 178 return None 179 180 return self.nested_get(parts, raise_on_missing=raise_on_missing) 181 182 def nested_get( 183 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 184 ) -> t.Optional[t.Any]: 185 return nested_get( 186 d or self.mapping, 187 *zip(self.supported_table_args, reversed(parts)), 188 raise_on_missing=raise_on_missing, 189 )
141 @property 142 def supported_table_args(self) -> t.Tuple[str, ...]: 143 if not self._supported_table_args and self.mapping: 144 depth = self.depth() 145 146 if not depth: # None 147 self._supported_table_args = tuple() 148 elif 1 <= depth <= 3: 149 self._supported_table_args = exp.TABLE_PARTS[:depth] 150 else: 151 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 152 153 return self._supported_table_args
160 def find( 161 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 162 ) -> t.Optional[t.Any]: 163 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 164 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 165 166 if value == TrieResult.FAILED: 167 return None 168 169 if value == TrieResult.PREFIX: 170 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 171 172 if len(possibilities) == 1: 173 parts.extend(possibilities[0]) 174 else: 175 message = ", ".join(".".join(parts) for parts in possibilities) 176 if raise_on_missing: 177 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 178 return None 179 180 return self.nested_get(parts, raise_on_missing=raise_on_missing)
192class MappingSchema(AbstractMappingSchema, Schema): 193 """ 194 Schema based on a nested mapping. 195 196 Args: 197 schema: Mapping in one of the following forms: 198 1. {table: {col: type}} 199 2. {db: {table: {col: type}}} 200 3. {catalog: {db: {table: {col: type}}}} 201 4. None - Tables will be added later 202 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 203 are assumed to be visible. The nesting should mirror that of the schema: 204 1. {table: set(*cols)}} 205 2. {db: {table: set(*cols)}}} 206 3. {catalog: {db: {table: set(*cols)}}}} 207 dialect: The dialect to be used for custom type mappings & parsing string arguments. 208 normalize: Whether to normalize identifier names according to the given dialect or not. 209 """ 210 211 def __init__( 212 self, 213 schema: t.Optional[t.Dict] = None, 214 visible: t.Optional[t.Dict] = None, 215 dialect: DialectType = None, 216 normalize: bool = True, 217 ) -> None: 218 self.dialect = dialect 219 self.visible = {} if visible is None else visible 220 self.normalize = normalize 221 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 222 self._depth = 0 223 schema = {} if schema is None else schema 224 225 super().__init__(self._normalize(schema) if self.normalize else schema) 226 227 @classmethod 228 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 229 return MappingSchema( 230 schema=mapping_schema.mapping, 231 visible=mapping_schema.visible, 232 dialect=mapping_schema.dialect, 233 normalize=mapping_schema.normalize, 234 ) 235 236 def copy(self, **kwargs) -> MappingSchema: 237 return MappingSchema( 238 **{ # type: ignore 239 "schema": self.mapping.copy(), 240 "visible": self.visible.copy(), 241 "dialect": self.dialect, 242 "normalize": self.normalize, 243 **kwargs, 244 } 245 ) 246 247 def add_table( 248 self, 249 table: exp.Table | str, 250 column_mapping: t.Optional[ColumnMapping] = None, 251 dialect: DialectType = None, 252 normalize: t.Optional[bool] = None, 253 match_depth: bool = True, 254 ) -> None: 255 """ 256 Register or update a table. Updates are only performed if a new column mapping is provided. 257 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 258 259 Args: 260 table: the `Table` expression instance or string representing the table. 261 column_mapping: a column mapping that describes the structure of the table. 262 dialect: the SQL dialect that will be used to parse `table` if it's a string. 263 normalize: whether to normalize identifiers according to the dialect of interest. 264 match_depth: whether to enforce that the table must match the schema's depth or not. 265 """ 266 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 267 268 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 269 raise SchemaError( 270 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 271 f"schema's nesting level: {self.depth()}." 272 ) 273 274 normalized_column_mapping = { 275 self._normalize_name(key, dialect=dialect, normalize=normalize): value 276 for key, value in ensure_column_mapping(column_mapping).items() 277 } 278 279 schema = self.find(normalized_table, raise_on_missing=False) 280 if schema and not normalized_column_mapping: 281 return 282 283 parts = self.table_parts(normalized_table) 284 285 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 286 new_trie([parts], self.mapping_trie) 287 288 def column_names( 289 self, 290 table: exp.Table | str, 291 only_visible: bool = False, 292 dialect: DialectType = None, 293 normalize: t.Optional[bool] = None, 294 ) -> t.List[str]: 295 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 296 297 schema = self.find(normalized_table) 298 if schema is None: 299 return [] 300 301 if not only_visible or not self.visible: 302 return list(schema) 303 304 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 305 return [col for col in schema if col in visible] 306 307 def get_column_type( 308 self, 309 table: exp.Table | str, 310 column: exp.Column | str, 311 dialect: DialectType = None, 312 normalize: t.Optional[bool] = None, 313 ) -> exp.DataType: 314 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 315 316 normalized_column_name = self._normalize_name( 317 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 318 ) 319 320 table_schema = self.find(normalized_table, raise_on_missing=False) 321 if table_schema: 322 column_type = table_schema.get(normalized_column_name) 323 324 if isinstance(column_type, exp.DataType): 325 return column_type 326 elif isinstance(column_type, str): 327 return self._to_data_type(column_type, dialect=dialect) 328 329 return exp.DataType.build("unknown") 330 331 def has_column( 332 self, 333 table: exp.Table | str, 334 column: exp.Column | str, 335 dialect: DialectType = None, 336 normalize: t.Optional[bool] = None, 337 ) -> bool: 338 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 339 340 normalized_column_name = self._normalize_name( 341 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 342 ) 343 344 table_schema = self.find(normalized_table, raise_on_missing=False) 345 return normalized_column_name in table_schema if table_schema else False 346 347 def _normalize(self, schema: t.Dict) -> t.Dict: 348 """ 349 Normalizes all identifiers in the schema. 350 351 Args: 352 schema: the schema to normalize. 353 354 Returns: 355 The normalized schema mapping. 356 """ 357 normalized_mapping: t.Dict = {} 358 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 359 360 for keys in flattened_schema: 361 columns = nested_get(schema, *zip(keys, keys)) 362 363 if not isinstance(columns, dict): 364 raise SchemaError( 365 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 366 ) 367 368 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 369 for column_name, column_type in columns.items(): 370 nested_set( 371 normalized_mapping, 372 normalized_keys + [self._normalize_name(column_name)], 373 column_type, 374 ) 375 376 return normalized_mapping 377 378 def _normalize_table( 379 self, 380 table: exp.Table | str, 381 dialect: DialectType = None, 382 normalize: t.Optional[bool] = None, 383 ) -> exp.Table: 384 dialect = dialect or self.dialect 385 normalize = self.normalize if normalize is None else normalize 386 387 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 388 389 if normalize: 390 for arg in exp.TABLE_PARTS: 391 value = normalized_table.args.get(arg) 392 if isinstance(value, exp.Identifier): 393 normalized_table.set( 394 arg, 395 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 396 ) 397 398 return normalized_table 399 400 def _normalize_name( 401 self, 402 name: str | exp.Identifier, 403 dialect: DialectType = None, 404 is_table: bool = False, 405 normalize: t.Optional[bool] = None, 406 ) -> str: 407 return normalize_name( 408 name, 409 dialect=dialect or self.dialect, 410 is_table=is_table, 411 normalize=self.normalize if normalize is None else normalize, 412 ).name 413 414 def depth(self) -> int: 415 if not self.empty and not self._depth: 416 # The columns themselves are a mapping, but we don't want to include those 417 self._depth = super().depth() - 1 418 return self._depth 419 420 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 421 """ 422 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 423 424 Args: 425 schema_type: the type we want to convert. 426 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 427 428 Returns: 429 The resulting expression type. 430 """ 431 if schema_type not in self._type_mapping_cache: 432 dialect = dialect or self.dialect 433 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 434 435 try: 436 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 437 self._type_mapping_cache[schema_type] = expression 438 except AttributeError: 439 in_dialect = f" in dialect {dialect}" if dialect else "" 440 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 441 442 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
211 def __init__( 212 self, 213 schema: t.Optional[t.Dict] = None, 214 visible: t.Optional[t.Dict] = None, 215 dialect: DialectType = None, 216 normalize: bool = True, 217 ) -> None: 218 self.dialect = dialect 219 self.visible = {} if visible is None else visible 220 self.normalize = normalize 221 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 222 self._depth = 0 223 schema = {} if schema is None else schema 224 225 super().__init__(self._normalize(schema) if self.normalize else schema)
247 def add_table( 248 self, 249 table: exp.Table | str, 250 column_mapping: t.Optional[ColumnMapping] = None, 251 dialect: DialectType = None, 252 normalize: t.Optional[bool] = None, 253 match_depth: bool = True, 254 ) -> None: 255 """ 256 Register or update a table. Updates are only performed if a new column mapping is provided. 257 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 258 259 Args: 260 table: the `Table` expression instance or string representing the table. 261 column_mapping: a column mapping that describes the structure of the table. 262 dialect: the SQL dialect that will be used to parse `table` if it's a string. 263 normalize: whether to normalize identifiers according to the dialect of interest. 264 match_depth: whether to enforce that the table must match the schema's depth or not. 265 """ 266 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 267 268 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 269 raise SchemaError( 270 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 271 f"schema's nesting level: {self.depth()}." 272 ) 273 274 normalized_column_mapping = { 275 self._normalize_name(key, dialect=dialect, normalize=normalize): value 276 for key, value in ensure_column_mapping(column_mapping).items() 277 } 278 279 schema = self.find(normalized_table, raise_on_missing=False) 280 if schema and not normalized_column_mapping: 281 return 282 283 parts = self.table_parts(normalized_table) 284 285 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 286 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
288 def column_names( 289 self, 290 table: exp.Table | str, 291 only_visible: bool = False, 292 dialect: DialectType = None, 293 normalize: t.Optional[bool] = None, 294 ) -> t.List[str]: 295 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 296 297 schema = self.find(normalized_table) 298 if schema is None: 299 return [] 300 301 if not only_visible or not self.visible: 302 return list(schema) 303 304 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 305 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
307 def get_column_type( 308 self, 309 table: exp.Table | str, 310 column: exp.Column | str, 311 dialect: DialectType = None, 312 normalize: t.Optional[bool] = None, 313 ) -> exp.DataType: 314 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 315 316 normalized_column_name = self._normalize_name( 317 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 318 ) 319 320 table_schema = self.find(normalized_table, raise_on_missing=False) 321 if table_schema: 322 column_type = table_schema.get(normalized_column_name) 323 324 if isinstance(column_type, exp.DataType): 325 return column_type 326 elif isinstance(column_type, str): 327 return self._to_data_type(column_type, dialect=dialect) 328 329 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
331 def has_column( 332 self, 333 table: exp.Table | str, 334 column: exp.Column | str, 335 dialect: DialectType = None, 336 normalize: t.Optional[bool] = None, 337 ) -> bool: 338 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 339 340 normalized_column_name = self._normalize_name( 341 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 342 ) 343 344 table_schema = self.find(normalized_table, raise_on_missing=False) 345 return normalized_column_name in table_schema if table_schema else False
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
445def normalize_name( 446 identifier: str | exp.Identifier, 447 dialect: DialectType = None, 448 is_table: bool = False, 449 normalize: t.Optional[bool] = True, 450) -> exp.Identifier: 451 if isinstance(identifier, str): 452 identifier = exp.parse_identifier(identifier, dialect=dialect) 453 454 if not normalize: 455 return identifier 456 457 # this is used for normalize_identifier, bigquery has special rules pertaining tables 458 identifier.meta["is_table"] = is_table 459 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
469def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 470 if mapping is None: 471 return {} 472 elif isinstance(mapping, dict): 473 return mapping 474 elif isinstance(mapping, str): 475 col_name_type_strs = [x.strip() for x in mapping.split(",")] 476 return { 477 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 478 for name_type_str in col_name_type_strs 479 } 480 # Check if mapping looks like a DataFrame StructType 481 elif hasattr(mapping, "simpleString"): 482 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 483 elif isinstance(mapping, list): 484 return {x.strip(): None for x in mapping} 485 486 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
489def flatten_schema( 490 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 491) -> t.List[t.List[str]]: 492 tables = [] 493 keys = keys or [] 494 495 for k, v in schema.items(): 496 if depth >= 2: 497 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 498 elif depth == 1: 499 tables.append(keys + [k]) 500 501 return tables
504def nested_get( 505 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 506) -> t.Optional[t.Any]: 507 """ 508 Get a value for a nested dictionary. 509 510 Args: 511 d: the dictionary to search. 512 *path: tuples of (name, key), where: 513 `key` is the key in the dictionary to get. 514 `name` is a string to use in the error if `key` isn't found. 515 516 Returns: 517 The value or None if it doesn't exist. 518 """ 519 for name, key in path: 520 d = d.get(key) # type: ignore 521 if d is None: 522 if raise_on_missing: 523 name = "table" if name == "this" else name 524 raise ValueError(f"Unknown {name}: {key}") 525 return None 526 527 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
530def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 531 """ 532 In-place set a value for a nested dictionary 533 534 Example: 535 >>> nested_set({}, ["top_key", "second_key"], "value") 536 {'top_key': {'second_key': 'value'}} 537 538 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 539 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 540 541 Args: 542 d: dictionary to update. 543 keys: the keys that makeup the path to `value`. 544 value: the value to set in the dictionary for the given key path. 545 546 Returns: 547 The (possibly) updated dictionary. 548 """ 549 if not keys: 550 return d 551 552 if len(keys) == 1: 553 d[keys[0]] = value 554 return d 555 556 subd = d 557 for key in keys[:-1]: 558 if key not in subd: 559 subd = subd.setdefault(key, {}) 560 else: 561 subd = subd[key] 562 563 subd[keys[-1]] = value 564 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.