sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18 19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.List[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The list of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether or not `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether or not the schema is empty.""" 119 return True 120 121 122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 if isinstance(table.this, exp.ReadCSV): 156 return [table.this.name] 157 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 158 159 def find( 160 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 161 ) -> t.Optional[t.Any]: 162 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 163 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 164 165 if value == TrieResult.FAILED: 166 return None 167 168 if value == TrieResult.PREFIX: 169 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 170 171 if len(possibilities) == 1: 172 parts.extend(possibilities[0]) 173 else: 174 message = ", ".join(".".join(parts) for parts in possibilities) 175 if raise_on_missing: 176 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 177 return None 178 179 return self.nested_get(parts, raise_on_missing=raise_on_missing) 180 181 def nested_get( 182 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 183 ) -> t.Optional[t.Any]: 184 return nested_get( 185 d or self.mapping, 186 *zip(self.supported_table_args, reversed(parts)), 187 raise_on_missing=raise_on_missing, 188 ) 189 190 191class MappingSchema(AbstractMappingSchema, Schema): 192 """ 193 Schema based on a nested mapping. 194 195 Args: 196 schema: Mapping in one of the following forms: 197 1. {table: {col: type}} 198 2. {db: {table: {col: type}}} 199 3. {catalog: {db: {table: {col: type}}}} 200 4. None - Tables will be added later 201 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 202 are assumed to be visible. The nesting should mirror that of the schema: 203 1. {table: set(*cols)}} 204 2. {db: {table: set(*cols)}}} 205 3. {catalog: {db: {table: set(*cols)}}}} 206 dialect: The dialect to be used for custom type mappings & parsing string arguments. 207 normalize: Whether to normalize identifier names according to the given dialect or not. 208 """ 209 210 def __init__( 211 self, 212 schema: t.Optional[t.Dict] = None, 213 visible: t.Optional[t.Dict] = None, 214 dialect: DialectType = None, 215 normalize: bool = True, 216 ) -> None: 217 self.dialect = dialect 218 self.visible = visible or {} 219 self.normalize = normalize 220 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 221 self._depth = 0 222 223 super().__init__(self._normalize(schema or {})) 224 225 @classmethod 226 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 227 return MappingSchema( 228 schema=mapping_schema.mapping, 229 visible=mapping_schema.visible, 230 dialect=mapping_schema.dialect, 231 normalize=mapping_schema.normalize, 232 ) 233 234 def copy(self, **kwargs) -> MappingSchema: 235 return MappingSchema( 236 **{ # type: ignore 237 "schema": self.mapping.copy(), 238 "visible": self.visible.copy(), 239 "dialect": self.dialect, 240 "normalize": self.normalize, 241 **kwargs, 242 } 243 ) 244 245 def add_table( 246 self, 247 table: exp.Table | str, 248 column_mapping: t.Optional[ColumnMapping] = None, 249 dialect: DialectType = None, 250 normalize: t.Optional[bool] = None, 251 match_depth: bool = True, 252 ) -> None: 253 """ 254 Register or update a table. Updates are only performed if a new column mapping is provided. 255 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 256 257 Args: 258 table: the `Table` expression instance or string representing the table. 259 column_mapping: a column mapping that describes the structure of the table. 260 dialect: the SQL dialect that will be used to parse `table` if it's a string. 261 normalize: whether to normalize identifiers according to the dialect of interest. 262 match_depth: whether to enforce that the table must match the schema's depth or not. 263 """ 264 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 265 266 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 267 raise SchemaError( 268 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 269 f"schema's nesting level: {self.depth()}." 270 ) 271 272 normalized_column_mapping = { 273 self._normalize_name(key, dialect=dialect, normalize=normalize): value 274 for key, value in ensure_column_mapping(column_mapping).items() 275 } 276 277 schema = self.find(normalized_table, raise_on_missing=False) 278 if schema and not normalized_column_mapping: 279 return 280 281 parts = self.table_parts(normalized_table) 282 283 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 284 new_trie([parts], self.mapping_trie) 285 286 def column_names( 287 self, 288 table: exp.Table | str, 289 only_visible: bool = False, 290 dialect: DialectType = None, 291 normalize: t.Optional[bool] = None, 292 ) -> t.List[str]: 293 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 294 295 schema = self.find(normalized_table) 296 if schema is None: 297 return [] 298 299 if not only_visible or not self.visible: 300 return list(schema) 301 302 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 303 return [col for col in schema if col in visible] 304 305 def get_column_type( 306 self, 307 table: exp.Table | str, 308 column: exp.Column | str, 309 dialect: DialectType = None, 310 normalize: t.Optional[bool] = None, 311 ) -> exp.DataType: 312 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 313 314 normalized_column_name = self._normalize_name( 315 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 316 ) 317 318 table_schema = self.find(normalized_table, raise_on_missing=False) 319 if table_schema: 320 column_type = table_schema.get(normalized_column_name) 321 322 if isinstance(column_type, exp.DataType): 323 return column_type 324 elif isinstance(column_type, str): 325 return self._to_data_type(column_type, dialect=dialect) 326 327 return exp.DataType.build("unknown") 328 329 def has_column( 330 self, 331 table: exp.Table | str, 332 column: exp.Column | str, 333 dialect: DialectType = None, 334 normalize: t.Optional[bool] = None, 335 ) -> bool: 336 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 337 338 normalized_column_name = self._normalize_name( 339 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 340 ) 341 342 table_schema = self.find(normalized_table, raise_on_missing=False) 343 return normalized_column_name in table_schema if table_schema else False 344 345 def _normalize(self, schema: t.Dict) -> t.Dict: 346 """ 347 Normalizes all identifiers in the schema. 348 349 Args: 350 schema: the schema to normalize. 351 352 Returns: 353 The normalized schema mapping. 354 """ 355 normalized_mapping: t.Dict = {} 356 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 357 358 for keys in flattened_schema: 359 columns = nested_get(schema, *zip(keys, keys)) 360 361 if not isinstance(columns, dict): 362 raise SchemaError( 363 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 364 ) 365 366 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 367 for column_name, column_type in columns.items(): 368 nested_set( 369 normalized_mapping, 370 normalized_keys + [self._normalize_name(column_name)], 371 column_type, 372 ) 373 374 return normalized_mapping 375 376 def _normalize_table( 377 self, 378 table: exp.Table | str, 379 dialect: DialectType = None, 380 normalize: t.Optional[bool] = None, 381 ) -> exp.Table: 382 dialect = dialect or self.dialect 383 normalize = self.normalize if normalize is None else normalize 384 385 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 386 387 if normalize: 388 for arg in exp.TABLE_PARTS: 389 value = normalized_table.args.get(arg) 390 if isinstance(value, exp.Identifier): 391 normalized_table.set( 392 arg, 393 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 394 ) 395 396 return normalized_table 397 398 def _normalize_name( 399 self, 400 name: str | exp.Identifier, 401 dialect: DialectType = None, 402 is_table: bool = False, 403 normalize: t.Optional[bool] = None, 404 ) -> str: 405 return normalize_name( 406 name, 407 dialect=dialect or self.dialect, 408 is_table=is_table, 409 normalize=self.normalize if normalize is None else normalize, 410 ).name 411 412 def depth(self) -> int: 413 if not self.empty and not self._depth: 414 # The columns themselves are a mapping, but we don't want to include those 415 self._depth = super().depth() - 1 416 return self._depth 417 418 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 419 """ 420 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 421 422 Args: 423 schema_type: the type we want to convert. 424 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 425 426 Returns: 427 The resulting expression type. 428 """ 429 if schema_type not in self._type_mapping_cache: 430 dialect = dialect or self.dialect 431 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 432 433 try: 434 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 435 self._type_mapping_cache[schema_type] = expression 436 except AttributeError: 437 in_dialect = f" in dialect {dialect}" if dialect else "" 438 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 439 440 return self._type_mapping_cache[schema_type] 441 442 443def normalize_name( 444 identifier: str | exp.Identifier, 445 dialect: DialectType = None, 446 is_table: bool = False, 447 normalize: t.Optional[bool] = True, 448) -> exp.Identifier: 449 if isinstance(identifier, str): 450 identifier = exp.parse_identifier(identifier, dialect=dialect) 451 452 if not normalize: 453 return identifier 454 455 # this is used for normalize_identifier, bigquery has special rules pertaining tables 456 identifier.meta["is_table"] = is_table 457 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 458 459 460def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 461 if isinstance(schema, Schema): 462 return schema 463 464 return MappingSchema(schema, **kwargs) 465 466 467def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 468 if mapping is None: 469 return {} 470 elif isinstance(mapping, dict): 471 return mapping 472 elif isinstance(mapping, str): 473 col_name_type_strs = [x.strip() for x in mapping.split(",")] 474 return { 475 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 476 for name_type_str in col_name_type_strs 477 } 478 # Check if mapping looks like a DataFrame StructType 479 elif hasattr(mapping, "simpleString"): 480 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 481 elif isinstance(mapping, list): 482 return {x.strip(): None for x in mapping} 483 484 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 485 486 487def flatten_schema( 488 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 489) -> t.List[t.List[str]]: 490 tables = [] 491 keys = keys or [] 492 493 for k, v in schema.items(): 494 if depth >= 2: 495 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 496 elif depth == 1: 497 tables.append(keys + [k]) 498 499 return tables 500 501 502def nested_get( 503 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 504) -> t.Optional[t.Any]: 505 """ 506 Get a value for a nested dictionary. 507 508 Args: 509 d: the dictionary to search. 510 *path: tuples of (name, key), where: 511 `key` is the key in the dictionary to get. 512 `name` is a string to use in the error if `key` isn't found. 513 514 Returns: 515 The value or None if it doesn't exist. 516 """ 517 for name, key in path: 518 d = d.get(key) # type: ignore 519 if d is None: 520 if raise_on_missing: 521 name = "table" if name == "this" else name 522 raise ValueError(f"Unknown {name}: {key}") 523 return None 524 525 return d 526 527 528def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 529 """ 530 In-place set a value for a nested dictionary 531 532 Example: 533 >>> nested_set({}, ["top_key", "second_key"], "value") 534 {'top_key': {'second_key': 'value'}} 535 536 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 537 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 538 539 Args: 540 d: dictionary to update. 541 keys: the keys that makeup the path to `value`. 542 value: the value to set in the dictionary for the given key path. 543 544 Returns: 545 The (possibly) updated dictionary. 546 """ 547 if not keys: 548 return d 549 550 if len(keys) == 1: 551 d[keys[0]] = value 552 return d 553 554 subd = d 555 for key in keys[:-1]: 556 if key not in subd: 557 subd = subd.setdefault(key, {}) 558 else: 559 subd = subd[key] 560 561 subd[keys[-1]] = value 562 return d
20class Schema(abc.ABC): 21 """Abstract base class for database schemas""" 22 23 dialect: DialectType 24 25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """ 45 46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """ 66 67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """ 87 88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize) 109 110 @property 111 @abc.abstractmethod 112 def supported_table_args(self) -> t.Tuple[str, ...]: 113 """ 114 Table arguments this schema support, e.g. `("this", "db", "catalog")` 115 """ 116 117 @property 118 def empty(self) -> bool: 119 """Returns whether or not the schema is empty.""" 120 return True
Abstract base class for database schemas
25 @abc.abstractmethod 26 def add_table( 27 self, 28 table: exp.Table | str, 29 column_mapping: t.Optional[ColumnMapping] = None, 30 dialect: DialectType = None, 31 normalize: t.Optional[bool] = None, 32 match_depth: bool = True, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 normalize: whether to normalize identifiers according to the dialect of interest. 43 match_depth: whether to enforce that the table must match the schema's depth or not. 44 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column | str, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
88 def has_column( 89 self, 90 table: exp.Table | str, 91 column: exp.Column | str, 92 dialect: DialectType = None, 93 normalize: t.Optional[bool] = None, 94 ) -> bool: 95 """ 96 Returns whether or not `column` appears in `table`'s schema. 97 98 Args: 99 table: the source table. 100 column: the target column. 101 dialect: the SQL dialect that will be used to parse `table` if it's a string. 102 normalize: whether to normalize identifiers according to the dialect of interest. 103 104 Returns: 105 True if the column appears in the schema, False otherwise. 106 """ 107 name = column if isinstance(column, str) else column.name 108 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
123class AbstractMappingSchema: 124 def __init__( 125 self, 126 mapping: t.Optional[t.Dict] = None, 127 ) -> None: 128 self.mapping = mapping or {} 129 self.mapping_trie = new_trie( 130 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 131 ) 132 self._supported_table_args: t.Tuple[str, ...] = tuple() 133 134 @property 135 def empty(self) -> bool: 136 return not self.mapping 137 138 def depth(self) -> int: 139 return dict_depth(self.mapping) 140 141 @property 142 def supported_table_args(self) -> t.Tuple[str, ...]: 143 if not self._supported_table_args and self.mapping: 144 depth = self.depth() 145 146 if not depth: # None 147 self._supported_table_args = tuple() 148 elif 1 <= depth <= 3: 149 self._supported_table_args = exp.TABLE_PARTS[:depth] 150 else: 151 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 152 153 return self._supported_table_args 154 155 def table_parts(self, table: exp.Table) -> t.List[str]: 156 if isinstance(table.this, exp.ReadCSV): 157 return [table.this.name] 158 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 159 160 def find( 161 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 162 ) -> t.Optional[t.Any]: 163 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 164 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 165 166 if value == TrieResult.FAILED: 167 return None 168 169 if value == TrieResult.PREFIX: 170 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 171 172 if len(possibilities) == 1: 173 parts.extend(possibilities[0]) 174 else: 175 message = ", ".join(".".join(parts) for parts in possibilities) 176 if raise_on_missing: 177 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 178 return None 179 180 return self.nested_get(parts, raise_on_missing=raise_on_missing) 181 182 def nested_get( 183 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 184 ) -> t.Optional[t.Any]: 185 return nested_get( 186 d or self.mapping, 187 *zip(self.supported_table_args, reversed(parts)), 188 raise_on_missing=raise_on_missing, 189 )
160 def find( 161 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 162 ) -> t.Optional[t.Any]: 163 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 164 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 165 166 if value == TrieResult.FAILED: 167 return None 168 169 if value == TrieResult.PREFIX: 170 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 171 172 if len(possibilities) == 1: 173 parts.extend(possibilities[0]) 174 else: 175 message = ", ".join(".".join(parts) for parts in possibilities) 176 if raise_on_missing: 177 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 178 return None 179 180 return self.nested_get(parts, raise_on_missing=raise_on_missing)
192class MappingSchema(AbstractMappingSchema, Schema): 193 """ 194 Schema based on a nested mapping. 195 196 Args: 197 schema: Mapping in one of the following forms: 198 1. {table: {col: type}} 199 2. {db: {table: {col: type}}} 200 3. {catalog: {db: {table: {col: type}}}} 201 4. None - Tables will be added later 202 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 203 are assumed to be visible. The nesting should mirror that of the schema: 204 1. {table: set(*cols)}} 205 2. {db: {table: set(*cols)}}} 206 3. {catalog: {db: {table: set(*cols)}}}} 207 dialect: The dialect to be used for custom type mappings & parsing string arguments. 208 normalize: Whether to normalize identifier names according to the given dialect or not. 209 """ 210 211 def __init__( 212 self, 213 schema: t.Optional[t.Dict] = None, 214 visible: t.Optional[t.Dict] = None, 215 dialect: DialectType = None, 216 normalize: bool = True, 217 ) -> None: 218 self.dialect = dialect 219 self.visible = visible or {} 220 self.normalize = normalize 221 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 222 self._depth = 0 223 224 super().__init__(self._normalize(schema or {})) 225 226 @classmethod 227 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 228 return MappingSchema( 229 schema=mapping_schema.mapping, 230 visible=mapping_schema.visible, 231 dialect=mapping_schema.dialect, 232 normalize=mapping_schema.normalize, 233 ) 234 235 def copy(self, **kwargs) -> MappingSchema: 236 return MappingSchema( 237 **{ # type: ignore 238 "schema": self.mapping.copy(), 239 "visible": self.visible.copy(), 240 "dialect": self.dialect, 241 "normalize": self.normalize, 242 **kwargs, 243 } 244 ) 245 246 def add_table( 247 self, 248 table: exp.Table | str, 249 column_mapping: t.Optional[ColumnMapping] = None, 250 dialect: DialectType = None, 251 normalize: t.Optional[bool] = None, 252 match_depth: bool = True, 253 ) -> None: 254 """ 255 Register or update a table. Updates are only performed if a new column mapping is provided. 256 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 257 258 Args: 259 table: the `Table` expression instance or string representing the table. 260 column_mapping: a column mapping that describes the structure of the table. 261 dialect: the SQL dialect that will be used to parse `table` if it's a string. 262 normalize: whether to normalize identifiers according to the dialect of interest. 263 match_depth: whether to enforce that the table must match the schema's depth or not. 264 """ 265 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 266 267 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 268 raise SchemaError( 269 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 270 f"schema's nesting level: {self.depth()}." 271 ) 272 273 normalized_column_mapping = { 274 self._normalize_name(key, dialect=dialect, normalize=normalize): value 275 for key, value in ensure_column_mapping(column_mapping).items() 276 } 277 278 schema = self.find(normalized_table, raise_on_missing=False) 279 if schema and not normalized_column_mapping: 280 return 281 282 parts = self.table_parts(normalized_table) 283 284 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 285 new_trie([parts], self.mapping_trie) 286 287 def column_names( 288 self, 289 table: exp.Table | str, 290 only_visible: bool = False, 291 dialect: DialectType = None, 292 normalize: t.Optional[bool] = None, 293 ) -> t.List[str]: 294 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 295 296 schema = self.find(normalized_table) 297 if schema is None: 298 return [] 299 300 if not only_visible or not self.visible: 301 return list(schema) 302 303 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 304 return [col for col in schema if col in visible] 305 306 def get_column_type( 307 self, 308 table: exp.Table | str, 309 column: exp.Column | str, 310 dialect: DialectType = None, 311 normalize: t.Optional[bool] = None, 312 ) -> exp.DataType: 313 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 314 315 normalized_column_name = self._normalize_name( 316 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 317 ) 318 319 table_schema = self.find(normalized_table, raise_on_missing=False) 320 if table_schema: 321 column_type = table_schema.get(normalized_column_name) 322 323 if isinstance(column_type, exp.DataType): 324 return column_type 325 elif isinstance(column_type, str): 326 return self._to_data_type(column_type, dialect=dialect) 327 328 return exp.DataType.build("unknown") 329 330 def has_column( 331 self, 332 table: exp.Table | str, 333 column: exp.Column | str, 334 dialect: DialectType = None, 335 normalize: t.Optional[bool] = None, 336 ) -> bool: 337 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 338 339 normalized_column_name = self._normalize_name( 340 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 341 ) 342 343 table_schema = self.find(normalized_table, raise_on_missing=False) 344 return normalized_column_name in table_schema if table_schema else False 345 346 def _normalize(self, schema: t.Dict) -> t.Dict: 347 """ 348 Normalizes all identifiers in the schema. 349 350 Args: 351 schema: the schema to normalize. 352 353 Returns: 354 The normalized schema mapping. 355 """ 356 normalized_mapping: t.Dict = {} 357 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 358 359 for keys in flattened_schema: 360 columns = nested_get(schema, *zip(keys, keys)) 361 362 if not isinstance(columns, dict): 363 raise SchemaError( 364 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 365 ) 366 367 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 368 for column_name, column_type in columns.items(): 369 nested_set( 370 normalized_mapping, 371 normalized_keys + [self._normalize_name(column_name)], 372 column_type, 373 ) 374 375 return normalized_mapping 376 377 def _normalize_table( 378 self, 379 table: exp.Table | str, 380 dialect: DialectType = None, 381 normalize: t.Optional[bool] = None, 382 ) -> exp.Table: 383 dialect = dialect or self.dialect 384 normalize = self.normalize if normalize is None else normalize 385 386 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 387 388 if normalize: 389 for arg in exp.TABLE_PARTS: 390 value = normalized_table.args.get(arg) 391 if isinstance(value, exp.Identifier): 392 normalized_table.set( 393 arg, 394 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 395 ) 396 397 return normalized_table 398 399 def _normalize_name( 400 self, 401 name: str | exp.Identifier, 402 dialect: DialectType = None, 403 is_table: bool = False, 404 normalize: t.Optional[bool] = None, 405 ) -> str: 406 return normalize_name( 407 name, 408 dialect=dialect or self.dialect, 409 is_table=is_table, 410 normalize=self.normalize if normalize is None else normalize, 411 ).name 412 413 def depth(self) -> int: 414 if not self.empty and not self._depth: 415 # The columns themselves are a mapping, but we don't want to include those 416 self._depth = super().depth() - 1 417 return self._depth 418 419 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 420 """ 421 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 422 423 Args: 424 schema_type: the type we want to convert. 425 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 426 427 Returns: 428 The resulting expression type. 429 """ 430 if schema_type not in self._type_mapping_cache: 431 dialect = dialect or self.dialect 432 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 433 434 try: 435 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 436 self._type_mapping_cache[schema_type] = expression 437 except AttributeError: 438 in_dialect = f" in dialect {dialect}" if dialect else "" 439 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 440 441 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
211 def __init__( 212 self, 213 schema: t.Optional[t.Dict] = None, 214 visible: t.Optional[t.Dict] = None, 215 dialect: DialectType = None, 216 normalize: bool = True, 217 ) -> None: 218 self.dialect = dialect 219 self.visible = visible or {} 220 self.normalize = normalize 221 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 222 self._depth = 0 223 224 super().__init__(self._normalize(schema or {}))
246 def add_table( 247 self, 248 table: exp.Table | str, 249 column_mapping: t.Optional[ColumnMapping] = None, 250 dialect: DialectType = None, 251 normalize: t.Optional[bool] = None, 252 match_depth: bool = True, 253 ) -> None: 254 """ 255 Register or update a table. Updates are only performed if a new column mapping is provided. 256 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 257 258 Args: 259 table: the `Table` expression instance or string representing the table. 260 column_mapping: a column mapping that describes the structure of the table. 261 dialect: the SQL dialect that will be used to parse `table` if it's a string. 262 normalize: whether to normalize identifiers according to the dialect of interest. 263 match_depth: whether to enforce that the table must match the schema's depth or not. 264 """ 265 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 266 267 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 268 raise SchemaError( 269 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 270 f"schema's nesting level: {self.depth()}." 271 ) 272 273 normalized_column_mapping = { 274 self._normalize_name(key, dialect=dialect, normalize=normalize): value 275 for key, value in ensure_column_mapping(column_mapping).items() 276 } 277 278 schema = self.find(normalized_table, raise_on_missing=False) 279 if schema and not normalized_column_mapping: 280 return 281 282 parts = self.table_parts(normalized_table) 283 284 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 285 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
287 def column_names( 288 self, 289 table: exp.Table | str, 290 only_visible: bool = False, 291 dialect: DialectType = None, 292 normalize: t.Optional[bool] = None, 293 ) -> t.List[str]: 294 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 295 296 schema = self.find(normalized_table) 297 if schema is None: 298 return [] 299 300 if not only_visible or not self.visible: 301 return list(schema) 302 303 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 304 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
306 def get_column_type( 307 self, 308 table: exp.Table | str, 309 column: exp.Column | str, 310 dialect: DialectType = None, 311 normalize: t.Optional[bool] = None, 312 ) -> exp.DataType: 313 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 314 315 normalized_column_name = self._normalize_name( 316 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 317 ) 318 319 table_schema = self.find(normalized_table, raise_on_missing=False) 320 if table_schema: 321 column_type = table_schema.get(normalized_column_name) 322 323 if isinstance(column_type, exp.DataType): 324 return column_type 325 elif isinstance(column_type, str): 326 return self._to_data_type(column_type, dialect=dialect) 327 328 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
330 def has_column( 331 self, 332 table: exp.Table | str, 333 column: exp.Column | str, 334 dialect: DialectType = None, 335 normalize: t.Optional[bool] = None, 336 ) -> bool: 337 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 338 339 normalized_column_name = self._normalize_name( 340 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 341 ) 342 343 table_schema = self.find(normalized_table, raise_on_missing=False) 344 return normalized_column_name in table_schema if table_schema else False
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
444def normalize_name( 445 identifier: str | exp.Identifier, 446 dialect: DialectType = None, 447 is_table: bool = False, 448 normalize: t.Optional[bool] = True, 449) -> exp.Identifier: 450 if isinstance(identifier, str): 451 identifier = exp.parse_identifier(identifier, dialect=dialect) 452 453 if not normalize: 454 return identifier 455 456 # this is used for normalize_identifier, bigquery has special rules pertaining tables 457 identifier.meta["is_table"] = is_table 458 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
468def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 469 if mapping is None: 470 return {} 471 elif isinstance(mapping, dict): 472 return mapping 473 elif isinstance(mapping, str): 474 col_name_type_strs = [x.strip() for x in mapping.split(",")] 475 return { 476 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 477 for name_type_str in col_name_type_strs 478 } 479 # Check if mapping looks like a DataFrame StructType 480 elif hasattr(mapping, "simpleString"): 481 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 482 elif isinstance(mapping, list): 483 return {x.strip(): None for x in mapping} 484 485 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
488def flatten_schema( 489 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 490) -> t.List[t.List[str]]: 491 tables = [] 492 keys = keys or [] 493 494 for k, v in schema.items(): 495 if depth >= 2: 496 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 497 elif depth == 1: 498 tables.append(keys + [k]) 499 500 return tables
503def nested_get( 504 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 505) -> t.Optional[t.Any]: 506 """ 507 Get a value for a nested dictionary. 508 509 Args: 510 d: the dictionary to search. 511 *path: tuples of (name, key), where: 512 `key` is the key in the dictionary to get. 513 `name` is a string to use in the error if `key` isn't found. 514 515 Returns: 516 The value or None if it doesn't exist. 517 """ 518 for name, key in path: 519 d = d.get(key) # type: ignore 520 if d is None: 521 if raise_on_missing: 522 name = "table" if name == "this" else name 523 raise ValueError(f"Unknown {name}: {key}") 524 return None 525 526 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
529def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 530 """ 531 In-place set a value for a nested dictionary 532 533 Example: 534 >>> nested_set({}, ["top_key", "second_key"], "value") 535 {'top_key': {'second_key': 'value'}} 536 537 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 538 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 539 540 Args: 541 d: dictionary to update. 542 keys: the keys that makeup the path to `value`. 543 value: the value to set in the dictionary for the given key path. 544 545 Returns: 546 The (possibly) updated dictionary. 547 """ 548 if not keys: 549 return d 550 551 if len(keys) == 1: 552 d[keys[0]] = value 553 return d 554 555 subd = d 556 for key in keys[:-1]: 557 if key not in subd: 558 subd = subd.setdefault(key, {}) 559 else: 560 subd = subd[key] 561 562 subd[keys[-1]] = value 563 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.