sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 annotator = TypeAnnotator(schema) 53 infer_schema = schema.empty if infer_schema is None else infer_schema 54 dialect = Dialect.get_or_raise(schema.dialect) 55 pseudocolumns = dialect.PSEUDOCOLUMNS 56 57 for scope in traverse_scope(expression): 58 resolver = Resolver(scope, schema, infer_schema=infer_schema) 59 _pop_table_column_aliases(scope.ctes) 60 _pop_table_column_aliases(scope.derived_tables) 61 using_column_tables = _expand_using(scope, resolver) 62 63 if schema.empty and expand_alias_refs: 64 _expand_alias_refs(scope, resolver) 65 66 _convert_columns_to_dots(scope, resolver) 67 _qualify_columns(scope, resolver) 68 69 if not schema.empty and expand_alias_refs: 70 _expand_alias_refs(scope, resolver) 71 72 if not isinstance(scope.expression, exp.UDTF): 73 if expand_stars: 74 _expand_stars( 75 scope, 76 resolver, 77 using_column_tables, 78 pseudocolumns, 79 annotator, 80 ) 81 qualify_outputs(scope) 82 83 _expand_group_by(scope, dialect) 84 _expand_order_by(scope, resolver) 85 86 if dialect == "bigquery": 87 annotator.annotate_scope(scope) 88 89 return expression 90 91 92def validate_qualify_columns(expression: E) -> E: 93 """Raise an `OptimizeError` if any columns aren't qualified""" 94 all_unqualified_columns = [] 95 for scope in traverse_scope(expression): 96 if isinstance(scope.expression, exp.Select): 97 unqualified_columns = scope.unqualified_columns 98 99 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 100 column = scope.external_columns[0] 101 for_table = f" for table: '{column.table}'" if column.table else "" 102 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 103 104 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 105 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 106 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 107 # this list here to ensure those in the former category will be excluded. 108 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 109 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 110 111 all_unqualified_columns.extend(unqualified_columns) 112 113 if all_unqualified_columns: 114 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 115 116 return expression 117 118 119def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 120 name_column = [] 121 field = unpivot.args.get("field") 122 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 123 name_column.append(field.this) 124 125 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 126 return itertools.chain(name_column, value_columns) 127 128 129def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 130 """ 131 Remove table column aliases. 132 133 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 134 """ 135 for derived_table in derived_tables: 136 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 137 continue 138 table_alias = derived_table.args.get("alias") 139 if table_alias: 140 table_alias.args.pop("columns", None) 141 142 143def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 144 joins = list(scope.find_all(exp.Join)) 145 names = {join.alias_or_name for join in joins} 146 ordered = [key for key in scope.selected_sources if key not in names] 147 148 # Mapping of automatically joined column names to an ordered set of source names (dict). 149 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 150 151 for join in joins: 152 using = join.args.get("using") 153 154 if not using: 155 continue 156 157 join_table = join.alias_or_name 158 159 columns = {} 160 161 for source_name in scope.selected_sources: 162 if source_name in ordered: 163 for column_name in resolver.get_source_columns(source_name): 164 if column_name not in columns: 165 columns[column_name] = source_name 166 167 source_table = ordered[-1] 168 ordered.append(join_table) 169 join_columns = resolver.get_source_columns(join_table) 170 conditions = [] 171 172 for identifier in using: 173 identifier = identifier.name 174 table = columns.get(identifier) 175 176 if not table or identifier not in join_columns: 177 if (columns and "*" not in columns) and join_columns: 178 raise OptimizeError(f"Cannot automatically join: {identifier}") 179 180 table = table or source_table 181 conditions.append( 182 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 183 ) 184 185 # Set all values in the dict to None, because we only care about the key ordering 186 tables = column_tables.setdefault(identifier, {}) 187 if table not in tables: 188 tables[table] = None 189 if join_table not in tables: 190 tables[join_table] = None 191 192 join.args.pop("using") 193 join.set("on", exp.and_(*conditions, copy=False)) 194 195 if column_tables: 196 for column in scope.columns: 197 if not column.table and column.name in column_tables: 198 tables = column_tables[column.name] 199 coalesce = [exp.column(column.name, table=table) for table in tables] 200 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 201 202 # Ensure selects keep their output name 203 if isinstance(column.parent, exp.Select): 204 replacement = alias(replacement, alias=column.name, copy=False) 205 206 scope.replace(column, replacement) 207 208 return column_tables 209 210 211def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 212 expression = scope.expression 213 214 if not isinstance(expression, exp.Select): 215 return 216 217 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 218 219 def replace_columns( 220 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 221 ) -> None: 222 if not node: 223 return 224 225 for column in walk_in_scope(node, prune=lambda node: node.is_star): 226 if not isinstance(column, exp.Column): 227 continue 228 229 table = resolver.get_table(column.name) if resolve_table and not column.table else None 230 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 231 double_agg = ( 232 ( 233 alias_expr.find(exp.AggFunc) 234 and ( 235 column.find_ancestor(exp.AggFunc) 236 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 237 ) 238 ) 239 if alias_expr 240 else False 241 ) 242 243 if table and (not alias_expr or double_agg): 244 column.set("table", table) 245 elif not column.table and alias_expr and not double_agg: 246 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 247 if literal_index: 248 column.replace(exp.Literal.number(i)) 249 else: 250 column = column.replace(exp.paren(alias_expr)) 251 simplified = simplify_parens(column) 252 if simplified is not column: 253 column.replace(simplified) 254 255 for i, projection in enumerate(scope.expression.selects): 256 replace_columns(projection) 257 258 if isinstance(projection, exp.Alias): 259 alias_to_expression[projection.alias] = (projection.this, i + 1) 260 261 replace_columns(expression.args.get("where")) 262 replace_columns(expression.args.get("group"), literal_index=True) 263 replace_columns(expression.args.get("having"), resolve_table=True) 264 replace_columns(expression.args.get("qualify"), resolve_table=True) 265 266 scope.clear_cache() 267 268 269def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 270 expression = scope.expression 271 group = expression.args.get("group") 272 if not group: 273 return 274 275 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 276 expression.set("group", group) 277 278 279def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 280 order = scope.expression.args.get("order") 281 if not order: 282 return 283 284 ordereds = order.expressions 285 for ordered, new_expression in zip( 286 ordereds, 287 _expand_positional_references( 288 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 289 ), 290 ): 291 for agg in ordered.find_all(exp.AggFunc): 292 for col in agg.find_all(exp.Column): 293 if not col.table: 294 col.set("table", resolver.get_table(col.name)) 295 296 ordered.set("this", new_expression) 297 298 if scope.expression.args.get("group"): 299 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 300 301 for ordered in ordereds: 302 ordered = ordered.this 303 304 ordered.replace( 305 exp.to_identifier(_select_by_pos(scope, ordered).alias) 306 if ordered.is_int 307 else selects.get(ordered, ordered) 308 ) 309 310 311def _expand_positional_references( 312 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 313) -> t.List[exp.Expression]: 314 new_nodes: t.List[exp.Expression] = [] 315 ambiguous_projections = None 316 317 for node in expressions: 318 if node.is_int: 319 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 320 321 if alias: 322 new_nodes.append(exp.column(select.args["alias"].copy())) 323 else: 324 select = select.this 325 326 if dialect == "bigquery": 327 if ambiguous_projections is None: 328 # When a projection name is also a source name and it is referenced in the 329 # GROUP BY clause, BQ can't understand what the identifier corresponds to 330 ambiguous_projections = { 331 s.alias_or_name 332 for s in scope.expression.selects 333 if s.alias_or_name in scope.selected_sources 334 } 335 336 ambiguous = any( 337 column.parts[0].name in ambiguous_projections 338 for column in select.find_all(exp.Column) 339 ) 340 else: 341 ambiguous = False 342 343 if ( 344 isinstance(select, exp.CONSTANTS) 345 or select.find(exp.Explode, exp.Unnest) 346 or ambiguous 347 ): 348 new_nodes.append(node) 349 else: 350 new_nodes.append(select.copy()) 351 else: 352 new_nodes.append(node) 353 354 return new_nodes 355 356 357def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 358 try: 359 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 360 except IndexError: 361 raise OptimizeError(f"Unknown output column: {node.name}") 362 363 364def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 365 """ 366 Converts `Column` instances that represent struct field lookup into chained `Dots`. 367 368 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 369 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 370 """ 371 converted = False 372 for column in itertools.chain(scope.columns, scope.stars): 373 if isinstance(column, exp.Dot): 374 continue 375 376 column_table: t.Optional[str | exp.Identifier] = column.table 377 if ( 378 column_table 379 and column_table not in scope.sources 380 and ( 381 not scope.parent 382 or column_table not in scope.parent.sources 383 or not scope.is_correlated_subquery 384 ) 385 ): 386 root, *parts = column.parts 387 388 if root.name in scope.sources: 389 # The struct is already qualified, but we still need to change the AST 390 column_table = root 391 root, *parts = parts 392 else: 393 column_table = resolver.get_table(root.name) 394 395 if column_table: 396 converted = True 397 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 398 399 if converted: 400 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 401 # a `for column in scope.columns` iteration, even though they shouldn't be 402 scope.clear_cache() 403 404 405def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 406 """Disambiguate columns, ensuring each column specifies a source""" 407 for column in scope.columns: 408 column_table = column.table 409 column_name = column.name 410 411 if column_table and column_table in scope.sources: 412 source_columns = resolver.get_source_columns(column_table) 413 if source_columns and column_name not in source_columns and "*" not in source_columns: 414 raise OptimizeError(f"Unknown column: {column_name}") 415 416 if not column_table: 417 if scope.pivots and not column.find_ancestor(exp.Pivot): 418 # If the column is under the Pivot expression, we need to qualify it 419 # using the name of the pivoted source instead of the pivot's alias 420 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 421 continue 422 423 # column_table can be a '' because bigquery unnest has no table alias 424 column_table = resolver.get_table(column_name) 425 if column_table: 426 column.set("table", column_table) 427 428 for pivot in scope.pivots: 429 for column in pivot.find_all(exp.Column): 430 if not column.table and column.name in resolver.all_columns: 431 column_table = resolver.get_table(column.name) 432 if column_table: 433 column.set("table", column_table) 434 435 436def _expand_struct_stars( 437 expression: exp.Dot, 438) -> t.List[exp.Alias]: 439 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 440 441 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 442 if not dot_column.is_type(exp.DataType.Type.STRUCT): 443 return [] 444 445 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 446 dot_column = dot_column.copy() 447 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 448 449 # First part is the table name and last part is the star so they can be dropped 450 dot_parts = expression.parts[1:-1] 451 452 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 453 for part in dot_parts[1:]: 454 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 455 # Unable to expand star unless all fields are named 456 if not isinstance(field.this, exp.Identifier): 457 return [] 458 459 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 460 starting_struct = field 461 break 462 else: 463 # There is no matching field in the struct 464 return [] 465 466 taken_names = set() 467 new_selections = [] 468 469 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 470 name = field.name 471 472 # Ambiguous or anonymous fields can't be expanded 473 if name in taken_names or not isinstance(field.this, exp.Identifier): 474 return [] 475 476 taken_names.add(name) 477 478 this = field.this.copy() 479 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 480 new_column = exp.column( 481 t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts 482 ) 483 new_selections.append(alias(new_column, this, copy=False)) 484 485 return new_selections 486 487 488def _expand_stars( 489 scope: Scope, 490 resolver: Resolver, 491 using_column_tables: t.Dict[str, t.Any], 492 pseudocolumns: t.Set[str], 493 annotator: TypeAnnotator, 494) -> None: 495 """Expand stars to lists of column selections""" 496 497 new_selections = [] 498 except_columns: t.Dict[int, t.Set[str]] = {} 499 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 500 coalesced_columns = set() 501 dialect = resolver.schema.dialect 502 503 pivot_output_columns = None 504 pivot_exclude_columns = None 505 506 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 507 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 508 if pivot.unpivot: 509 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 510 511 field = pivot.args.get("field") 512 if isinstance(field, exp.In): 513 pivot_exclude_columns = { 514 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 515 } 516 else: 517 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 518 519 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 520 if not pivot_output_columns: 521 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 522 523 is_bigquery = dialect == "bigquery" 524 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 525 # Found struct expansion, annotate scope ahead of time 526 annotator.annotate_scope(scope) 527 528 for expression in scope.expression.selects: 529 tables = [] 530 if isinstance(expression, exp.Star): 531 tables.extend(scope.selected_sources) 532 _add_except_columns(expression, tables, except_columns) 533 _add_replace_columns(expression, tables, replace_columns) 534 elif expression.is_star: 535 if not isinstance(expression, exp.Dot): 536 tables.append(expression.table) 537 _add_except_columns(expression.this, tables, except_columns) 538 _add_replace_columns(expression.this, tables, replace_columns) 539 elif is_bigquery: 540 struct_fields = _expand_struct_stars(expression) 541 if struct_fields: 542 new_selections.extend(struct_fields) 543 continue 544 545 if not tables: 546 new_selections.append(expression) 547 continue 548 549 for table in tables: 550 if table not in scope.sources: 551 raise OptimizeError(f"Unknown table: {table}") 552 553 columns = resolver.get_source_columns(table, only_visible=True) 554 columns = columns or scope.outer_columns 555 556 if pseudocolumns: 557 columns = [name for name in columns if name.upper() not in pseudocolumns] 558 559 if not columns or "*" in columns: 560 return 561 562 table_id = id(table) 563 columns_to_exclude = except_columns.get(table_id) or set() 564 565 if pivot: 566 if pivot_output_columns and pivot_exclude_columns: 567 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 568 pivot_columns.extend(pivot_output_columns) 569 else: 570 pivot_columns = pivot.alias_column_names 571 572 if pivot_columns: 573 new_selections.extend( 574 alias(exp.column(name, table=pivot.alias), name, copy=False) 575 for name in pivot_columns 576 if name not in columns_to_exclude 577 ) 578 continue 579 580 for name in columns: 581 if name in columns_to_exclude or name in coalesced_columns: 582 continue 583 if name in using_column_tables and table in using_column_tables[name]: 584 coalesced_columns.add(name) 585 tables = using_column_tables[name] 586 coalesce = [exp.column(name, table=table) for table in tables] 587 588 new_selections.append( 589 alias( 590 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 591 alias=name, 592 copy=False, 593 ) 594 ) 595 else: 596 alias_ = replace_columns.get(table_id, {}).get(name, name) 597 column = exp.column(name, table=table) 598 new_selections.append( 599 alias(column, alias_, copy=False) if alias_ != name else column 600 ) 601 602 # Ensures we don't overwrite the initial selections with an empty list 603 if new_selections and isinstance(scope.expression, exp.Select): 604 scope.expression.set("expressions", new_selections) 605 606 607def _add_except_columns( 608 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 609) -> None: 610 except_ = expression.args.get("except") 611 612 if not except_: 613 return 614 615 columns = {e.name for e in except_} 616 617 for table in tables: 618 except_columns[id(table)] = columns 619 620 621def _add_replace_columns( 622 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 623) -> None: 624 replace = expression.args.get("replace") 625 626 if not replace: 627 return 628 629 columns = {e.this.name: e.alias for e in replace} 630 631 for table in tables: 632 replace_columns[id(table)] = columns 633 634 635def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 636 """Ensure all output columns are aliased""" 637 if isinstance(scope_or_expression, exp.Expression): 638 scope = build_scope(scope_or_expression) 639 if not isinstance(scope, Scope): 640 return 641 else: 642 scope = scope_or_expression 643 644 new_selections = [] 645 for i, (selection, aliased_column) in enumerate( 646 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 647 ): 648 if selection is None: 649 break 650 651 if isinstance(selection, exp.Subquery): 652 if not selection.output_name: 653 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 654 elif not isinstance(selection, exp.Alias) and not selection.is_star: 655 selection = alias( 656 selection, 657 alias=selection.output_name or f"_col_{i}", 658 copy=False, 659 ) 660 if aliased_column: 661 selection.set("alias", exp.to_identifier(aliased_column)) 662 663 new_selections.append(selection) 664 665 if isinstance(scope.expression, exp.Select): 666 scope.expression.set("expressions", new_selections) 667 668 669def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 670 """Makes sure all identifiers that need to be quoted are quoted.""" 671 return expression.transform( 672 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 673 ) # type: ignore 674 675 676def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 677 """ 678 Pushes down the CTE alias columns into the projection, 679 680 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 681 682 Example: 683 >>> import sqlglot 684 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 685 >>> pushdown_cte_alias_columns(expression).sql() 686 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 687 688 Args: 689 expression: Expression to pushdown. 690 691 Returns: 692 The expression with the CTE aliases pushed down into the projection. 693 """ 694 for cte in expression.find_all(exp.CTE): 695 if cte.alias_column_names: 696 new_expressions = [] 697 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 698 if isinstance(projection, exp.Alias): 699 projection.set("alias", _alias) 700 else: 701 projection = alias(projection, alias=_alias) 702 new_expressions.append(projection) 703 cte.this.set("expressions", new_expressions) 704 705 return expression 706 707 708class Resolver: 709 """ 710 Helper for resolving columns. 711 712 This is a class so we can lazily load some things and easily share them across functions. 713 """ 714 715 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 716 self.scope = scope 717 self.schema = schema 718 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 719 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 720 self._all_columns: t.Optional[t.Set[str]] = None 721 self._infer_schema = infer_schema 722 723 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 724 """ 725 Get the table for a column name. 726 727 Args: 728 column_name: The column name to find the table for. 729 Returns: 730 The table name if it can be found/inferred. 731 """ 732 if self._unambiguous_columns is None: 733 self._unambiguous_columns = self._get_unambiguous_columns( 734 self._get_all_source_columns() 735 ) 736 737 table_name = self._unambiguous_columns.get(column_name) 738 739 if not table_name and self._infer_schema: 740 sources_without_schema = tuple( 741 source 742 for source, columns in self._get_all_source_columns().items() 743 if not columns or "*" in columns 744 ) 745 if len(sources_without_schema) == 1: 746 table_name = sources_without_schema[0] 747 748 if table_name not in self.scope.selected_sources: 749 return exp.to_identifier(table_name) 750 751 node, _ = self.scope.selected_sources.get(table_name) 752 753 if isinstance(node, exp.Query): 754 while node and node.alias != table_name: 755 node = node.parent 756 757 node_alias = node.args.get("alias") 758 if node_alias: 759 return exp.to_identifier(node_alias.this) 760 761 return exp.to_identifier(table_name) 762 763 @property 764 def all_columns(self) -> t.Set[str]: 765 """All available columns of all sources in this scope""" 766 if self._all_columns is None: 767 self._all_columns = { 768 column for columns in self._get_all_source_columns().values() for column in columns 769 } 770 return self._all_columns 771 772 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 773 """Resolve the source columns for a given source `name`.""" 774 if name not in self.scope.sources: 775 raise OptimizeError(f"Unknown table: {name}") 776 777 source = self.scope.sources[name] 778 779 if isinstance(source, exp.Table): 780 columns = self.schema.column_names(source, only_visible) 781 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 782 columns = source.expression.named_selects 783 784 # in bigquery, unnest structs are automatically scoped as tables, so you can 785 # directly select a struct field in a query. 786 # this handles the case where the unnest is statically defined. 787 if self.schema.dialect == "bigquery": 788 if source.expression.is_type(exp.DataType.Type.STRUCT): 789 for k in source.expression.type.expressions: # type: ignore 790 columns.append(k.name) 791 else: 792 columns = source.expression.named_selects 793 794 node, _ = self.scope.selected_sources.get(name) or (None, None) 795 if isinstance(node, Scope): 796 column_aliases = node.expression.alias_column_names 797 elif isinstance(node, exp.Expression): 798 column_aliases = node.alias_column_names 799 else: 800 column_aliases = [] 801 802 if column_aliases: 803 # If the source's columns are aliased, their aliases shadow the corresponding column names. 804 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 805 return [ 806 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 807 ] 808 return columns 809 810 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 811 if self._source_columns is None: 812 self._source_columns = { 813 source_name: self.get_source_columns(source_name) 814 for source_name, source in itertools.chain( 815 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 816 ) 817 } 818 return self._source_columns 819 820 def _get_unambiguous_columns( 821 self, source_columns: t.Dict[str, t.Sequence[str]] 822 ) -> t.Mapping[str, str]: 823 """ 824 Find all the unambiguous columns in sources. 825 826 Args: 827 source_columns: Mapping of names to source columns. 828 829 Returns: 830 Mapping of column name to source name. 831 """ 832 if not source_columns: 833 return {} 834 835 source_columns_pairs = list(source_columns.items()) 836 837 first_table, first_columns = source_columns_pairs[0] 838 839 if len(source_columns_pairs) == 1: 840 # Performance optimization - avoid copying first_columns if there is only one table. 841 return SingleValuedMapping(first_columns, first_table) 842 843 unambiguous_columns = {col: first_table for col in first_columns} 844 all_columns = set(unambiguous_columns) 845 846 for table, columns in source_columns_pairs[1:]: 847 unique = set(columns) 848 ambiguous = all_columns.intersection(unique) 849 all_columns.update(columns) 850 851 for column in ambiguous: 852 unambiguous_columns.pop(column, None) 853 for column in unique.difference(ambiguous): 854 unambiguous_columns[column] = table 855 856 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 46 Returns: 47 The qualified expression. 48 49 Notes: 50 - Currently only handles a single PIVOT or UNPIVOT operator 51 """ 52 schema = ensure_schema(schema) 53 annotator = TypeAnnotator(schema) 54 infer_schema = schema.empty if infer_schema is None else infer_schema 55 dialect = Dialect.get_or_raise(schema.dialect) 56 pseudocolumns = dialect.PSEUDOCOLUMNS 57 58 for scope in traverse_scope(expression): 59 resolver = Resolver(scope, schema, infer_schema=infer_schema) 60 _pop_table_column_aliases(scope.ctes) 61 _pop_table_column_aliases(scope.derived_tables) 62 using_column_tables = _expand_using(scope, resolver) 63 64 if schema.empty and expand_alias_refs: 65 _expand_alias_refs(scope, resolver) 66 67 _convert_columns_to_dots(scope, resolver) 68 _qualify_columns(scope, resolver) 69 70 if not schema.empty and expand_alias_refs: 71 _expand_alias_refs(scope, resolver) 72 73 if not isinstance(scope.expression, exp.UDTF): 74 if expand_stars: 75 _expand_stars( 76 scope, 77 resolver, 78 using_column_tables, 79 pseudocolumns, 80 annotator, 81 ) 82 qualify_outputs(scope) 83 84 _expand_group_by(scope, dialect) 85 _expand_order_by(scope, resolver) 86 87 if dialect == "bigquery": 88 annotator.annotate_scope(scope) 89 90 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
93def validate_qualify_columns(expression: E) -> E: 94 """Raise an `OptimizeError` if any columns aren't qualified""" 95 all_unqualified_columns = [] 96 for scope in traverse_scope(expression): 97 if isinstance(scope.expression, exp.Select): 98 unqualified_columns = scope.unqualified_columns 99 100 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 101 column = scope.external_columns[0] 102 for_table = f" for table: '{column.table}'" if column.table else "" 103 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 104 105 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 106 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 107 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 108 # this list here to ensure those in the former category will be excluded. 109 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 110 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 111 112 all_unqualified_columns.extend(unqualified_columns) 113 114 if all_unqualified_columns: 115 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 116 117 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
636def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 637 """Ensure all output columns are aliased""" 638 if isinstance(scope_or_expression, exp.Expression): 639 scope = build_scope(scope_or_expression) 640 if not isinstance(scope, Scope): 641 return 642 else: 643 scope = scope_or_expression 644 645 new_selections = [] 646 for i, (selection, aliased_column) in enumerate( 647 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 648 ): 649 if selection is None: 650 break 651 652 if isinstance(selection, exp.Subquery): 653 if not selection.output_name: 654 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 655 elif not isinstance(selection, exp.Alias) and not selection.is_star: 656 selection = alias( 657 selection, 658 alias=selection.output_name or f"_col_{i}", 659 copy=False, 660 ) 661 if aliased_column: 662 selection.set("alias", exp.to_identifier(aliased_column)) 663 664 new_selections.append(selection) 665 666 if isinstance(scope.expression, exp.Select): 667 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
670def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 671 """Makes sure all identifiers that need to be quoted are quoted.""" 672 return expression.transform( 673 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 674 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
677def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 678 """ 679 Pushes down the CTE alias columns into the projection, 680 681 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 682 683 Example: 684 >>> import sqlglot 685 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 686 >>> pushdown_cte_alias_columns(expression).sql() 687 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 688 689 Args: 690 expression: Expression to pushdown. 691 692 Returns: 693 The expression with the CTE aliases pushed down into the projection. 694 """ 695 for cte in expression.find_all(exp.CTE): 696 if cte.alias_column_names: 697 new_expressions = [] 698 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 699 if isinstance(projection, exp.Alias): 700 projection.set("alias", _alias) 701 else: 702 projection = alias(projection, alias=_alias) 703 new_expressions.append(projection) 704 cte.this.set("expressions", new_expressions) 705 706 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
709class Resolver: 710 """ 711 Helper for resolving columns. 712 713 This is a class so we can lazily load some things and easily share them across functions. 714 """ 715 716 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 717 self.scope = scope 718 self.schema = schema 719 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 720 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 721 self._all_columns: t.Optional[t.Set[str]] = None 722 self._infer_schema = infer_schema 723 724 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 725 """ 726 Get the table for a column name. 727 728 Args: 729 column_name: The column name to find the table for. 730 Returns: 731 The table name if it can be found/inferred. 732 """ 733 if self._unambiguous_columns is None: 734 self._unambiguous_columns = self._get_unambiguous_columns( 735 self._get_all_source_columns() 736 ) 737 738 table_name = self._unambiguous_columns.get(column_name) 739 740 if not table_name and self._infer_schema: 741 sources_without_schema = tuple( 742 source 743 for source, columns in self._get_all_source_columns().items() 744 if not columns or "*" in columns 745 ) 746 if len(sources_without_schema) == 1: 747 table_name = sources_without_schema[0] 748 749 if table_name not in self.scope.selected_sources: 750 return exp.to_identifier(table_name) 751 752 node, _ = self.scope.selected_sources.get(table_name) 753 754 if isinstance(node, exp.Query): 755 while node and node.alias != table_name: 756 node = node.parent 757 758 node_alias = node.args.get("alias") 759 if node_alias: 760 return exp.to_identifier(node_alias.this) 761 762 return exp.to_identifier(table_name) 763 764 @property 765 def all_columns(self) -> t.Set[str]: 766 """All available columns of all sources in this scope""" 767 if self._all_columns is None: 768 self._all_columns = { 769 column for columns in self._get_all_source_columns().values() for column in columns 770 } 771 return self._all_columns 772 773 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 774 """Resolve the source columns for a given source `name`.""" 775 if name not in self.scope.sources: 776 raise OptimizeError(f"Unknown table: {name}") 777 778 source = self.scope.sources[name] 779 780 if isinstance(source, exp.Table): 781 columns = self.schema.column_names(source, only_visible) 782 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 783 columns = source.expression.named_selects 784 785 # in bigquery, unnest structs are automatically scoped as tables, so you can 786 # directly select a struct field in a query. 787 # this handles the case where the unnest is statically defined. 788 if self.schema.dialect == "bigquery": 789 if source.expression.is_type(exp.DataType.Type.STRUCT): 790 for k in source.expression.type.expressions: # type: ignore 791 columns.append(k.name) 792 else: 793 columns = source.expression.named_selects 794 795 node, _ = self.scope.selected_sources.get(name) or (None, None) 796 if isinstance(node, Scope): 797 column_aliases = node.expression.alias_column_names 798 elif isinstance(node, exp.Expression): 799 column_aliases = node.alias_column_names 800 else: 801 column_aliases = [] 802 803 if column_aliases: 804 # If the source's columns are aliased, their aliases shadow the corresponding column names. 805 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 806 return [ 807 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 808 ] 809 return columns 810 811 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 812 if self._source_columns is None: 813 self._source_columns = { 814 source_name: self.get_source_columns(source_name) 815 for source_name, source in itertools.chain( 816 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 817 ) 818 } 819 return self._source_columns 820 821 def _get_unambiguous_columns( 822 self, source_columns: t.Dict[str, t.Sequence[str]] 823 ) -> t.Mapping[str, str]: 824 """ 825 Find all the unambiguous columns in sources. 826 827 Args: 828 source_columns: Mapping of names to source columns. 829 830 Returns: 831 Mapping of column name to source name. 832 """ 833 if not source_columns: 834 return {} 835 836 source_columns_pairs = list(source_columns.items()) 837 838 first_table, first_columns = source_columns_pairs[0] 839 840 if len(source_columns_pairs) == 1: 841 # Performance optimization - avoid copying first_columns if there is only one table. 842 return SingleValuedMapping(first_columns, first_table) 843 844 unambiguous_columns = {col: first_table for col in first_columns} 845 all_columns = set(unambiguous_columns) 846 847 for table, columns in source_columns_pairs[1:]: 848 unique = set(columns) 849 ambiguous = all_columns.intersection(unique) 850 all_columns.update(columns) 851 852 for column in ambiguous: 853 unambiguous_columns.pop(column, None) 854 for column in unique.difference(ambiguous): 855 unambiguous_columns[column] = table 856 857 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
716 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 717 self.scope = scope 718 self.schema = schema 719 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 720 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 721 self._all_columns: t.Optional[t.Set[str]] = None 722 self._infer_schema = infer_schema
724 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 725 """ 726 Get the table for a column name. 727 728 Args: 729 column_name: The column name to find the table for. 730 Returns: 731 The table name if it can be found/inferred. 732 """ 733 if self._unambiguous_columns is None: 734 self._unambiguous_columns = self._get_unambiguous_columns( 735 self._get_all_source_columns() 736 ) 737 738 table_name = self._unambiguous_columns.get(column_name) 739 740 if not table_name and self._infer_schema: 741 sources_without_schema = tuple( 742 source 743 for source, columns in self._get_all_source_columns().items() 744 if not columns or "*" in columns 745 ) 746 if len(sources_without_schema) == 1: 747 table_name = sources_without_schema[0] 748 749 if table_name not in self.scope.selected_sources: 750 return exp.to_identifier(table_name) 751 752 node, _ = self.scope.selected_sources.get(table_name) 753 754 if isinstance(node, exp.Query): 755 while node and node.alias != table_name: 756 node = node.parent 757 758 node_alias = node.args.get("alias") 759 if node_alias: 760 return exp.to_identifier(node_alias.this) 761 762 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
764 @property 765 def all_columns(self) -> t.Set[str]: 766 """All available columns of all sources in this scope""" 767 if self._all_columns is None: 768 self._all_columns = { 769 column for columns in self._get_all_source_columns().values() for column in columns 770 } 771 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
773 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 774 """Resolve the source columns for a given source `name`.""" 775 if name not in self.scope.sources: 776 raise OptimizeError(f"Unknown table: {name}") 777 778 source = self.scope.sources[name] 779 780 if isinstance(source, exp.Table): 781 columns = self.schema.column_names(source, only_visible) 782 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 783 columns = source.expression.named_selects 784 785 # in bigquery, unnest structs are automatically scoped as tables, so you can 786 # directly select a struct field in a query. 787 # this handles the case where the unnest is statically defined. 788 if self.schema.dialect == "bigquery": 789 if source.expression.is_type(exp.DataType.Type.STRUCT): 790 for k in source.expression.type.expressions: # type: ignore 791 columns.append(k.name) 792 else: 793 columns = source.expression.named_selects 794 795 node, _ = self.scope.selected_sources.get(name) or (None, None) 796 if isinstance(node, Scope): 797 column_aliases = node.expression.alias_column_names 798 elif isinstance(node, exp.Expression): 799 column_aliases = node.alias_column_names 800 else: 801 column_aliases = [] 802 803 if column_aliases: 804 # If the source's columns are aliased, their aliases shadow the corresponding column names. 805 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 806 return [ 807 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 808 ] 809 return columns
Resolve the source columns for a given source name
.