sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 11from sqlglot.optimizer.simplify import simplify_parens 12from sqlglot.schema import Schema, ensure_schema 13 14if t.TYPE_CHECKING: 15 from sqlglot._typing import E 16 17 18def qualify_columns( 19 expression: exp.Expression, 20 schema: t.Dict | Schema, 21 expand_alias_refs: bool = True, 22 expand_stars: bool = True, 23 infer_schema: t.Optional[bool] = None, 24) -> exp.Expression: 25 """ 26 Rewrite sqlglot AST to have fully qualified columns. 27 28 Example: 29 >>> import sqlglot 30 >>> schema = {"tbl": {"col": "INT"}} 31 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 32 >>> qualify_columns(expression, schema).sql() 33 'SELECT tbl.col AS col FROM tbl' 34 35 Args: 36 expression: Expression to qualify. 37 schema: Database schema. 38 expand_alias_refs: Whether or not to expand references to aliases. 39 expand_stars: Whether or not to expand star queries. This is a necessary step 40 for most of the optimizer's rules to work; do not set to False unless you 41 know what you're doing! 42 infer_schema: Whether or not to infer the schema if missing. 43 44 Returns: 45 The qualified expression. 46 47 Notes: 48 - Currently only handles a single PIVOT or UNPIVOT operator 49 """ 50 schema = ensure_schema(schema) 51 infer_schema = schema.empty if infer_schema is None else infer_schema 52 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 53 54 for scope in traverse_scope(expression): 55 resolver = Resolver(scope, schema, infer_schema=infer_schema) 56 _pop_table_column_aliases(scope.ctes) 57 _pop_table_column_aliases(scope.derived_tables) 58 using_column_tables = _expand_using(scope, resolver) 59 60 if schema.empty and expand_alias_refs: 61 _expand_alias_refs(scope, resolver) 62 63 _qualify_columns(scope, resolver) 64 65 if not schema.empty and expand_alias_refs: 66 _expand_alias_refs(scope, resolver) 67 68 if not isinstance(scope.expression, exp.UDTF): 69 if expand_stars: 70 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 71 qualify_outputs(scope) 72 73 _expand_group_by(scope) 74 _expand_order_by(scope, resolver) 75 76 return expression 77 78 79def validate_qualify_columns(expression: E) -> E: 80 """Raise an `OptimizeError` if any columns aren't qualified""" 81 all_unqualified_columns = [] 82 for scope in traverse_scope(expression): 83 if isinstance(scope.expression, exp.Select): 84 unqualified_columns = scope.unqualified_columns 85 86 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 87 column = scope.external_columns[0] 88 for_table = f" for table: '{column.table}'" if column.table else "" 89 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 90 91 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 92 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 93 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 94 # this list here to ensure those in the former category will be excluded. 95 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 96 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 97 98 all_unqualified_columns.extend(unqualified_columns) 99 100 if all_unqualified_columns: 101 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 102 103 return expression 104 105 106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 107 name_column = [] 108 field = unpivot.args.get("field") 109 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 110 name_column.append(field.this) 111 112 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 113 return itertools.chain(name_column, value_columns) 114 115 116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 117 """ 118 Remove table column aliases. 119 120 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 121 """ 122 for derived_table in derived_tables: 123 table_alias = derived_table.args.get("alias") 124 if table_alias: 125 table_alias.args.pop("columns", None) 126 127 128def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 129 joins = list(scope.find_all(exp.Join)) 130 names = {join.alias_or_name for join in joins} 131 ordered = [key for key in scope.selected_sources if key not in names] 132 133 # Mapping of automatically joined column names to an ordered set of source names (dict). 134 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 135 136 for join in joins: 137 using = join.args.get("using") 138 139 if not using: 140 continue 141 142 join_table = join.alias_or_name 143 144 columns = {} 145 146 for source_name in scope.selected_sources: 147 if source_name in ordered: 148 for column_name in resolver.get_source_columns(source_name): 149 if column_name not in columns: 150 columns[column_name] = source_name 151 152 source_table = ordered[-1] 153 ordered.append(join_table) 154 join_columns = resolver.get_source_columns(join_table) 155 conditions = [] 156 157 for identifier in using: 158 identifier = identifier.name 159 table = columns.get(identifier) 160 161 if not table or identifier not in join_columns: 162 if (columns and "*" not in columns) and join_columns: 163 raise OptimizeError(f"Cannot automatically join: {identifier}") 164 165 table = table or source_table 166 conditions.append( 167 exp.condition( 168 exp.EQ( 169 this=exp.column(identifier, table=table), 170 expression=exp.column(identifier, table=join_table), 171 ) 172 ) 173 ) 174 175 # Set all values in the dict to None, because we only care about the key ordering 176 tables = column_tables.setdefault(identifier, {}) 177 if table not in tables: 178 tables[table] = None 179 if join_table not in tables: 180 tables[join_table] = None 181 182 join.args.pop("using") 183 join.set("on", exp.and_(*conditions, copy=False)) 184 185 if column_tables: 186 for column in scope.columns: 187 if not column.table and column.name in column_tables: 188 tables = column_tables[column.name] 189 coalesce = [exp.column(column.name, table=table) for table in tables] 190 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 191 192 # Ensure selects keep their output name 193 if isinstance(column.parent, exp.Select): 194 replacement = alias(replacement, alias=column.name, copy=False) 195 196 scope.replace(column, replacement) 197 198 return column_tables 199 200 201def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 202 expression = scope.expression 203 204 if not isinstance(expression, exp.Select): 205 return 206 207 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 208 209 def replace_columns( 210 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 211 ) -> None: 212 if not node: 213 return 214 215 for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star): 216 if not isinstance(column, exp.Column): 217 continue 218 219 table = resolver.get_table(column.name) if resolve_table and not column.table else None 220 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 221 double_agg = ( 222 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 223 if alias_expr 224 else False 225 ) 226 227 if table and (not alias_expr or double_agg): 228 column.set("table", table) 229 elif not column.table and alias_expr and not double_agg: 230 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 231 if literal_index: 232 column.replace(exp.Literal.number(i)) 233 else: 234 column = column.replace(exp.paren(alias_expr)) 235 simplified = simplify_parens(column) 236 if simplified is not column: 237 column.replace(simplified) 238 239 for i, projection in enumerate(scope.expression.selects): 240 replace_columns(projection) 241 242 if isinstance(projection, exp.Alias): 243 alias_to_expression[projection.alias] = (projection.this, i + 1) 244 245 replace_columns(expression.args.get("where")) 246 replace_columns(expression.args.get("group"), literal_index=True) 247 replace_columns(expression.args.get("having"), resolve_table=True) 248 replace_columns(expression.args.get("qualify"), resolve_table=True) 249 250 scope.clear_cache() 251 252 253def _expand_group_by(scope: Scope) -> None: 254 expression = scope.expression 255 group = expression.args.get("group") 256 if not group: 257 return 258 259 group.set("expressions", _expand_positional_references(scope, group.expressions)) 260 expression.set("group", group) 261 262 263def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 264 order = scope.expression.args.get("order") 265 if not order: 266 return 267 268 ordereds = order.expressions 269 for ordered, new_expression in zip( 270 ordereds, 271 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 272 ): 273 for agg in ordered.find_all(exp.AggFunc): 274 for col in agg.find_all(exp.Column): 275 if not col.table: 276 col.set("table", resolver.get_table(col.name)) 277 278 ordered.set("this", new_expression) 279 280 if scope.expression.args.get("group"): 281 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 282 283 for ordered in ordereds: 284 ordered = ordered.this 285 286 ordered.replace( 287 exp.to_identifier(_select_by_pos(scope, ordered).alias) 288 if ordered.is_int 289 else selects.get(ordered, ordered) 290 ) 291 292 293def _expand_positional_references( 294 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 295) -> t.List[exp.Expression]: 296 new_nodes: t.List[exp.Expression] = [] 297 for node in expressions: 298 if node.is_int: 299 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 300 301 if alias: 302 new_nodes.append(exp.column(select.args["alias"].copy())) 303 else: 304 select = select.this 305 306 if isinstance(select, exp.Literal): 307 new_nodes.append(node) 308 else: 309 new_nodes.append(select.copy()) 310 else: 311 new_nodes.append(node) 312 313 return new_nodes 314 315 316def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 317 try: 318 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 319 except IndexError: 320 raise OptimizeError(f"Unknown output column: {node.name}") 321 322 323def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 324 """Disambiguate columns, ensuring each column specifies a source""" 325 for column in scope.columns: 326 column_table = column.table 327 column_name = column.name 328 329 if column_table and column_table in scope.sources: 330 source_columns = resolver.get_source_columns(column_table) 331 if source_columns and column_name not in source_columns and "*" not in source_columns: 332 raise OptimizeError(f"Unknown column: {column_name}") 333 334 if not column_table: 335 if scope.pivots and not column.find_ancestor(exp.Pivot): 336 # If the column is under the Pivot expression, we need to qualify it 337 # using the name of the pivoted source instead of the pivot's alias 338 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 339 continue 340 341 column_table = resolver.get_table(column_name) 342 343 # column_table can be a '' because bigquery unnest has no table alias 344 if column_table: 345 column.set("table", column_table) 346 elif column_table not in scope.sources and ( 347 not scope.parent 348 or column_table not in scope.parent.sources 349 or not scope.is_correlated_subquery 350 ): 351 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 352 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 353 354 root, *parts = column.parts 355 356 if root.name in scope.sources: 357 # struct is already qualified, but we still need to change the AST representation 358 column_table = root 359 root, *parts = parts 360 else: 361 column_table = resolver.get_table(root.name) 362 363 if column_table: 364 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 365 366 for pivot in scope.pivots: 367 for column in pivot.find_all(exp.Column): 368 if not column.table and column.name in resolver.all_columns: 369 column_table = resolver.get_table(column.name) 370 if column_table: 371 column.set("table", column_table) 372 373 374def _expand_stars( 375 scope: Scope, 376 resolver: Resolver, 377 using_column_tables: t.Dict[str, t.Any], 378 pseudocolumns: t.Set[str], 379) -> None: 380 """Expand stars to lists of column selections""" 381 382 new_selections = [] 383 except_columns: t.Dict[int, t.Set[str]] = {} 384 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 385 coalesced_columns = set() 386 387 pivot_output_columns = None 388 pivot_exclude_columns = None 389 390 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 391 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 392 if pivot.unpivot: 393 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 394 395 field = pivot.args.get("field") 396 if isinstance(field, exp.In): 397 pivot_exclude_columns = { 398 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 399 } 400 else: 401 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 402 403 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 404 if not pivot_output_columns: 405 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 406 407 for expression in scope.expression.selects: 408 if isinstance(expression, exp.Star): 409 tables = list(scope.selected_sources) 410 _add_except_columns(expression, tables, except_columns) 411 _add_replace_columns(expression, tables, replace_columns) 412 elif expression.is_star: 413 tables = [expression.table] 414 _add_except_columns(expression.this, tables, except_columns) 415 _add_replace_columns(expression.this, tables, replace_columns) 416 else: 417 new_selections.append(expression) 418 continue 419 420 for table in tables: 421 if table not in scope.sources: 422 raise OptimizeError(f"Unknown table: {table}") 423 424 columns = resolver.get_source_columns(table, only_visible=True) 425 columns = columns or scope.outer_column_list 426 427 if pseudocolumns: 428 columns = [name for name in columns if name.upper() not in pseudocolumns] 429 430 if not columns or "*" in columns: 431 return 432 433 table_id = id(table) 434 columns_to_exclude = except_columns.get(table_id) or set() 435 436 if pivot: 437 if pivot_output_columns and pivot_exclude_columns: 438 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 439 pivot_columns.extend(pivot_output_columns) 440 else: 441 pivot_columns = pivot.alias_column_names 442 443 if pivot_columns: 444 new_selections.extend( 445 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 446 for name in pivot_columns 447 if name not in columns_to_exclude 448 ) 449 continue 450 451 for name in columns: 452 if name in using_column_tables and table in using_column_tables[name]: 453 if name in coalesced_columns: 454 continue 455 456 coalesced_columns.add(name) 457 tables = using_column_tables[name] 458 coalesce = [exp.column(name, table=table) for table in tables] 459 460 new_selections.append( 461 alias( 462 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 463 alias=name, 464 copy=False, 465 ) 466 ) 467 elif name not in columns_to_exclude: 468 alias_ = replace_columns.get(table_id, {}).get(name, name) 469 column = exp.column(name, table=table) 470 new_selections.append( 471 alias(column, alias_, copy=False) if alias_ != name else column 472 ) 473 474 # Ensures we don't overwrite the initial selections with an empty list 475 if new_selections: 476 scope.expression.set("expressions", new_selections) 477 478 479def _add_except_columns( 480 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 481) -> None: 482 except_ = expression.args.get("except") 483 484 if not except_: 485 return 486 487 columns = {e.name for e in except_} 488 489 for table in tables: 490 except_columns[id(table)] = columns 491 492 493def _add_replace_columns( 494 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 495) -> None: 496 replace = expression.args.get("replace") 497 498 if not replace: 499 return 500 501 columns = {e.this.name: e.alias for e in replace} 502 503 for table in tables: 504 replace_columns[id(table)] = columns 505 506 507def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 508 """Ensure all output columns are aliased""" 509 if isinstance(scope_or_expression, exp.Expression): 510 scope = build_scope(scope_or_expression) 511 if not isinstance(scope, Scope): 512 return 513 else: 514 scope = scope_or_expression 515 516 new_selections = [] 517 for i, (selection, aliased_column) in enumerate( 518 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 519 ): 520 if selection is None: 521 break 522 523 if isinstance(selection, exp.Subquery): 524 if not selection.output_name: 525 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 526 elif not isinstance(selection, exp.Alias) and not selection.is_star: 527 selection = alias( 528 selection, 529 alias=selection.output_name or f"_col_{i}", 530 copy=False, 531 ) 532 if aliased_column: 533 selection.set("alias", exp.to_identifier(aliased_column)) 534 535 new_selections.append(selection) 536 537 scope.expression.set("expressions", new_selections) 538 539 540def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 541 """Makes sure all identifiers that need to be quoted are quoted.""" 542 return expression.transform( 543 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 544 ) 545 546 547def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 548 """ 549 Pushes down the CTE alias columns into the projection, 550 551 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 552 553 Example: 554 >>> import sqlglot 555 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 556 >>> pushdown_cte_alias_columns(expression).sql() 557 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 558 559 Args: 560 expression: Expression to pushdown. 561 562 Returns: 563 The expression with the CTE aliases pushed down into the projection. 564 """ 565 for cte in expression.find_all(exp.CTE): 566 if cte.alias_column_names: 567 new_expressions = [] 568 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 569 if isinstance(projection, exp.Alias): 570 projection.set("alias", _alias) 571 else: 572 projection = alias(projection, alias=_alias) 573 new_expressions.append(projection) 574 cte.this.set("expressions", new_expressions) 575 576 return expression 577 578 579class Resolver: 580 """ 581 Helper for resolving columns. 582 583 This is a class so we can lazily load some things and easily share them across functions. 584 """ 585 586 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 587 self.scope = scope 588 self.schema = schema 589 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 590 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 591 self._all_columns: t.Optional[t.Set[str]] = None 592 self._infer_schema = infer_schema 593 594 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 595 """ 596 Get the table for a column name. 597 598 Args: 599 column_name: The column name to find the table for. 600 Returns: 601 The table name if it can be found/inferred. 602 """ 603 if self._unambiguous_columns is None: 604 self._unambiguous_columns = self._get_unambiguous_columns( 605 self._get_all_source_columns() 606 ) 607 608 table_name = self._unambiguous_columns.get(column_name) 609 610 if not table_name and self._infer_schema: 611 sources_without_schema = tuple( 612 source 613 for source, columns in self._get_all_source_columns().items() 614 if not columns or "*" in columns 615 ) 616 if len(sources_without_schema) == 1: 617 table_name = sources_without_schema[0] 618 619 if table_name not in self.scope.selected_sources: 620 return exp.to_identifier(table_name) 621 622 node, _ = self.scope.selected_sources.get(table_name) 623 624 if isinstance(node, exp.Subqueryable): 625 while node and node.alias != table_name: 626 node = node.parent 627 628 node_alias = node.args.get("alias") 629 if node_alias: 630 return exp.to_identifier(node_alias.this) 631 632 return exp.to_identifier(table_name) 633 634 @property 635 def all_columns(self) -> t.Set[str]: 636 """All available columns of all sources in this scope""" 637 if self._all_columns is None: 638 self._all_columns = { 639 column for columns in self._get_all_source_columns().values() for column in columns 640 } 641 return self._all_columns 642 643 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 644 """Resolve the source columns for a given source `name`.""" 645 if name not in self.scope.sources: 646 raise OptimizeError(f"Unknown table: {name}") 647 648 source = self.scope.sources[name] 649 650 if isinstance(source, exp.Table): 651 columns = self.schema.column_names(source, only_visible) 652 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 653 columns = source.expression.alias_column_names 654 else: 655 columns = source.expression.named_selects 656 657 node, _ = self.scope.selected_sources.get(name) or (None, None) 658 if isinstance(node, Scope): 659 column_aliases = node.expression.alias_column_names 660 elif isinstance(node, exp.Expression): 661 column_aliases = node.alias_column_names 662 else: 663 column_aliases = [] 664 665 if column_aliases: 666 # If the source's columns are aliased, their aliases shadow the corresponding column names. 667 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 668 return [ 669 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 670 ] 671 return columns 672 673 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 674 if self._source_columns is None: 675 self._source_columns = { 676 source_name: self.get_source_columns(source_name) 677 for source_name, source in itertools.chain( 678 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 679 ) 680 } 681 return self._source_columns 682 683 def _get_unambiguous_columns( 684 self, source_columns: t.Dict[str, t.List[str]] 685 ) -> t.Dict[str, str]: 686 """ 687 Find all the unambiguous columns in sources. 688 689 Args: 690 source_columns: Mapping of names to source columns. 691 692 Returns: 693 Mapping of column name to source name. 694 """ 695 if not source_columns: 696 return {} 697 698 source_columns_pairs = list(source_columns.items()) 699 700 first_table, first_columns = source_columns_pairs[0] 701 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 702 all_columns = set(unambiguous_columns) 703 704 for table, columns in source_columns_pairs[1:]: 705 unique = self._find_unique_columns(columns) 706 ambiguous = set(all_columns).intersection(unique) 707 all_columns.update(columns) 708 709 for column in ambiguous: 710 unambiguous_columns.pop(column, None) 711 for column in unique.difference(ambiguous): 712 unambiguous_columns[column] = table 713 714 return unambiguous_columns 715 716 @staticmethod 717 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 718 """ 719 Find the unique columns in a list of columns. 720 721 Example: 722 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 723 ['a', 'c'] 724 725 This is necessary because duplicate column names are ambiguous. 726 """ 727 counts: t.Dict[str, int] = {} 728 for column in columns: 729 counts[column] = counts.get(column, 0) + 1 730 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether or not to expand references to aliases. 40 expand_stars: Whether or not to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether or not to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 infer_schema = schema.empty if infer_schema is None else infer_schema 53 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 54 55 for scope in traverse_scope(expression): 56 resolver = Resolver(scope, schema, infer_schema=infer_schema) 57 _pop_table_column_aliases(scope.ctes) 58 _pop_table_column_aliases(scope.derived_tables) 59 using_column_tables = _expand_using(scope, resolver) 60 61 if schema.empty and expand_alias_refs: 62 _expand_alias_refs(scope, resolver) 63 64 _qualify_columns(scope, resolver) 65 66 if not schema.empty and expand_alias_refs: 67 _expand_alias_refs(scope, resolver) 68 69 if not isinstance(scope.expression, exp.UDTF): 70 if expand_stars: 71 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 72 qualify_outputs(scope) 73 74 _expand_group_by(scope) 75 _expand_order_by(scope, resolver) 76 77 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
80def validate_qualify_columns(expression: E) -> E: 81 """Raise an `OptimizeError` if any columns aren't qualified""" 82 all_unqualified_columns = [] 83 for scope in traverse_scope(expression): 84 if isinstance(scope.expression, exp.Select): 85 unqualified_columns = scope.unqualified_columns 86 87 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 88 column = scope.external_columns[0] 89 for_table = f" for table: '{column.table}'" if column.table else "" 90 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 91 92 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 93 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 94 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 95 # this list here to ensure those in the former category will be excluded. 96 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 97 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 98 99 all_unqualified_columns.extend(unqualified_columns) 100 101 if all_unqualified_columns: 102 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 103 104 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
508def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 509 """Ensure all output columns are aliased""" 510 if isinstance(scope_or_expression, exp.Expression): 511 scope = build_scope(scope_or_expression) 512 if not isinstance(scope, Scope): 513 return 514 else: 515 scope = scope_or_expression 516 517 new_selections = [] 518 for i, (selection, aliased_column) in enumerate( 519 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 520 ): 521 if selection is None: 522 break 523 524 if isinstance(selection, exp.Subquery): 525 if not selection.output_name: 526 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 527 elif not isinstance(selection, exp.Alias) and not selection.is_star: 528 selection = alias( 529 selection, 530 alias=selection.output_name or f"_col_{i}", 531 copy=False, 532 ) 533 if aliased_column: 534 selection.set("alias", exp.to_identifier(aliased_column)) 535 536 new_selections.append(selection) 537 538 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
541def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 542 """Makes sure all identifiers that need to be quoted are quoted.""" 543 return expression.transform( 544 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 545 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
548def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 549 """ 550 Pushes down the CTE alias columns into the projection, 551 552 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 553 554 Example: 555 >>> import sqlglot 556 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 557 >>> pushdown_cte_alias_columns(expression).sql() 558 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 559 560 Args: 561 expression: Expression to pushdown. 562 563 Returns: 564 The expression with the CTE aliases pushed down into the projection. 565 """ 566 for cte in expression.find_all(exp.CTE): 567 if cte.alias_column_names: 568 new_expressions = [] 569 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 570 if isinstance(projection, exp.Alias): 571 projection.set("alias", _alias) 572 else: 573 projection = alias(projection, alias=_alias) 574 new_expressions.append(projection) 575 cte.this.set("expressions", new_expressions) 576 577 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
580class Resolver: 581 """ 582 Helper for resolving columns. 583 584 This is a class so we can lazily load some things and easily share them across functions. 585 """ 586 587 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 588 self.scope = scope 589 self.schema = schema 590 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 591 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 592 self._all_columns: t.Optional[t.Set[str]] = None 593 self._infer_schema = infer_schema 594 595 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 596 """ 597 Get the table for a column name. 598 599 Args: 600 column_name: The column name to find the table for. 601 Returns: 602 The table name if it can be found/inferred. 603 """ 604 if self._unambiguous_columns is None: 605 self._unambiguous_columns = self._get_unambiguous_columns( 606 self._get_all_source_columns() 607 ) 608 609 table_name = self._unambiguous_columns.get(column_name) 610 611 if not table_name and self._infer_schema: 612 sources_without_schema = tuple( 613 source 614 for source, columns in self._get_all_source_columns().items() 615 if not columns or "*" in columns 616 ) 617 if len(sources_without_schema) == 1: 618 table_name = sources_without_schema[0] 619 620 if table_name not in self.scope.selected_sources: 621 return exp.to_identifier(table_name) 622 623 node, _ = self.scope.selected_sources.get(table_name) 624 625 if isinstance(node, exp.Subqueryable): 626 while node and node.alias != table_name: 627 node = node.parent 628 629 node_alias = node.args.get("alias") 630 if node_alias: 631 return exp.to_identifier(node_alias.this) 632 633 return exp.to_identifier(table_name) 634 635 @property 636 def all_columns(self) -> t.Set[str]: 637 """All available columns of all sources in this scope""" 638 if self._all_columns is None: 639 self._all_columns = { 640 column for columns in self._get_all_source_columns().values() for column in columns 641 } 642 return self._all_columns 643 644 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 645 """Resolve the source columns for a given source `name`.""" 646 if name not in self.scope.sources: 647 raise OptimizeError(f"Unknown table: {name}") 648 649 source = self.scope.sources[name] 650 651 if isinstance(source, exp.Table): 652 columns = self.schema.column_names(source, only_visible) 653 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 654 columns = source.expression.alias_column_names 655 else: 656 columns = source.expression.named_selects 657 658 node, _ = self.scope.selected_sources.get(name) or (None, None) 659 if isinstance(node, Scope): 660 column_aliases = node.expression.alias_column_names 661 elif isinstance(node, exp.Expression): 662 column_aliases = node.alias_column_names 663 else: 664 column_aliases = [] 665 666 if column_aliases: 667 # If the source's columns are aliased, their aliases shadow the corresponding column names. 668 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 669 return [ 670 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 671 ] 672 return columns 673 674 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 675 if self._source_columns is None: 676 self._source_columns = { 677 source_name: self.get_source_columns(source_name) 678 for source_name, source in itertools.chain( 679 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 680 ) 681 } 682 return self._source_columns 683 684 def _get_unambiguous_columns( 685 self, source_columns: t.Dict[str, t.List[str]] 686 ) -> t.Dict[str, str]: 687 """ 688 Find all the unambiguous columns in sources. 689 690 Args: 691 source_columns: Mapping of names to source columns. 692 693 Returns: 694 Mapping of column name to source name. 695 """ 696 if not source_columns: 697 return {} 698 699 source_columns_pairs = list(source_columns.items()) 700 701 first_table, first_columns = source_columns_pairs[0] 702 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 703 all_columns = set(unambiguous_columns) 704 705 for table, columns in source_columns_pairs[1:]: 706 unique = self._find_unique_columns(columns) 707 ambiguous = set(all_columns).intersection(unique) 708 all_columns.update(columns) 709 710 for column in ambiguous: 711 unambiguous_columns.pop(column, None) 712 for column in unique.difference(ambiguous): 713 unambiguous_columns[column] = table 714 715 return unambiguous_columns 716 717 @staticmethod 718 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 719 """ 720 Find the unique columns in a list of columns. 721 722 Example: 723 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 724 ['a', 'c'] 725 726 This is necessary because duplicate column names are ambiguous. 727 """ 728 counts: t.Dict[str, int] = {} 729 for column in columns: 730 counts[column] = counts.get(column, 0) + 1 731 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
587 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 588 self.scope = scope 589 self.schema = schema 590 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 591 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 592 self._all_columns: t.Optional[t.Set[str]] = None 593 self._infer_schema = infer_schema
595 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 596 """ 597 Get the table for a column name. 598 599 Args: 600 column_name: The column name to find the table for. 601 Returns: 602 The table name if it can be found/inferred. 603 """ 604 if self._unambiguous_columns is None: 605 self._unambiguous_columns = self._get_unambiguous_columns( 606 self._get_all_source_columns() 607 ) 608 609 table_name = self._unambiguous_columns.get(column_name) 610 611 if not table_name and self._infer_schema: 612 sources_without_schema = tuple( 613 source 614 for source, columns in self._get_all_source_columns().items() 615 if not columns or "*" in columns 616 ) 617 if len(sources_without_schema) == 1: 618 table_name = sources_without_schema[0] 619 620 if table_name not in self.scope.selected_sources: 621 return exp.to_identifier(table_name) 622 623 node, _ = self.scope.selected_sources.get(table_name) 624 625 if isinstance(node, exp.Subqueryable): 626 while node and node.alias != table_name: 627 node = node.parent 628 629 node_alias = node.args.get("alias") 630 if node_alias: 631 return exp.to_identifier(node_alias.this) 632 633 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
635 @property 636 def all_columns(self) -> t.Set[str]: 637 """All available columns of all sources in this scope""" 638 if self._all_columns is None: 639 self._all_columns = { 640 column for columns in self._get_all_source_columns().values() for column in columns 641 } 642 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
644 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 645 """Resolve the source columns for a given source `name`.""" 646 if name not in self.scope.sources: 647 raise OptimizeError(f"Unknown table: {name}") 648 649 source = self.scope.sources[name] 650 651 if isinstance(source, exp.Table): 652 columns = self.schema.column_names(source, only_visible) 653 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 654 columns = source.expression.alias_column_names 655 else: 656 columns = source.expression.named_selects 657 658 node, _ = self.scope.selected_sources.get(name) or (None, None) 659 if isinstance(node, Scope): 660 column_aliases = node.expression.alias_column_names 661 elif isinstance(node, exp.Expression): 662 column_aliases = node.alias_column_names 663 else: 664 column_aliases = [] 665 666 if column_aliases: 667 # If the source's columns are aliased, their aliases shadow the corresponding column names. 668 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 669 return [ 670 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 671 ] 672 return columns
Resolve the source columns for a given source name
.