sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15 16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 expand_stars: bool = True, 21 infer_schema: t.Optional[bool] = None, 22) -> exp.Expression: 23 """ 24 Rewrite sqlglot AST to have fully qualified columns. 25 26 Example: 27 >>> import sqlglot 28 >>> schema = {"tbl": {"col": "INT"}} 29 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 30 >>> qualify_columns(expression, schema).sql() 31 'SELECT tbl.col AS col FROM tbl' 32 33 Args: 34 expression: Expression to qualify. 35 schema: Database schema. 36 expand_alias_refs: Whether or not to expand references to aliases. 37 expand_stars: Whether or not to expand star queries. This is a necessary step 38 for most of the optimizer's rules to work; do not set to False unless you 39 know what you're doing! 40 infer_schema: Whether or not to infer the schema if missing. 41 42 Returns: 43 The qualified expression. 44 45 Notes: 46 - Currently only handles a single PIVOT or UNPIVOT operator 47 """ 48 schema = ensure_schema(schema) 49 infer_schema = schema.empty if infer_schema is None else infer_schema 50 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 51 52 for scope in traverse_scope(expression): 53 resolver = Resolver(scope, schema, infer_schema=infer_schema) 54 _pop_table_column_aliases(scope.ctes) 55 _pop_table_column_aliases(scope.derived_tables) 56 using_column_tables = _expand_using(scope, resolver) 57 58 if schema.empty and expand_alias_refs: 59 _expand_alias_refs(scope, resolver) 60 61 _qualify_columns(scope, resolver) 62 63 if not schema.empty and expand_alias_refs: 64 _expand_alias_refs(scope, resolver) 65 66 if not isinstance(scope.expression, exp.UDTF): 67 if expand_stars: 68 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 69 qualify_outputs(scope) 70 71 _expand_group_by(scope) 72 _expand_order_by(scope, resolver) 73 74 return expression 75 76 77def validate_qualify_columns(expression: E) -> E: 78 """Raise an `OptimizeError` if any columns aren't qualified""" 79 all_unqualified_columns = [] 80 for scope in traverse_scope(expression): 81 if isinstance(scope.expression, exp.Select): 82 unqualified_columns = scope.unqualified_columns 83 84 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 85 column = scope.external_columns[0] 86 for_table = f" for table: '{column.table}'" if column.table else "" 87 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 88 89 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 90 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 91 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 92 # this list here to ensure those in the former category will be excluded. 93 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 94 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 95 96 all_unqualified_columns.extend(unqualified_columns) 97 98 if all_unqualified_columns: 99 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 100 101 return expression 102 103 104def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 105 name_column = [] 106 field = unpivot.args.get("field") 107 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 108 name_column.append(field.this) 109 110 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 111 return itertools.chain(name_column, value_columns) 112 113 114def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 115 """ 116 Remove table column aliases. 117 118 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 119 """ 120 for derived_table in derived_tables: 121 table_alias = derived_table.args.get("alias") 122 if table_alias: 123 table_alias.args.pop("columns", None) 124 125 126def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 127 joins = list(scope.find_all(exp.Join)) 128 names = {join.alias_or_name for join in joins} 129 ordered = [key for key in scope.selected_sources if key not in names] 130 131 # Mapping of automatically joined column names to an ordered set of source names (dict). 132 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 133 134 for join in joins: 135 using = join.args.get("using") 136 137 if not using: 138 continue 139 140 join_table = join.alias_or_name 141 142 columns = {} 143 144 for source_name in scope.selected_sources: 145 if source_name in ordered: 146 for column_name in resolver.get_source_columns(source_name): 147 if column_name not in columns: 148 columns[column_name] = source_name 149 150 source_table = ordered[-1] 151 ordered.append(join_table) 152 join_columns = resolver.get_source_columns(join_table) 153 conditions = [] 154 155 for identifier in using: 156 identifier = identifier.name 157 table = columns.get(identifier) 158 159 if not table or identifier not in join_columns: 160 if (columns and "*" not in columns) and join_columns: 161 raise OptimizeError(f"Cannot automatically join: {identifier}") 162 163 table = table or source_table 164 conditions.append( 165 exp.condition( 166 exp.EQ( 167 this=exp.column(identifier, table=table), 168 expression=exp.column(identifier, table=join_table), 169 ) 170 ) 171 ) 172 173 # Set all values in the dict to None, because we only care about the key ordering 174 tables = column_tables.setdefault(identifier, {}) 175 if table not in tables: 176 tables[table] = None 177 if join_table not in tables: 178 tables[join_table] = None 179 180 join.args.pop("using") 181 join.set("on", exp.and_(*conditions, copy=False)) 182 183 if column_tables: 184 for column in scope.columns: 185 if not column.table and column.name in column_tables: 186 tables = column_tables[column.name] 187 coalesce = [exp.column(column.name, table=table) for table in tables] 188 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 189 190 # Ensure selects keep their output name 191 if isinstance(column.parent, exp.Select): 192 replacement = alias(replacement, alias=column.name, copy=False) 193 194 scope.replace(column, replacement) 195 196 return column_tables 197 198 199def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 200 expression = scope.expression 201 202 if not isinstance(expression, exp.Select): 203 return 204 205 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 206 207 def replace_columns( 208 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 209 ) -> None: 210 if not node: 211 return 212 213 for column, *_ in walk_in_scope(node): 214 if not isinstance(column, exp.Column): 215 continue 216 217 table = resolver.get_table(column.name) if resolve_table and not column.table else None 218 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 219 double_agg = ( 220 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 221 if alias_expr 222 else False 223 ) 224 225 if table and (not alias_expr or double_agg): 226 column.set("table", table) 227 elif not column.table and alias_expr and not double_agg: 228 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 229 if literal_index: 230 column.replace(exp.Literal.number(i)) 231 else: 232 column = column.replace(exp.paren(alias_expr)) 233 simplified = simplify_parens(column) 234 if simplified is not column: 235 column.replace(simplified) 236 237 for i, projection in enumerate(scope.expression.selects): 238 replace_columns(projection) 239 240 if isinstance(projection, exp.Alias): 241 alias_to_expression[projection.alias] = (projection.this, i + 1) 242 243 replace_columns(expression.args.get("where")) 244 replace_columns(expression.args.get("group"), literal_index=True) 245 replace_columns(expression.args.get("having"), resolve_table=True) 246 replace_columns(expression.args.get("qualify"), resolve_table=True) 247 248 scope.clear_cache() 249 250 251def _expand_group_by(scope: Scope) -> None: 252 expression = scope.expression 253 group = expression.args.get("group") 254 if not group: 255 return 256 257 group.set("expressions", _expand_positional_references(scope, group.expressions)) 258 expression.set("group", group) 259 260 261def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 262 order = scope.expression.args.get("order") 263 if not order: 264 return 265 266 ordereds = order.expressions 267 for ordered, new_expression in zip( 268 ordereds, 269 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 270 ): 271 for agg in ordered.find_all(exp.AggFunc): 272 for col in agg.find_all(exp.Column): 273 if not col.table: 274 col.set("table", resolver.get_table(col.name)) 275 276 ordered.set("this", new_expression) 277 278 if scope.expression.args.get("group"): 279 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 280 281 for ordered in ordereds: 282 ordered = ordered.this 283 284 ordered.replace( 285 exp.to_identifier(_select_by_pos(scope, ordered).alias) 286 if ordered.is_int 287 else selects.get(ordered, ordered) 288 ) 289 290 291def _expand_positional_references( 292 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 293) -> t.List[exp.Expression]: 294 new_nodes: t.List[exp.Expression] = [] 295 for node in expressions: 296 if node.is_int: 297 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 298 299 if alias: 300 new_nodes.append(exp.column(select.args["alias"].copy())) 301 else: 302 select = select.this 303 304 if isinstance(select, exp.Literal): 305 new_nodes.append(node) 306 else: 307 new_nodes.append(select.copy()) 308 else: 309 new_nodes.append(node) 310 311 return new_nodes 312 313 314def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 315 try: 316 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 317 except IndexError: 318 raise OptimizeError(f"Unknown output column: {node.name}") 319 320 321def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 322 """Disambiguate columns, ensuring each column specifies a source""" 323 for column in scope.columns: 324 column_table = column.table 325 column_name = column.name 326 327 if column_table and column_table in scope.sources: 328 source_columns = resolver.get_source_columns(column_table) 329 if source_columns and column_name not in source_columns and "*" not in source_columns: 330 raise OptimizeError(f"Unknown column: {column_name}") 331 332 if not column_table: 333 if scope.pivots and not column.find_ancestor(exp.Pivot): 334 # If the column is under the Pivot expression, we need to qualify it 335 # using the name of the pivoted source instead of the pivot's alias 336 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 337 continue 338 339 column_table = resolver.get_table(column_name) 340 341 # column_table can be a '' because bigquery unnest has no table alias 342 if column_table: 343 column.set("table", column_table) 344 elif column_table not in scope.sources and ( 345 not scope.parent 346 or column_table not in scope.parent.sources 347 or not scope.is_correlated_subquery 348 ): 349 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 350 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 351 352 root, *parts = column.parts 353 354 if root.name in scope.sources: 355 # struct is already qualified, but we still need to change the AST representation 356 column_table = root 357 root, *parts = parts 358 else: 359 column_table = resolver.get_table(root.name) 360 361 if column_table: 362 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 363 364 for pivot in scope.pivots: 365 for column in pivot.find_all(exp.Column): 366 if not column.table and column.name in resolver.all_columns: 367 column_table = resolver.get_table(column.name) 368 if column_table: 369 column.set("table", column_table) 370 371 372def _expand_stars( 373 scope: Scope, 374 resolver: Resolver, 375 using_column_tables: t.Dict[str, t.Any], 376 pseudocolumns: t.Set[str], 377) -> None: 378 """Expand stars to lists of column selections""" 379 380 new_selections = [] 381 except_columns: t.Dict[int, t.Set[str]] = {} 382 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 383 coalesced_columns = set() 384 385 pivot_output_columns = None 386 pivot_exclude_columns = None 387 388 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 389 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 390 if pivot.unpivot: 391 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 392 393 field = pivot.args.get("field") 394 if isinstance(field, exp.In): 395 pivot_exclude_columns = { 396 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 397 } 398 else: 399 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 400 401 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 402 if not pivot_output_columns: 403 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 404 405 for expression in scope.expression.selects: 406 if isinstance(expression, exp.Star): 407 tables = list(scope.selected_sources) 408 _add_except_columns(expression, tables, except_columns) 409 _add_replace_columns(expression, tables, replace_columns) 410 elif expression.is_star: 411 tables = [expression.table] 412 _add_except_columns(expression.this, tables, except_columns) 413 _add_replace_columns(expression.this, tables, replace_columns) 414 else: 415 new_selections.append(expression) 416 continue 417 418 for table in tables: 419 if table not in scope.sources: 420 raise OptimizeError(f"Unknown table: {table}") 421 422 columns = resolver.get_source_columns(table, only_visible=True) 423 columns = columns or scope.outer_column_list 424 425 if pseudocolumns: 426 columns = [name for name in columns if name.upper() not in pseudocolumns] 427 428 if not columns or "*" in columns: 429 return 430 431 table_id = id(table) 432 columns_to_exclude = except_columns.get(table_id) or set() 433 434 if pivot: 435 if pivot_output_columns and pivot_exclude_columns: 436 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 437 pivot_columns.extend(pivot_output_columns) 438 else: 439 pivot_columns = pivot.alias_column_names 440 441 if pivot_columns: 442 new_selections.extend( 443 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 444 for name in pivot_columns 445 if name not in columns_to_exclude 446 ) 447 continue 448 449 for name in columns: 450 if name in using_column_tables and table in using_column_tables[name]: 451 if name in coalesced_columns: 452 continue 453 454 coalesced_columns.add(name) 455 tables = using_column_tables[name] 456 coalesce = [exp.column(name, table=table) for table in tables] 457 458 new_selections.append( 459 alias( 460 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 461 alias=name, 462 copy=False, 463 ) 464 ) 465 elif name not in columns_to_exclude: 466 alias_ = replace_columns.get(table_id, {}).get(name, name) 467 column = exp.column(name, table=table) 468 new_selections.append( 469 alias(column, alias_, copy=False) if alias_ != name else column 470 ) 471 472 # Ensures we don't overwrite the initial selections with an empty list 473 if new_selections: 474 scope.expression.set("expressions", new_selections) 475 476 477def _add_except_columns( 478 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 479) -> None: 480 except_ = expression.args.get("except") 481 482 if not except_: 483 return 484 485 columns = {e.name for e in except_} 486 487 for table in tables: 488 except_columns[id(table)] = columns 489 490 491def _add_replace_columns( 492 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 493) -> None: 494 replace = expression.args.get("replace") 495 496 if not replace: 497 return 498 499 columns = {e.this.name: e.alias for e in replace} 500 501 for table in tables: 502 replace_columns[id(table)] = columns 503 504 505def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 506 """Ensure all output columns are aliased""" 507 if isinstance(scope_or_expression, exp.Expression): 508 scope = build_scope(scope_or_expression) 509 if not isinstance(scope, Scope): 510 return 511 else: 512 scope = scope_or_expression 513 514 new_selections = [] 515 for i, (selection, aliased_column) in enumerate( 516 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 517 ): 518 if selection is None: 519 break 520 521 if isinstance(selection, exp.Subquery): 522 if not selection.output_name: 523 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 524 elif not isinstance(selection, exp.Alias) and not selection.is_star: 525 selection = alias( 526 selection, 527 alias=selection.output_name or f"_col_{i}", 528 copy=False, 529 ) 530 if aliased_column: 531 selection.set("alias", exp.to_identifier(aliased_column)) 532 533 new_selections.append(selection) 534 535 scope.expression.set("expressions", new_selections) 536 537 538def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 539 """Makes sure all identifiers that need to be quoted are quoted.""" 540 return expression.transform( 541 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 542 ) 543 544 545def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 546 """ 547 Pushes down the CTE alias columns into the projection, 548 549 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 550 551 Example: 552 >>> import sqlglot 553 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 554 >>> pushdown_cte_alias_columns(expression).sql() 555 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 556 557 Args: 558 expression: Expression to pushdown. 559 560 Returns: 561 The expression with the CTE aliases pushed down into the projection. 562 """ 563 for cte in expression.find_all(exp.CTE): 564 if cte.alias_column_names: 565 new_expressions = [] 566 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 567 if isinstance(projection, exp.Alias): 568 projection.set("alias", _alias) 569 else: 570 projection = alias(projection, alias=_alias) 571 new_expressions.append(projection) 572 cte.this.set("expressions", new_expressions) 573 574 return expression 575 576 577class Resolver: 578 """ 579 Helper for resolving columns. 580 581 This is a class so we can lazily load some things and easily share them across functions. 582 """ 583 584 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 585 self.scope = scope 586 self.schema = schema 587 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 588 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 589 self._all_columns: t.Optional[t.Set[str]] = None 590 self._infer_schema = infer_schema 591 592 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 593 """ 594 Get the table for a column name. 595 596 Args: 597 column_name: The column name to find the table for. 598 Returns: 599 The table name if it can be found/inferred. 600 """ 601 if self._unambiguous_columns is None: 602 self._unambiguous_columns = self._get_unambiguous_columns( 603 self._get_all_source_columns() 604 ) 605 606 table_name = self._unambiguous_columns.get(column_name) 607 608 if not table_name and self._infer_schema: 609 sources_without_schema = tuple( 610 source 611 for source, columns in self._get_all_source_columns().items() 612 if not columns or "*" in columns 613 ) 614 if len(sources_without_schema) == 1: 615 table_name = sources_without_schema[0] 616 617 if table_name not in self.scope.selected_sources: 618 return exp.to_identifier(table_name) 619 620 node, _ = self.scope.selected_sources.get(table_name) 621 622 if isinstance(node, exp.Subqueryable): 623 while node and node.alias != table_name: 624 node = node.parent 625 626 node_alias = node.args.get("alias") 627 if node_alias: 628 return exp.to_identifier(node_alias.this) 629 630 return exp.to_identifier(table_name) 631 632 @property 633 def all_columns(self) -> t.Set[str]: 634 """All available columns of all sources in this scope""" 635 if self._all_columns is None: 636 self._all_columns = { 637 column for columns in self._get_all_source_columns().values() for column in columns 638 } 639 return self._all_columns 640 641 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 642 """Resolve the source columns for a given source `name`.""" 643 if name not in self.scope.sources: 644 raise OptimizeError(f"Unknown table: {name}") 645 646 source = self.scope.sources[name] 647 648 if isinstance(source, exp.Table): 649 columns = self.schema.column_names(source, only_visible) 650 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 651 columns = source.expression.alias_column_names 652 else: 653 columns = source.expression.named_selects 654 655 node, _ = self.scope.selected_sources.get(name) or (None, None) 656 if isinstance(node, Scope): 657 column_aliases = node.expression.alias_column_names 658 elif isinstance(node, exp.Expression): 659 column_aliases = node.alias_column_names 660 else: 661 column_aliases = [] 662 663 # If the source's columns are aliased, their aliases shadow the corresponding column names 664 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 665 666 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 667 if self._source_columns is None: 668 self._source_columns = { 669 source_name: self.get_source_columns(source_name) 670 for source_name, source in itertools.chain( 671 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 672 ) 673 } 674 return self._source_columns 675 676 def _get_unambiguous_columns( 677 self, source_columns: t.Dict[str, t.List[str]] 678 ) -> t.Dict[str, str]: 679 """ 680 Find all the unambiguous columns in sources. 681 682 Args: 683 source_columns: Mapping of names to source columns. 684 685 Returns: 686 Mapping of column name to source name. 687 """ 688 if not source_columns: 689 return {} 690 691 source_columns_pairs = list(source_columns.items()) 692 693 first_table, first_columns = source_columns_pairs[0] 694 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 695 all_columns = set(unambiguous_columns) 696 697 for table, columns in source_columns_pairs[1:]: 698 unique = self._find_unique_columns(columns) 699 ambiguous = set(all_columns).intersection(unique) 700 all_columns.update(columns) 701 702 for column in ambiguous: 703 unambiguous_columns.pop(column, None) 704 for column in unique.difference(ambiguous): 705 unambiguous_columns[column] = table 706 707 return unambiguous_columns 708 709 @staticmethod 710 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 711 """ 712 Find the unique columns in a list of columns. 713 714 Example: 715 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 716 ['a', 'c'] 717 718 This is necessary because duplicate column names are ambiguous. 719 """ 720 counts: t.Dict[str, int] = {} 721 for column in columns: 722 counts[column] = counts.get(column, 0) + 1 723 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns( 18 expression: exp.Expression, 19 schema: t.Dict | Schema, 20 expand_alias_refs: bool = True, 21 expand_stars: bool = True, 22 infer_schema: t.Optional[bool] = None, 23) -> exp.Expression: 24 """ 25 Rewrite sqlglot AST to have fully qualified columns. 26 27 Example: 28 >>> import sqlglot 29 >>> schema = {"tbl": {"col": "INT"}} 30 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 31 >>> qualify_columns(expression, schema).sql() 32 'SELECT tbl.col AS col FROM tbl' 33 34 Args: 35 expression: Expression to qualify. 36 schema: Database schema. 37 expand_alias_refs: Whether or not to expand references to aliases. 38 expand_stars: Whether or not to expand star queries. This is a necessary step 39 for most of the optimizer's rules to work; do not set to False unless you 40 know what you're doing! 41 infer_schema: Whether or not to infer the schema if missing. 42 43 Returns: 44 The qualified expression. 45 46 Notes: 47 - Currently only handles a single PIVOT or UNPIVOT operator 48 """ 49 schema = ensure_schema(schema) 50 infer_schema = schema.empty if infer_schema is None else infer_schema 51 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 52 53 for scope in traverse_scope(expression): 54 resolver = Resolver(scope, schema, infer_schema=infer_schema) 55 _pop_table_column_aliases(scope.ctes) 56 _pop_table_column_aliases(scope.derived_tables) 57 using_column_tables = _expand_using(scope, resolver) 58 59 if schema.empty and expand_alias_refs: 60 _expand_alias_refs(scope, resolver) 61 62 _qualify_columns(scope, resolver) 63 64 if not schema.empty and expand_alias_refs: 65 _expand_alias_refs(scope, resolver) 66 67 if not isinstance(scope.expression, exp.UDTF): 68 if expand_stars: 69 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 70 qualify_outputs(scope) 71 72 _expand_group_by(scope) 73 _expand_order_by(scope, resolver) 74 75 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
78def validate_qualify_columns(expression: E) -> E: 79 """Raise an `OptimizeError` if any columns aren't qualified""" 80 all_unqualified_columns = [] 81 for scope in traverse_scope(expression): 82 if isinstance(scope.expression, exp.Select): 83 unqualified_columns = scope.unqualified_columns 84 85 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 86 column = scope.external_columns[0] 87 for_table = f" for table: '{column.table}'" if column.table else "" 88 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 89 90 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 91 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 92 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 93 # this list here to ensure those in the former category will be excluded. 94 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 95 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 96 97 all_unqualified_columns.extend(unqualified_columns) 98 99 if all_unqualified_columns: 100 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 101 102 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
506def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 507 """Ensure all output columns are aliased""" 508 if isinstance(scope_or_expression, exp.Expression): 509 scope = build_scope(scope_or_expression) 510 if not isinstance(scope, Scope): 511 return 512 else: 513 scope = scope_or_expression 514 515 new_selections = [] 516 for i, (selection, aliased_column) in enumerate( 517 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 518 ): 519 if selection is None: 520 break 521 522 if isinstance(selection, exp.Subquery): 523 if not selection.output_name: 524 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 525 elif not isinstance(selection, exp.Alias) and not selection.is_star: 526 selection = alias( 527 selection, 528 alias=selection.output_name or f"_col_{i}", 529 copy=False, 530 ) 531 if aliased_column: 532 selection.set("alias", exp.to_identifier(aliased_column)) 533 534 new_selections.append(selection) 535 536 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
539def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 540 """Makes sure all identifiers that need to be quoted are quoted.""" 541 return expression.transform( 542 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 543 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
546def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 547 """ 548 Pushes down the CTE alias columns into the projection, 549 550 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 551 552 Example: 553 >>> import sqlglot 554 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 555 >>> pushdown_cte_alias_columns(expression).sql() 556 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 557 558 Args: 559 expression: Expression to pushdown. 560 561 Returns: 562 The expression with the CTE aliases pushed down into the projection. 563 """ 564 for cte in expression.find_all(exp.CTE): 565 if cte.alias_column_names: 566 new_expressions = [] 567 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 568 if isinstance(projection, exp.Alias): 569 projection.set("alias", _alias) 570 else: 571 projection = alias(projection, alias=_alias) 572 new_expressions.append(projection) 573 cte.this.set("expressions", new_expressions) 574 575 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
578class Resolver: 579 """ 580 Helper for resolving columns. 581 582 This is a class so we can lazily load some things and easily share them across functions. 583 """ 584 585 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 586 self.scope = scope 587 self.schema = schema 588 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 589 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 590 self._all_columns: t.Optional[t.Set[str]] = None 591 self._infer_schema = infer_schema 592 593 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 594 """ 595 Get the table for a column name. 596 597 Args: 598 column_name: The column name to find the table for. 599 Returns: 600 The table name if it can be found/inferred. 601 """ 602 if self._unambiguous_columns is None: 603 self._unambiguous_columns = self._get_unambiguous_columns( 604 self._get_all_source_columns() 605 ) 606 607 table_name = self._unambiguous_columns.get(column_name) 608 609 if not table_name and self._infer_schema: 610 sources_without_schema = tuple( 611 source 612 for source, columns in self._get_all_source_columns().items() 613 if not columns or "*" in columns 614 ) 615 if len(sources_without_schema) == 1: 616 table_name = sources_without_schema[0] 617 618 if table_name not in self.scope.selected_sources: 619 return exp.to_identifier(table_name) 620 621 node, _ = self.scope.selected_sources.get(table_name) 622 623 if isinstance(node, exp.Subqueryable): 624 while node and node.alias != table_name: 625 node = node.parent 626 627 node_alias = node.args.get("alias") 628 if node_alias: 629 return exp.to_identifier(node_alias.this) 630 631 return exp.to_identifier(table_name) 632 633 @property 634 def all_columns(self) -> t.Set[str]: 635 """All available columns of all sources in this scope""" 636 if self._all_columns is None: 637 self._all_columns = { 638 column for columns in self._get_all_source_columns().values() for column in columns 639 } 640 return self._all_columns 641 642 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 643 """Resolve the source columns for a given source `name`.""" 644 if name not in self.scope.sources: 645 raise OptimizeError(f"Unknown table: {name}") 646 647 source = self.scope.sources[name] 648 649 if isinstance(source, exp.Table): 650 columns = self.schema.column_names(source, only_visible) 651 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 652 columns = source.expression.alias_column_names 653 else: 654 columns = source.expression.named_selects 655 656 node, _ = self.scope.selected_sources.get(name) or (None, None) 657 if isinstance(node, Scope): 658 column_aliases = node.expression.alias_column_names 659 elif isinstance(node, exp.Expression): 660 column_aliases = node.alias_column_names 661 else: 662 column_aliases = [] 663 664 # If the source's columns are aliased, their aliases shadow the corresponding column names 665 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 666 667 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 668 if self._source_columns is None: 669 self._source_columns = { 670 source_name: self.get_source_columns(source_name) 671 for source_name, source in itertools.chain( 672 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 673 ) 674 } 675 return self._source_columns 676 677 def _get_unambiguous_columns( 678 self, source_columns: t.Dict[str, t.List[str]] 679 ) -> t.Dict[str, str]: 680 """ 681 Find all the unambiguous columns in sources. 682 683 Args: 684 source_columns: Mapping of names to source columns. 685 686 Returns: 687 Mapping of column name to source name. 688 """ 689 if not source_columns: 690 return {} 691 692 source_columns_pairs = list(source_columns.items()) 693 694 first_table, first_columns = source_columns_pairs[0] 695 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 696 all_columns = set(unambiguous_columns) 697 698 for table, columns in source_columns_pairs[1:]: 699 unique = self._find_unique_columns(columns) 700 ambiguous = set(all_columns).intersection(unique) 701 all_columns.update(columns) 702 703 for column in ambiguous: 704 unambiguous_columns.pop(column, None) 705 for column in unique.difference(ambiguous): 706 unambiguous_columns[column] = table 707 708 return unambiguous_columns 709 710 @staticmethod 711 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 712 """ 713 Find the unique columns in a list of columns. 714 715 Example: 716 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 717 ['a', 'c'] 718 719 This is necessary because duplicate column names are ambiguous. 720 """ 721 counts: t.Dict[str, int] = {} 722 for column in columns: 723 counts[column] = counts.get(column, 0) + 1 724 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
585 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 586 self.scope = scope 587 self.schema = schema 588 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 589 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 590 self._all_columns: t.Optional[t.Set[str]] = None 591 self._infer_schema = infer_schema
593 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 594 """ 595 Get the table for a column name. 596 597 Args: 598 column_name: The column name to find the table for. 599 Returns: 600 The table name if it can be found/inferred. 601 """ 602 if self._unambiguous_columns is None: 603 self._unambiguous_columns = self._get_unambiguous_columns( 604 self._get_all_source_columns() 605 ) 606 607 table_name = self._unambiguous_columns.get(column_name) 608 609 if not table_name and self._infer_schema: 610 sources_without_schema = tuple( 611 source 612 for source, columns in self._get_all_source_columns().items() 613 if not columns or "*" in columns 614 ) 615 if len(sources_without_schema) == 1: 616 table_name = sources_without_schema[0] 617 618 if table_name not in self.scope.selected_sources: 619 return exp.to_identifier(table_name) 620 621 node, _ = self.scope.selected_sources.get(table_name) 622 623 if isinstance(node, exp.Subqueryable): 624 while node and node.alias != table_name: 625 node = node.parent 626 627 node_alias = node.args.get("alias") 628 if node_alias: 629 return exp.to_identifier(node_alias.this) 630 631 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
633 @property 634 def all_columns(self) -> t.Set[str]: 635 """All available columns of all sources in this scope""" 636 if self._all_columns is None: 637 self._all_columns = { 638 column for columns in self._get_all_source_columns().values() for column in columns 639 } 640 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
642 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 643 """Resolve the source columns for a given source `name`.""" 644 if name not in self.scope.sources: 645 raise OptimizeError(f"Unknown table: {name}") 646 647 source = self.scope.sources[name] 648 649 if isinstance(source, exp.Table): 650 columns = self.schema.column_names(source, only_visible) 651 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 652 columns = source.expression.alias_column_names 653 else: 654 columns = source.expression.named_selects 655 656 node, _ = self.scope.selected_sources.get(name) or (None, None) 657 if isinstance(node, Scope): 658 column_aliases = node.expression.alias_column_names 659 elif isinstance(node, exp.Expression): 660 column_aliases = node.alias_column_names 661 else: 662 column_aliases = [] 663 664 # If the source's columns are aliased, their aliases shadow the corresponding column names 665 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
Resolve the source columns for a given source name
.