sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import annotate_types 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 infer_schema = schema.empty if infer_schema is None else infer_schema 53 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 54 55 for scope in traverse_scope(expression): 56 resolver = Resolver(scope, schema, infer_schema=infer_schema) 57 _pop_table_column_aliases(scope.ctes) 58 _pop_table_column_aliases(scope.derived_tables) 59 using_column_tables = _expand_using(scope, resolver) 60 61 if schema.empty and expand_alias_refs: 62 _expand_alias_refs(scope, resolver) 63 64 _qualify_columns(scope, resolver) 65 66 if not schema.empty and expand_alias_refs: 67 _expand_alias_refs(scope, resolver) 68 69 if not isinstance(scope.expression, exp.UDTF): 70 if expand_stars: 71 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 72 qualify_outputs(scope) 73 74 _expand_group_by(scope) 75 _expand_order_by(scope, resolver) 76 77 return expression 78 79 80def validate_qualify_columns(expression: E) -> E: 81 """Raise an `OptimizeError` if any columns aren't qualified""" 82 all_unqualified_columns = [] 83 for scope in traverse_scope(expression): 84 if isinstance(scope.expression, exp.Select): 85 unqualified_columns = scope.unqualified_columns 86 87 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 88 column = scope.external_columns[0] 89 for_table = f" for table: '{column.table}'" if column.table else "" 90 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 91 92 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 93 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 94 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 95 # this list here to ensure those in the former category will be excluded. 96 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 97 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 98 99 all_unqualified_columns.extend(unqualified_columns) 100 101 if all_unqualified_columns: 102 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 103 104 return expression 105 106 107def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 108 name_column = [] 109 field = unpivot.args.get("field") 110 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 111 name_column.append(field.this) 112 113 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 114 return itertools.chain(name_column, value_columns) 115 116 117def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 118 """ 119 Remove table column aliases. 120 121 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 122 """ 123 for derived_table in derived_tables: 124 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 125 continue 126 table_alias = derived_table.args.get("alias") 127 if table_alias: 128 table_alias.args.pop("columns", None) 129 130 131def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 132 joins = list(scope.find_all(exp.Join)) 133 names = {join.alias_or_name for join in joins} 134 ordered = [key for key in scope.selected_sources if key not in names] 135 136 # Mapping of automatically joined column names to an ordered set of source names (dict). 137 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 138 139 for join in joins: 140 using = join.args.get("using") 141 142 if not using: 143 continue 144 145 join_table = join.alias_or_name 146 147 columns = {} 148 149 for source_name in scope.selected_sources: 150 if source_name in ordered: 151 for column_name in resolver.get_source_columns(source_name): 152 if column_name not in columns: 153 columns[column_name] = source_name 154 155 source_table = ordered[-1] 156 ordered.append(join_table) 157 join_columns = resolver.get_source_columns(join_table) 158 conditions = [] 159 160 for identifier in using: 161 identifier = identifier.name 162 table = columns.get(identifier) 163 164 if not table or identifier not in join_columns: 165 if (columns and "*" not in columns) and join_columns: 166 raise OptimizeError(f"Cannot automatically join: {identifier}") 167 168 table = table or source_table 169 conditions.append( 170 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 171 ) 172 173 # Set all values in the dict to None, because we only care about the key ordering 174 tables = column_tables.setdefault(identifier, {}) 175 if table not in tables: 176 tables[table] = None 177 if join_table not in tables: 178 tables[join_table] = None 179 180 join.args.pop("using") 181 join.set("on", exp.and_(*conditions, copy=False)) 182 183 if column_tables: 184 for column in scope.columns: 185 if not column.table and column.name in column_tables: 186 tables = column_tables[column.name] 187 coalesce = [exp.column(column.name, table=table) for table in tables] 188 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 189 190 # Ensure selects keep their output name 191 if isinstance(column.parent, exp.Select): 192 replacement = alias(replacement, alias=column.name, copy=False) 193 194 scope.replace(column, replacement) 195 196 return column_tables 197 198 199def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 200 expression = scope.expression 201 202 if not isinstance(expression, exp.Select): 203 return 204 205 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 206 207 def replace_columns( 208 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 209 ) -> None: 210 if not node: 211 return 212 213 for column in walk_in_scope(node, prune=lambda node: node.is_star): 214 if not isinstance(column, exp.Column): 215 continue 216 217 table = resolver.get_table(column.name) if resolve_table and not column.table else None 218 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 219 double_agg = ( 220 ( 221 alias_expr.find(exp.AggFunc) 222 and ( 223 column.find_ancestor(exp.AggFunc) 224 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 225 ) 226 ) 227 if alias_expr 228 else False 229 ) 230 231 if table and (not alias_expr or double_agg): 232 column.set("table", table) 233 elif not column.table and alias_expr and not double_agg: 234 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 235 if literal_index: 236 column.replace(exp.Literal.number(i)) 237 else: 238 column = column.replace(exp.paren(alias_expr)) 239 simplified = simplify_parens(column) 240 if simplified is not column: 241 column.replace(simplified) 242 243 for i, projection in enumerate(scope.expression.selects): 244 replace_columns(projection) 245 246 if isinstance(projection, exp.Alias): 247 alias_to_expression[projection.alias] = (projection.this, i + 1) 248 249 replace_columns(expression.args.get("where")) 250 replace_columns(expression.args.get("group"), literal_index=True) 251 replace_columns(expression.args.get("having"), resolve_table=True) 252 replace_columns(expression.args.get("qualify"), resolve_table=True) 253 254 scope.clear_cache() 255 256 257def _expand_group_by(scope: Scope) -> None: 258 expression = scope.expression 259 group = expression.args.get("group") 260 if not group: 261 return 262 263 group.set("expressions", _expand_positional_references(scope, group.expressions)) 264 expression.set("group", group) 265 266 267def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 268 order = scope.expression.args.get("order") 269 if not order: 270 return 271 272 ordereds = order.expressions 273 for ordered, new_expression in zip( 274 ordereds, 275 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 276 ): 277 for agg in ordered.find_all(exp.AggFunc): 278 for col in agg.find_all(exp.Column): 279 if not col.table: 280 col.set("table", resolver.get_table(col.name)) 281 282 ordered.set("this", new_expression) 283 284 if scope.expression.args.get("group"): 285 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 286 287 for ordered in ordereds: 288 ordered = ordered.this 289 290 ordered.replace( 291 exp.to_identifier(_select_by_pos(scope, ordered).alias) 292 if ordered.is_int 293 else selects.get(ordered, ordered) 294 ) 295 296 297def _expand_positional_references( 298 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 299) -> t.List[exp.Expression]: 300 new_nodes: t.List[exp.Expression] = [] 301 for node in expressions: 302 if node.is_int: 303 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 304 305 if alias: 306 new_nodes.append(exp.column(select.args["alias"].copy())) 307 else: 308 select = select.this 309 310 if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest): 311 new_nodes.append(node) 312 else: 313 new_nodes.append(select.copy()) 314 else: 315 new_nodes.append(node) 316 317 return new_nodes 318 319 320def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 321 try: 322 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 323 except IndexError: 324 raise OptimizeError(f"Unknown output column: {node.name}") 325 326 327def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 328 """Disambiguate columns, ensuring each column specifies a source""" 329 for column in scope.columns: 330 column_table = column.table 331 column_name = column.name 332 333 if column_table and column_table in scope.sources: 334 source_columns = resolver.get_source_columns(column_table) 335 if source_columns and column_name not in source_columns and "*" not in source_columns: 336 raise OptimizeError(f"Unknown column: {column_name}") 337 338 if not column_table: 339 if scope.pivots and not column.find_ancestor(exp.Pivot): 340 # If the column is under the Pivot expression, we need to qualify it 341 # using the name of the pivoted source instead of the pivot's alias 342 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 343 continue 344 345 column_table = resolver.get_table(column_name) 346 347 # column_table can be a '' because bigquery unnest has no table alias 348 if column_table: 349 column.set("table", column_table) 350 elif column_table not in scope.sources and ( 351 not scope.parent 352 or column_table not in scope.parent.sources 353 or not scope.is_correlated_subquery 354 ): 355 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 356 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 357 358 root, *parts = column.parts 359 360 if root.name in scope.sources: 361 # struct is already qualified, but we still need to change the AST representation 362 column_table = root 363 root, *parts = parts 364 else: 365 column_table = resolver.get_table(root.name) 366 367 if column_table: 368 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 369 370 for pivot in scope.pivots: 371 for column in pivot.find_all(exp.Column): 372 if not column.table and column.name in resolver.all_columns: 373 column_table = resolver.get_table(column.name) 374 if column_table: 375 column.set("table", column_table) 376 377 378def _expand_stars( 379 scope: Scope, 380 resolver: Resolver, 381 using_column_tables: t.Dict[str, t.Any], 382 pseudocolumns: t.Set[str], 383) -> None: 384 """Expand stars to lists of column selections""" 385 386 new_selections = [] 387 except_columns: t.Dict[int, t.Set[str]] = {} 388 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 389 coalesced_columns = set() 390 391 pivot_output_columns = None 392 pivot_exclude_columns = None 393 394 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 395 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 396 if pivot.unpivot: 397 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 398 399 field = pivot.args.get("field") 400 if isinstance(field, exp.In): 401 pivot_exclude_columns = { 402 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 403 } 404 else: 405 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 406 407 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 408 if not pivot_output_columns: 409 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 410 411 for expression in scope.expression.selects: 412 if isinstance(expression, exp.Star): 413 tables = list(scope.selected_sources) 414 _add_except_columns(expression, tables, except_columns) 415 _add_replace_columns(expression, tables, replace_columns) 416 elif expression.is_star and not isinstance(expression, exp.Dot): 417 tables = [expression.table] 418 _add_except_columns(expression.this, tables, except_columns) 419 _add_replace_columns(expression.this, tables, replace_columns) 420 else: 421 new_selections.append(expression) 422 continue 423 424 for table in tables: 425 if table not in scope.sources: 426 raise OptimizeError(f"Unknown table: {table}") 427 428 columns = resolver.get_source_columns(table, only_visible=True) 429 columns = columns or scope.outer_columns 430 431 if pseudocolumns: 432 columns = [name for name in columns if name.upper() not in pseudocolumns] 433 434 if not columns or "*" in columns: 435 return 436 437 table_id = id(table) 438 columns_to_exclude = except_columns.get(table_id) or set() 439 440 if pivot: 441 if pivot_output_columns and pivot_exclude_columns: 442 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 443 pivot_columns.extend(pivot_output_columns) 444 else: 445 pivot_columns = pivot.alias_column_names 446 447 if pivot_columns: 448 new_selections.extend( 449 alias(exp.column(name, table=pivot.alias), name, copy=False) 450 for name in pivot_columns 451 if name not in columns_to_exclude 452 ) 453 continue 454 455 for name in columns: 456 if name in columns_to_exclude or name in coalesced_columns: 457 continue 458 if name in using_column_tables and table in using_column_tables[name]: 459 coalesced_columns.add(name) 460 tables = using_column_tables[name] 461 coalesce = [exp.column(name, table=table) for table in tables] 462 463 new_selections.append( 464 alias( 465 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 466 alias=name, 467 copy=False, 468 ) 469 ) 470 else: 471 alias_ = replace_columns.get(table_id, {}).get(name, name) 472 column = exp.column(name, table=table) 473 new_selections.append( 474 alias(column, alias_, copy=False) if alias_ != name else column 475 ) 476 477 # Ensures we don't overwrite the initial selections with an empty list 478 if new_selections and isinstance(scope.expression, exp.Select): 479 scope.expression.set("expressions", new_selections) 480 481 482def _add_except_columns( 483 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 484) -> None: 485 except_ = expression.args.get("except") 486 487 if not except_: 488 return 489 490 columns = {e.name for e in except_} 491 492 for table in tables: 493 except_columns[id(table)] = columns 494 495 496def _add_replace_columns( 497 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 498) -> None: 499 replace = expression.args.get("replace") 500 501 if not replace: 502 return 503 504 columns = {e.this.name: e.alias for e in replace} 505 506 for table in tables: 507 replace_columns[id(table)] = columns 508 509 510def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 511 """Ensure all output columns are aliased""" 512 if isinstance(scope_or_expression, exp.Expression): 513 scope = build_scope(scope_or_expression) 514 if not isinstance(scope, Scope): 515 return 516 else: 517 scope = scope_or_expression 518 519 new_selections = [] 520 for i, (selection, aliased_column) in enumerate( 521 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 522 ): 523 if selection is None: 524 break 525 526 if isinstance(selection, exp.Subquery): 527 if not selection.output_name: 528 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 529 elif not isinstance(selection, exp.Alias) and not selection.is_star: 530 selection = alias( 531 selection, 532 alias=selection.output_name or f"_col_{i}", 533 copy=False, 534 ) 535 if aliased_column: 536 selection.set("alias", exp.to_identifier(aliased_column)) 537 538 new_selections.append(selection) 539 540 if isinstance(scope.expression, exp.Select): 541 scope.expression.set("expressions", new_selections) 542 543 544def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 545 """Makes sure all identifiers that need to be quoted are quoted.""" 546 return expression.transform( 547 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 548 ) # type: ignore 549 550 551def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 552 """ 553 Pushes down the CTE alias columns into the projection, 554 555 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 556 557 Example: 558 >>> import sqlglot 559 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 560 >>> pushdown_cte_alias_columns(expression).sql() 561 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 562 563 Args: 564 expression: Expression to pushdown. 565 566 Returns: 567 The expression with the CTE aliases pushed down into the projection. 568 """ 569 for cte in expression.find_all(exp.CTE): 570 if cte.alias_column_names: 571 new_expressions = [] 572 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 573 if isinstance(projection, exp.Alias): 574 projection.set("alias", _alias) 575 else: 576 projection = alias(projection, alias=_alias) 577 new_expressions.append(projection) 578 cte.this.set("expressions", new_expressions) 579 580 return expression 581 582 583class Resolver: 584 """ 585 Helper for resolving columns. 586 587 This is a class so we can lazily load some things and easily share them across functions. 588 """ 589 590 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 591 self.scope = scope 592 self.schema = schema 593 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 594 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 595 self._all_columns: t.Optional[t.Set[str]] = None 596 self._infer_schema = infer_schema 597 598 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 599 """ 600 Get the table for a column name. 601 602 Args: 603 column_name: The column name to find the table for. 604 Returns: 605 The table name if it can be found/inferred. 606 """ 607 if self._unambiguous_columns is None: 608 self._unambiguous_columns = self._get_unambiguous_columns( 609 self._get_all_source_columns() 610 ) 611 612 table_name = self._unambiguous_columns.get(column_name) 613 614 if not table_name and self._infer_schema: 615 sources_without_schema = tuple( 616 source 617 for source, columns in self._get_all_source_columns().items() 618 if not columns or "*" in columns 619 ) 620 if len(sources_without_schema) == 1: 621 table_name = sources_without_schema[0] 622 623 if table_name not in self.scope.selected_sources: 624 return exp.to_identifier(table_name) 625 626 node, _ = self.scope.selected_sources.get(table_name) 627 628 if isinstance(node, exp.Query): 629 while node and node.alias != table_name: 630 node = node.parent 631 632 node_alias = node.args.get("alias") 633 if node_alias: 634 return exp.to_identifier(node_alias.this) 635 636 return exp.to_identifier(table_name) 637 638 @property 639 def all_columns(self) -> t.Set[str]: 640 """All available columns of all sources in this scope""" 641 if self._all_columns is None: 642 self._all_columns = { 643 column for columns in self._get_all_source_columns().values() for column in columns 644 } 645 return self._all_columns 646 647 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 648 """Resolve the source columns for a given source `name`.""" 649 if name not in self.scope.sources: 650 raise OptimizeError(f"Unknown table: {name}") 651 652 source = self.scope.sources[name] 653 654 if isinstance(source, exp.Table): 655 columns = self.schema.column_names(source, only_visible) 656 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 657 columns = source.expression.named_selects 658 659 # in bigquery, unnest structs are automatically scoped as tables, so you can 660 # directly select a struct field in a query. 661 # this handles the case where the unnest is statically defined. 662 if self.schema.dialect == "bigquery": 663 expression = source.expression 664 annotate_types(expression) 665 666 if expression.is_type(exp.DataType.Type.STRUCT): 667 for k in expression.type.expressions: # type: ignore 668 columns.append(k.name) 669 else: 670 columns = source.expression.named_selects 671 672 node, _ = self.scope.selected_sources.get(name) or (None, None) 673 if isinstance(node, Scope): 674 column_aliases = node.expression.alias_column_names 675 elif isinstance(node, exp.Expression): 676 column_aliases = node.alias_column_names 677 else: 678 column_aliases = [] 679 680 if column_aliases: 681 # If the source's columns are aliased, their aliases shadow the corresponding column names. 682 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 683 return [ 684 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 685 ] 686 return columns 687 688 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 689 if self._source_columns is None: 690 self._source_columns = { 691 source_name: self.get_source_columns(source_name) 692 for source_name, source in itertools.chain( 693 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 694 ) 695 } 696 return self._source_columns 697 698 def _get_unambiguous_columns( 699 self, source_columns: t.Dict[str, t.Sequence[str]] 700 ) -> t.Mapping[str, str]: 701 """ 702 Find all the unambiguous columns in sources. 703 704 Args: 705 source_columns: Mapping of names to source columns. 706 707 Returns: 708 Mapping of column name to source name. 709 """ 710 if not source_columns: 711 return {} 712 713 source_columns_pairs = list(source_columns.items()) 714 715 first_table, first_columns = source_columns_pairs[0] 716 717 if len(source_columns_pairs) == 1: 718 # Performance optimization - avoid copying first_columns if there is only one table. 719 return SingleValuedMapping(first_columns, first_table) 720 721 unambiguous_columns = {col: first_table for col in first_columns} 722 all_columns = set(unambiguous_columns) 723 724 for table, columns in source_columns_pairs[1:]: 725 unique = set(columns) 726 ambiguous = all_columns.intersection(unique) 727 all_columns.update(columns) 728 729 for column in ambiguous: 730 unambiguous_columns.pop(column, None) 731 for column in unique.difference(ambiguous): 732 unambiguous_columns[column] = table 733 734 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 46 Returns: 47 The qualified expression. 48 49 Notes: 50 - Currently only handles a single PIVOT or UNPIVOT operator 51 """ 52 schema = ensure_schema(schema) 53 infer_schema = schema.empty if infer_schema is None else infer_schema 54 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 55 56 for scope in traverse_scope(expression): 57 resolver = Resolver(scope, schema, infer_schema=infer_schema) 58 _pop_table_column_aliases(scope.ctes) 59 _pop_table_column_aliases(scope.derived_tables) 60 using_column_tables = _expand_using(scope, resolver) 61 62 if schema.empty and expand_alias_refs: 63 _expand_alias_refs(scope, resolver) 64 65 _qualify_columns(scope, resolver) 66 67 if not schema.empty and expand_alias_refs: 68 _expand_alias_refs(scope, resolver) 69 70 if not isinstance(scope.expression, exp.UDTF): 71 if expand_stars: 72 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 73 qualify_outputs(scope) 74 75 _expand_group_by(scope) 76 _expand_order_by(scope, resolver) 77 78 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
81def validate_qualify_columns(expression: E) -> E: 82 """Raise an `OptimizeError` if any columns aren't qualified""" 83 all_unqualified_columns = [] 84 for scope in traverse_scope(expression): 85 if isinstance(scope.expression, exp.Select): 86 unqualified_columns = scope.unqualified_columns 87 88 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 89 column = scope.external_columns[0] 90 for_table = f" for table: '{column.table}'" if column.table else "" 91 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 92 93 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 94 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 95 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 96 # this list here to ensure those in the former category will be excluded. 97 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 98 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 99 100 all_unqualified_columns.extend(unqualified_columns) 101 102 if all_unqualified_columns: 103 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 104 105 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
511def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 512 """Ensure all output columns are aliased""" 513 if isinstance(scope_or_expression, exp.Expression): 514 scope = build_scope(scope_or_expression) 515 if not isinstance(scope, Scope): 516 return 517 else: 518 scope = scope_or_expression 519 520 new_selections = [] 521 for i, (selection, aliased_column) in enumerate( 522 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 523 ): 524 if selection is None: 525 break 526 527 if isinstance(selection, exp.Subquery): 528 if not selection.output_name: 529 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 530 elif not isinstance(selection, exp.Alias) and not selection.is_star: 531 selection = alias( 532 selection, 533 alias=selection.output_name or f"_col_{i}", 534 copy=False, 535 ) 536 if aliased_column: 537 selection.set("alias", exp.to_identifier(aliased_column)) 538 539 new_selections.append(selection) 540 541 if isinstance(scope.expression, exp.Select): 542 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
545def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 546 """Makes sure all identifiers that need to be quoted are quoted.""" 547 return expression.transform( 548 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 549 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
552def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 553 """ 554 Pushes down the CTE alias columns into the projection, 555 556 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 557 558 Example: 559 >>> import sqlglot 560 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 561 >>> pushdown_cte_alias_columns(expression).sql() 562 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 563 564 Args: 565 expression: Expression to pushdown. 566 567 Returns: 568 The expression with the CTE aliases pushed down into the projection. 569 """ 570 for cte in expression.find_all(exp.CTE): 571 if cte.alias_column_names: 572 new_expressions = [] 573 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 574 if isinstance(projection, exp.Alias): 575 projection.set("alias", _alias) 576 else: 577 projection = alias(projection, alias=_alias) 578 new_expressions.append(projection) 579 cte.this.set("expressions", new_expressions) 580 581 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
584class Resolver: 585 """ 586 Helper for resolving columns. 587 588 This is a class so we can lazily load some things and easily share them across functions. 589 """ 590 591 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 592 self.scope = scope 593 self.schema = schema 594 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 595 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 596 self._all_columns: t.Optional[t.Set[str]] = None 597 self._infer_schema = infer_schema 598 599 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 600 """ 601 Get the table for a column name. 602 603 Args: 604 column_name: The column name to find the table for. 605 Returns: 606 The table name if it can be found/inferred. 607 """ 608 if self._unambiguous_columns is None: 609 self._unambiguous_columns = self._get_unambiguous_columns( 610 self._get_all_source_columns() 611 ) 612 613 table_name = self._unambiguous_columns.get(column_name) 614 615 if not table_name and self._infer_schema: 616 sources_without_schema = tuple( 617 source 618 for source, columns in self._get_all_source_columns().items() 619 if not columns or "*" in columns 620 ) 621 if len(sources_without_schema) == 1: 622 table_name = sources_without_schema[0] 623 624 if table_name not in self.scope.selected_sources: 625 return exp.to_identifier(table_name) 626 627 node, _ = self.scope.selected_sources.get(table_name) 628 629 if isinstance(node, exp.Query): 630 while node and node.alias != table_name: 631 node = node.parent 632 633 node_alias = node.args.get("alias") 634 if node_alias: 635 return exp.to_identifier(node_alias.this) 636 637 return exp.to_identifier(table_name) 638 639 @property 640 def all_columns(self) -> t.Set[str]: 641 """All available columns of all sources in this scope""" 642 if self._all_columns is None: 643 self._all_columns = { 644 column for columns in self._get_all_source_columns().values() for column in columns 645 } 646 return self._all_columns 647 648 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 649 """Resolve the source columns for a given source `name`.""" 650 if name not in self.scope.sources: 651 raise OptimizeError(f"Unknown table: {name}") 652 653 source = self.scope.sources[name] 654 655 if isinstance(source, exp.Table): 656 columns = self.schema.column_names(source, only_visible) 657 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 658 columns = source.expression.named_selects 659 660 # in bigquery, unnest structs are automatically scoped as tables, so you can 661 # directly select a struct field in a query. 662 # this handles the case where the unnest is statically defined. 663 if self.schema.dialect == "bigquery": 664 expression = source.expression 665 annotate_types(expression) 666 667 if expression.is_type(exp.DataType.Type.STRUCT): 668 for k in expression.type.expressions: # type: ignore 669 columns.append(k.name) 670 else: 671 columns = source.expression.named_selects 672 673 node, _ = self.scope.selected_sources.get(name) or (None, None) 674 if isinstance(node, Scope): 675 column_aliases = node.expression.alias_column_names 676 elif isinstance(node, exp.Expression): 677 column_aliases = node.alias_column_names 678 else: 679 column_aliases = [] 680 681 if column_aliases: 682 # If the source's columns are aliased, their aliases shadow the corresponding column names. 683 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 684 return [ 685 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 686 ] 687 return columns 688 689 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 690 if self._source_columns is None: 691 self._source_columns = { 692 source_name: self.get_source_columns(source_name) 693 for source_name, source in itertools.chain( 694 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 695 ) 696 } 697 return self._source_columns 698 699 def _get_unambiguous_columns( 700 self, source_columns: t.Dict[str, t.Sequence[str]] 701 ) -> t.Mapping[str, str]: 702 """ 703 Find all the unambiguous columns in sources. 704 705 Args: 706 source_columns: Mapping of names to source columns. 707 708 Returns: 709 Mapping of column name to source name. 710 """ 711 if not source_columns: 712 return {} 713 714 source_columns_pairs = list(source_columns.items()) 715 716 first_table, first_columns = source_columns_pairs[0] 717 718 if len(source_columns_pairs) == 1: 719 # Performance optimization - avoid copying first_columns if there is only one table. 720 return SingleValuedMapping(first_columns, first_table) 721 722 unambiguous_columns = {col: first_table for col in first_columns} 723 all_columns = set(unambiguous_columns) 724 725 for table, columns in source_columns_pairs[1:]: 726 unique = set(columns) 727 ambiguous = all_columns.intersection(unique) 728 all_columns.update(columns) 729 730 for column in ambiguous: 731 unambiguous_columns.pop(column, None) 732 for column in unique.difference(ambiguous): 733 unambiguous_columns[column] = table 734 735 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
591 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 592 self.scope = scope 593 self.schema = schema 594 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 595 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 596 self._all_columns: t.Optional[t.Set[str]] = None 597 self._infer_schema = infer_schema
599 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 600 """ 601 Get the table for a column name. 602 603 Args: 604 column_name: The column name to find the table for. 605 Returns: 606 The table name if it can be found/inferred. 607 """ 608 if self._unambiguous_columns is None: 609 self._unambiguous_columns = self._get_unambiguous_columns( 610 self._get_all_source_columns() 611 ) 612 613 table_name = self._unambiguous_columns.get(column_name) 614 615 if not table_name and self._infer_schema: 616 sources_without_schema = tuple( 617 source 618 for source, columns in self._get_all_source_columns().items() 619 if not columns or "*" in columns 620 ) 621 if len(sources_without_schema) == 1: 622 table_name = sources_without_schema[0] 623 624 if table_name not in self.scope.selected_sources: 625 return exp.to_identifier(table_name) 626 627 node, _ = self.scope.selected_sources.get(table_name) 628 629 if isinstance(node, exp.Query): 630 while node and node.alias != table_name: 631 node = node.parent 632 633 node_alias = node.args.get("alias") 634 if node_alias: 635 return exp.to_identifier(node_alias.this) 636 637 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
639 @property 640 def all_columns(self) -> t.Set[str]: 641 """All available columns of all sources in this scope""" 642 if self._all_columns is None: 643 self._all_columns = { 644 column for columns in self._get_all_source_columns().values() for column in columns 645 } 646 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
648 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 649 """Resolve the source columns for a given source `name`.""" 650 if name not in self.scope.sources: 651 raise OptimizeError(f"Unknown table: {name}") 652 653 source = self.scope.sources[name] 654 655 if isinstance(source, exp.Table): 656 columns = self.schema.column_names(source, only_visible) 657 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 658 columns = source.expression.named_selects 659 660 # in bigquery, unnest structs are automatically scoped as tables, so you can 661 # directly select a struct field in a query. 662 # this handles the case where the unnest is statically defined. 663 if self.schema.dialect == "bigquery": 664 expression = source.expression 665 annotate_types(expression) 666 667 if expression.is_type(exp.DataType.Type.STRUCT): 668 for k in expression.type.expressions: # type: ignore 669 columns.append(k.name) 670 else: 671 columns = source.expression.named_selects 672 673 node, _ = self.scope.selected_sources.get(name) or (None, None) 674 if isinstance(node, Scope): 675 column_aliases = node.expression.alias_column_names 676 elif isinstance(node, exp.Expression): 677 column_aliases = node.alias_column_names 678 else: 679 column_aliases = [] 680 681 if column_aliases: 682 # If the source's columns are aliased, their aliases shadow the corresponding column names. 683 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 684 return [ 685 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 686 ] 687 return columns
Resolve the source columns for a given source name
.