sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 59 for scope in traverse_scope(expression): 60 resolver = Resolver(scope, schema, infer_schema=infer_schema) 61 _pop_table_column_aliases(scope.ctes) 62 _pop_table_column_aliases(scope.derived_tables) 63 using_column_tables = _expand_using(scope, resolver) 64 65 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 66 _expand_alias_refs( 67 scope, 68 resolver, 69 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 70 ) 71 72 _convert_columns_to_dots(scope, resolver) 73 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 74 75 if not schema.empty and expand_alias_refs: 76 _expand_alias_refs(scope, resolver) 77 78 if not isinstance(scope.expression, exp.UDTF): 79 if expand_stars: 80 _expand_stars( 81 scope, 82 resolver, 83 using_column_tables, 84 pseudocolumns, 85 annotator, 86 ) 87 qualify_outputs(scope) 88 89 _expand_group_by(scope, dialect) 90 _expand_order_by(scope, resolver) 91 92 if dialect == "bigquery": 93 annotator.annotate_scope(scope) 94 95 return expression 96 97 98def validate_qualify_columns(expression: E) -> E: 99 """Raise an `OptimizeError` if any columns aren't qualified""" 100 all_unqualified_columns = [] 101 for scope in traverse_scope(expression): 102 if isinstance(scope.expression, exp.Select): 103 unqualified_columns = scope.unqualified_columns 104 105 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 106 column = scope.external_columns[0] 107 for_table = f" for table: '{column.table}'" if column.table else "" 108 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 109 110 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 111 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 112 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 113 # this list here to ensure those in the former category will be excluded. 114 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 115 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 116 117 all_unqualified_columns.extend(unqualified_columns) 118 119 if all_unqualified_columns: 120 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 121 122 return expression 123 124 125def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 126 name_column = [] 127 field = unpivot.args.get("field") 128 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 129 name_column.append(field.this) 130 131 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 132 return itertools.chain(name_column, value_columns) 133 134 135def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 136 """ 137 Remove table column aliases. 138 139 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 140 """ 141 for derived_table in derived_tables: 142 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 143 continue 144 table_alias = derived_table.args.get("alias") 145 if table_alias: 146 table_alias.args.pop("columns", None) 147 148 149def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 150 columns = {} 151 152 def _update_source_columns(source_name: str) -> None: 153 for column_name in resolver.get_source_columns(source_name): 154 if column_name not in columns: 155 columns[column_name] = source_name 156 157 joins = list(scope.find_all(exp.Join)) 158 names = {join.alias_or_name for join in joins} 159 ordered = [key for key in scope.selected_sources if key not in names] 160 161 # Mapping of automatically joined column names to an ordered set of source names (dict). 162 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 163 164 for source_name in ordered: 165 _update_source_columns(source_name) 166 167 for i, join in enumerate(joins): 168 source_table = ordered[-1] 169 if source_table: 170 _update_source_columns(source_table) 171 172 join_table = join.alias_or_name 173 ordered.append(join_table) 174 175 using = join.args.get("using") 176 if not using: 177 continue 178 179 join_columns = resolver.get_source_columns(join_table) 180 conditions = [] 181 using_identifier_count = len(using) 182 183 for identifier in using: 184 identifier = identifier.name 185 table = columns.get(identifier) 186 187 if not table or identifier not in join_columns: 188 if (columns and "*" not in columns) and join_columns: 189 raise OptimizeError(f"Cannot automatically join: {identifier}") 190 191 table = table or source_table 192 193 if i == 0 or using_identifier_count == 1: 194 lhs: exp.Expression = exp.column(identifier, table=table) 195 else: 196 coalesce_columns = [ 197 exp.column(identifier, table=t) 198 for t in ordered[:-1] 199 if identifier in resolver.get_source_columns(t) 200 ] 201 if len(coalesce_columns) > 1: 202 lhs = exp.func("coalesce", *coalesce_columns) 203 else: 204 lhs = exp.column(identifier, table=table) 205 206 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 207 208 # Set all values in the dict to None, because we only care about the key ordering 209 tables = column_tables.setdefault(identifier, {}) 210 if table not in tables: 211 tables[table] = None 212 if join_table not in tables: 213 tables[join_table] = None 214 215 join.args.pop("using") 216 join.set("on", exp.and_(*conditions, copy=False)) 217 218 if column_tables: 219 for column in scope.columns: 220 if not column.table and column.name in column_tables: 221 tables = column_tables[column.name] 222 coalesce_args = [exp.column(column.name, table=table) for table in tables] 223 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 224 225 if isinstance(column.parent, exp.Select): 226 # Ensure the USING column keeps its name if it's projected 227 replacement = alias(replacement, alias=column.name, copy=False) 228 elif isinstance(column.parent, exp.Struct): 229 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 230 replacement = exp.PropertyEQ( 231 this=exp.to_identifier(column.name), expression=replacement 232 ) 233 234 scope.replace(column, replacement) 235 236 return column_tables 237 238 239def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None: 240 expression = scope.expression 241 242 if not isinstance(expression, exp.Select): 243 return 244 245 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 246 247 def replace_columns( 248 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 249 ) -> None: 250 is_group_by = isinstance(node, exp.Group) 251 if not node or (expand_only_groupby and not is_group_by): 252 return 253 254 for column in walk_in_scope(node, prune=lambda node: node.is_star): 255 if not isinstance(column, exp.Column): 256 continue 257 258 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 259 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 260 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 261 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 262 if expand_only_groupby and is_group_by and column.parent is not node: 263 continue 264 265 table = resolver.get_table(column.name) if resolve_table and not column.table else None 266 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 267 double_agg = ( 268 ( 269 alias_expr.find(exp.AggFunc) 270 and ( 271 column.find_ancestor(exp.AggFunc) 272 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 273 ) 274 ) 275 if alias_expr 276 else False 277 ) 278 279 if table and (not alias_expr or double_agg): 280 column.set("table", table) 281 elif not column.table and alias_expr and not double_agg: 282 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 283 if literal_index: 284 column.replace(exp.Literal.number(i)) 285 else: 286 column = column.replace(exp.paren(alias_expr)) 287 simplified = simplify_parens(column) 288 if simplified is not column: 289 column.replace(simplified) 290 291 for i, projection in enumerate(expression.selects): 292 replace_columns(projection) 293 if isinstance(projection, exp.Alias): 294 alias_to_expression[projection.alias] = (projection.this, i + 1) 295 296 parent_scope = scope 297 while parent_scope.is_union: 298 parent_scope = parent_scope.parent 299 300 # We shouldn't expand aliases if they match the recursive CTE's columns 301 if parent_scope.is_cte: 302 cte = parent_scope.expression.parent 303 if cte.find_ancestor(exp.With).recursive: 304 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 305 alias_to_expression.pop(recursive_cte_column.output_name, None) 306 307 replace_columns(expression.args.get("where")) 308 replace_columns(expression.args.get("group"), literal_index=True) 309 replace_columns(expression.args.get("having"), resolve_table=True) 310 replace_columns(expression.args.get("qualify"), resolve_table=True) 311 312 scope.clear_cache() 313 314 315def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 316 expression = scope.expression 317 group = expression.args.get("group") 318 if not group: 319 return 320 321 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 322 expression.set("group", group) 323 324 325def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 326 order = scope.expression.args.get("order") 327 if not order: 328 return 329 330 ordereds = order.expressions 331 for ordered, new_expression in zip( 332 ordereds, 333 _expand_positional_references( 334 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 335 ), 336 ): 337 for agg in ordered.find_all(exp.AggFunc): 338 for col in agg.find_all(exp.Column): 339 if not col.table: 340 col.set("table", resolver.get_table(col.name)) 341 342 ordered.set("this", new_expression) 343 344 if scope.expression.args.get("group"): 345 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 346 347 for ordered in ordereds: 348 ordered = ordered.this 349 350 ordered.replace( 351 exp.to_identifier(_select_by_pos(scope, ordered).alias) 352 if ordered.is_int 353 else selects.get(ordered, ordered) 354 ) 355 356 357def _expand_positional_references( 358 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 359) -> t.List[exp.Expression]: 360 new_nodes: t.List[exp.Expression] = [] 361 ambiguous_projections = None 362 363 for node in expressions: 364 if node.is_int: 365 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 366 367 if alias: 368 new_nodes.append(exp.column(select.args["alias"].copy())) 369 else: 370 select = select.this 371 372 if dialect == "bigquery": 373 if ambiguous_projections is None: 374 # When a projection name is also a source name and it is referenced in the 375 # GROUP BY clause, BQ can't understand what the identifier corresponds to 376 ambiguous_projections = { 377 s.alias_or_name 378 for s in scope.expression.selects 379 if s.alias_or_name in scope.selected_sources 380 } 381 382 ambiguous = any( 383 column.parts[0].name in ambiguous_projections 384 for column in select.find_all(exp.Column) 385 ) 386 else: 387 ambiguous = False 388 389 if ( 390 isinstance(select, exp.CONSTANTS) 391 or select.find(exp.Explode, exp.Unnest) 392 or ambiguous 393 ): 394 new_nodes.append(node) 395 else: 396 new_nodes.append(select.copy()) 397 else: 398 new_nodes.append(node) 399 400 return new_nodes 401 402 403def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 404 try: 405 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 406 except IndexError: 407 raise OptimizeError(f"Unknown output column: {node.name}") 408 409 410def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 411 """ 412 Converts `Column` instances that represent struct field lookup into chained `Dots`. 413 414 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 415 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 416 """ 417 converted = False 418 for column in itertools.chain(scope.columns, scope.stars): 419 if isinstance(column, exp.Dot): 420 continue 421 422 column_table: t.Optional[str | exp.Identifier] = column.table 423 if ( 424 column_table 425 and column_table not in scope.sources 426 and ( 427 not scope.parent 428 or column_table not in scope.parent.sources 429 or not scope.is_correlated_subquery 430 ) 431 ): 432 root, *parts = column.parts 433 434 if root.name in scope.sources: 435 # The struct is already qualified, but we still need to change the AST 436 column_table = root 437 root, *parts = parts 438 else: 439 column_table = resolver.get_table(root.name) 440 441 if column_table: 442 converted = True 443 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 444 445 if converted: 446 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 447 # a `for column in scope.columns` iteration, even though they shouldn't be 448 scope.clear_cache() 449 450 451def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 452 """Disambiguate columns, ensuring each column specifies a source""" 453 for column in scope.columns: 454 column_table = column.table 455 column_name = column.name 456 457 if column_table and column_table in scope.sources: 458 source_columns = resolver.get_source_columns(column_table) 459 if ( 460 not allow_partial_qualification 461 and source_columns 462 and column_name not in source_columns 463 and "*" not in source_columns 464 ): 465 raise OptimizeError(f"Unknown column: {column_name}") 466 467 if not column_table: 468 if scope.pivots and not column.find_ancestor(exp.Pivot): 469 # If the column is under the Pivot expression, we need to qualify it 470 # using the name of the pivoted source instead of the pivot's alias 471 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 472 continue 473 474 # column_table can be a '' because bigquery unnest has no table alias 475 column_table = resolver.get_table(column_name) 476 if column_table: 477 column.set("table", column_table) 478 479 for pivot in scope.pivots: 480 for column in pivot.find_all(exp.Column): 481 if not column.table and column.name in resolver.all_columns: 482 column_table = resolver.get_table(column.name) 483 if column_table: 484 column.set("table", column_table) 485 486 487def _expand_struct_stars( 488 expression: exp.Dot, 489) -> t.List[exp.Alias]: 490 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 491 492 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 493 if not dot_column.is_type(exp.DataType.Type.STRUCT): 494 return [] 495 496 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 497 dot_column = dot_column.copy() 498 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 499 500 # First part is the table name and last part is the star so they can be dropped 501 dot_parts = expression.parts[1:-1] 502 503 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 504 for part in dot_parts[1:]: 505 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 506 # Unable to expand star unless all fields are named 507 if not isinstance(field.this, exp.Identifier): 508 return [] 509 510 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 511 starting_struct = field 512 break 513 else: 514 # There is no matching field in the struct 515 return [] 516 517 taken_names = set() 518 new_selections = [] 519 520 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 521 name = field.name 522 523 # Ambiguous or anonymous fields can't be expanded 524 if name in taken_names or not isinstance(field.this, exp.Identifier): 525 return [] 526 527 taken_names.add(name) 528 529 this = field.this.copy() 530 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 531 new_column = exp.column( 532 t.cast(exp.Identifier, root), 533 table=dot_column.args.get("table"), 534 fields=t.cast(t.List[exp.Identifier], parts), 535 ) 536 new_selections.append(alias(new_column, this, copy=False)) 537 538 return new_selections 539 540 541def _expand_stars( 542 scope: Scope, 543 resolver: Resolver, 544 using_column_tables: t.Dict[str, t.Any], 545 pseudocolumns: t.Set[str], 546 annotator: TypeAnnotator, 547) -> None: 548 """Expand stars to lists of column selections""" 549 550 new_selections: t.List[exp.Expression] = [] 551 except_columns: t.Dict[int, t.Set[str]] = {} 552 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 553 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 554 555 coalesced_columns = set() 556 dialect = resolver.schema.dialect 557 558 pivot_output_columns = None 559 pivot_exclude_columns = None 560 561 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 562 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 563 if pivot.unpivot: 564 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 565 566 field = pivot.args.get("field") 567 if isinstance(field, exp.In): 568 pivot_exclude_columns = { 569 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 570 } 571 else: 572 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 573 574 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 575 if not pivot_output_columns: 576 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 577 578 is_bigquery = dialect == "bigquery" 579 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 580 # Found struct expansion, annotate scope ahead of time 581 annotator.annotate_scope(scope) 582 583 for expression in scope.expression.selects: 584 tables = [] 585 if isinstance(expression, exp.Star): 586 tables.extend(scope.selected_sources) 587 _add_except_columns(expression, tables, except_columns) 588 _add_replace_columns(expression, tables, replace_columns) 589 _add_rename_columns(expression, tables, rename_columns) 590 elif expression.is_star: 591 if not isinstance(expression, exp.Dot): 592 tables.append(expression.table) 593 _add_except_columns(expression.this, tables, except_columns) 594 _add_replace_columns(expression.this, tables, replace_columns) 595 _add_rename_columns(expression.this, tables, rename_columns) 596 elif is_bigquery: 597 struct_fields = _expand_struct_stars(expression) 598 if struct_fields: 599 new_selections.extend(struct_fields) 600 continue 601 602 if not tables: 603 new_selections.append(expression) 604 continue 605 606 for table in tables: 607 if table not in scope.sources: 608 raise OptimizeError(f"Unknown table: {table}") 609 610 columns = resolver.get_source_columns(table, only_visible=True) 611 columns = columns or scope.outer_columns 612 613 if pseudocolumns: 614 columns = [name for name in columns if name.upper() not in pseudocolumns] 615 616 if not columns or "*" in columns: 617 return 618 619 table_id = id(table) 620 columns_to_exclude = except_columns.get(table_id) or set() 621 renamed_columns = rename_columns.get(table_id, {}) 622 replaced_columns = replace_columns.get(table_id, {}) 623 624 if pivot: 625 if pivot_output_columns and pivot_exclude_columns: 626 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 627 pivot_columns.extend(pivot_output_columns) 628 else: 629 pivot_columns = pivot.alias_column_names 630 631 if pivot_columns: 632 new_selections.extend( 633 alias(exp.column(name, table=pivot.alias), name, copy=False) 634 for name in pivot_columns 635 if name not in columns_to_exclude 636 ) 637 continue 638 639 for name in columns: 640 if name in columns_to_exclude or name in coalesced_columns: 641 continue 642 if name in using_column_tables and table in using_column_tables[name]: 643 coalesced_columns.add(name) 644 tables = using_column_tables[name] 645 coalesce_args = [exp.column(name, table=table) for table in tables] 646 647 new_selections.append( 648 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 649 ) 650 else: 651 alias_ = renamed_columns.get(name, name) 652 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 653 new_selections.append( 654 alias(selection_expr, alias_, copy=False) 655 if alias_ != name 656 else selection_expr 657 ) 658 659 # Ensures we don't overwrite the initial selections with an empty list 660 if new_selections and isinstance(scope.expression, exp.Select): 661 scope.expression.set("expressions", new_selections) 662 663 664def _add_except_columns( 665 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 666) -> None: 667 except_ = expression.args.get("except") 668 669 if not except_: 670 return 671 672 columns = {e.name for e in except_} 673 674 for table in tables: 675 except_columns[id(table)] = columns 676 677 678def _add_rename_columns( 679 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 680) -> None: 681 rename = expression.args.get("rename") 682 683 if not rename: 684 return 685 686 columns = {e.this.name: e.alias for e in rename} 687 688 for table in tables: 689 rename_columns[id(table)] = columns 690 691 692def _add_replace_columns( 693 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 694) -> None: 695 replace = expression.args.get("replace") 696 697 if not replace: 698 return 699 700 columns = {e.alias: e for e in replace} 701 702 for table in tables: 703 replace_columns[id(table)] = columns 704 705 706def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 707 """Ensure all output columns are aliased""" 708 if isinstance(scope_or_expression, exp.Expression): 709 scope = build_scope(scope_or_expression) 710 if not isinstance(scope, Scope): 711 return 712 else: 713 scope = scope_or_expression 714 715 new_selections = [] 716 for i, (selection, aliased_column) in enumerate( 717 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 718 ): 719 if selection is None: 720 break 721 722 if isinstance(selection, exp.Subquery): 723 if not selection.output_name: 724 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 725 elif not isinstance(selection, exp.Alias) and not selection.is_star: 726 selection = alias( 727 selection, 728 alias=selection.output_name or f"_col_{i}", 729 copy=False, 730 ) 731 if aliased_column: 732 selection.set("alias", exp.to_identifier(aliased_column)) 733 734 new_selections.append(selection) 735 736 if isinstance(scope.expression, exp.Select): 737 scope.expression.set("expressions", new_selections) 738 739 740def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 741 """Makes sure all identifiers that need to be quoted are quoted.""" 742 return expression.transform( 743 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 744 ) # type: ignore 745 746 747def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 748 """ 749 Pushes down the CTE alias columns into the projection, 750 751 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 752 753 Example: 754 >>> import sqlglot 755 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 756 >>> pushdown_cte_alias_columns(expression).sql() 757 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 758 759 Args: 760 expression: Expression to pushdown. 761 762 Returns: 763 The expression with the CTE aliases pushed down into the projection. 764 """ 765 for cte in expression.find_all(exp.CTE): 766 if cte.alias_column_names: 767 new_expressions = [] 768 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 769 if isinstance(projection, exp.Alias): 770 projection.set("alias", _alias) 771 else: 772 projection = alias(projection, alias=_alias) 773 new_expressions.append(projection) 774 cte.this.set("expressions", new_expressions) 775 776 return expression 777 778 779class Resolver: 780 """ 781 Helper for resolving columns. 782 783 This is a class so we can lazily load some things and easily share them across functions. 784 """ 785 786 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 787 self.scope = scope 788 self.schema = schema 789 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 790 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 791 self._all_columns: t.Optional[t.Set[str]] = None 792 self._infer_schema = infer_schema 793 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 794 795 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 796 """ 797 Get the table for a column name. 798 799 Args: 800 column_name: The column name to find the table for. 801 Returns: 802 The table name if it can be found/inferred. 803 """ 804 if self._unambiguous_columns is None: 805 self._unambiguous_columns = self._get_unambiguous_columns( 806 self._get_all_source_columns() 807 ) 808 809 table_name = self._unambiguous_columns.get(column_name) 810 811 if not table_name and self._infer_schema: 812 sources_without_schema = tuple( 813 source 814 for source, columns in self._get_all_source_columns().items() 815 if not columns or "*" in columns 816 ) 817 if len(sources_without_schema) == 1: 818 table_name = sources_without_schema[0] 819 820 if table_name not in self.scope.selected_sources: 821 return exp.to_identifier(table_name) 822 823 node, _ = self.scope.selected_sources.get(table_name) 824 825 if isinstance(node, exp.Query): 826 while node and node.alias != table_name: 827 node = node.parent 828 829 node_alias = node.args.get("alias") 830 if node_alias: 831 return exp.to_identifier(node_alias.this) 832 833 return exp.to_identifier(table_name) 834 835 @property 836 def all_columns(self) -> t.Set[str]: 837 """All available columns of all sources in this scope""" 838 if self._all_columns is None: 839 self._all_columns = { 840 column for columns in self._get_all_source_columns().values() for column in columns 841 } 842 return self._all_columns 843 844 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 845 """Resolve the source columns for a given source `name`.""" 846 cache_key = (name, only_visible) 847 if cache_key not in self._get_source_columns_cache: 848 if name not in self.scope.sources: 849 raise OptimizeError(f"Unknown table: {name}") 850 851 source = self.scope.sources[name] 852 853 if isinstance(source, exp.Table): 854 columns = self.schema.column_names(source, only_visible) 855 elif isinstance(source, Scope) and isinstance( 856 source.expression, (exp.Values, exp.Unnest) 857 ): 858 columns = source.expression.named_selects 859 860 # in bigquery, unnest structs are automatically scoped as tables, so you can 861 # directly select a struct field in a query. 862 # this handles the case where the unnest is statically defined. 863 if self.schema.dialect == "bigquery": 864 if source.expression.is_type(exp.DataType.Type.STRUCT): 865 for k in source.expression.type.expressions: # type: ignore 866 columns.append(k.name) 867 else: 868 columns = source.expression.named_selects 869 870 node, _ = self.scope.selected_sources.get(name) or (None, None) 871 if isinstance(node, Scope): 872 column_aliases = node.expression.alias_column_names 873 elif isinstance(node, exp.Expression): 874 column_aliases = node.alias_column_names 875 else: 876 column_aliases = [] 877 878 if column_aliases: 879 # If the source's columns are aliased, their aliases shadow the corresponding column names. 880 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 881 columns = [ 882 alias or name 883 for (name, alias) in itertools.zip_longest(columns, column_aliases) 884 ] 885 886 self._get_source_columns_cache[cache_key] = columns 887 888 return self._get_source_columns_cache[cache_key] 889 890 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 891 if self._source_columns is None: 892 self._source_columns = { 893 source_name: self.get_source_columns(source_name) 894 for source_name, source in itertools.chain( 895 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 896 ) 897 } 898 return self._source_columns 899 900 def _get_unambiguous_columns( 901 self, source_columns: t.Dict[str, t.Sequence[str]] 902 ) -> t.Mapping[str, str]: 903 """ 904 Find all the unambiguous columns in sources. 905 906 Args: 907 source_columns: Mapping of names to source columns. 908 909 Returns: 910 Mapping of column name to source name. 911 """ 912 if not source_columns: 913 return {} 914 915 source_columns_pairs = list(source_columns.items()) 916 917 first_table, first_columns = source_columns_pairs[0] 918 919 if len(source_columns_pairs) == 1: 920 # Performance optimization - avoid copying first_columns if there is only one table. 921 return SingleValuedMapping(first_columns, first_table) 922 923 unambiguous_columns = {col: first_table for col in first_columns} 924 all_columns = set(unambiguous_columns) 925 926 for table, columns in source_columns_pairs[1:]: 927 unique = set(columns) 928 ambiguous = all_columns.intersection(unique) 929 all_columns.update(columns) 930 931 for column in ambiguous: 932 unambiguous_columns.pop(column, None) 933 for column in unique.difference(ambiguous): 934 unambiguous_columns[column] = table 935 936 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 60 for scope in traverse_scope(expression): 61 resolver = Resolver(scope, schema, infer_schema=infer_schema) 62 _pop_table_column_aliases(scope.ctes) 63 _pop_table_column_aliases(scope.derived_tables) 64 using_column_tables = _expand_using(scope, resolver) 65 66 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 67 _expand_alias_refs( 68 scope, 69 resolver, 70 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 71 ) 72 73 _convert_columns_to_dots(scope, resolver) 74 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 75 76 if not schema.empty and expand_alias_refs: 77 _expand_alias_refs(scope, resolver) 78 79 if not isinstance(scope.expression, exp.UDTF): 80 if expand_stars: 81 _expand_stars( 82 scope, 83 resolver, 84 using_column_tables, 85 pseudocolumns, 86 annotator, 87 ) 88 qualify_outputs(scope) 89 90 _expand_group_by(scope, dialect) 91 _expand_order_by(scope, resolver) 92 93 if dialect == "bigquery": 94 annotator.annotate_scope(scope) 95 96 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
99def validate_qualify_columns(expression: E) -> E: 100 """Raise an `OptimizeError` if any columns aren't qualified""" 101 all_unqualified_columns = [] 102 for scope in traverse_scope(expression): 103 if isinstance(scope.expression, exp.Select): 104 unqualified_columns = scope.unqualified_columns 105 106 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 107 column = scope.external_columns[0] 108 for_table = f" for table: '{column.table}'" if column.table else "" 109 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 110 111 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 112 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 113 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 114 # this list here to ensure those in the former category will be excluded. 115 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 116 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 117 118 all_unqualified_columns.extend(unqualified_columns) 119 120 if all_unqualified_columns: 121 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 122 123 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
707def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 708 """Ensure all output columns are aliased""" 709 if isinstance(scope_or_expression, exp.Expression): 710 scope = build_scope(scope_or_expression) 711 if not isinstance(scope, Scope): 712 return 713 else: 714 scope = scope_or_expression 715 716 new_selections = [] 717 for i, (selection, aliased_column) in enumerate( 718 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 719 ): 720 if selection is None: 721 break 722 723 if isinstance(selection, exp.Subquery): 724 if not selection.output_name: 725 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 726 elif not isinstance(selection, exp.Alias) and not selection.is_star: 727 selection = alias( 728 selection, 729 alias=selection.output_name or f"_col_{i}", 730 copy=False, 731 ) 732 if aliased_column: 733 selection.set("alias", exp.to_identifier(aliased_column)) 734 735 new_selections.append(selection) 736 737 if isinstance(scope.expression, exp.Select): 738 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
741def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 742 """Makes sure all identifiers that need to be quoted are quoted.""" 743 return expression.transform( 744 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 745 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
748def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 749 """ 750 Pushes down the CTE alias columns into the projection, 751 752 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 753 754 Example: 755 >>> import sqlglot 756 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 757 >>> pushdown_cte_alias_columns(expression).sql() 758 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 759 760 Args: 761 expression: Expression to pushdown. 762 763 Returns: 764 The expression with the CTE aliases pushed down into the projection. 765 """ 766 for cte in expression.find_all(exp.CTE): 767 if cte.alias_column_names: 768 new_expressions = [] 769 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 770 if isinstance(projection, exp.Alias): 771 projection.set("alias", _alias) 772 else: 773 projection = alias(projection, alias=_alias) 774 new_expressions.append(projection) 775 cte.this.set("expressions", new_expressions) 776 777 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
780class Resolver: 781 """ 782 Helper for resolving columns. 783 784 This is a class so we can lazily load some things and easily share them across functions. 785 """ 786 787 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 788 self.scope = scope 789 self.schema = schema 790 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 791 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 792 self._all_columns: t.Optional[t.Set[str]] = None 793 self._infer_schema = infer_schema 794 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 795 796 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 797 """ 798 Get the table for a column name. 799 800 Args: 801 column_name: The column name to find the table for. 802 Returns: 803 The table name if it can be found/inferred. 804 """ 805 if self._unambiguous_columns is None: 806 self._unambiguous_columns = self._get_unambiguous_columns( 807 self._get_all_source_columns() 808 ) 809 810 table_name = self._unambiguous_columns.get(column_name) 811 812 if not table_name and self._infer_schema: 813 sources_without_schema = tuple( 814 source 815 for source, columns in self._get_all_source_columns().items() 816 if not columns or "*" in columns 817 ) 818 if len(sources_without_schema) == 1: 819 table_name = sources_without_schema[0] 820 821 if table_name not in self.scope.selected_sources: 822 return exp.to_identifier(table_name) 823 824 node, _ = self.scope.selected_sources.get(table_name) 825 826 if isinstance(node, exp.Query): 827 while node and node.alias != table_name: 828 node = node.parent 829 830 node_alias = node.args.get("alias") 831 if node_alias: 832 return exp.to_identifier(node_alias.this) 833 834 return exp.to_identifier(table_name) 835 836 @property 837 def all_columns(self) -> t.Set[str]: 838 """All available columns of all sources in this scope""" 839 if self._all_columns is None: 840 self._all_columns = { 841 column for columns in self._get_all_source_columns().values() for column in columns 842 } 843 return self._all_columns 844 845 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 846 """Resolve the source columns for a given source `name`.""" 847 cache_key = (name, only_visible) 848 if cache_key not in self._get_source_columns_cache: 849 if name not in self.scope.sources: 850 raise OptimizeError(f"Unknown table: {name}") 851 852 source = self.scope.sources[name] 853 854 if isinstance(source, exp.Table): 855 columns = self.schema.column_names(source, only_visible) 856 elif isinstance(source, Scope) and isinstance( 857 source.expression, (exp.Values, exp.Unnest) 858 ): 859 columns = source.expression.named_selects 860 861 # in bigquery, unnest structs are automatically scoped as tables, so you can 862 # directly select a struct field in a query. 863 # this handles the case where the unnest is statically defined. 864 if self.schema.dialect == "bigquery": 865 if source.expression.is_type(exp.DataType.Type.STRUCT): 866 for k in source.expression.type.expressions: # type: ignore 867 columns.append(k.name) 868 else: 869 columns = source.expression.named_selects 870 871 node, _ = self.scope.selected_sources.get(name) or (None, None) 872 if isinstance(node, Scope): 873 column_aliases = node.expression.alias_column_names 874 elif isinstance(node, exp.Expression): 875 column_aliases = node.alias_column_names 876 else: 877 column_aliases = [] 878 879 if column_aliases: 880 # If the source's columns are aliased, their aliases shadow the corresponding column names. 881 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 882 columns = [ 883 alias or name 884 for (name, alias) in itertools.zip_longest(columns, column_aliases) 885 ] 886 887 self._get_source_columns_cache[cache_key] = columns 888 889 return self._get_source_columns_cache[cache_key] 890 891 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 892 if self._source_columns is None: 893 self._source_columns = { 894 source_name: self.get_source_columns(source_name) 895 for source_name, source in itertools.chain( 896 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 897 ) 898 } 899 return self._source_columns 900 901 def _get_unambiguous_columns( 902 self, source_columns: t.Dict[str, t.Sequence[str]] 903 ) -> t.Mapping[str, str]: 904 """ 905 Find all the unambiguous columns in sources. 906 907 Args: 908 source_columns: Mapping of names to source columns. 909 910 Returns: 911 Mapping of column name to source name. 912 """ 913 if not source_columns: 914 return {} 915 916 source_columns_pairs = list(source_columns.items()) 917 918 first_table, first_columns = source_columns_pairs[0] 919 920 if len(source_columns_pairs) == 1: 921 # Performance optimization - avoid copying first_columns if there is only one table. 922 return SingleValuedMapping(first_columns, first_table) 923 924 unambiguous_columns = {col: first_table for col in first_columns} 925 all_columns = set(unambiguous_columns) 926 927 for table, columns in source_columns_pairs[1:]: 928 unique = set(columns) 929 ambiguous = all_columns.intersection(unique) 930 all_columns.update(columns) 931 932 for column in ambiguous: 933 unambiguous_columns.pop(column, None) 934 for column in unique.difference(ambiguous): 935 unambiguous_columns[column] = table 936 937 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
787 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 788 self.scope = scope 789 self.schema = schema 790 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 791 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 792 self._all_columns: t.Optional[t.Set[str]] = None 793 self._infer_schema = infer_schema 794 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
796 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 797 """ 798 Get the table for a column name. 799 800 Args: 801 column_name: The column name to find the table for. 802 Returns: 803 The table name if it can be found/inferred. 804 """ 805 if self._unambiguous_columns is None: 806 self._unambiguous_columns = self._get_unambiguous_columns( 807 self._get_all_source_columns() 808 ) 809 810 table_name = self._unambiguous_columns.get(column_name) 811 812 if not table_name and self._infer_schema: 813 sources_without_schema = tuple( 814 source 815 for source, columns in self._get_all_source_columns().items() 816 if not columns or "*" in columns 817 ) 818 if len(sources_without_schema) == 1: 819 table_name = sources_without_schema[0] 820 821 if table_name not in self.scope.selected_sources: 822 return exp.to_identifier(table_name) 823 824 node, _ = self.scope.selected_sources.get(table_name) 825 826 if isinstance(node, exp.Query): 827 while node and node.alias != table_name: 828 node = node.parent 829 830 node_alias = node.args.get("alias") 831 if node_alias: 832 return exp.to_identifier(node_alias.this) 833 834 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
836 @property 837 def all_columns(self) -> t.Set[str]: 838 """All available columns of all sources in this scope""" 839 if self._all_columns is None: 840 self._all_columns = { 841 column for columns in self._get_all_source_columns().values() for column in columns 842 } 843 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
845 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 846 """Resolve the source columns for a given source `name`.""" 847 cache_key = (name, only_visible) 848 if cache_key not in self._get_source_columns_cache: 849 if name not in self.scope.sources: 850 raise OptimizeError(f"Unknown table: {name}") 851 852 source = self.scope.sources[name] 853 854 if isinstance(source, exp.Table): 855 columns = self.schema.column_names(source, only_visible) 856 elif isinstance(source, Scope) and isinstance( 857 source.expression, (exp.Values, exp.Unnest) 858 ): 859 columns = source.expression.named_selects 860 861 # in bigquery, unnest structs are automatically scoped as tables, so you can 862 # directly select a struct field in a query. 863 # this handles the case where the unnest is statically defined. 864 if self.schema.dialect == "bigquery": 865 if source.expression.is_type(exp.DataType.Type.STRUCT): 866 for k in source.expression.type.expressions: # type: ignore 867 columns.append(k.name) 868 else: 869 columns = source.expression.named_selects 870 871 node, _ = self.scope.selected_sources.get(name) or (None, None) 872 if isinstance(node, Scope): 873 column_aliases = node.expression.alias_column_names 874 elif isinstance(node, exp.Expression): 875 column_aliases = node.alias_column_names 876 else: 877 column_aliases = [] 878 879 if column_aliases: 880 # If the source's columns are aliased, their aliases shadow the corresponding column names. 881 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 882 columns = [ 883 alias or name 884 for (name, alias) in itertools.zip_longest(columns, column_aliases) 885 ] 886 887 self._get_source_columns_cache[cache_key] = columns 888 889 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.