Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25    allow_partial_qualification: bool = False,
 26) -> exp.Expression:
 27    """
 28    Rewrite sqlglot AST to have fully qualified columns.
 29
 30    Example:
 31        >>> import sqlglot
 32        >>> schema = {"tbl": {"col": "INT"}}
 33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 34        >>> qualify_columns(expression, schema).sql()
 35        'SELECT tbl.col AS col FROM tbl'
 36
 37    Args:
 38        expression: Expression to qualify.
 39        schema: Database schema.
 40        expand_alias_refs: Whether to expand references to aliases.
 41        expand_stars: Whether to expand star queries. This is a necessary step
 42            for most of the optimizer's rules to work; do not set to False unless you
 43            know what you're doing!
 44        infer_schema: Whether to infer the schema if missing.
 45        allow_partial_qualification: Whether to allow partial qualification.
 46
 47    Returns:
 48        The qualified expression.
 49
 50    Notes:
 51        - Currently only handles a single PIVOT or UNPIVOT operator
 52    """
 53    schema = ensure_schema(schema)
 54    annotator = TypeAnnotator(schema)
 55    infer_schema = schema.empty if infer_schema is None else infer_schema
 56    dialect = Dialect.get_or_raise(schema.dialect)
 57    pseudocolumns = dialect.PSEUDOCOLUMNS
 58
 59    for scope in traverse_scope(expression):
 60        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 61        _pop_table_column_aliases(scope.ctes)
 62        _pop_table_column_aliases(scope.derived_tables)
 63        using_column_tables = _expand_using(scope, resolver)
 64
 65        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 66            _expand_alias_refs(
 67                scope,
 68                resolver,
 69                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 70            )
 71
 72        _convert_columns_to_dots(scope, resolver)
 73        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 74
 75        if not schema.empty and expand_alias_refs:
 76            _expand_alias_refs(scope, resolver)
 77
 78        if not isinstance(scope.expression, exp.UDTF):
 79            if expand_stars:
 80                _expand_stars(
 81                    scope,
 82                    resolver,
 83                    using_column_tables,
 84                    pseudocolumns,
 85                    annotator,
 86                )
 87            qualify_outputs(scope)
 88
 89        _expand_group_by(scope, dialect)
 90        _expand_order_by(scope, resolver)
 91
 92        if dialect == "bigquery":
 93            annotator.annotate_scope(scope)
 94
 95    return expression
 96
 97
 98def validate_qualify_columns(expression: E) -> E:
 99    """Raise an `OptimizeError` if any columns aren't qualified"""
100    all_unqualified_columns = []
101    for scope in traverse_scope(expression):
102        if isinstance(scope.expression, exp.Select):
103            unqualified_columns = scope.unqualified_columns
104
105            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
106                column = scope.external_columns[0]
107                for_table = f" for table: '{column.table}'" if column.table else ""
108                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
109
110            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
111                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
112                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
113                # this list here to ensure those in the former category will be excluded.
114                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
115                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
116
117            all_unqualified_columns.extend(unqualified_columns)
118
119    if all_unqualified_columns:
120        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
121
122    return expression
123
124
125def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
126    name_column = []
127    field = unpivot.args.get("field")
128    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
129        name_column.append(field.this)
130
131    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
132    return itertools.chain(name_column, value_columns)
133
134
135def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
136    """
137    Remove table column aliases.
138
139    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
140    """
141    for derived_table in derived_tables:
142        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
143            continue
144        table_alias = derived_table.args.get("alias")
145        if table_alias:
146            table_alias.args.pop("columns", None)
147
148
149def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
150    columns = {}
151
152    def _update_source_columns(source_name: str) -> None:
153        for column_name in resolver.get_source_columns(source_name):
154            if column_name not in columns:
155                columns[column_name] = source_name
156
157    joins = list(scope.find_all(exp.Join))
158    names = {join.alias_or_name for join in joins}
159    ordered = [key for key in scope.selected_sources if key not in names]
160
161    # Mapping of automatically joined column names to an ordered set of source names (dict).
162    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
163
164    for source_name in ordered:
165        _update_source_columns(source_name)
166
167    for i, join in enumerate(joins):
168        source_table = ordered[-1]
169        if source_table:
170            _update_source_columns(source_table)
171
172        join_table = join.alias_or_name
173        ordered.append(join_table)
174
175        using = join.args.get("using")
176        if not using:
177            continue
178
179        join_columns = resolver.get_source_columns(join_table)
180        conditions = []
181        using_identifier_count = len(using)
182
183        for identifier in using:
184            identifier = identifier.name
185            table = columns.get(identifier)
186
187            if not table or identifier not in join_columns:
188                if (columns and "*" not in columns) and join_columns:
189                    raise OptimizeError(f"Cannot automatically join: {identifier}")
190
191            table = table or source_table
192
193            if i == 0 or using_identifier_count == 1:
194                lhs: exp.Expression = exp.column(identifier, table=table)
195            else:
196                coalesce_columns = [
197                    exp.column(identifier, table=t)
198                    for t in ordered[:-1]
199                    if identifier in resolver.get_source_columns(t)
200                ]
201                if len(coalesce_columns) > 1:
202                    lhs = exp.func("coalesce", *coalesce_columns)
203                else:
204                    lhs = exp.column(identifier, table=table)
205
206            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
207
208            # Set all values in the dict to None, because we only care about the key ordering
209            tables = column_tables.setdefault(identifier, {})
210            if table not in tables:
211                tables[table] = None
212            if join_table not in tables:
213                tables[join_table] = None
214
215        join.args.pop("using")
216        join.set("on", exp.and_(*conditions, copy=False))
217
218    if column_tables:
219        for column in scope.columns:
220            if not column.table and column.name in column_tables:
221                tables = column_tables[column.name]
222                coalesce_args = [exp.column(column.name, table=table) for table in tables]
223                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
224
225                if isinstance(column.parent, exp.Select):
226                    # Ensure the USING column keeps its name if it's projected
227                    replacement = alias(replacement, alias=column.name, copy=False)
228                elif isinstance(column.parent, exp.Struct):
229                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
230                    replacement = exp.PropertyEQ(
231                        this=exp.to_identifier(column.name), expression=replacement
232                    )
233
234                scope.replace(column, replacement)
235
236    return column_tables
237
238
239def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
240    expression = scope.expression
241
242    if not isinstance(expression, exp.Select):
243        return
244
245    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
246
247    def replace_columns(
248        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
249    ) -> None:
250        is_group_by = isinstance(node, exp.Group)
251        if not node or (expand_only_groupby and not is_group_by):
252            return
253
254        for column in walk_in_scope(node, prune=lambda node: node.is_star):
255            if not isinstance(column, exp.Column):
256                continue
257
258            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
259            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
260            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
261            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
262            if expand_only_groupby and is_group_by and column.parent is not node:
263                continue
264
265            table = resolver.get_table(column.name) if resolve_table and not column.table else None
266            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
267            double_agg = (
268                (
269                    alias_expr.find(exp.AggFunc)
270                    and (
271                        column.find_ancestor(exp.AggFunc)
272                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
273                    )
274                )
275                if alias_expr
276                else False
277            )
278
279            if table and (not alias_expr or double_agg):
280                column.set("table", table)
281            elif not column.table and alias_expr and not double_agg:
282                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
283                    if literal_index:
284                        column.replace(exp.Literal.number(i))
285                else:
286                    column = column.replace(exp.paren(alias_expr))
287                    simplified = simplify_parens(column)
288                    if simplified is not column:
289                        column.replace(simplified)
290
291    for i, projection in enumerate(expression.selects):
292        replace_columns(projection)
293        if isinstance(projection, exp.Alias):
294            alias_to_expression[projection.alias] = (projection.this, i + 1)
295
296    parent_scope = scope
297    while parent_scope.is_union:
298        parent_scope = parent_scope.parent
299
300    # We shouldn't expand aliases if they match the recursive CTE's columns
301    if parent_scope.is_cte:
302        cte = parent_scope.expression.parent
303        if cte.find_ancestor(exp.With).recursive:
304            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
305                alias_to_expression.pop(recursive_cte_column.output_name, None)
306
307    replace_columns(expression.args.get("where"))
308    replace_columns(expression.args.get("group"), literal_index=True)
309    replace_columns(expression.args.get("having"), resolve_table=True)
310    replace_columns(expression.args.get("qualify"), resolve_table=True)
311
312    scope.clear_cache()
313
314
315def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
316    expression = scope.expression
317    group = expression.args.get("group")
318    if not group:
319        return
320
321    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
322    expression.set("group", group)
323
324
325def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
326    order = scope.expression.args.get("order")
327    if not order:
328        return
329
330    ordereds = order.expressions
331    for ordered, new_expression in zip(
332        ordereds,
333        _expand_positional_references(
334            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
335        ),
336    ):
337        for agg in ordered.find_all(exp.AggFunc):
338            for col in agg.find_all(exp.Column):
339                if not col.table:
340                    col.set("table", resolver.get_table(col.name))
341
342        ordered.set("this", new_expression)
343
344    if scope.expression.args.get("group"):
345        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
346
347        for ordered in ordereds:
348            ordered = ordered.this
349
350            ordered.replace(
351                exp.to_identifier(_select_by_pos(scope, ordered).alias)
352                if ordered.is_int
353                else selects.get(ordered, ordered)
354            )
355
356
357def _expand_positional_references(
358    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
359) -> t.List[exp.Expression]:
360    new_nodes: t.List[exp.Expression] = []
361    ambiguous_projections = None
362
363    for node in expressions:
364        if node.is_int:
365            select = _select_by_pos(scope, t.cast(exp.Literal, node))
366
367            if alias:
368                new_nodes.append(exp.column(select.args["alias"].copy()))
369            else:
370                select = select.this
371
372                if dialect == "bigquery":
373                    if ambiguous_projections is None:
374                        # When a projection name is also a source name and it is referenced in the
375                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
376                        ambiguous_projections = {
377                            s.alias_or_name
378                            for s in scope.expression.selects
379                            if s.alias_or_name in scope.selected_sources
380                        }
381
382                    ambiguous = any(
383                        column.parts[0].name in ambiguous_projections
384                        for column in select.find_all(exp.Column)
385                    )
386                else:
387                    ambiguous = False
388
389                if (
390                    isinstance(select, exp.CONSTANTS)
391                    or select.find(exp.Explode, exp.Unnest)
392                    or ambiguous
393                ):
394                    new_nodes.append(node)
395                else:
396                    new_nodes.append(select.copy())
397        else:
398            new_nodes.append(node)
399
400    return new_nodes
401
402
403def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
404    try:
405        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
406    except IndexError:
407        raise OptimizeError(f"Unknown output column: {node.name}")
408
409
410def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
411    """
412    Converts `Column` instances that represent struct field lookup into chained `Dots`.
413
414    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
415    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
416    """
417    converted = False
418    for column in itertools.chain(scope.columns, scope.stars):
419        if isinstance(column, exp.Dot):
420            continue
421
422        column_table: t.Optional[str | exp.Identifier] = column.table
423        if (
424            column_table
425            and column_table not in scope.sources
426            and (
427                not scope.parent
428                or column_table not in scope.parent.sources
429                or not scope.is_correlated_subquery
430            )
431        ):
432            root, *parts = column.parts
433
434            if root.name in scope.sources:
435                # The struct is already qualified, but we still need to change the AST
436                column_table = root
437                root, *parts = parts
438            else:
439                column_table = resolver.get_table(root.name)
440
441            if column_table:
442                converted = True
443                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
444
445    if converted:
446        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
447        # a `for column in scope.columns` iteration, even though they shouldn't be
448        scope.clear_cache()
449
450
451def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
452    """Disambiguate columns, ensuring each column specifies a source"""
453    for column in scope.columns:
454        column_table = column.table
455        column_name = column.name
456
457        if column_table and column_table in scope.sources:
458            source_columns = resolver.get_source_columns(column_table)
459            if (
460                not allow_partial_qualification
461                and source_columns
462                and column_name not in source_columns
463                and "*" not in source_columns
464            ):
465                raise OptimizeError(f"Unknown column: {column_name}")
466
467        if not column_table:
468            if scope.pivots and not column.find_ancestor(exp.Pivot):
469                # If the column is under the Pivot expression, we need to qualify it
470                # using the name of the pivoted source instead of the pivot's alias
471                column.set("table", exp.to_identifier(scope.pivots[0].alias))
472                continue
473
474            # column_table can be a '' because bigquery unnest has no table alias
475            column_table = resolver.get_table(column_name)
476            if column_table:
477                column.set("table", column_table)
478
479    for pivot in scope.pivots:
480        for column in pivot.find_all(exp.Column):
481            if not column.table and column.name in resolver.all_columns:
482                column_table = resolver.get_table(column.name)
483                if column_table:
484                    column.set("table", column_table)
485
486
487def _expand_struct_stars(
488    expression: exp.Dot,
489) -> t.List[exp.Alias]:
490    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
491
492    dot_column = t.cast(exp.Column, expression.find(exp.Column))
493    if not dot_column.is_type(exp.DataType.Type.STRUCT):
494        return []
495
496    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
497    dot_column = dot_column.copy()
498    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
499
500    # First part is the table name and last part is the star so they can be dropped
501    dot_parts = expression.parts[1:-1]
502
503    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
504    for part in dot_parts[1:]:
505        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
506            # Unable to expand star unless all fields are named
507            if not isinstance(field.this, exp.Identifier):
508                return []
509
510            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
511                starting_struct = field
512                break
513        else:
514            # There is no matching field in the struct
515            return []
516
517    taken_names = set()
518    new_selections = []
519
520    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
521        name = field.name
522
523        # Ambiguous or anonymous fields can't be expanded
524        if name in taken_names or not isinstance(field.this, exp.Identifier):
525            return []
526
527        taken_names.add(name)
528
529        this = field.this.copy()
530        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
531        new_column = exp.column(
532            t.cast(exp.Identifier, root),
533            table=dot_column.args.get("table"),
534            fields=t.cast(t.List[exp.Identifier], parts),
535        )
536        new_selections.append(alias(new_column, this, copy=False))
537
538    return new_selections
539
540
541def _expand_stars(
542    scope: Scope,
543    resolver: Resolver,
544    using_column_tables: t.Dict[str, t.Any],
545    pseudocolumns: t.Set[str],
546    annotator: TypeAnnotator,
547) -> None:
548    """Expand stars to lists of column selections"""
549
550    new_selections: t.List[exp.Expression] = []
551    except_columns: t.Dict[int, t.Set[str]] = {}
552    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
553    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
554
555    coalesced_columns = set()
556    dialect = resolver.schema.dialect
557
558    pivot_output_columns = None
559    pivot_exclude_columns = None
560
561    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
562    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
563        if pivot.unpivot:
564            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
565
566            field = pivot.args.get("field")
567            if isinstance(field, exp.In):
568                pivot_exclude_columns = {
569                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
570                }
571        else:
572            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
573
574            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
575            if not pivot_output_columns:
576                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
577
578    is_bigquery = dialect == "bigquery"
579    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
580        # Found struct expansion, annotate scope ahead of time
581        annotator.annotate_scope(scope)
582
583    for expression in scope.expression.selects:
584        tables = []
585        if isinstance(expression, exp.Star):
586            tables.extend(scope.selected_sources)
587            _add_except_columns(expression, tables, except_columns)
588            _add_replace_columns(expression, tables, replace_columns)
589            _add_rename_columns(expression, tables, rename_columns)
590        elif expression.is_star:
591            if not isinstance(expression, exp.Dot):
592                tables.append(expression.table)
593                _add_except_columns(expression.this, tables, except_columns)
594                _add_replace_columns(expression.this, tables, replace_columns)
595                _add_rename_columns(expression.this, tables, rename_columns)
596            elif is_bigquery:
597                struct_fields = _expand_struct_stars(expression)
598                if struct_fields:
599                    new_selections.extend(struct_fields)
600                    continue
601
602        if not tables:
603            new_selections.append(expression)
604            continue
605
606        for table in tables:
607            if table not in scope.sources:
608                raise OptimizeError(f"Unknown table: {table}")
609
610            columns = resolver.get_source_columns(table, only_visible=True)
611            columns = columns or scope.outer_columns
612
613            if pseudocolumns:
614                columns = [name for name in columns if name.upper() not in pseudocolumns]
615
616            if not columns or "*" in columns:
617                return
618
619            table_id = id(table)
620            columns_to_exclude = except_columns.get(table_id) or set()
621            renamed_columns = rename_columns.get(table_id, {})
622            replaced_columns = replace_columns.get(table_id, {})
623
624            if pivot:
625                if pivot_output_columns and pivot_exclude_columns:
626                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
627                    pivot_columns.extend(pivot_output_columns)
628                else:
629                    pivot_columns = pivot.alias_column_names
630
631                if pivot_columns:
632                    new_selections.extend(
633                        alias(exp.column(name, table=pivot.alias), name, copy=False)
634                        for name in pivot_columns
635                        if name not in columns_to_exclude
636                    )
637                    continue
638
639            for name in columns:
640                if name in columns_to_exclude or name in coalesced_columns:
641                    continue
642                if name in using_column_tables and table in using_column_tables[name]:
643                    coalesced_columns.add(name)
644                    tables = using_column_tables[name]
645                    coalesce_args = [exp.column(name, table=table) for table in tables]
646
647                    new_selections.append(
648                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
649                    )
650                else:
651                    alias_ = renamed_columns.get(name, name)
652                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
653                    new_selections.append(
654                        alias(selection_expr, alias_, copy=False)
655                        if alias_ != name
656                        else selection_expr
657                    )
658
659    # Ensures we don't overwrite the initial selections with an empty list
660    if new_selections and isinstance(scope.expression, exp.Select):
661        scope.expression.set("expressions", new_selections)
662
663
664def _add_except_columns(
665    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
666) -> None:
667    except_ = expression.args.get("except")
668
669    if not except_:
670        return
671
672    columns = {e.name for e in except_}
673
674    for table in tables:
675        except_columns[id(table)] = columns
676
677
678def _add_rename_columns(
679    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
680) -> None:
681    rename = expression.args.get("rename")
682
683    if not rename:
684        return
685
686    columns = {e.this.name: e.alias for e in rename}
687
688    for table in tables:
689        rename_columns[id(table)] = columns
690
691
692def _add_replace_columns(
693    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
694) -> None:
695    replace = expression.args.get("replace")
696
697    if not replace:
698        return
699
700    columns = {e.alias: e for e in replace}
701
702    for table in tables:
703        replace_columns[id(table)] = columns
704
705
706def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
707    """Ensure all output columns are aliased"""
708    if isinstance(scope_or_expression, exp.Expression):
709        scope = build_scope(scope_or_expression)
710        if not isinstance(scope, Scope):
711            return
712    else:
713        scope = scope_or_expression
714
715    new_selections = []
716    for i, (selection, aliased_column) in enumerate(
717        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
718    ):
719        if selection is None:
720            break
721
722        if isinstance(selection, exp.Subquery):
723            if not selection.output_name:
724                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
725        elif not isinstance(selection, exp.Alias) and not selection.is_star:
726            selection = alias(
727                selection,
728                alias=selection.output_name or f"_col_{i}",
729                copy=False,
730            )
731        if aliased_column:
732            selection.set("alias", exp.to_identifier(aliased_column))
733
734        new_selections.append(selection)
735
736    if isinstance(scope.expression, exp.Select):
737        scope.expression.set("expressions", new_selections)
738
739
740def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
741    """Makes sure all identifiers that need to be quoted are quoted."""
742    return expression.transform(
743        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
744    )  # type: ignore
745
746
747def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
748    """
749    Pushes down the CTE alias columns into the projection,
750
751    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
752
753    Example:
754        >>> import sqlglot
755        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
756        >>> pushdown_cte_alias_columns(expression).sql()
757        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
758
759    Args:
760        expression: Expression to pushdown.
761
762    Returns:
763        The expression with the CTE aliases pushed down into the projection.
764    """
765    for cte in expression.find_all(exp.CTE):
766        if cte.alias_column_names:
767            new_expressions = []
768            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
769                if isinstance(projection, exp.Alias):
770                    projection.set("alias", _alias)
771                else:
772                    projection = alias(projection, alias=_alias)
773                new_expressions.append(projection)
774            cte.this.set("expressions", new_expressions)
775
776    return expression
777
778
779class Resolver:
780    """
781    Helper for resolving columns.
782
783    This is a class so we can lazily load some things and easily share them across functions.
784    """
785
786    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
787        self.scope = scope
788        self.schema = schema
789        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
790        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
791        self._all_columns: t.Optional[t.Set[str]] = None
792        self._infer_schema = infer_schema
793        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
794
795    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
796        """
797        Get the table for a column name.
798
799        Args:
800            column_name: The column name to find the table for.
801        Returns:
802            The table name if it can be found/inferred.
803        """
804        if self._unambiguous_columns is None:
805            self._unambiguous_columns = self._get_unambiguous_columns(
806                self._get_all_source_columns()
807            )
808
809        table_name = self._unambiguous_columns.get(column_name)
810
811        if not table_name and self._infer_schema:
812            sources_without_schema = tuple(
813                source
814                for source, columns in self._get_all_source_columns().items()
815                if not columns or "*" in columns
816            )
817            if len(sources_without_schema) == 1:
818                table_name = sources_without_schema[0]
819
820        if table_name not in self.scope.selected_sources:
821            return exp.to_identifier(table_name)
822
823        node, _ = self.scope.selected_sources.get(table_name)
824
825        if isinstance(node, exp.Query):
826            while node and node.alias != table_name:
827                node = node.parent
828
829        node_alias = node.args.get("alias")
830        if node_alias:
831            return exp.to_identifier(node_alias.this)
832
833        return exp.to_identifier(table_name)
834
835    @property
836    def all_columns(self) -> t.Set[str]:
837        """All available columns of all sources in this scope"""
838        if self._all_columns is None:
839            self._all_columns = {
840                column for columns in self._get_all_source_columns().values() for column in columns
841            }
842        return self._all_columns
843
844    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
845        """Resolve the source columns for a given source `name`."""
846        cache_key = (name, only_visible)
847        if cache_key not in self._get_source_columns_cache:
848            if name not in self.scope.sources:
849                raise OptimizeError(f"Unknown table: {name}")
850
851            source = self.scope.sources[name]
852
853            if isinstance(source, exp.Table):
854                columns = self.schema.column_names(source, only_visible)
855            elif isinstance(source, Scope) and isinstance(
856                source.expression, (exp.Values, exp.Unnest)
857            ):
858                columns = source.expression.named_selects
859
860                # in bigquery, unnest structs are automatically scoped as tables, so you can
861                # directly select a struct field in a query.
862                # this handles the case where the unnest is statically defined.
863                if self.schema.dialect == "bigquery":
864                    if source.expression.is_type(exp.DataType.Type.STRUCT):
865                        for k in source.expression.type.expressions:  # type: ignore
866                            columns.append(k.name)
867            else:
868                columns = source.expression.named_selects
869
870            node, _ = self.scope.selected_sources.get(name) or (None, None)
871            if isinstance(node, Scope):
872                column_aliases = node.expression.alias_column_names
873            elif isinstance(node, exp.Expression):
874                column_aliases = node.alias_column_names
875            else:
876                column_aliases = []
877
878            if column_aliases:
879                # If the source's columns are aliased, their aliases shadow the corresponding column names.
880                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
881                columns = [
882                    alias or name
883                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
884                ]
885
886            self._get_source_columns_cache[cache_key] = columns
887
888        return self._get_source_columns_cache[cache_key]
889
890    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
891        if self._source_columns is None:
892            self._source_columns = {
893                source_name: self.get_source_columns(source_name)
894                for source_name, source in itertools.chain(
895                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
896                )
897            }
898        return self._source_columns
899
900    def _get_unambiguous_columns(
901        self, source_columns: t.Dict[str, t.Sequence[str]]
902    ) -> t.Mapping[str, str]:
903        """
904        Find all the unambiguous columns in sources.
905
906        Args:
907            source_columns: Mapping of names to source columns.
908
909        Returns:
910            Mapping of column name to source name.
911        """
912        if not source_columns:
913            return {}
914
915        source_columns_pairs = list(source_columns.items())
916
917        first_table, first_columns = source_columns_pairs[0]
918
919        if len(source_columns_pairs) == 1:
920            # Performance optimization - avoid copying first_columns if there is only one table.
921            return SingleValuedMapping(first_columns, first_table)
922
923        unambiguous_columns = {col: first_table for col in first_columns}
924        all_columns = set(unambiguous_columns)
925
926        for table, columns in source_columns_pairs[1:]:
927            unique = set(columns)
928            ambiguous = all_columns.intersection(unique)
929            all_columns.update(columns)
930
931            for column in ambiguous:
932                unambiguous_columns.pop(column, None)
933            for column in unique.difference(ambiguous):
934                unambiguous_columns[column] = table
935
936        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26    allow_partial_qualification: bool = False,
27) -> exp.Expression:
28    """
29    Rewrite sqlglot AST to have fully qualified columns.
30
31    Example:
32        >>> import sqlglot
33        >>> schema = {"tbl": {"col": "INT"}}
34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
35        >>> qualify_columns(expression, schema).sql()
36        'SELECT tbl.col AS col FROM tbl'
37
38    Args:
39        expression: Expression to qualify.
40        schema: Database schema.
41        expand_alias_refs: Whether to expand references to aliases.
42        expand_stars: Whether to expand star queries. This is a necessary step
43            for most of the optimizer's rules to work; do not set to False unless you
44            know what you're doing!
45        infer_schema: Whether to infer the schema if missing.
46        allow_partial_qualification: Whether to allow partial qualification.
47
48    Returns:
49        The qualified expression.
50
51    Notes:
52        - Currently only handles a single PIVOT or UNPIVOT operator
53    """
54    schema = ensure_schema(schema)
55    annotator = TypeAnnotator(schema)
56    infer_schema = schema.empty if infer_schema is None else infer_schema
57    dialect = Dialect.get_or_raise(schema.dialect)
58    pseudocolumns = dialect.PSEUDOCOLUMNS
59
60    for scope in traverse_scope(expression):
61        resolver = Resolver(scope, schema, infer_schema=infer_schema)
62        _pop_table_column_aliases(scope.ctes)
63        _pop_table_column_aliases(scope.derived_tables)
64        using_column_tables = _expand_using(scope, resolver)
65
66        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
67            _expand_alias_refs(
68                scope,
69                resolver,
70                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
71            )
72
73        _convert_columns_to_dots(scope, resolver)
74        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
75
76        if not schema.empty and expand_alias_refs:
77            _expand_alias_refs(scope, resolver)
78
79        if not isinstance(scope.expression, exp.UDTF):
80            if expand_stars:
81                _expand_stars(
82                    scope,
83                    resolver,
84                    using_column_tables,
85                    pseudocolumns,
86                    annotator,
87                )
88            qualify_outputs(scope)
89
90        _expand_group_by(scope, dialect)
91        _expand_order_by(scope, resolver)
92
93        if dialect == "bigquery":
94            annotator.annotate_scope(scope)
95
96    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 99def validate_qualify_columns(expression: E) -> E:
100    """Raise an `OptimizeError` if any columns aren't qualified"""
101    all_unqualified_columns = []
102    for scope in traverse_scope(expression):
103        if isinstance(scope.expression, exp.Select):
104            unqualified_columns = scope.unqualified_columns
105
106            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
107                column = scope.external_columns[0]
108                for_table = f" for table: '{column.table}'" if column.table else ""
109                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
110
111            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
112                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
113                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
114                # this list here to ensure those in the former category will be excluded.
115                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
116                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
117
118            all_unqualified_columns.extend(unqualified_columns)
119
120    if all_unqualified_columns:
121        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
122
123    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
707def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
708    """Ensure all output columns are aliased"""
709    if isinstance(scope_or_expression, exp.Expression):
710        scope = build_scope(scope_or_expression)
711        if not isinstance(scope, Scope):
712            return
713    else:
714        scope = scope_or_expression
715
716    new_selections = []
717    for i, (selection, aliased_column) in enumerate(
718        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
719    ):
720        if selection is None:
721            break
722
723        if isinstance(selection, exp.Subquery):
724            if not selection.output_name:
725                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
726        elif not isinstance(selection, exp.Alias) and not selection.is_star:
727            selection = alias(
728                selection,
729                alias=selection.output_name or f"_col_{i}",
730                copy=False,
731            )
732        if aliased_column:
733            selection.set("alias", exp.to_identifier(aliased_column))
734
735        new_selections.append(selection)
736
737    if isinstance(scope.expression, exp.Select):
738        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
741def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
742    """Makes sure all identifiers that need to be quoted are quoted."""
743    return expression.transform(
744        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
745    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
748def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
749    """
750    Pushes down the CTE alias columns into the projection,
751
752    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
753
754    Example:
755        >>> import sqlglot
756        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
757        >>> pushdown_cte_alias_columns(expression).sql()
758        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
759
760    Args:
761        expression: Expression to pushdown.
762
763    Returns:
764        The expression with the CTE aliases pushed down into the projection.
765    """
766    for cte in expression.find_all(exp.CTE):
767        if cte.alias_column_names:
768            new_expressions = []
769            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
770                if isinstance(projection, exp.Alias):
771                    projection.set("alias", _alias)
772                else:
773                    projection = alias(projection, alias=_alias)
774                new_expressions.append(projection)
775            cte.this.set("expressions", new_expressions)
776
777    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
780class Resolver:
781    """
782    Helper for resolving columns.
783
784    This is a class so we can lazily load some things and easily share them across functions.
785    """
786
787    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
788        self.scope = scope
789        self.schema = schema
790        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
791        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
792        self._all_columns: t.Optional[t.Set[str]] = None
793        self._infer_schema = infer_schema
794        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
795
796    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
797        """
798        Get the table for a column name.
799
800        Args:
801            column_name: The column name to find the table for.
802        Returns:
803            The table name if it can be found/inferred.
804        """
805        if self._unambiguous_columns is None:
806            self._unambiguous_columns = self._get_unambiguous_columns(
807                self._get_all_source_columns()
808            )
809
810        table_name = self._unambiguous_columns.get(column_name)
811
812        if not table_name and self._infer_schema:
813            sources_without_schema = tuple(
814                source
815                for source, columns in self._get_all_source_columns().items()
816                if not columns or "*" in columns
817            )
818            if len(sources_without_schema) == 1:
819                table_name = sources_without_schema[0]
820
821        if table_name not in self.scope.selected_sources:
822            return exp.to_identifier(table_name)
823
824        node, _ = self.scope.selected_sources.get(table_name)
825
826        if isinstance(node, exp.Query):
827            while node and node.alias != table_name:
828                node = node.parent
829
830        node_alias = node.args.get("alias")
831        if node_alias:
832            return exp.to_identifier(node_alias.this)
833
834        return exp.to_identifier(table_name)
835
836    @property
837    def all_columns(self) -> t.Set[str]:
838        """All available columns of all sources in this scope"""
839        if self._all_columns is None:
840            self._all_columns = {
841                column for columns in self._get_all_source_columns().values() for column in columns
842            }
843        return self._all_columns
844
845    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
846        """Resolve the source columns for a given source `name`."""
847        cache_key = (name, only_visible)
848        if cache_key not in self._get_source_columns_cache:
849            if name not in self.scope.sources:
850                raise OptimizeError(f"Unknown table: {name}")
851
852            source = self.scope.sources[name]
853
854            if isinstance(source, exp.Table):
855                columns = self.schema.column_names(source, only_visible)
856            elif isinstance(source, Scope) and isinstance(
857                source.expression, (exp.Values, exp.Unnest)
858            ):
859                columns = source.expression.named_selects
860
861                # in bigquery, unnest structs are automatically scoped as tables, so you can
862                # directly select a struct field in a query.
863                # this handles the case where the unnest is statically defined.
864                if self.schema.dialect == "bigquery":
865                    if source.expression.is_type(exp.DataType.Type.STRUCT):
866                        for k in source.expression.type.expressions:  # type: ignore
867                            columns.append(k.name)
868            else:
869                columns = source.expression.named_selects
870
871            node, _ = self.scope.selected_sources.get(name) or (None, None)
872            if isinstance(node, Scope):
873                column_aliases = node.expression.alias_column_names
874            elif isinstance(node, exp.Expression):
875                column_aliases = node.alias_column_names
876            else:
877                column_aliases = []
878
879            if column_aliases:
880                # If the source's columns are aliased, their aliases shadow the corresponding column names.
881                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
882                columns = [
883                    alias or name
884                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
885                ]
886
887            self._get_source_columns_cache[cache_key] = columns
888
889        return self._get_source_columns_cache[cache_key]
890
891    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
892        if self._source_columns is None:
893            self._source_columns = {
894                source_name: self.get_source_columns(source_name)
895                for source_name, source in itertools.chain(
896                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
897                )
898            }
899        return self._source_columns
900
901    def _get_unambiguous_columns(
902        self, source_columns: t.Dict[str, t.Sequence[str]]
903    ) -> t.Mapping[str, str]:
904        """
905        Find all the unambiguous columns in sources.
906
907        Args:
908            source_columns: Mapping of names to source columns.
909
910        Returns:
911            Mapping of column name to source name.
912        """
913        if not source_columns:
914            return {}
915
916        source_columns_pairs = list(source_columns.items())
917
918        first_table, first_columns = source_columns_pairs[0]
919
920        if len(source_columns_pairs) == 1:
921            # Performance optimization - avoid copying first_columns if there is only one table.
922            return SingleValuedMapping(first_columns, first_table)
923
924        unambiguous_columns = {col: first_table for col in first_columns}
925        all_columns = set(unambiguous_columns)
926
927        for table, columns in source_columns_pairs[1:]:
928            unique = set(columns)
929            ambiguous = all_columns.intersection(unique)
930            all_columns.update(columns)
931
932            for column in ambiguous:
933                unambiguous_columns.pop(column, None)
934            for column in unique.difference(ambiguous):
935                unambiguous_columns[column] = table
936
937        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
787    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
788        self.scope = scope
789        self.schema = schema
790        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
791        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
792        self._all_columns: t.Optional[t.Set[str]] = None
793        self._infer_schema = infer_schema
794        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
796    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
797        """
798        Get the table for a column name.
799
800        Args:
801            column_name: The column name to find the table for.
802        Returns:
803            The table name if it can be found/inferred.
804        """
805        if self._unambiguous_columns is None:
806            self._unambiguous_columns = self._get_unambiguous_columns(
807                self._get_all_source_columns()
808            )
809
810        table_name = self._unambiguous_columns.get(column_name)
811
812        if not table_name and self._infer_schema:
813            sources_without_schema = tuple(
814                source
815                for source, columns in self._get_all_source_columns().items()
816                if not columns or "*" in columns
817            )
818            if len(sources_without_schema) == 1:
819                table_name = sources_without_schema[0]
820
821        if table_name not in self.scope.selected_sources:
822            return exp.to_identifier(table_name)
823
824        node, _ = self.scope.selected_sources.get(table_name)
825
826        if isinstance(node, exp.Query):
827            while node and node.alias != table_name:
828                node = node.parent
829
830        node_alias = node.args.get("alias")
831        if node_alias:
832            return exp.to_identifier(node_alias.this)
833
834        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
836    @property
837    def all_columns(self) -> t.Set[str]:
838        """All available columns of all sources in this scope"""
839        if self._all_columns is None:
840            self._all_columns = {
841                column for columns in self._get_all_source_columns().values() for column in columns
842            }
843        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
845    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
846        """Resolve the source columns for a given source `name`."""
847        cache_key = (name, only_visible)
848        if cache_key not in self._get_source_columns_cache:
849            if name not in self.scope.sources:
850                raise OptimizeError(f"Unknown table: {name}")
851
852            source = self.scope.sources[name]
853
854            if isinstance(source, exp.Table):
855                columns = self.schema.column_names(source, only_visible)
856            elif isinstance(source, Scope) and isinstance(
857                source.expression, (exp.Values, exp.Unnest)
858            ):
859                columns = source.expression.named_selects
860
861                # in bigquery, unnest structs are automatically scoped as tables, so you can
862                # directly select a struct field in a query.
863                # this handles the case where the unnest is statically defined.
864                if self.schema.dialect == "bigquery":
865                    if source.expression.is_type(exp.DataType.Type.STRUCT):
866                        for k in source.expression.type.expressions:  # type: ignore
867                            columns.append(k.name)
868            else:
869                columns = source.expression.named_selects
870
871            node, _ = self.scope.selected_sources.get(name) or (None, None)
872            if isinstance(node, Scope):
873                column_aliases = node.expression.alias_column_names
874            elif isinstance(node, exp.Expression):
875                column_aliases = node.alias_column_names
876            else:
877                column_aliases = []
878
879            if column_aliases:
880                # If the source's columns are aliased, their aliases shadow the corresponding column names.
881                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
882                columns = [
883                    alias or name
884                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
885                ]
886
887            self._get_source_columns_cache[cache_key] = columns
888
889        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.