sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 annotator = TypeAnnotator(schema) 53 infer_schema = schema.empty if infer_schema is None else infer_schema 54 dialect = Dialect.get_or_raise(schema.dialect) 55 pseudocolumns = dialect.PSEUDOCOLUMNS 56 57 for scope in traverse_scope(expression): 58 resolver = Resolver(scope, schema, infer_schema=infer_schema) 59 _pop_table_column_aliases(scope.ctes) 60 _pop_table_column_aliases(scope.derived_tables) 61 using_column_tables = _expand_using(scope, resolver) 62 63 if schema.empty and expand_alias_refs: 64 _expand_alias_refs(scope, resolver) 65 66 _qualify_columns(scope, resolver) 67 68 if not schema.empty and expand_alias_refs: 69 _expand_alias_refs(scope, resolver) 70 71 if not isinstance(scope.expression, exp.UDTF): 72 if expand_stars: 73 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 74 qualify_outputs(scope) 75 76 _expand_group_by(scope) 77 _expand_order_by(scope, resolver) 78 79 if dialect == "bigquery": 80 annotator.annotate_scope(scope) 81 82 return expression 83 84 85def validate_qualify_columns(expression: E) -> E: 86 """Raise an `OptimizeError` if any columns aren't qualified""" 87 all_unqualified_columns = [] 88 for scope in traverse_scope(expression): 89 if isinstance(scope.expression, exp.Select): 90 unqualified_columns = scope.unqualified_columns 91 92 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 93 column = scope.external_columns[0] 94 for_table = f" for table: '{column.table}'" if column.table else "" 95 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 96 97 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 98 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 99 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 100 # this list here to ensure those in the former category will be excluded. 101 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 102 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 103 104 all_unqualified_columns.extend(unqualified_columns) 105 106 if all_unqualified_columns: 107 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 108 109 return expression 110 111 112def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 113 name_column = [] 114 field = unpivot.args.get("field") 115 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 116 name_column.append(field.this) 117 118 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 119 return itertools.chain(name_column, value_columns) 120 121 122def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 123 """ 124 Remove table column aliases. 125 126 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 127 """ 128 for derived_table in derived_tables: 129 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 130 continue 131 table_alias = derived_table.args.get("alias") 132 if table_alias: 133 table_alias.args.pop("columns", None) 134 135 136def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 137 joins = list(scope.find_all(exp.Join)) 138 names = {join.alias_or_name for join in joins} 139 ordered = [key for key in scope.selected_sources if key not in names] 140 141 # Mapping of automatically joined column names to an ordered set of source names (dict). 142 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 143 144 for join in joins: 145 using = join.args.get("using") 146 147 if not using: 148 continue 149 150 join_table = join.alias_or_name 151 152 columns = {} 153 154 for source_name in scope.selected_sources: 155 if source_name in ordered: 156 for column_name in resolver.get_source_columns(source_name): 157 if column_name not in columns: 158 columns[column_name] = source_name 159 160 source_table = ordered[-1] 161 ordered.append(join_table) 162 join_columns = resolver.get_source_columns(join_table) 163 conditions = [] 164 165 for identifier in using: 166 identifier = identifier.name 167 table = columns.get(identifier) 168 169 if not table or identifier not in join_columns: 170 if (columns and "*" not in columns) and join_columns: 171 raise OptimizeError(f"Cannot automatically join: {identifier}") 172 173 table = table or source_table 174 conditions.append( 175 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 176 ) 177 178 # Set all values in the dict to None, because we only care about the key ordering 179 tables = column_tables.setdefault(identifier, {}) 180 if table not in tables: 181 tables[table] = None 182 if join_table not in tables: 183 tables[join_table] = None 184 185 join.args.pop("using") 186 join.set("on", exp.and_(*conditions, copy=False)) 187 188 if column_tables: 189 for column in scope.columns: 190 if not column.table and column.name in column_tables: 191 tables = column_tables[column.name] 192 coalesce = [exp.column(column.name, table=table) for table in tables] 193 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 194 195 # Ensure selects keep their output name 196 if isinstance(column.parent, exp.Select): 197 replacement = alias(replacement, alias=column.name, copy=False) 198 199 scope.replace(column, replacement) 200 201 return column_tables 202 203 204def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 205 expression = scope.expression 206 207 if not isinstance(expression, exp.Select): 208 return 209 210 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 211 212 def replace_columns( 213 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 214 ) -> None: 215 if not node: 216 return 217 218 for column in walk_in_scope(node, prune=lambda node: node.is_star): 219 if not isinstance(column, exp.Column): 220 continue 221 222 table = resolver.get_table(column.name) if resolve_table and not column.table else None 223 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 224 double_agg = ( 225 ( 226 alias_expr.find(exp.AggFunc) 227 and ( 228 column.find_ancestor(exp.AggFunc) 229 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 230 ) 231 ) 232 if alias_expr 233 else False 234 ) 235 236 if table and (not alias_expr or double_agg): 237 column.set("table", table) 238 elif not column.table and alias_expr and not double_agg: 239 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 240 if literal_index: 241 column.replace(exp.Literal.number(i)) 242 else: 243 column = column.replace(exp.paren(alias_expr)) 244 simplified = simplify_parens(column) 245 if simplified is not column: 246 column.replace(simplified) 247 248 for i, projection in enumerate(scope.expression.selects): 249 replace_columns(projection) 250 251 if isinstance(projection, exp.Alias): 252 alias_to_expression[projection.alias] = (projection.this, i + 1) 253 254 replace_columns(expression.args.get("where")) 255 replace_columns(expression.args.get("group"), literal_index=True) 256 replace_columns(expression.args.get("having"), resolve_table=True) 257 replace_columns(expression.args.get("qualify"), resolve_table=True) 258 259 scope.clear_cache() 260 261 262def _expand_group_by(scope: Scope) -> None: 263 expression = scope.expression 264 group = expression.args.get("group") 265 if not group: 266 return 267 268 group.set("expressions", _expand_positional_references(scope, group.expressions)) 269 expression.set("group", group) 270 271 272def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 273 order = scope.expression.args.get("order") 274 if not order: 275 return 276 277 ordereds = order.expressions 278 for ordered, new_expression in zip( 279 ordereds, 280 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 281 ): 282 for agg in ordered.find_all(exp.AggFunc): 283 for col in agg.find_all(exp.Column): 284 if not col.table: 285 col.set("table", resolver.get_table(col.name)) 286 287 ordered.set("this", new_expression) 288 289 if scope.expression.args.get("group"): 290 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 291 292 for ordered in ordereds: 293 ordered = ordered.this 294 295 ordered.replace( 296 exp.to_identifier(_select_by_pos(scope, ordered).alias) 297 if ordered.is_int 298 else selects.get(ordered, ordered) 299 ) 300 301 302def _expand_positional_references( 303 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 304) -> t.List[exp.Expression]: 305 new_nodes: t.List[exp.Expression] = [] 306 for node in expressions: 307 if node.is_int: 308 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 309 310 if alias: 311 new_nodes.append(exp.column(select.args["alias"].copy())) 312 else: 313 select = select.this 314 315 if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest): 316 new_nodes.append(node) 317 else: 318 new_nodes.append(select.copy()) 319 else: 320 new_nodes.append(node) 321 322 return new_nodes 323 324 325def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 326 try: 327 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 328 except IndexError: 329 raise OptimizeError(f"Unknown output column: {node.name}") 330 331 332def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 333 """Disambiguate columns, ensuring each column specifies a source""" 334 for column in scope.columns: 335 column_table = column.table 336 column_name = column.name 337 338 if column_table and column_table in scope.sources: 339 source_columns = resolver.get_source_columns(column_table) 340 if source_columns and column_name not in source_columns and "*" not in source_columns: 341 raise OptimizeError(f"Unknown column: {column_name}") 342 343 if not column_table: 344 if scope.pivots and not column.find_ancestor(exp.Pivot): 345 # If the column is under the Pivot expression, we need to qualify it 346 # using the name of the pivoted source instead of the pivot's alias 347 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 348 continue 349 350 column_table = resolver.get_table(column_name) 351 352 # column_table can be a '' because bigquery unnest has no table alias 353 if column_table: 354 column.set("table", column_table) 355 elif column_table not in scope.sources and ( 356 not scope.parent 357 or column_table not in scope.parent.sources 358 or not scope.is_correlated_subquery 359 ): 360 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 361 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 362 363 root, *parts = column.parts 364 365 if root.name in scope.sources: 366 # struct is already qualified, but we still need to change the AST representation 367 column_table = root 368 root, *parts = parts 369 else: 370 column_table = resolver.get_table(root.name) 371 372 if column_table: 373 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 374 375 for pivot in scope.pivots: 376 for column in pivot.find_all(exp.Column): 377 if not column.table and column.name in resolver.all_columns: 378 column_table = resolver.get_table(column.name) 379 if column_table: 380 column.set("table", column_table) 381 382 383def _expand_stars( 384 scope: Scope, 385 resolver: Resolver, 386 using_column_tables: t.Dict[str, t.Any], 387 pseudocolumns: t.Set[str], 388) -> None: 389 """Expand stars to lists of column selections""" 390 391 new_selections = [] 392 except_columns: t.Dict[int, t.Set[str]] = {} 393 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 394 coalesced_columns = set() 395 396 pivot_output_columns = None 397 pivot_exclude_columns = None 398 399 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 400 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 401 if pivot.unpivot: 402 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 403 404 field = pivot.args.get("field") 405 if isinstance(field, exp.In): 406 pivot_exclude_columns = { 407 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 408 } 409 else: 410 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 411 412 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 413 if not pivot_output_columns: 414 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 415 416 for expression in scope.expression.selects: 417 if isinstance(expression, exp.Star): 418 tables = list(scope.selected_sources) 419 _add_except_columns(expression, tables, except_columns) 420 _add_replace_columns(expression, tables, replace_columns) 421 elif expression.is_star and not isinstance(expression, exp.Dot): 422 tables = [expression.table] 423 _add_except_columns(expression.this, tables, except_columns) 424 _add_replace_columns(expression.this, tables, replace_columns) 425 else: 426 new_selections.append(expression) 427 continue 428 429 for table in tables: 430 if table not in scope.sources: 431 raise OptimizeError(f"Unknown table: {table}") 432 433 columns = resolver.get_source_columns(table, only_visible=True) 434 columns = columns or scope.outer_columns 435 436 if pseudocolumns: 437 columns = [name for name in columns if name.upper() not in pseudocolumns] 438 439 if not columns or "*" in columns: 440 return 441 442 table_id = id(table) 443 columns_to_exclude = except_columns.get(table_id) or set() 444 445 if pivot: 446 if pivot_output_columns and pivot_exclude_columns: 447 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 448 pivot_columns.extend(pivot_output_columns) 449 else: 450 pivot_columns = pivot.alias_column_names 451 452 if pivot_columns: 453 new_selections.extend( 454 alias(exp.column(name, table=pivot.alias), name, copy=False) 455 for name in pivot_columns 456 if name not in columns_to_exclude 457 ) 458 continue 459 460 for name in columns: 461 if name in columns_to_exclude or name in coalesced_columns: 462 continue 463 if name in using_column_tables and table in using_column_tables[name]: 464 coalesced_columns.add(name) 465 tables = using_column_tables[name] 466 coalesce = [exp.column(name, table=table) for table in tables] 467 468 new_selections.append( 469 alias( 470 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 471 alias=name, 472 copy=False, 473 ) 474 ) 475 else: 476 alias_ = replace_columns.get(table_id, {}).get(name, name) 477 column = exp.column(name, table=table) 478 new_selections.append( 479 alias(column, alias_, copy=False) if alias_ != name else column 480 ) 481 482 # Ensures we don't overwrite the initial selections with an empty list 483 if new_selections and isinstance(scope.expression, exp.Select): 484 scope.expression.set("expressions", new_selections) 485 486 487def _add_except_columns( 488 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 489) -> None: 490 except_ = expression.args.get("except") 491 492 if not except_: 493 return 494 495 columns = {e.name for e in except_} 496 497 for table in tables: 498 except_columns[id(table)] = columns 499 500 501def _add_replace_columns( 502 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 503) -> None: 504 replace = expression.args.get("replace") 505 506 if not replace: 507 return 508 509 columns = {e.this.name: e.alias for e in replace} 510 511 for table in tables: 512 replace_columns[id(table)] = columns 513 514 515def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 516 """Ensure all output columns are aliased""" 517 if isinstance(scope_or_expression, exp.Expression): 518 scope = build_scope(scope_or_expression) 519 if not isinstance(scope, Scope): 520 return 521 else: 522 scope = scope_or_expression 523 524 new_selections = [] 525 for i, (selection, aliased_column) in enumerate( 526 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 527 ): 528 if selection is None: 529 break 530 531 if isinstance(selection, exp.Subquery): 532 if not selection.output_name: 533 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 534 elif not isinstance(selection, exp.Alias) and not selection.is_star: 535 selection = alias( 536 selection, 537 alias=selection.output_name or f"_col_{i}", 538 copy=False, 539 ) 540 if aliased_column: 541 selection.set("alias", exp.to_identifier(aliased_column)) 542 543 new_selections.append(selection) 544 545 if isinstance(scope.expression, exp.Select): 546 scope.expression.set("expressions", new_selections) 547 548 549def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 550 """Makes sure all identifiers that need to be quoted are quoted.""" 551 return expression.transform( 552 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 553 ) # type: ignore 554 555 556def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 557 """ 558 Pushes down the CTE alias columns into the projection, 559 560 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 561 562 Example: 563 >>> import sqlglot 564 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 565 >>> pushdown_cte_alias_columns(expression).sql() 566 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 567 568 Args: 569 expression: Expression to pushdown. 570 571 Returns: 572 The expression with the CTE aliases pushed down into the projection. 573 """ 574 for cte in expression.find_all(exp.CTE): 575 if cte.alias_column_names: 576 new_expressions = [] 577 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 578 if isinstance(projection, exp.Alias): 579 projection.set("alias", _alias) 580 else: 581 projection = alias(projection, alias=_alias) 582 new_expressions.append(projection) 583 cte.this.set("expressions", new_expressions) 584 585 return expression 586 587 588class Resolver: 589 """ 590 Helper for resolving columns. 591 592 This is a class so we can lazily load some things and easily share them across functions. 593 """ 594 595 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 596 self.scope = scope 597 self.schema = schema 598 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 599 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 600 self._all_columns: t.Optional[t.Set[str]] = None 601 self._infer_schema = infer_schema 602 603 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 604 """ 605 Get the table for a column name. 606 607 Args: 608 column_name: The column name to find the table for. 609 Returns: 610 The table name if it can be found/inferred. 611 """ 612 if self._unambiguous_columns is None: 613 self._unambiguous_columns = self._get_unambiguous_columns( 614 self._get_all_source_columns() 615 ) 616 617 table_name = self._unambiguous_columns.get(column_name) 618 619 if not table_name and self._infer_schema: 620 sources_without_schema = tuple( 621 source 622 for source, columns in self._get_all_source_columns().items() 623 if not columns or "*" in columns 624 ) 625 if len(sources_without_schema) == 1: 626 table_name = sources_without_schema[0] 627 628 if table_name not in self.scope.selected_sources: 629 return exp.to_identifier(table_name) 630 631 node, _ = self.scope.selected_sources.get(table_name) 632 633 if isinstance(node, exp.Query): 634 while node and node.alias != table_name: 635 node = node.parent 636 637 node_alias = node.args.get("alias") 638 if node_alias: 639 return exp.to_identifier(node_alias.this) 640 641 return exp.to_identifier(table_name) 642 643 @property 644 def all_columns(self) -> t.Set[str]: 645 """All available columns of all sources in this scope""" 646 if self._all_columns is None: 647 self._all_columns = { 648 column for columns in self._get_all_source_columns().values() for column in columns 649 } 650 return self._all_columns 651 652 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 653 """Resolve the source columns for a given source `name`.""" 654 if name not in self.scope.sources: 655 raise OptimizeError(f"Unknown table: {name}") 656 657 source = self.scope.sources[name] 658 659 if isinstance(source, exp.Table): 660 columns = self.schema.column_names(source, only_visible) 661 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 662 columns = source.expression.named_selects 663 664 # in bigquery, unnest structs are automatically scoped as tables, so you can 665 # directly select a struct field in a query. 666 # this handles the case where the unnest is statically defined. 667 if self.schema.dialect == "bigquery": 668 if source.expression.is_type(exp.DataType.Type.STRUCT): 669 for k in source.expression.type.expressions: # type: ignore 670 columns.append(k.name) 671 else: 672 columns = source.expression.named_selects 673 674 node, _ = self.scope.selected_sources.get(name) or (None, None) 675 if isinstance(node, Scope): 676 column_aliases = node.expression.alias_column_names 677 elif isinstance(node, exp.Expression): 678 column_aliases = node.alias_column_names 679 else: 680 column_aliases = [] 681 682 if column_aliases: 683 # If the source's columns are aliased, their aliases shadow the corresponding column names. 684 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 685 return [ 686 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 687 ] 688 return columns 689 690 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 691 if self._source_columns is None: 692 self._source_columns = { 693 source_name: self.get_source_columns(source_name) 694 for source_name, source in itertools.chain( 695 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 696 ) 697 } 698 return self._source_columns 699 700 def _get_unambiguous_columns( 701 self, source_columns: t.Dict[str, t.Sequence[str]] 702 ) -> t.Mapping[str, str]: 703 """ 704 Find all the unambiguous columns in sources. 705 706 Args: 707 source_columns: Mapping of names to source columns. 708 709 Returns: 710 Mapping of column name to source name. 711 """ 712 if not source_columns: 713 return {} 714 715 source_columns_pairs = list(source_columns.items()) 716 717 first_table, first_columns = source_columns_pairs[0] 718 719 if len(source_columns_pairs) == 1: 720 # Performance optimization - avoid copying first_columns if there is only one table. 721 return SingleValuedMapping(first_columns, first_table) 722 723 unambiguous_columns = {col: first_table for col in first_columns} 724 all_columns = set(unambiguous_columns) 725 726 for table, columns in source_columns_pairs[1:]: 727 unique = set(columns) 728 ambiguous = all_columns.intersection(unique) 729 all_columns.update(columns) 730 731 for column in ambiguous: 732 unambiguous_columns.pop(column, None) 733 for column in unique.difference(ambiguous): 734 unambiguous_columns[column] = table 735 736 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 46 Returns: 47 The qualified expression. 48 49 Notes: 50 - Currently only handles a single PIVOT or UNPIVOT operator 51 """ 52 schema = ensure_schema(schema) 53 annotator = TypeAnnotator(schema) 54 infer_schema = schema.empty if infer_schema is None else infer_schema 55 dialect = Dialect.get_or_raise(schema.dialect) 56 pseudocolumns = dialect.PSEUDOCOLUMNS 57 58 for scope in traverse_scope(expression): 59 resolver = Resolver(scope, schema, infer_schema=infer_schema) 60 _pop_table_column_aliases(scope.ctes) 61 _pop_table_column_aliases(scope.derived_tables) 62 using_column_tables = _expand_using(scope, resolver) 63 64 if schema.empty and expand_alias_refs: 65 _expand_alias_refs(scope, resolver) 66 67 _qualify_columns(scope, resolver) 68 69 if not schema.empty and expand_alias_refs: 70 _expand_alias_refs(scope, resolver) 71 72 if not isinstance(scope.expression, exp.UDTF): 73 if expand_stars: 74 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 75 qualify_outputs(scope) 76 77 _expand_group_by(scope) 78 _expand_order_by(scope, resolver) 79 80 if dialect == "bigquery": 81 annotator.annotate_scope(scope) 82 83 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
86def validate_qualify_columns(expression: E) -> E: 87 """Raise an `OptimizeError` if any columns aren't qualified""" 88 all_unqualified_columns = [] 89 for scope in traverse_scope(expression): 90 if isinstance(scope.expression, exp.Select): 91 unqualified_columns = scope.unqualified_columns 92 93 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 94 column = scope.external_columns[0] 95 for_table = f" for table: '{column.table}'" if column.table else "" 96 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 97 98 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 99 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 100 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 101 # this list here to ensure those in the former category will be excluded. 102 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 103 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 104 105 all_unqualified_columns.extend(unqualified_columns) 106 107 if all_unqualified_columns: 108 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 109 110 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
516def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 517 """Ensure all output columns are aliased""" 518 if isinstance(scope_or_expression, exp.Expression): 519 scope = build_scope(scope_or_expression) 520 if not isinstance(scope, Scope): 521 return 522 else: 523 scope = scope_or_expression 524 525 new_selections = [] 526 for i, (selection, aliased_column) in enumerate( 527 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 528 ): 529 if selection is None: 530 break 531 532 if isinstance(selection, exp.Subquery): 533 if not selection.output_name: 534 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 535 elif not isinstance(selection, exp.Alias) and not selection.is_star: 536 selection = alias( 537 selection, 538 alias=selection.output_name or f"_col_{i}", 539 copy=False, 540 ) 541 if aliased_column: 542 selection.set("alias", exp.to_identifier(aliased_column)) 543 544 new_selections.append(selection) 545 546 if isinstance(scope.expression, exp.Select): 547 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
550def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 551 """Makes sure all identifiers that need to be quoted are quoted.""" 552 return expression.transform( 553 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 554 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
557def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 558 """ 559 Pushes down the CTE alias columns into the projection, 560 561 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 562 563 Example: 564 >>> import sqlglot 565 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 566 >>> pushdown_cte_alias_columns(expression).sql() 567 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 568 569 Args: 570 expression: Expression to pushdown. 571 572 Returns: 573 The expression with the CTE aliases pushed down into the projection. 574 """ 575 for cte in expression.find_all(exp.CTE): 576 if cte.alias_column_names: 577 new_expressions = [] 578 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 579 if isinstance(projection, exp.Alias): 580 projection.set("alias", _alias) 581 else: 582 projection = alias(projection, alias=_alias) 583 new_expressions.append(projection) 584 cte.this.set("expressions", new_expressions) 585 586 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
589class Resolver: 590 """ 591 Helper for resolving columns. 592 593 This is a class so we can lazily load some things and easily share them across functions. 594 """ 595 596 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 597 self.scope = scope 598 self.schema = schema 599 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 600 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 601 self._all_columns: t.Optional[t.Set[str]] = None 602 self._infer_schema = infer_schema 603 604 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 605 """ 606 Get the table for a column name. 607 608 Args: 609 column_name: The column name to find the table for. 610 Returns: 611 The table name if it can be found/inferred. 612 """ 613 if self._unambiguous_columns is None: 614 self._unambiguous_columns = self._get_unambiguous_columns( 615 self._get_all_source_columns() 616 ) 617 618 table_name = self._unambiguous_columns.get(column_name) 619 620 if not table_name and self._infer_schema: 621 sources_without_schema = tuple( 622 source 623 for source, columns in self._get_all_source_columns().items() 624 if not columns or "*" in columns 625 ) 626 if len(sources_without_schema) == 1: 627 table_name = sources_without_schema[0] 628 629 if table_name not in self.scope.selected_sources: 630 return exp.to_identifier(table_name) 631 632 node, _ = self.scope.selected_sources.get(table_name) 633 634 if isinstance(node, exp.Query): 635 while node and node.alias != table_name: 636 node = node.parent 637 638 node_alias = node.args.get("alias") 639 if node_alias: 640 return exp.to_identifier(node_alias.this) 641 642 return exp.to_identifier(table_name) 643 644 @property 645 def all_columns(self) -> t.Set[str]: 646 """All available columns of all sources in this scope""" 647 if self._all_columns is None: 648 self._all_columns = { 649 column for columns in self._get_all_source_columns().values() for column in columns 650 } 651 return self._all_columns 652 653 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 654 """Resolve the source columns for a given source `name`.""" 655 if name not in self.scope.sources: 656 raise OptimizeError(f"Unknown table: {name}") 657 658 source = self.scope.sources[name] 659 660 if isinstance(source, exp.Table): 661 columns = self.schema.column_names(source, only_visible) 662 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 663 columns = source.expression.named_selects 664 665 # in bigquery, unnest structs are automatically scoped as tables, so you can 666 # directly select a struct field in a query. 667 # this handles the case where the unnest is statically defined. 668 if self.schema.dialect == "bigquery": 669 if source.expression.is_type(exp.DataType.Type.STRUCT): 670 for k in source.expression.type.expressions: # type: ignore 671 columns.append(k.name) 672 else: 673 columns = source.expression.named_selects 674 675 node, _ = self.scope.selected_sources.get(name) or (None, None) 676 if isinstance(node, Scope): 677 column_aliases = node.expression.alias_column_names 678 elif isinstance(node, exp.Expression): 679 column_aliases = node.alias_column_names 680 else: 681 column_aliases = [] 682 683 if column_aliases: 684 # If the source's columns are aliased, their aliases shadow the corresponding column names. 685 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 686 return [ 687 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 688 ] 689 return columns 690 691 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 692 if self._source_columns is None: 693 self._source_columns = { 694 source_name: self.get_source_columns(source_name) 695 for source_name, source in itertools.chain( 696 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 697 ) 698 } 699 return self._source_columns 700 701 def _get_unambiguous_columns( 702 self, source_columns: t.Dict[str, t.Sequence[str]] 703 ) -> t.Mapping[str, str]: 704 """ 705 Find all the unambiguous columns in sources. 706 707 Args: 708 source_columns: Mapping of names to source columns. 709 710 Returns: 711 Mapping of column name to source name. 712 """ 713 if not source_columns: 714 return {} 715 716 source_columns_pairs = list(source_columns.items()) 717 718 first_table, first_columns = source_columns_pairs[0] 719 720 if len(source_columns_pairs) == 1: 721 # Performance optimization - avoid copying first_columns if there is only one table. 722 return SingleValuedMapping(first_columns, first_table) 723 724 unambiguous_columns = {col: first_table for col in first_columns} 725 all_columns = set(unambiguous_columns) 726 727 for table, columns in source_columns_pairs[1:]: 728 unique = set(columns) 729 ambiguous = all_columns.intersection(unique) 730 all_columns.update(columns) 731 732 for column in ambiguous: 733 unambiguous_columns.pop(column, None) 734 for column in unique.difference(ambiguous): 735 unambiguous_columns[column] = table 736 737 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
596 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 597 self.scope = scope 598 self.schema = schema 599 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 600 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 601 self._all_columns: t.Optional[t.Set[str]] = None 602 self._infer_schema = infer_schema
604 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 605 """ 606 Get the table for a column name. 607 608 Args: 609 column_name: The column name to find the table for. 610 Returns: 611 The table name if it can be found/inferred. 612 """ 613 if self._unambiguous_columns is None: 614 self._unambiguous_columns = self._get_unambiguous_columns( 615 self._get_all_source_columns() 616 ) 617 618 table_name = self._unambiguous_columns.get(column_name) 619 620 if not table_name and self._infer_schema: 621 sources_without_schema = tuple( 622 source 623 for source, columns in self._get_all_source_columns().items() 624 if not columns or "*" in columns 625 ) 626 if len(sources_without_schema) == 1: 627 table_name = sources_without_schema[0] 628 629 if table_name not in self.scope.selected_sources: 630 return exp.to_identifier(table_name) 631 632 node, _ = self.scope.selected_sources.get(table_name) 633 634 if isinstance(node, exp.Query): 635 while node and node.alias != table_name: 636 node = node.parent 637 638 node_alias = node.args.get("alias") 639 if node_alias: 640 return exp.to_identifier(node_alias.this) 641 642 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
644 @property 645 def all_columns(self) -> t.Set[str]: 646 """All available columns of all sources in this scope""" 647 if self._all_columns is None: 648 self._all_columns = { 649 column for columns in self._get_all_source_columns().values() for column in columns 650 } 651 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
653 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 654 """Resolve the source columns for a given source `name`.""" 655 if name not in self.scope.sources: 656 raise OptimizeError(f"Unknown table: {name}") 657 658 source = self.scope.sources[name] 659 660 if isinstance(source, exp.Table): 661 columns = self.schema.column_names(source, only_visible) 662 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 663 columns = source.expression.named_selects 664 665 # in bigquery, unnest structs are automatically scoped as tables, so you can 666 # directly select a struct field in a query. 667 # this handles the case where the unnest is statically defined. 668 if self.schema.dialect == "bigquery": 669 if source.expression.is_type(exp.DataType.Type.STRUCT): 670 for k in source.expression.type.expressions: # type: ignore 671 columns.append(k.name) 672 else: 673 columns = source.expression.named_selects 674 675 node, _ = self.scope.selected_sources.get(name) or (None, None) 676 if isinstance(node, Scope): 677 column_aliases = node.expression.alias_column_names 678 elif isinstance(node, exp.Expression): 679 column_aliases = node.alias_column_names 680 else: 681 column_aliases = [] 682 683 if column_aliases: 684 # If the source's columns are aliased, their aliases shadow the corresponding column names. 685 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 686 return [ 687 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 688 ] 689 return columns
Resolve the source columns for a given source name
.