sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 bigquery = dialect == "bigquery" 59 60 for scope in traverse_scope(expression): 61 scope_expression = scope.expression 62 is_select = isinstance(scope_expression, exp.Select) 63 64 if is_select and scope_expression.args.get("connect"): 65 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 66 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 67 scope_expression.transform( 68 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 69 copy=False, 70 ) 71 scope.clear_cache() 72 73 resolver = Resolver(scope, schema, infer_schema=infer_schema) 74 _pop_table_column_aliases(scope.ctes) 75 _pop_table_column_aliases(scope.derived_tables) 76 using_column_tables = _expand_using(scope, resolver) 77 78 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 79 _expand_alias_refs( 80 scope, 81 resolver, 82 dialect, 83 expand_only_groupby=bigquery, 84 ) 85 86 _convert_columns_to_dots(scope, resolver) 87 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 88 89 if not schema.empty and expand_alias_refs: 90 _expand_alias_refs(scope, resolver, dialect) 91 92 if is_select: 93 if expand_stars: 94 _expand_stars( 95 scope, 96 resolver, 97 using_column_tables, 98 pseudocolumns, 99 annotator, 100 ) 101 qualify_outputs(scope) 102 103 _expand_group_by(scope, dialect) 104 105 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 106 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 107 _expand_order_by_and_distinct_on(scope, resolver) 108 109 if bigquery: 110 annotator.annotate_scope(scope) 111 112 return expression 113 114 115def validate_qualify_columns(expression: E) -> E: 116 """Raise an `OptimizeError` if any columns aren't qualified""" 117 all_unqualified_columns = [] 118 for scope in traverse_scope(expression): 119 if isinstance(scope.expression, exp.Select): 120 unqualified_columns = scope.unqualified_columns 121 122 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 123 column = scope.external_columns[0] 124 for_table = f" for table: '{column.table}'" if column.table else "" 125 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 126 127 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 128 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 129 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 130 # this list here to ensure those in the former category will be excluded. 131 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 132 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 133 134 all_unqualified_columns.extend(unqualified_columns) 135 136 if all_unqualified_columns: 137 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 138 139 return expression 140 141 142def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 143 name_column = [] 144 field = unpivot.args.get("field") 145 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 146 name_column.append(field.this) 147 148 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 149 return itertools.chain(name_column, value_columns) 150 151 152def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 153 """ 154 Remove table column aliases. 155 156 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 157 """ 158 for derived_table in derived_tables: 159 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 160 continue 161 table_alias = derived_table.args.get("alias") 162 if table_alias: 163 table_alias.args.pop("columns", None) 164 165 166def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 167 columns = {} 168 169 def _update_source_columns(source_name: str) -> None: 170 for column_name in resolver.get_source_columns(source_name): 171 if column_name not in columns: 172 columns[column_name] = source_name 173 174 joins = list(scope.find_all(exp.Join)) 175 names = {join.alias_or_name for join in joins} 176 ordered = [key for key in scope.selected_sources if key not in names] 177 178 if names and not ordered: 179 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 180 181 # Mapping of automatically joined column names to an ordered set of source names (dict). 182 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 183 184 for source_name in ordered: 185 _update_source_columns(source_name) 186 187 for i, join in enumerate(joins): 188 source_table = ordered[-1] 189 if source_table: 190 _update_source_columns(source_table) 191 192 join_table = join.alias_or_name 193 ordered.append(join_table) 194 195 using = join.args.get("using") 196 if not using: 197 continue 198 199 join_columns = resolver.get_source_columns(join_table) 200 conditions = [] 201 using_identifier_count = len(using) 202 is_semi_or_anti_join = join.is_semi_or_anti_join 203 204 for identifier in using: 205 identifier = identifier.name 206 table = columns.get(identifier) 207 208 if not table or identifier not in join_columns: 209 if (columns and "*" not in columns) and join_columns: 210 raise OptimizeError(f"Cannot automatically join: {identifier}") 211 212 table = table or source_table 213 214 if i == 0 or using_identifier_count == 1: 215 lhs: exp.Expression = exp.column(identifier, table=table) 216 else: 217 coalesce_columns = [ 218 exp.column(identifier, table=t) 219 for t in ordered[:-1] 220 if identifier in resolver.get_source_columns(t) 221 ] 222 if len(coalesce_columns) > 1: 223 lhs = exp.func("coalesce", *coalesce_columns) 224 else: 225 lhs = exp.column(identifier, table=table) 226 227 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 228 229 # Set all values in the dict to None, because we only care about the key ordering 230 tables = column_tables.setdefault(identifier, {}) 231 232 # Do not update the dict if this was a SEMI/ANTI join in 233 # order to avoid generating COALESCE columns for this join pair 234 if not is_semi_or_anti_join: 235 if table not in tables: 236 tables[table] = None 237 if join_table not in tables: 238 tables[join_table] = None 239 240 join.args.pop("using") 241 join.set("on", exp.and_(*conditions, copy=False)) 242 243 if column_tables: 244 for column in scope.columns: 245 if not column.table and column.name in column_tables: 246 tables = column_tables[column.name] 247 coalesce_args = [exp.column(column.name, table=table) for table in tables] 248 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 249 250 if isinstance(column.parent, exp.Select): 251 # Ensure the USING column keeps its name if it's projected 252 replacement = alias(replacement, alias=column.name, copy=False) 253 elif isinstance(column.parent, exp.Struct): 254 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 255 replacement = exp.PropertyEQ( 256 this=exp.to_identifier(column.name), expression=replacement 257 ) 258 259 scope.replace(column, replacement) 260 261 return column_tables 262 263 264def _expand_alias_refs( 265 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 266) -> None: 267 """ 268 Expand references to aliases. 269 Example: 270 SELECT y.foo AS bar, bar * 2 AS baz FROM y 271 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 272 """ 273 expression = scope.expression 274 275 if not isinstance(expression, exp.Select): 276 return 277 278 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 279 projections = {s.alias_or_name for s in expression.selects} 280 281 def replace_columns( 282 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 283 ) -> None: 284 is_group_by = isinstance(node, exp.Group) 285 is_having = isinstance(node, exp.Having) 286 if not node or (expand_only_groupby and not is_group_by): 287 return 288 289 for column in walk_in_scope(node, prune=lambda node: node.is_star): 290 if not isinstance(column, exp.Column): 291 continue 292 293 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 294 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 295 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 296 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 297 if expand_only_groupby and is_group_by and column.parent is not node: 298 continue 299 300 skip_replace = False 301 table = resolver.get_table(column.name) if resolve_table and not column.table else None 302 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 303 304 if alias_expr: 305 skip_replace = bool( 306 alias_expr.find(exp.AggFunc) 307 and column.find_ancestor(exp.AggFunc) 308 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 309 ) 310 311 # BigQuery's having clause gets confused if an alias matches a source. 312 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 313 # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table 314 if is_having and dialect == "bigquery": 315 skip_replace = skip_replace or any( 316 node.parts[0].name in projections 317 for node in alias_expr.find_all(exp.Column) 318 ) 319 320 if table and (not alias_expr or skip_replace): 321 column.set("table", table) 322 elif not column.table and alias_expr and not skip_replace: 323 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 324 if literal_index: 325 column.replace(exp.Literal.number(i)) 326 else: 327 column = column.replace(exp.paren(alias_expr)) 328 simplified = simplify_parens(column) 329 if simplified is not column: 330 column.replace(simplified) 331 332 for i, projection in enumerate(expression.selects): 333 replace_columns(projection) 334 if isinstance(projection, exp.Alias): 335 alias_to_expression[projection.alias] = (projection.this, i + 1) 336 337 parent_scope = scope 338 while parent_scope.is_union: 339 parent_scope = parent_scope.parent 340 341 # We shouldn't expand aliases if they match the recursive CTE's columns 342 if parent_scope.is_cte: 343 cte = parent_scope.expression.parent 344 if cte.find_ancestor(exp.With).recursive: 345 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 346 alias_to_expression.pop(recursive_cte_column.output_name, None) 347 348 replace_columns(expression.args.get("where")) 349 replace_columns(expression.args.get("group"), literal_index=True) 350 replace_columns(expression.args.get("having"), resolve_table=True) 351 replace_columns(expression.args.get("qualify"), resolve_table=True) 352 353 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 354 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 355 if dialect == "snowflake": 356 for join in expression.args.get("joins") or []: 357 replace_columns(join) 358 359 scope.clear_cache() 360 361 362def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 363 expression = scope.expression 364 group = expression.args.get("group") 365 if not group: 366 return 367 368 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 369 expression.set("group", group) 370 371 372def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 373 for modifier_key in ("order", "distinct"): 374 modifier = scope.expression.args.get(modifier_key) 375 if isinstance(modifier, exp.Distinct): 376 modifier = modifier.args.get("on") 377 378 if not isinstance(modifier, exp.Expression): 379 continue 380 381 modifier_expressions = modifier.expressions 382 if modifier_key == "order": 383 modifier_expressions = [ordered.this for ordered in modifier_expressions] 384 385 for original, expanded in zip( 386 modifier_expressions, 387 _expand_positional_references( 388 scope, modifier_expressions, resolver.schema.dialect, alias=True 389 ), 390 ): 391 for agg in original.find_all(exp.AggFunc): 392 for col in agg.find_all(exp.Column): 393 if not col.table: 394 col.set("table", resolver.get_table(col.name)) 395 396 original.replace(expanded) 397 398 if scope.expression.args.get("group"): 399 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 400 401 for expression in modifier_expressions: 402 expression.replace( 403 exp.to_identifier(_select_by_pos(scope, expression).alias) 404 if expression.is_int 405 else selects.get(expression, expression) 406 ) 407 408 409def _expand_positional_references( 410 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 411) -> t.List[exp.Expression]: 412 new_nodes: t.List[exp.Expression] = [] 413 ambiguous_projections = None 414 415 for node in expressions: 416 if node.is_int: 417 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 418 419 if alias: 420 new_nodes.append(exp.column(select.args["alias"].copy())) 421 else: 422 select = select.this 423 424 if dialect == "bigquery": 425 if ambiguous_projections is None: 426 # When a projection name is also a source name and it is referenced in the 427 # GROUP BY clause, BQ can't understand what the identifier corresponds to 428 ambiguous_projections = { 429 s.alias_or_name 430 for s in scope.expression.selects 431 if s.alias_or_name in scope.selected_sources 432 } 433 434 ambiguous = any( 435 column.parts[0].name in ambiguous_projections 436 for column in select.find_all(exp.Column) 437 ) 438 else: 439 ambiguous = False 440 441 if ( 442 isinstance(select, exp.CONSTANTS) 443 or select.find(exp.Explode, exp.Unnest) 444 or ambiguous 445 ): 446 new_nodes.append(node) 447 else: 448 new_nodes.append(select.copy()) 449 else: 450 new_nodes.append(node) 451 452 return new_nodes 453 454 455def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 456 try: 457 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 458 except IndexError: 459 raise OptimizeError(f"Unknown output column: {node.name}") 460 461 462def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 463 """ 464 Converts `Column` instances that represent struct field lookup into chained `Dots`. 465 466 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 467 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 468 """ 469 converted = False 470 for column in itertools.chain(scope.columns, scope.stars): 471 if isinstance(column, exp.Dot): 472 continue 473 474 column_table: t.Optional[str | exp.Identifier] = column.table 475 if ( 476 column_table 477 and column_table not in scope.sources 478 and ( 479 not scope.parent 480 or column_table not in scope.parent.sources 481 or not scope.is_correlated_subquery 482 ) 483 ): 484 root, *parts = column.parts 485 486 if root.name in scope.sources: 487 # The struct is already qualified, but we still need to change the AST 488 column_table = root 489 root, *parts = parts 490 else: 491 column_table = resolver.get_table(root.name) 492 493 if column_table: 494 converted = True 495 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 496 497 if converted: 498 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 499 # a `for column in scope.columns` iteration, even though they shouldn't be 500 scope.clear_cache() 501 502 503def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 504 """Disambiguate columns, ensuring each column specifies a source""" 505 for column in scope.columns: 506 column_table = column.table 507 column_name = column.name 508 509 if column_table and column_table in scope.sources: 510 source_columns = resolver.get_source_columns(column_table) 511 if ( 512 not allow_partial_qualification 513 and source_columns 514 and column_name not in source_columns 515 and "*" not in source_columns 516 ): 517 raise OptimizeError(f"Unknown column: {column_name}") 518 519 if not column_table: 520 if scope.pivots and not column.find_ancestor(exp.Pivot): 521 # If the column is under the Pivot expression, we need to qualify it 522 # using the name of the pivoted source instead of the pivot's alias 523 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 524 continue 525 526 # column_table can be a '' because bigquery unnest has no table alias 527 column_table = resolver.get_table(column_name) 528 if column_table: 529 column.set("table", column_table) 530 531 for pivot in scope.pivots: 532 for column in pivot.find_all(exp.Column): 533 if not column.table and column.name in resolver.all_columns: 534 column_table = resolver.get_table(column.name) 535 if column_table: 536 column.set("table", column_table) 537 538 539def _expand_struct_stars( 540 expression: exp.Dot, 541) -> t.List[exp.Alias]: 542 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 543 544 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 545 if not dot_column.is_type(exp.DataType.Type.STRUCT): 546 return [] 547 548 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 549 dot_column = dot_column.copy() 550 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 551 552 # First part is the table name and last part is the star so they can be dropped 553 dot_parts = expression.parts[1:-1] 554 555 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 556 for part in dot_parts[1:]: 557 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 558 # Unable to expand star unless all fields are named 559 if not isinstance(field.this, exp.Identifier): 560 return [] 561 562 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 563 starting_struct = field 564 break 565 else: 566 # There is no matching field in the struct 567 return [] 568 569 taken_names = set() 570 new_selections = [] 571 572 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 573 name = field.name 574 575 # Ambiguous or anonymous fields can't be expanded 576 if name in taken_names or not isinstance(field.this, exp.Identifier): 577 return [] 578 579 taken_names.add(name) 580 581 this = field.this.copy() 582 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 583 new_column = exp.column( 584 t.cast(exp.Identifier, root), 585 table=dot_column.args.get("table"), 586 fields=t.cast(t.List[exp.Identifier], parts), 587 ) 588 new_selections.append(alias(new_column, this, copy=False)) 589 590 return new_selections 591 592 593def _expand_stars( 594 scope: Scope, 595 resolver: Resolver, 596 using_column_tables: t.Dict[str, t.Any], 597 pseudocolumns: t.Set[str], 598 annotator: TypeAnnotator, 599) -> None: 600 """Expand stars to lists of column selections""" 601 602 new_selections: t.List[exp.Expression] = [] 603 except_columns: t.Dict[int, t.Set[str]] = {} 604 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 605 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 606 607 coalesced_columns = set() 608 dialect = resolver.schema.dialect 609 610 pivot_output_columns = None 611 pivot_exclude_columns = None 612 613 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 614 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 615 if pivot.unpivot: 616 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 617 618 field = pivot.args.get("field") 619 if isinstance(field, exp.In): 620 pivot_exclude_columns = { 621 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 622 } 623 else: 624 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 625 626 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 627 if not pivot_output_columns: 628 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 629 630 is_bigquery = dialect == "bigquery" 631 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 632 # Found struct expansion, annotate scope ahead of time 633 annotator.annotate_scope(scope) 634 635 for expression in scope.expression.selects: 636 tables = [] 637 if isinstance(expression, exp.Star): 638 tables.extend(scope.selected_sources) 639 _add_except_columns(expression, tables, except_columns) 640 _add_replace_columns(expression, tables, replace_columns) 641 _add_rename_columns(expression, tables, rename_columns) 642 elif expression.is_star: 643 if not isinstance(expression, exp.Dot): 644 tables.append(expression.table) 645 _add_except_columns(expression.this, tables, except_columns) 646 _add_replace_columns(expression.this, tables, replace_columns) 647 _add_rename_columns(expression.this, tables, rename_columns) 648 elif is_bigquery: 649 struct_fields = _expand_struct_stars(expression) 650 if struct_fields: 651 new_selections.extend(struct_fields) 652 continue 653 654 if not tables: 655 new_selections.append(expression) 656 continue 657 658 for table in tables: 659 if table not in scope.sources: 660 raise OptimizeError(f"Unknown table: {table}") 661 662 columns = resolver.get_source_columns(table, only_visible=True) 663 columns = columns or scope.outer_columns 664 665 if pseudocolumns: 666 columns = [name for name in columns if name.upper() not in pseudocolumns] 667 668 if not columns or "*" in columns: 669 return 670 671 table_id = id(table) 672 columns_to_exclude = except_columns.get(table_id) or set() 673 renamed_columns = rename_columns.get(table_id, {}) 674 replaced_columns = replace_columns.get(table_id, {}) 675 676 if pivot: 677 if pivot_output_columns and pivot_exclude_columns: 678 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 679 pivot_columns.extend(pivot_output_columns) 680 else: 681 pivot_columns = pivot.alias_column_names 682 683 if pivot_columns: 684 new_selections.extend( 685 alias(exp.column(name, table=pivot.alias), name, copy=False) 686 for name in pivot_columns 687 if name not in columns_to_exclude 688 ) 689 continue 690 691 for name in columns: 692 if name in columns_to_exclude or name in coalesced_columns: 693 continue 694 if name in using_column_tables and table in using_column_tables[name]: 695 coalesced_columns.add(name) 696 tables = using_column_tables[name] 697 coalesce_args = [exp.column(name, table=table) for table in tables] 698 699 new_selections.append( 700 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 701 ) 702 else: 703 alias_ = renamed_columns.get(name, name) 704 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 705 new_selections.append( 706 alias(selection_expr, alias_, copy=False) 707 if alias_ != name 708 else selection_expr 709 ) 710 711 # Ensures we don't overwrite the initial selections with an empty list 712 if new_selections and isinstance(scope.expression, exp.Select): 713 scope.expression.set("expressions", new_selections) 714 715 716def _add_except_columns( 717 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 718) -> None: 719 except_ = expression.args.get("except") 720 721 if not except_: 722 return 723 724 columns = {e.name for e in except_} 725 726 for table in tables: 727 except_columns[id(table)] = columns 728 729 730def _add_rename_columns( 731 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 732) -> None: 733 rename = expression.args.get("rename") 734 735 if not rename: 736 return 737 738 columns = {e.this.name: e.alias for e in rename} 739 740 for table in tables: 741 rename_columns[id(table)] = columns 742 743 744def _add_replace_columns( 745 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 746) -> None: 747 replace = expression.args.get("replace") 748 749 if not replace: 750 return 751 752 columns = {e.alias: e for e in replace} 753 754 for table in tables: 755 replace_columns[id(table)] = columns 756 757 758def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 759 """Ensure all output columns are aliased""" 760 if isinstance(scope_or_expression, exp.Expression): 761 scope = build_scope(scope_or_expression) 762 if not isinstance(scope, Scope): 763 return 764 else: 765 scope = scope_or_expression 766 767 new_selections = [] 768 for i, (selection, aliased_column) in enumerate( 769 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 770 ): 771 if selection is None: 772 break 773 774 if isinstance(selection, exp.Subquery): 775 if not selection.output_name: 776 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 777 elif not isinstance(selection, exp.Alias) and not selection.is_star: 778 selection = alias( 779 selection, 780 alias=selection.output_name or f"_col_{i}", 781 copy=False, 782 ) 783 if aliased_column: 784 selection.set("alias", exp.to_identifier(aliased_column)) 785 786 new_selections.append(selection) 787 788 if isinstance(scope.expression, exp.Select): 789 scope.expression.set("expressions", new_selections) 790 791 792def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 793 """Makes sure all identifiers that need to be quoted are quoted.""" 794 return expression.transform( 795 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 796 ) # type: ignore 797 798 799def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 800 """ 801 Pushes down the CTE alias columns into the projection, 802 803 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 804 805 Example: 806 >>> import sqlglot 807 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 808 >>> pushdown_cte_alias_columns(expression).sql() 809 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 810 811 Args: 812 expression: Expression to pushdown. 813 814 Returns: 815 The expression with the CTE aliases pushed down into the projection. 816 """ 817 for cte in expression.find_all(exp.CTE): 818 if cte.alias_column_names: 819 new_expressions = [] 820 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 821 if isinstance(projection, exp.Alias): 822 projection.set("alias", _alias) 823 else: 824 projection = alias(projection, alias=_alias) 825 new_expressions.append(projection) 826 cte.this.set("expressions", new_expressions) 827 828 return expression 829 830 831class Resolver: 832 """ 833 Helper for resolving columns. 834 835 This is a class so we can lazily load some things and easily share them across functions. 836 """ 837 838 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 839 self.scope = scope 840 self.schema = schema 841 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 842 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 843 self._all_columns: t.Optional[t.Set[str]] = None 844 self._infer_schema = infer_schema 845 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 846 847 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 848 """ 849 Get the table for a column name. 850 851 Args: 852 column_name: The column name to find the table for. 853 Returns: 854 The table name if it can be found/inferred. 855 """ 856 if self._unambiguous_columns is None: 857 self._unambiguous_columns = self._get_unambiguous_columns( 858 self._get_all_source_columns() 859 ) 860 861 table_name = self._unambiguous_columns.get(column_name) 862 863 if not table_name and self._infer_schema: 864 sources_without_schema = tuple( 865 source 866 for source, columns in self._get_all_source_columns().items() 867 if not columns or "*" in columns 868 ) 869 if len(sources_without_schema) == 1: 870 table_name = sources_without_schema[0] 871 872 if table_name not in self.scope.selected_sources: 873 return exp.to_identifier(table_name) 874 875 node, _ = self.scope.selected_sources.get(table_name) 876 877 if isinstance(node, exp.Query): 878 while node and node.alias != table_name: 879 node = node.parent 880 881 node_alias = node.args.get("alias") 882 if node_alias: 883 return exp.to_identifier(node_alias.this) 884 885 return exp.to_identifier(table_name) 886 887 @property 888 def all_columns(self) -> t.Set[str]: 889 """All available columns of all sources in this scope""" 890 if self._all_columns is None: 891 self._all_columns = { 892 column for columns in self._get_all_source_columns().values() for column in columns 893 } 894 return self._all_columns 895 896 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 897 """Resolve the source columns for a given source `name`.""" 898 cache_key = (name, only_visible) 899 if cache_key not in self._get_source_columns_cache: 900 if name not in self.scope.sources: 901 raise OptimizeError(f"Unknown table: {name}") 902 903 source = self.scope.sources[name] 904 905 if isinstance(source, exp.Table): 906 columns = self.schema.column_names(source, only_visible) 907 elif isinstance(source, Scope) and isinstance( 908 source.expression, (exp.Values, exp.Unnest) 909 ): 910 columns = source.expression.named_selects 911 912 # in bigquery, unnest structs are automatically scoped as tables, so you can 913 # directly select a struct field in a query. 914 # this handles the case where the unnest is statically defined. 915 if self.schema.dialect == "bigquery": 916 if source.expression.is_type(exp.DataType.Type.STRUCT): 917 for k in source.expression.type.expressions: # type: ignore 918 columns.append(k.name) 919 else: 920 columns = source.expression.named_selects 921 922 node, _ = self.scope.selected_sources.get(name) or (None, None) 923 if isinstance(node, Scope): 924 column_aliases = node.expression.alias_column_names 925 elif isinstance(node, exp.Expression): 926 column_aliases = node.alias_column_names 927 else: 928 column_aliases = [] 929 930 if column_aliases: 931 # If the source's columns are aliased, their aliases shadow the corresponding column names. 932 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 933 columns = [ 934 alias or name 935 for (name, alias) in itertools.zip_longest(columns, column_aliases) 936 ] 937 938 self._get_source_columns_cache[cache_key] = columns 939 940 return self._get_source_columns_cache[cache_key] 941 942 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 943 if self._source_columns is None: 944 self._source_columns = { 945 source_name: self.get_source_columns(source_name) 946 for source_name, source in itertools.chain( 947 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 948 ) 949 } 950 return self._source_columns 951 952 def _get_unambiguous_columns( 953 self, source_columns: t.Dict[str, t.Sequence[str]] 954 ) -> t.Mapping[str, str]: 955 """ 956 Find all the unambiguous columns in sources. 957 958 Args: 959 source_columns: Mapping of names to source columns. 960 961 Returns: 962 Mapping of column name to source name. 963 """ 964 if not source_columns: 965 return {} 966 967 source_columns_pairs = list(source_columns.items()) 968 969 first_table, first_columns = source_columns_pairs[0] 970 971 if len(source_columns_pairs) == 1: 972 # Performance optimization - avoid copying first_columns if there is only one table. 973 return SingleValuedMapping(first_columns, first_table) 974 975 unambiguous_columns = {col: first_table for col in first_columns} 976 all_columns = set(unambiguous_columns) 977 978 for table, columns in source_columns_pairs[1:]: 979 unique = set(columns) 980 ambiguous = all_columns.intersection(unique) 981 all_columns.update(columns) 982 983 for column in ambiguous: 984 unambiguous_columns.pop(column, None) 985 for column in unique.difference(ambiguous): 986 unambiguous_columns[column] = table 987 988 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 bigquery = dialect == "bigquery" 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=bigquery, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if bigquery: 111 annotator.annotate_scope(scope) 112 113 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
759def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 760 """Ensure all output columns are aliased""" 761 if isinstance(scope_or_expression, exp.Expression): 762 scope = build_scope(scope_or_expression) 763 if not isinstance(scope, Scope): 764 return 765 else: 766 scope = scope_or_expression 767 768 new_selections = [] 769 for i, (selection, aliased_column) in enumerate( 770 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 771 ): 772 if selection is None: 773 break 774 775 if isinstance(selection, exp.Subquery): 776 if not selection.output_name: 777 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 778 elif not isinstance(selection, exp.Alias) and not selection.is_star: 779 selection = alias( 780 selection, 781 alias=selection.output_name or f"_col_{i}", 782 copy=False, 783 ) 784 if aliased_column: 785 selection.set("alias", exp.to_identifier(aliased_column)) 786 787 new_selections.append(selection) 788 789 if isinstance(scope.expression, exp.Select): 790 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
793def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 794 """Makes sure all identifiers that need to be quoted are quoted.""" 795 return expression.transform( 796 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 797 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
800def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 801 """ 802 Pushes down the CTE alias columns into the projection, 803 804 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 805 806 Example: 807 >>> import sqlglot 808 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 809 >>> pushdown_cte_alias_columns(expression).sql() 810 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 811 812 Args: 813 expression: Expression to pushdown. 814 815 Returns: 816 The expression with the CTE aliases pushed down into the projection. 817 """ 818 for cte in expression.find_all(exp.CTE): 819 if cte.alias_column_names: 820 new_expressions = [] 821 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 822 if isinstance(projection, exp.Alias): 823 projection.set("alias", _alias) 824 else: 825 projection = alias(projection, alias=_alias) 826 new_expressions.append(projection) 827 cte.this.set("expressions", new_expressions) 828 829 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
832class Resolver: 833 """ 834 Helper for resolving columns. 835 836 This is a class so we can lazily load some things and easily share them across functions. 837 """ 838 839 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 840 self.scope = scope 841 self.schema = schema 842 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 843 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 844 self._all_columns: t.Optional[t.Set[str]] = None 845 self._infer_schema = infer_schema 846 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 847 848 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 849 """ 850 Get the table for a column name. 851 852 Args: 853 column_name: The column name to find the table for. 854 Returns: 855 The table name if it can be found/inferred. 856 """ 857 if self._unambiguous_columns is None: 858 self._unambiguous_columns = self._get_unambiguous_columns( 859 self._get_all_source_columns() 860 ) 861 862 table_name = self._unambiguous_columns.get(column_name) 863 864 if not table_name and self._infer_schema: 865 sources_without_schema = tuple( 866 source 867 for source, columns in self._get_all_source_columns().items() 868 if not columns or "*" in columns 869 ) 870 if len(sources_without_schema) == 1: 871 table_name = sources_without_schema[0] 872 873 if table_name not in self.scope.selected_sources: 874 return exp.to_identifier(table_name) 875 876 node, _ = self.scope.selected_sources.get(table_name) 877 878 if isinstance(node, exp.Query): 879 while node and node.alias != table_name: 880 node = node.parent 881 882 node_alias = node.args.get("alias") 883 if node_alias: 884 return exp.to_identifier(node_alias.this) 885 886 return exp.to_identifier(table_name) 887 888 @property 889 def all_columns(self) -> t.Set[str]: 890 """All available columns of all sources in this scope""" 891 if self._all_columns is None: 892 self._all_columns = { 893 column for columns in self._get_all_source_columns().values() for column in columns 894 } 895 return self._all_columns 896 897 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 898 """Resolve the source columns for a given source `name`.""" 899 cache_key = (name, only_visible) 900 if cache_key not in self._get_source_columns_cache: 901 if name not in self.scope.sources: 902 raise OptimizeError(f"Unknown table: {name}") 903 904 source = self.scope.sources[name] 905 906 if isinstance(source, exp.Table): 907 columns = self.schema.column_names(source, only_visible) 908 elif isinstance(source, Scope) and isinstance( 909 source.expression, (exp.Values, exp.Unnest) 910 ): 911 columns = source.expression.named_selects 912 913 # in bigquery, unnest structs are automatically scoped as tables, so you can 914 # directly select a struct field in a query. 915 # this handles the case where the unnest is statically defined. 916 if self.schema.dialect == "bigquery": 917 if source.expression.is_type(exp.DataType.Type.STRUCT): 918 for k in source.expression.type.expressions: # type: ignore 919 columns.append(k.name) 920 else: 921 columns = source.expression.named_selects 922 923 node, _ = self.scope.selected_sources.get(name) or (None, None) 924 if isinstance(node, Scope): 925 column_aliases = node.expression.alias_column_names 926 elif isinstance(node, exp.Expression): 927 column_aliases = node.alias_column_names 928 else: 929 column_aliases = [] 930 931 if column_aliases: 932 # If the source's columns are aliased, their aliases shadow the corresponding column names. 933 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 934 columns = [ 935 alias or name 936 for (name, alias) in itertools.zip_longest(columns, column_aliases) 937 ] 938 939 self._get_source_columns_cache[cache_key] = columns 940 941 return self._get_source_columns_cache[cache_key] 942 943 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 944 if self._source_columns is None: 945 self._source_columns = { 946 source_name: self.get_source_columns(source_name) 947 for source_name, source in itertools.chain( 948 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 949 ) 950 } 951 return self._source_columns 952 953 def _get_unambiguous_columns( 954 self, source_columns: t.Dict[str, t.Sequence[str]] 955 ) -> t.Mapping[str, str]: 956 """ 957 Find all the unambiguous columns in sources. 958 959 Args: 960 source_columns: Mapping of names to source columns. 961 962 Returns: 963 Mapping of column name to source name. 964 """ 965 if not source_columns: 966 return {} 967 968 source_columns_pairs = list(source_columns.items()) 969 970 first_table, first_columns = source_columns_pairs[0] 971 972 if len(source_columns_pairs) == 1: 973 # Performance optimization - avoid copying first_columns if there is only one table. 974 return SingleValuedMapping(first_columns, first_table) 975 976 unambiguous_columns = {col: first_table for col in first_columns} 977 all_columns = set(unambiguous_columns) 978 979 for table, columns in source_columns_pairs[1:]: 980 unique = set(columns) 981 ambiguous = all_columns.intersection(unique) 982 all_columns.update(columns) 983 984 for column in ambiguous: 985 unambiguous_columns.pop(column, None) 986 for column in unique.difference(ambiguous): 987 unambiguous_columns[column] = table 988 989 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
839 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 840 self.scope = scope 841 self.schema = schema 842 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 843 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 844 self._all_columns: t.Optional[t.Set[str]] = None 845 self._infer_schema = infer_schema 846 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
848 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 849 """ 850 Get the table for a column name. 851 852 Args: 853 column_name: The column name to find the table for. 854 Returns: 855 The table name if it can be found/inferred. 856 """ 857 if self._unambiguous_columns is None: 858 self._unambiguous_columns = self._get_unambiguous_columns( 859 self._get_all_source_columns() 860 ) 861 862 table_name = self._unambiguous_columns.get(column_name) 863 864 if not table_name and self._infer_schema: 865 sources_without_schema = tuple( 866 source 867 for source, columns in self._get_all_source_columns().items() 868 if not columns or "*" in columns 869 ) 870 if len(sources_without_schema) == 1: 871 table_name = sources_without_schema[0] 872 873 if table_name not in self.scope.selected_sources: 874 return exp.to_identifier(table_name) 875 876 node, _ = self.scope.selected_sources.get(table_name) 877 878 if isinstance(node, exp.Query): 879 while node and node.alias != table_name: 880 node = node.parent 881 882 node_alias = node.args.get("alias") 883 if node_alias: 884 return exp.to_identifier(node_alias.this) 885 886 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
888 @property 889 def all_columns(self) -> t.Set[str]: 890 """All available columns of all sources in this scope""" 891 if self._all_columns is None: 892 self._all_columns = { 893 column for columns in self._get_all_source_columns().values() for column in columns 894 } 895 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
897 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 898 """Resolve the source columns for a given source `name`.""" 899 cache_key = (name, only_visible) 900 if cache_key not in self._get_source_columns_cache: 901 if name not in self.scope.sources: 902 raise OptimizeError(f"Unknown table: {name}") 903 904 source = self.scope.sources[name] 905 906 if isinstance(source, exp.Table): 907 columns = self.schema.column_names(source, only_visible) 908 elif isinstance(source, Scope) and isinstance( 909 source.expression, (exp.Values, exp.Unnest) 910 ): 911 columns = source.expression.named_selects 912 913 # in bigquery, unnest structs are automatically scoped as tables, so you can 914 # directly select a struct field in a query. 915 # this handles the case where the unnest is statically defined. 916 if self.schema.dialect == "bigquery": 917 if source.expression.is_type(exp.DataType.Type.STRUCT): 918 for k in source.expression.type.expressions: # type: ignore 919 columns.append(k.name) 920 else: 921 columns = source.expression.named_selects 922 923 node, _ = self.scope.selected_sources.get(name) or (None, None) 924 if isinstance(node, Scope): 925 column_aliases = node.expression.alias_column_names 926 elif isinstance(node, exp.Expression): 927 column_aliases = node.alias_column_names 928 else: 929 column_aliases = [] 930 931 if column_aliases: 932 # If the source's columns are aliased, their aliases shadow the corresponding column names. 933 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 934 columns = [ 935 alias or name 936 for (name, alias) in itertools.zip_longest(columns, column_aliases) 937 ] 938 939 self._get_source_columns_cache[cache_key] = columns 940 941 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.