sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 annotator = TypeAnnotator(schema) 53 infer_schema = schema.empty if infer_schema is None else infer_schema 54 dialect = Dialect.get_or_raise(schema.dialect) 55 pseudocolumns = dialect.PSEUDOCOLUMNS 56 57 for scope in traverse_scope(expression): 58 resolver = Resolver(scope, schema, infer_schema=infer_schema) 59 _pop_table_column_aliases(scope.ctes) 60 _pop_table_column_aliases(scope.derived_tables) 61 using_column_tables = _expand_using(scope, resolver) 62 63 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 64 _expand_alias_refs( 65 scope, 66 resolver, 67 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 68 ) 69 70 _convert_columns_to_dots(scope, resolver) 71 _qualify_columns(scope, resolver) 72 73 if not schema.empty and expand_alias_refs: 74 _expand_alias_refs(scope, resolver) 75 76 if not isinstance(scope.expression, exp.UDTF): 77 if expand_stars: 78 _expand_stars( 79 scope, 80 resolver, 81 using_column_tables, 82 pseudocolumns, 83 annotator, 84 ) 85 qualify_outputs(scope) 86 87 _expand_group_by(scope, dialect) 88 _expand_order_by(scope, resolver) 89 90 if dialect == "bigquery": 91 annotator.annotate_scope(scope) 92 93 return expression 94 95 96def validate_qualify_columns(expression: E) -> E: 97 """Raise an `OptimizeError` if any columns aren't qualified""" 98 all_unqualified_columns = [] 99 for scope in traverse_scope(expression): 100 if isinstance(scope.expression, exp.Select): 101 unqualified_columns = scope.unqualified_columns 102 103 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 104 column = scope.external_columns[0] 105 for_table = f" for table: '{column.table}'" if column.table else "" 106 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 107 108 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 109 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 110 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 111 # this list here to ensure those in the former category will be excluded. 112 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 113 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 114 115 all_unqualified_columns.extend(unqualified_columns) 116 117 if all_unqualified_columns: 118 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 119 120 return expression 121 122 123def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 124 name_column = [] 125 field = unpivot.args.get("field") 126 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 127 name_column.append(field.this) 128 129 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 130 return itertools.chain(name_column, value_columns) 131 132 133def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 134 """ 135 Remove table column aliases. 136 137 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 138 """ 139 for derived_table in derived_tables: 140 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 141 continue 142 table_alias = derived_table.args.get("alias") 143 if table_alias: 144 table_alias.args.pop("columns", None) 145 146 147def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 148 joins = list(scope.find_all(exp.Join)) 149 names = {join.alias_or_name for join in joins} 150 ordered = [key for key in scope.selected_sources if key not in names] 151 152 # Mapping of automatically joined column names to an ordered set of source names (dict). 153 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 154 155 for i, join in enumerate(joins): 156 using = join.args.get("using") 157 158 if not using: 159 continue 160 161 join_table = join.alias_or_name 162 163 columns = {} 164 165 for source_name in scope.selected_sources: 166 if source_name in ordered: 167 for column_name in resolver.get_source_columns(source_name): 168 if column_name not in columns: 169 columns[column_name] = source_name 170 171 source_table = ordered[-1] 172 ordered.append(join_table) 173 join_columns = resolver.get_source_columns(join_table) 174 conditions = [] 175 using_identifier_count = len(using) 176 177 for identifier in using: 178 identifier = identifier.name 179 table = columns.get(identifier) 180 181 if not table or identifier not in join_columns: 182 if (columns and "*" not in columns) and join_columns: 183 raise OptimizeError(f"Cannot automatically join: {identifier}") 184 185 table = table or source_table 186 187 if i == 0 or using_identifier_count == 1: 188 lhs: exp.Expression = exp.column(identifier, table=table) 189 else: 190 lhs = exp.func("coalesce", *[exp.column(identifier, table=t) for t in ordered[:-1]]) 191 192 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 193 194 # Set all values in the dict to None, because we only care about the key ordering 195 tables = column_tables.setdefault(identifier, {}) 196 if table not in tables: 197 tables[table] = None 198 if join_table not in tables: 199 tables[join_table] = None 200 201 join.args.pop("using") 202 join.set("on", exp.and_(*conditions, copy=False)) 203 204 if column_tables: 205 for column in scope.columns: 206 if not column.table and column.name in column_tables: 207 tables = column_tables[column.name] 208 coalesce_args = [exp.column(column.name, table=table) for table in tables] 209 replacement = exp.func("coalesce", *coalesce_args) 210 211 # Ensure selects keep their output name 212 if isinstance(column.parent, exp.Select): 213 replacement = alias(replacement, alias=column.name, copy=False) 214 215 scope.replace(column, replacement) 216 217 return column_tables 218 219 220def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None: 221 expression = scope.expression 222 223 if not isinstance(expression, exp.Select): 224 return 225 226 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 227 228 def replace_columns( 229 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 230 ) -> None: 231 if not node or (expand_only_groupby and not isinstance(node, exp.Group)): 232 return 233 234 for column in walk_in_scope(node, prune=lambda node: node.is_star): 235 if not isinstance(column, exp.Column): 236 continue 237 238 table = resolver.get_table(column.name) if resolve_table and not column.table else None 239 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 240 double_agg = ( 241 ( 242 alias_expr.find(exp.AggFunc) 243 and ( 244 column.find_ancestor(exp.AggFunc) 245 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 246 ) 247 ) 248 if alias_expr 249 else False 250 ) 251 252 if table and (not alias_expr or double_agg): 253 column.set("table", table) 254 elif not column.table and alias_expr and not double_agg: 255 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 256 if literal_index: 257 column.replace(exp.Literal.number(i)) 258 else: 259 column = column.replace(exp.paren(alias_expr)) 260 simplified = simplify_parens(column) 261 if simplified is not column: 262 column.replace(simplified) 263 264 for i, projection in enumerate(scope.expression.selects): 265 replace_columns(projection) 266 267 if isinstance(projection, exp.Alias): 268 alias_to_expression[projection.alias] = (projection.this, i + 1) 269 270 replace_columns(expression.args.get("where")) 271 replace_columns(expression.args.get("group"), literal_index=True) 272 replace_columns(expression.args.get("having"), resolve_table=True) 273 replace_columns(expression.args.get("qualify"), resolve_table=True) 274 275 scope.clear_cache() 276 277 278def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 279 expression = scope.expression 280 group = expression.args.get("group") 281 if not group: 282 return 283 284 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 285 expression.set("group", group) 286 287 288def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 289 order = scope.expression.args.get("order") 290 if not order: 291 return 292 293 ordereds = order.expressions 294 for ordered, new_expression in zip( 295 ordereds, 296 _expand_positional_references( 297 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 298 ), 299 ): 300 for agg in ordered.find_all(exp.AggFunc): 301 for col in agg.find_all(exp.Column): 302 if not col.table: 303 col.set("table", resolver.get_table(col.name)) 304 305 ordered.set("this", new_expression) 306 307 if scope.expression.args.get("group"): 308 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 309 310 for ordered in ordereds: 311 ordered = ordered.this 312 313 ordered.replace( 314 exp.to_identifier(_select_by_pos(scope, ordered).alias) 315 if ordered.is_int 316 else selects.get(ordered, ordered) 317 ) 318 319 320def _expand_positional_references( 321 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 322) -> t.List[exp.Expression]: 323 new_nodes: t.List[exp.Expression] = [] 324 ambiguous_projections = None 325 326 for node in expressions: 327 if node.is_int: 328 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 329 330 if alias: 331 new_nodes.append(exp.column(select.args["alias"].copy())) 332 else: 333 select = select.this 334 335 if dialect == "bigquery": 336 if ambiguous_projections is None: 337 # When a projection name is also a source name and it is referenced in the 338 # GROUP BY clause, BQ can't understand what the identifier corresponds to 339 ambiguous_projections = { 340 s.alias_or_name 341 for s in scope.expression.selects 342 if s.alias_or_name in scope.selected_sources 343 } 344 345 ambiguous = any( 346 column.parts[0].name in ambiguous_projections 347 for column in select.find_all(exp.Column) 348 ) 349 else: 350 ambiguous = False 351 352 if ( 353 isinstance(select, exp.CONSTANTS) 354 or select.find(exp.Explode, exp.Unnest) 355 or ambiguous 356 ): 357 new_nodes.append(node) 358 else: 359 new_nodes.append(select.copy()) 360 else: 361 new_nodes.append(node) 362 363 return new_nodes 364 365 366def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 367 try: 368 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 369 except IndexError: 370 raise OptimizeError(f"Unknown output column: {node.name}") 371 372 373def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 374 """ 375 Converts `Column` instances that represent struct field lookup into chained `Dots`. 376 377 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 378 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 379 """ 380 converted = False 381 for column in itertools.chain(scope.columns, scope.stars): 382 if isinstance(column, exp.Dot): 383 continue 384 385 column_table: t.Optional[str | exp.Identifier] = column.table 386 if ( 387 column_table 388 and column_table not in scope.sources 389 and ( 390 not scope.parent 391 or column_table not in scope.parent.sources 392 or not scope.is_correlated_subquery 393 ) 394 ): 395 root, *parts = column.parts 396 397 if root.name in scope.sources: 398 # The struct is already qualified, but we still need to change the AST 399 column_table = root 400 root, *parts = parts 401 else: 402 column_table = resolver.get_table(root.name) 403 404 if column_table: 405 converted = True 406 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 407 408 if converted: 409 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 410 # a `for column in scope.columns` iteration, even though they shouldn't be 411 scope.clear_cache() 412 413 414def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 415 """Disambiguate columns, ensuring each column specifies a source""" 416 for column in scope.columns: 417 column_table = column.table 418 column_name = column.name 419 420 if column_table and column_table in scope.sources: 421 source_columns = resolver.get_source_columns(column_table) 422 if source_columns and column_name not in source_columns and "*" not in source_columns: 423 raise OptimizeError(f"Unknown column: {column_name}") 424 425 if not column_table: 426 if scope.pivots and not column.find_ancestor(exp.Pivot): 427 # If the column is under the Pivot expression, we need to qualify it 428 # using the name of the pivoted source instead of the pivot's alias 429 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 430 continue 431 432 # column_table can be a '' because bigquery unnest has no table alias 433 column_table = resolver.get_table(column_name) 434 if column_table: 435 column.set("table", column_table) 436 437 for pivot in scope.pivots: 438 for column in pivot.find_all(exp.Column): 439 if not column.table and column.name in resolver.all_columns: 440 column_table = resolver.get_table(column.name) 441 if column_table: 442 column.set("table", column_table) 443 444 445def _expand_struct_stars( 446 expression: exp.Dot, 447) -> t.List[exp.Alias]: 448 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 449 450 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 451 if not dot_column.is_type(exp.DataType.Type.STRUCT): 452 return [] 453 454 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 455 dot_column = dot_column.copy() 456 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 457 458 # First part is the table name and last part is the star so they can be dropped 459 dot_parts = expression.parts[1:-1] 460 461 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 462 for part in dot_parts[1:]: 463 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 464 # Unable to expand star unless all fields are named 465 if not isinstance(field.this, exp.Identifier): 466 return [] 467 468 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 469 starting_struct = field 470 break 471 else: 472 # There is no matching field in the struct 473 return [] 474 475 taken_names = set() 476 new_selections = [] 477 478 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 479 name = field.name 480 481 # Ambiguous or anonymous fields can't be expanded 482 if name in taken_names or not isinstance(field.this, exp.Identifier): 483 return [] 484 485 taken_names.add(name) 486 487 this = field.this.copy() 488 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 489 new_column = exp.column( 490 t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts 491 ) 492 new_selections.append(alias(new_column, this, copy=False)) 493 494 return new_selections 495 496 497def _expand_stars( 498 scope: Scope, 499 resolver: Resolver, 500 using_column_tables: t.Dict[str, t.Any], 501 pseudocolumns: t.Set[str], 502 annotator: TypeAnnotator, 503) -> None: 504 """Expand stars to lists of column selections""" 505 506 new_selections = [] 507 except_columns: t.Dict[int, t.Set[str]] = {} 508 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 509 coalesced_columns = set() 510 dialect = resolver.schema.dialect 511 512 pivot_output_columns = None 513 pivot_exclude_columns = None 514 515 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 516 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 517 if pivot.unpivot: 518 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 519 520 field = pivot.args.get("field") 521 if isinstance(field, exp.In): 522 pivot_exclude_columns = { 523 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 524 } 525 else: 526 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 527 528 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 529 if not pivot_output_columns: 530 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 531 532 is_bigquery = dialect == "bigquery" 533 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 534 # Found struct expansion, annotate scope ahead of time 535 annotator.annotate_scope(scope) 536 537 for expression in scope.expression.selects: 538 tables = [] 539 if isinstance(expression, exp.Star): 540 tables.extend(scope.selected_sources) 541 _add_except_columns(expression, tables, except_columns) 542 _add_replace_columns(expression, tables, replace_columns) 543 elif expression.is_star: 544 if not isinstance(expression, exp.Dot): 545 tables.append(expression.table) 546 _add_except_columns(expression.this, tables, except_columns) 547 _add_replace_columns(expression.this, tables, replace_columns) 548 elif is_bigquery: 549 struct_fields = _expand_struct_stars(expression) 550 if struct_fields: 551 new_selections.extend(struct_fields) 552 continue 553 554 if not tables: 555 new_selections.append(expression) 556 continue 557 558 for table in tables: 559 if table not in scope.sources: 560 raise OptimizeError(f"Unknown table: {table}") 561 562 columns = resolver.get_source_columns(table, only_visible=True) 563 columns = columns or scope.outer_columns 564 565 if pseudocolumns: 566 columns = [name for name in columns if name.upper() not in pseudocolumns] 567 568 if not columns or "*" in columns: 569 return 570 571 table_id = id(table) 572 columns_to_exclude = except_columns.get(table_id) or set() 573 574 if pivot: 575 if pivot_output_columns and pivot_exclude_columns: 576 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 577 pivot_columns.extend(pivot_output_columns) 578 else: 579 pivot_columns = pivot.alias_column_names 580 581 if pivot_columns: 582 new_selections.extend( 583 alias(exp.column(name, table=pivot.alias), name, copy=False) 584 for name in pivot_columns 585 if name not in columns_to_exclude 586 ) 587 continue 588 589 for name in columns: 590 if name in columns_to_exclude or name in coalesced_columns: 591 continue 592 if name in using_column_tables and table in using_column_tables[name]: 593 coalesced_columns.add(name) 594 tables = using_column_tables[name] 595 coalesce_args = [exp.column(name, table=table) for table in tables] 596 597 new_selections.append( 598 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 599 ) 600 else: 601 alias_ = replace_columns.get(table_id, {}).get(name, name) 602 column = exp.column(name, table=table) 603 new_selections.append( 604 alias(column, alias_, copy=False) if alias_ != name else column 605 ) 606 607 # Ensures we don't overwrite the initial selections with an empty list 608 if new_selections and isinstance(scope.expression, exp.Select): 609 scope.expression.set("expressions", new_selections) 610 611 612def _add_except_columns( 613 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 614) -> None: 615 except_ = expression.args.get("except") 616 617 if not except_: 618 return 619 620 columns = {e.name for e in except_} 621 622 for table in tables: 623 except_columns[id(table)] = columns 624 625 626def _add_replace_columns( 627 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 628) -> None: 629 replace = expression.args.get("replace") 630 631 if not replace: 632 return 633 634 columns = {e.this.name: e.alias for e in replace} 635 636 for table in tables: 637 replace_columns[id(table)] = columns 638 639 640def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 641 """Ensure all output columns are aliased""" 642 if isinstance(scope_or_expression, exp.Expression): 643 scope = build_scope(scope_or_expression) 644 if not isinstance(scope, Scope): 645 return 646 else: 647 scope = scope_or_expression 648 649 new_selections = [] 650 for i, (selection, aliased_column) in enumerate( 651 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 652 ): 653 if selection is None: 654 break 655 656 if isinstance(selection, exp.Subquery): 657 if not selection.output_name: 658 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 659 elif not isinstance(selection, exp.Alias) and not selection.is_star: 660 selection = alias( 661 selection, 662 alias=selection.output_name or f"_col_{i}", 663 copy=False, 664 ) 665 if aliased_column: 666 selection.set("alias", exp.to_identifier(aliased_column)) 667 668 new_selections.append(selection) 669 670 if isinstance(scope.expression, exp.Select): 671 scope.expression.set("expressions", new_selections) 672 673 674def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 675 """Makes sure all identifiers that need to be quoted are quoted.""" 676 return expression.transform( 677 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 678 ) # type: ignore 679 680 681def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 682 """ 683 Pushes down the CTE alias columns into the projection, 684 685 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 686 687 Example: 688 >>> import sqlglot 689 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 690 >>> pushdown_cte_alias_columns(expression).sql() 691 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 692 693 Args: 694 expression: Expression to pushdown. 695 696 Returns: 697 The expression with the CTE aliases pushed down into the projection. 698 """ 699 for cte in expression.find_all(exp.CTE): 700 if cte.alias_column_names: 701 new_expressions = [] 702 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 703 if isinstance(projection, exp.Alias): 704 projection.set("alias", _alias) 705 else: 706 projection = alias(projection, alias=_alias) 707 new_expressions.append(projection) 708 cte.this.set("expressions", new_expressions) 709 710 return expression 711 712 713class Resolver: 714 """ 715 Helper for resolving columns. 716 717 This is a class so we can lazily load some things and easily share them across functions. 718 """ 719 720 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 721 self.scope = scope 722 self.schema = schema 723 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 724 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 725 self._all_columns: t.Optional[t.Set[str]] = None 726 self._infer_schema = infer_schema 727 728 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 729 """ 730 Get the table for a column name. 731 732 Args: 733 column_name: The column name to find the table for. 734 Returns: 735 The table name if it can be found/inferred. 736 """ 737 if self._unambiguous_columns is None: 738 self._unambiguous_columns = self._get_unambiguous_columns( 739 self._get_all_source_columns() 740 ) 741 742 table_name = self._unambiguous_columns.get(column_name) 743 744 if not table_name and self._infer_schema: 745 sources_without_schema = tuple( 746 source 747 for source, columns in self._get_all_source_columns().items() 748 if not columns or "*" in columns 749 ) 750 if len(sources_without_schema) == 1: 751 table_name = sources_without_schema[0] 752 753 if table_name not in self.scope.selected_sources: 754 return exp.to_identifier(table_name) 755 756 node, _ = self.scope.selected_sources.get(table_name) 757 758 if isinstance(node, exp.Query): 759 while node and node.alias != table_name: 760 node = node.parent 761 762 node_alias = node.args.get("alias") 763 if node_alias: 764 return exp.to_identifier(node_alias.this) 765 766 return exp.to_identifier(table_name) 767 768 @property 769 def all_columns(self) -> t.Set[str]: 770 """All available columns of all sources in this scope""" 771 if self._all_columns is None: 772 self._all_columns = { 773 column for columns in self._get_all_source_columns().values() for column in columns 774 } 775 return self._all_columns 776 777 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 778 """Resolve the source columns for a given source `name`.""" 779 if name not in self.scope.sources: 780 raise OptimizeError(f"Unknown table: {name}") 781 782 source = self.scope.sources[name] 783 784 if isinstance(source, exp.Table): 785 columns = self.schema.column_names(source, only_visible) 786 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 787 columns = source.expression.named_selects 788 789 # in bigquery, unnest structs are automatically scoped as tables, so you can 790 # directly select a struct field in a query. 791 # this handles the case where the unnest is statically defined. 792 if self.schema.dialect == "bigquery": 793 if source.expression.is_type(exp.DataType.Type.STRUCT): 794 for k in source.expression.type.expressions: # type: ignore 795 columns.append(k.name) 796 else: 797 columns = source.expression.named_selects 798 799 node, _ = self.scope.selected_sources.get(name) or (None, None) 800 if isinstance(node, Scope): 801 column_aliases = node.expression.alias_column_names 802 elif isinstance(node, exp.Expression): 803 column_aliases = node.alias_column_names 804 else: 805 column_aliases = [] 806 807 if column_aliases: 808 # If the source's columns are aliased, their aliases shadow the corresponding column names. 809 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 810 return [ 811 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 812 ] 813 return columns 814 815 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 816 if self._source_columns is None: 817 self._source_columns = { 818 source_name: self.get_source_columns(source_name) 819 for source_name, source in itertools.chain( 820 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 821 ) 822 } 823 return self._source_columns 824 825 def _get_unambiguous_columns( 826 self, source_columns: t.Dict[str, t.Sequence[str]] 827 ) -> t.Mapping[str, str]: 828 """ 829 Find all the unambiguous columns in sources. 830 831 Args: 832 source_columns: Mapping of names to source columns. 833 834 Returns: 835 Mapping of column name to source name. 836 """ 837 if not source_columns: 838 return {} 839 840 source_columns_pairs = list(source_columns.items()) 841 842 first_table, first_columns = source_columns_pairs[0] 843 844 if len(source_columns_pairs) == 1: 845 # Performance optimization - avoid copying first_columns if there is only one table. 846 return SingleValuedMapping(first_columns, first_table) 847 848 unambiguous_columns = {col: first_table for col in first_columns} 849 all_columns = set(unambiguous_columns) 850 851 for table, columns in source_columns_pairs[1:]: 852 unique = set(columns) 853 ambiguous = all_columns.intersection(unique) 854 all_columns.update(columns) 855 856 for column in ambiguous: 857 unambiguous_columns.pop(column, None) 858 for column in unique.difference(ambiguous): 859 unambiguous_columns[column] = table 860 861 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 46 Returns: 47 The qualified expression. 48 49 Notes: 50 - Currently only handles a single PIVOT or UNPIVOT operator 51 """ 52 schema = ensure_schema(schema) 53 annotator = TypeAnnotator(schema) 54 infer_schema = schema.empty if infer_schema is None else infer_schema 55 dialect = Dialect.get_or_raise(schema.dialect) 56 pseudocolumns = dialect.PSEUDOCOLUMNS 57 58 for scope in traverse_scope(expression): 59 resolver = Resolver(scope, schema, infer_schema=infer_schema) 60 _pop_table_column_aliases(scope.ctes) 61 _pop_table_column_aliases(scope.derived_tables) 62 using_column_tables = _expand_using(scope, resolver) 63 64 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 65 _expand_alias_refs( 66 scope, 67 resolver, 68 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 69 ) 70 71 _convert_columns_to_dots(scope, resolver) 72 _qualify_columns(scope, resolver) 73 74 if not schema.empty and expand_alias_refs: 75 _expand_alias_refs(scope, resolver) 76 77 if not isinstance(scope.expression, exp.UDTF): 78 if expand_stars: 79 _expand_stars( 80 scope, 81 resolver, 82 using_column_tables, 83 pseudocolumns, 84 annotator, 85 ) 86 qualify_outputs(scope) 87 88 _expand_group_by(scope, dialect) 89 _expand_order_by(scope, resolver) 90 91 if dialect == "bigquery": 92 annotator.annotate_scope(scope) 93 94 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
97def validate_qualify_columns(expression: E) -> E: 98 """Raise an `OptimizeError` if any columns aren't qualified""" 99 all_unqualified_columns = [] 100 for scope in traverse_scope(expression): 101 if isinstance(scope.expression, exp.Select): 102 unqualified_columns = scope.unqualified_columns 103 104 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 105 column = scope.external_columns[0] 106 for_table = f" for table: '{column.table}'" if column.table else "" 107 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 108 109 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 110 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 111 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 112 # this list here to ensure those in the former category will be excluded. 113 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 114 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 115 116 all_unqualified_columns.extend(unqualified_columns) 117 118 if all_unqualified_columns: 119 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 120 121 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
641def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 642 """Ensure all output columns are aliased""" 643 if isinstance(scope_or_expression, exp.Expression): 644 scope = build_scope(scope_or_expression) 645 if not isinstance(scope, Scope): 646 return 647 else: 648 scope = scope_or_expression 649 650 new_selections = [] 651 for i, (selection, aliased_column) in enumerate( 652 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 653 ): 654 if selection is None: 655 break 656 657 if isinstance(selection, exp.Subquery): 658 if not selection.output_name: 659 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 660 elif not isinstance(selection, exp.Alias) and not selection.is_star: 661 selection = alias( 662 selection, 663 alias=selection.output_name or f"_col_{i}", 664 copy=False, 665 ) 666 if aliased_column: 667 selection.set("alias", exp.to_identifier(aliased_column)) 668 669 new_selections.append(selection) 670 671 if isinstance(scope.expression, exp.Select): 672 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
675def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 676 """Makes sure all identifiers that need to be quoted are quoted.""" 677 return expression.transform( 678 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 679 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
682def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 683 """ 684 Pushes down the CTE alias columns into the projection, 685 686 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 687 688 Example: 689 >>> import sqlglot 690 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 691 >>> pushdown_cte_alias_columns(expression).sql() 692 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 693 694 Args: 695 expression: Expression to pushdown. 696 697 Returns: 698 The expression with the CTE aliases pushed down into the projection. 699 """ 700 for cte in expression.find_all(exp.CTE): 701 if cte.alias_column_names: 702 new_expressions = [] 703 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 704 if isinstance(projection, exp.Alias): 705 projection.set("alias", _alias) 706 else: 707 projection = alias(projection, alias=_alias) 708 new_expressions.append(projection) 709 cte.this.set("expressions", new_expressions) 710 711 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
714class Resolver: 715 """ 716 Helper for resolving columns. 717 718 This is a class so we can lazily load some things and easily share them across functions. 719 """ 720 721 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 722 self.scope = scope 723 self.schema = schema 724 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 725 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 726 self._all_columns: t.Optional[t.Set[str]] = None 727 self._infer_schema = infer_schema 728 729 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 730 """ 731 Get the table for a column name. 732 733 Args: 734 column_name: The column name to find the table for. 735 Returns: 736 The table name if it can be found/inferred. 737 """ 738 if self._unambiguous_columns is None: 739 self._unambiguous_columns = self._get_unambiguous_columns( 740 self._get_all_source_columns() 741 ) 742 743 table_name = self._unambiguous_columns.get(column_name) 744 745 if not table_name and self._infer_schema: 746 sources_without_schema = tuple( 747 source 748 for source, columns in self._get_all_source_columns().items() 749 if not columns or "*" in columns 750 ) 751 if len(sources_without_schema) == 1: 752 table_name = sources_without_schema[0] 753 754 if table_name not in self.scope.selected_sources: 755 return exp.to_identifier(table_name) 756 757 node, _ = self.scope.selected_sources.get(table_name) 758 759 if isinstance(node, exp.Query): 760 while node and node.alias != table_name: 761 node = node.parent 762 763 node_alias = node.args.get("alias") 764 if node_alias: 765 return exp.to_identifier(node_alias.this) 766 767 return exp.to_identifier(table_name) 768 769 @property 770 def all_columns(self) -> t.Set[str]: 771 """All available columns of all sources in this scope""" 772 if self._all_columns is None: 773 self._all_columns = { 774 column for columns in self._get_all_source_columns().values() for column in columns 775 } 776 return self._all_columns 777 778 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 779 """Resolve the source columns for a given source `name`.""" 780 if name not in self.scope.sources: 781 raise OptimizeError(f"Unknown table: {name}") 782 783 source = self.scope.sources[name] 784 785 if isinstance(source, exp.Table): 786 columns = self.schema.column_names(source, only_visible) 787 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 788 columns = source.expression.named_selects 789 790 # in bigquery, unnest structs are automatically scoped as tables, so you can 791 # directly select a struct field in a query. 792 # this handles the case where the unnest is statically defined. 793 if self.schema.dialect == "bigquery": 794 if source.expression.is_type(exp.DataType.Type.STRUCT): 795 for k in source.expression.type.expressions: # type: ignore 796 columns.append(k.name) 797 else: 798 columns = source.expression.named_selects 799 800 node, _ = self.scope.selected_sources.get(name) or (None, None) 801 if isinstance(node, Scope): 802 column_aliases = node.expression.alias_column_names 803 elif isinstance(node, exp.Expression): 804 column_aliases = node.alias_column_names 805 else: 806 column_aliases = [] 807 808 if column_aliases: 809 # If the source's columns are aliased, their aliases shadow the corresponding column names. 810 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 811 return [ 812 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 813 ] 814 return columns 815 816 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 817 if self._source_columns is None: 818 self._source_columns = { 819 source_name: self.get_source_columns(source_name) 820 for source_name, source in itertools.chain( 821 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 822 ) 823 } 824 return self._source_columns 825 826 def _get_unambiguous_columns( 827 self, source_columns: t.Dict[str, t.Sequence[str]] 828 ) -> t.Mapping[str, str]: 829 """ 830 Find all the unambiguous columns in sources. 831 832 Args: 833 source_columns: Mapping of names to source columns. 834 835 Returns: 836 Mapping of column name to source name. 837 """ 838 if not source_columns: 839 return {} 840 841 source_columns_pairs = list(source_columns.items()) 842 843 first_table, first_columns = source_columns_pairs[0] 844 845 if len(source_columns_pairs) == 1: 846 # Performance optimization - avoid copying first_columns if there is only one table. 847 return SingleValuedMapping(first_columns, first_table) 848 849 unambiguous_columns = {col: first_table for col in first_columns} 850 all_columns = set(unambiguous_columns) 851 852 for table, columns in source_columns_pairs[1:]: 853 unique = set(columns) 854 ambiguous = all_columns.intersection(unique) 855 all_columns.update(columns) 856 857 for column in ambiguous: 858 unambiguous_columns.pop(column, None) 859 for column in unique.difference(ambiguous): 860 unambiguous_columns[column] = table 861 862 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
721 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 722 self.scope = scope 723 self.schema = schema 724 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 725 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 726 self._all_columns: t.Optional[t.Set[str]] = None 727 self._infer_schema = infer_schema
729 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 730 """ 731 Get the table for a column name. 732 733 Args: 734 column_name: The column name to find the table for. 735 Returns: 736 The table name if it can be found/inferred. 737 """ 738 if self._unambiguous_columns is None: 739 self._unambiguous_columns = self._get_unambiguous_columns( 740 self._get_all_source_columns() 741 ) 742 743 table_name = self._unambiguous_columns.get(column_name) 744 745 if not table_name and self._infer_schema: 746 sources_without_schema = tuple( 747 source 748 for source, columns in self._get_all_source_columns().items() 749 if not columns or "*" in columns 750 ) 751 if len(sources_without_schema) == 1: 752 table_name = sources_without_schema[0] 753 754 if table_name not in self.scope.selected_sources: 755 return exp.to_identifier(table_name) 756 757 node, _ = self.scope.selected_sources.get(table_name) 758 759 if isinstance(node, exp.Query): 760 while node and node.alias != table_name: 761 node = node.parent 762 763 node_alias = node.args.get("alias") 764 if node_alias: 765 return exp.to_identifier(node_alias.this) 766 767 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
769 @property 770 def all_columns(self) -> t.Set[str]: 771 """All available columns of all sources in this scope""" 772 if self._all_columns is None: 773 self._all_columns = { 774 column for columns in self._get_all_source_columns().values() for column in columns 775 } 776 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
778 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 779 """Resolve the source columns for a given source `name`.""" 780 if name not in self.scope.sources: 781 raise OptimizeError(f"Unknown table: {name}") 782 783 source = self.scope.sources[name] 784 785 if isinstance(source, exp.Table): 786 columns = self.schema.column_names(source, only_visible) 787 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 788 columns = source.expression.named_selects 789 790 # in bigquery, unnest structs are automatically scoped as tables, so you can 791 # directly select a struct field in a query. 792 # this handles the case where the unnest is statically defined. 793 if self.schema.dialect == "bigquery": 794 if source.expression.is_type(exp.DataType.Type.STRUCT): 795 for k in source.expression.type.expressions: # type: ignore 796 columns.append(k.name) 797 else: 798 columns = source.expression.named_selects 799 800 node, _ = self.scope.selected_sources.get(name) or (None, None) 801 if isinstance(node, Scope): 802 column_aliases = node.expression.alias_column_names 803 elif isinstance(node, exp.Expression): 804 column_aliases = node.alias_column_names 805 else: 806 column_aliases = [] 807 808 if column_aliases: 809 # If the source's columns are aliased, their aliases shadow the corresponding column names. 810 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 811 return [ 812 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 813 ] 814 return columns
Resolve the source columns for a given source name
.