Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25    allow_partial_qualification: bool = False,
 26) -> exp.Expression:
 27    """
 28    Rewrite sqlglot AST to have fully qualified columns.
 29
 30    Example:
 31        >>> import sqlglot
 32        >>> schema = {"tbl": {"col": "INT"}}
 33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 34        >>> qualify_columns(expression, schema).sql()
 35        'SELECT tbl.col AS col FROM tbl'
 36
 37    Args:
 38        expression: Expression to qualify.
 39        schema: Database schema.
 40        expand_alias_refs: Whether to expand references to aliases.
 41        expand_stars: Whether to expand star queries. This is a necessary step
 42            for most of the optimizer's rules to work; do not set to False unless you
 43            know what you're doing!
 44        infer_schema: Whether to infer the schema if missing.
 45        allow_partial_qualification: Whether to allow partial qualification.
 46
 47    Returns:
 48        The qualified expression.
 49
 50    Notes:
 51        - Currently only handles a single PIVOT or UNPIVOT operator
 52    """
 53    schema = ensure_schema(schema)
 54    annotator = TypeAnnotator(schema)
 55    infer_schema = schema.empty if infer_schema is None else infer_schema
 56    dialect = Dialect.get_or_raise(schema.dialect)
 57    pseudocolumns = dialect.PSEUDOCOLUMNS
 58    bigquery = dialect == "bigquery"
 59
 60    for scope in traverse_scope(expression):
 61        scope_expression = scope.expression
 62        is_select = isinstance(scope_expression, exp.Select)
 63
 64        if is_select and scope_expression.args.get("connect"):
 65            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 66            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 67            scope_expression.transform(
 68                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 69                copy=False,
 70            )
 71            scope.clear_cache()
 72
 73        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 74        _pop_table_column_aliases(scope.ctes)
 75        _pop_table_column_aliases(scope.derived_tables)
 76        using_column_tables = _expand_using(scope, resolver)
 77
 78        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 79            _expand_alias_refs(
 80                scope,
 81                resolver,
 82                dialect,
 83                expand_only_groupby=bigquery,
 84            )
 85
 86        _convert_columns_to_dots(scope, resolver)
 87        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 88
 89        if not schema.empty and expand_alias_refs:
 90            _expand_alias_refs(scope, resolver, dialect)
 91
 92        if is_select:
 93            if expand_stars:
 94                _expand_stars(
 95                    scope,
 96                    resolver,
 97                    using_column_tables,
 98                    pseudocolumns,
 99                    annotator,
100                )
101            qualify_outputs(scope)
102
103        _expand_group_by(scope, dialect)
104
105        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
106        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
107        _expand_order_by_and_distinct_on(scope, resolver)
108
109        if bigquery:
110            annotator.annotate_scope(scope)
111
112    return expression
113
114
115def validate_qualify_columns(expression: E) -> E:
116    """Raise an `OptimizeError` if any columns aren't qualified"""
117    all_unqualified_columns = []
118    for scope in traverse_scope(expression):
119        if isinstance(scope.expression, exp.Select):
120            unqualified_columns = scope.unqualified_columns
121
122            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
123                column = scope.external_columns[0]
124                for_table = f" for table: '{column.table}'" if column.table else ""
125                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
126
127            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
128                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
129                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
130                # this list here to ensure those in the former category will be excluded.
131                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
132                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
133
134            all_unqualified_columns.extend(unqualified_columns)
135
136    if all_unqualified_columns:
137        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
138
139    return expression
140
141
142def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
143    name_column = []
144    field = unpivot.args.get("field")
145    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
146        name_column.append(field.this)
147
148    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
149    return itertools.chain(name_column, value_columns)
150
151
152def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
153    """
154    Remove table column aliases.
155
156    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
157    """
158    for derived_table in derived_tables:
159        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
160            continue
161        table_alias = derived_table.args.get("alias")
162        if table_alias:
163            table_alias.args.pop("columns", None)
164
165
166def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
167    columns = {}
168
169    def _update_source_columns(source_name: str) -> None:
170        for column_name in resolver.get_source_columns(source_name):
171            if column_name not in columns:
172                columns[column_name] = source_name
173
174    joins = list(scope.find_all(exp.Join))
175    names = {join.alias_or_name for join in joins}
176    ordered = [key for key in scope.selected_sources if key not in names]
177
178    if names and not ordered:
179        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
180
181    # Mapping of automatically joined column names to an ordered set of source names (dict).
182    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
183
184    for source_name in ordered:
185        _update_source_columns(source_name)
186
187    for i, join in enumerate(joins):
188        source_table = ordered[-1]
189        if source_table:
190            _update_source_columns(source_table)
191
192        join_table = join.alias_or_name
193        ordered.append(join_table)
194
195        using = join.args.get("using")
196        if not using:
197            continue
198
199        join_columns = resolver.get_source_columns(join_table)
200        conditions = []
201        using_identifier_count = len(using)
202        is_semi_or_anti_join = join.is_semi_or_anti_join
203
204        for identifier in using:
205            identifier = identifier.name
206            table = columns.get(identifier)
207
208            if not table or identifier not in join_columns:
209                if (columns and "*" not in columns) and join_columns:
210                    raise OptimizeError(f"Cannot automatically join: {identifier}")
211
212            table = table or source_table
213
214            if i == 0 or using_identifier_count == 1:
215                lhs: exp.Expression = exp.column(identifier, table=table)
216            else:
217                coalesce_columns = [
218                    exp.column(identifier, table=t)
219                    for t in ordered[:-1]
220                    if identifier in resolver.get_source_columns(t)
221                ]
222                if len(coalesce_columns) > 1:
223                    lhs = exp.func("coalesce", *coalesce_columns)
224                else:
225                    lhs = exp.column(identifier, table=table)
226
227            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
228
229            # Set all values in the dict to None, because we only care about the key ordering
230            tables = column_tables.setdefault(identifier, {})
231
232            # Do not update the dict if this was a SEMI/ANTI join in
233            # order to avoid generating COALESCE columns for this join pair
234            if not is_semi_or_anti_join:
235                if table not in tables:
236                    tables[table] = None
237                if join_table not in tables:
238                    tables[join_table] = None
239
240        join.args.pop("using")
241        join.set("on", exp.and_(*conditions, copy=False))
242
243    if column_tables:
244        for column in scope.columns:
245            if not column.table and column.name in column_tables:
246                tables = column_tables[column.name]
247                coalesce_args = [exp.column(column.name, table=table) for table in tables]
248                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
249
250                if isinstance(column.parent, exp.Select):
251                    # Ensure the USING column keeps its name if it's projected
252                    replacement = alias(replacement, alias=column.name, copy=False)
253                elif isinstance(column.parent, exp.Struct):
254                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
255                    replacement = exp.PropertyEQ(
256                        this=exp.to_identifier(column.name), expression=replacement
257                    )
258
259                scope.replace(column, replacement)
260
261    return column_tables
262
263
264def _expand_alias_refs(
265    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
266) -> None:
267    """
268    Expand references to aliases.
269    Example:
270        SELECT y.foo AS bar, bar * 2 AS baz FROM y
271     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
272    """
273    expression = scope.expression
274
275    if not isinstance(expression, exp.Select):
276        return
277
278    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
279    projections = {s.alias_or_name for s in expression.selects}
280
281    def replace_columns(
282        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
283    ) -> None:
284        is_group_by = isinstance(node, exp.Group)
285        is_having = isinstance(node, exp.Having)
286        if not node or (expand_only_groupby and not is_group_by):
287            return
288
289        for column in walk_in_scope(node, prune=lambda node: node.is_star):
290            if not isinstance(column, exp.Column):
291                continue
292
293            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
294            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
295            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
296            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
297            if expand_only_groupby and is_group_by and column.parent is not node:
298                continue
299
300            skip_replace = False
301            table = resolver.get_table(column.name) if resolve_table and not column.table else None
302            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
303
304            if alias_expr:
305                skip_replace = bool(
306                    alias_expr.find(exp.AggFunc)
307                    and column.find_ancestor(exp.AggFunc)
308                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
309                )
310
311                # BigQuery's having clause gets confused if an alias matches a source.
312                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
313                # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table
314                if is_having and dialect == "bigquery":
315                    skip_replace = skip_replace or any(
316                        node.parts[0].name in projections
317                        for node in alias_expr.find_all(exp.Column)
318                    )
319
320            if table and (not alias_expr or skip_replace):
321                column.set("table", table)
322            elif not column.table and alias_expr and not skip_replace:
323                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
324                    if literal_index:
325                        column.replace(exp.Literal.number(i))
326                else:
327                    column = column.replace(exp.paren(alias_expr))
328                    simplified = simplify_parens(column)
329                    if simplified is not column:
330                        column.replace(simplified)
331
332    for i, projection in enumerate(expression.selects):
333        replace_columns(projection)
334        if isinstance(projection, exp.Alias):
335            alias_to_expression[projection.alias] = (projection.this, i + 1)
336
337    parent_scope = scope
338    while parent_scope.is_union:
339        parent_scope = parent_scope.parent
340
341    # We shouldn't expand aliases if they match the recursive CTE's columns
342    if parent_scope.is_cte:
343        cte = parent_scope.expression.parent
344        if cte.find_ancestor(exp.With).recursive:
345            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
346                alias_to_expression.pop(recursive_cte_column.output_name, None)
347
348    replace_columns(expression.args.get("where"))
349    replace_columns(expression.args.get("group"), literal_index=True)
350    replace_columns(expression.args.get("having"), resolve_table=True)
351    replace_columns(expression.args.get("qualify"), resolve_table=True)
352
353    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
354    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
355    if dialect == "snowflake":
356        for join in expression.args.get("joins") or []:
357            replace_columns(join)
358
359    scope.clear_cache()
360
361
362def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
363    expression = scope.expression
364    group = expression.args.get("group")
365    if not group:
366        return
367
368    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
369    expression.set("group", group)
370
371
372def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
373    for modifier_key in ("order", "distinct"):
374        modifier = scope.expression.args.get(modifier_key)
375        if isinstance(modifier, exp.Distinct):
376            modifier = modifier.args.get("on")
377
378        if not isinstance(modifier, exp.Expression):
379            continue
380
381        modifier_expressions = modifier.expressions
382        if modifier_key == "order":
383            modifier_expressions = [ordered.this for ordered in modifier_expressions]
384
385        for original, expanded in zip(
386            modifier_expressions,
387            _expand_positional_references(
388                scope, modifier_expressions, resolver.schema.dialect, alias=True
389            ),
390        ):
391            for agg in original.find_all(exp.AggFunc):
392                for col in agg.find_all(exp.Column):
393                    if not col.table:
394                        col.set("table", resolver.get_table(col.name))
395
396            original.replace(expanded)
397
398        if scope.expression.args.get("group"):
399            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
400
401            for expression in modifier_expressions:
402                expression.replace(
403                    exp.to_identifier(_select_by_pos(scope, expression).alias)
404                    if expression.is_int
405                    else selects.get(expression, expression)
406                )
407
408
409def _expand_positional_references(
410    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
411) -> t.List[exp.Expression]:
412    new_nodes: t.List[exp.Expression] = []
413    ambiguous_projections = None
414
415    for node in expressions:
416        if node.is_int:
417            select = _select_by_pos(scope, t.cast(exp.Literal, node))
418
419            if alias:
420                new_nodes.append(exp.column(select.args["alias"].copy()))
421            else:
422                select = select.this
423
424                if dialect == "bigquery":
425                    if ambiguous_projections is None:
426                        # When a projection name is also a source name and it is referenced in the
427                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
428                        ambiguous_projections = {
429                            s.alias_or_name
430                            for s in scope.expression.selects
431                            if s.alias_or_name in scope.selected_sources
432                        }
433
434                    ambiguous = any(
435                        column.parts[0].name in ambiguous_projections
436                        for column in select.find_all(exp.Column)
437                    )
438                else:
439                    ambiguous = False
440
441                if (
442                    isinstance(select, exp.CONSTANTS)
443                    or select.find(exp.Explode, exp.Unnest)
444                    or ambiguous
445                ):
446                    new_nodes.append(node)
447                else:
448                    new_nodes.append(select.copy())
449        else:
450            new_nodes.append(node)
451
452    return new_nodes
453
454
455def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
456    try:
457        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
458    except IndexError:
459        raise OptimizeError(f"Unknown output column: {node.name}")
460
461
462def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
463    """
464    Converts `Column` instances that represent struct field lookup into chained `Dots`.
465
466    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
467    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
468    """
469    converted = False
470    for column in itertools.chain(scope.columns, scope.stars):
471        if isinstance(column, exp.Dot):
472            continue
473
474        column_table: t.Optional[str | exp.Identifier] = column.table
475        if (
476            column_table
477            and column_table not in scope.sources
478            and (
479                not scope.parent
480                or column_table not in scope.parent.sources
481                or not scope.is_correlated_subquery
482            )
483        ):
484            root, *parts = column.parts
485
486            if root.name in scope.sources:
487                # The struct is already qualified, but we still need to change the AST
488                column_table = root
489                root, *parts = parts
490            else:
491                column_table = resolver.get_table(root.name)
492
493            if column_table:
494                converted = True
495                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
496
497    if converted:
498        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
499        # a `for column in scope.columns` iteration, even though they shouldn't be
500        scope.clear_cache()
501
502
503def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
504    """Disambiguate columns, ensuring each column specifies a source"""
505    for column in scope.columns:
506        column_table = column.table
507        column_name = column.name
508
509        if column_table and column_table in scope.sources:
510            source_columns = resolver.get_source_columns(column_table)
511            if (
512                not allow_partial_qualification
513                and source_columns
514                and column_name not in source_columns
515                and "*" not in source_columns
516            ):
517                raise OptimizeError(f"Unknown column: {column_name}")
518
519        if not column_table:
520            if scope.pivots and not column.find_ancestor(exp.Pivot):
521                # If the column is under the Pivot expression, we need to qualify it
522                # using the name of the pivoted source instead of the pivot's alias
523                column.set("table", exp.to_identifier(scope.pivots[0].alias))
524                continue
525
526            # column_table can be a '' because bigquery unnest has no table alias
527            column_table = resolver.get_table(column_name)
528            if column_table:
529                column.set("table", column_table)
530
531    for pivot in scope.pivots:
532        for column in pivot.find_all(exp.Column):
533            if not column.table and column.name in resolver.all_columns:
534                column_table = resolver.get_table(column.name)
535                if column_table:
536                    column.set("table", column_table)
537
538
539def _expand_struct_stars(
540    expression: exp.Dot,
541) -> t.List[exp.Alias]:
542    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
543
544    dot_column = t.cast(exp.Column, expression.find(exp.Column))
545    if not dot_column.is_type(exp.DataType.Type.STRUCT):
546        return []
547
548    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
549    dot_column = dot_column.copy()
550    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
551
552    # First part is the table name and last part is the star so they can be dropped
553    dot_parts = expression.parts[1:-1]
554
555    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
556    for part in dot_parts[1:]:
557        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
558            # Unable to expand star unless all fields are named
559            if not isinstance(field.this, exp.Identifier):
560                return []
561
562            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
563                starting_struct = field
564                break
565        else:
566            # There is no matching field in the struct
567            return []
568
569    taken_names = set()
570    new_selections = []
571
572    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
573        name = field.name
574
575        # Ambiguous or anonymous fields can't be expanded
576        if name in taken_names or not isinstance(field.this, exp.Identifier):
577            return []
578
579        taken_names.add(name)
580
581        this = field.this.copy()
582        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
583        new_column = exp.column(
584            t.cast(exp.Identifier, root),
585            table=dot_column.args.get("table"),
586            fields=t.cast(t.List[exp.Identifier], parts),
587        )
588        new_selections.append(alias(new_column, this, copy=False))
589
590    return new_selections
591
592
593def _expand_stars(
594    scope: Scope,
595    resolver: Resolver,
596    using_column_tables: t.Dict[str, t.Any],
597    pseudocolumns: t.Set[str],
598    annotator: TypeAnnotator,
599) -> None:
600    """Expand stars to lists of column selections"""
601
602    new_selections: t.List[exp.Expression] = []
603    except_columns: t.Dict[int, t.Set[str]] = {}
604    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
605    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
606
607    coalesced_columns = set()
608    dialect = resolver.schema.dialect
609
610    pivot_output_columns = None
611    pivot_exclude_columns = None
612
613    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
614    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
615        if pivot.unpivot:
616            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
617
618            field = pivot.args.get("field")
619            if isinstance(field, exp.In):
620                pivot_exclude_columns = {
621                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
622                }
623        else:
624            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
625
626            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
627            if not pivot_output_columns:
628                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
629
630    is_bigquery = dialect == "bigquery"
631    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
632        # Found struct expansion, annotate scope ahead of time
633        annotator.annotate_scope(scope)
634
635    for expression in scope.expression.selects:
636        tables = []
637        if isinstance(expression, exp.Star):
638            tables.extend(scope.selected_sources)
639            _add_except_columns(expression, tables, except_columns)
640            _add_replace_columns(expression, tables, replace_columns)
641            _add_rename_columns(expression, tables, rename_columns)
642        elif expression.is_star:
643            if not isinstance(expression, exp.Dot):
644                tables.append(expression.table)
645                _add_except_columns(expression.this, tables, except_columns)
646                _add_replace_columns(expression.this, tables, replace_columns)
647                _add_rename_columns(expression.this, tables, rename_columns)
648            elif is_bigquery:
649                struct_fields = _expand_struct_stars(expression)
650                if struct_fields:
651                    new_selections.extend(struct_fields)
652                    continue
653
654        if not tables:
655            new_selections.append(expression)
656            continue
657
658        for table in tables:
659            if table not in scope.sources:
660                raise OptimizeError(f"Unknown table: {table}")
661
662            columns = resolver.get_source_columns(table, only_visible=True)
663            columns = columns or scope.outer_columns
664
665            if pseudocolumns:
666                columns = [name for name in columns if name.upper() not in pseudocolumns]
667
668            if not columns or "*" in columns:
669                return
670
671            table_id = id(table)
672            columns_to_exclude = except_columns.get(table_id) or set()
673            renamed_columns = rename_columns.get(table_id, {})
674            replaced_columns = replace_columns.get(table_id, {})
675
676            if pivot:
677                if pivot_output_columns and pivot_exclude_columns:
678                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
679                    pivot_columns.extend(pivot_output_columns)
680                else:
681                    pivot_columns = pivot.alias_column_names
682
683                if pivot_columns:
684                    new_selections.extend(
685                        alias(exp.column(name, table=pivot.alias), name, copy=False)
686                        for name in pivot_columns
687                        if name not in columns_to_exclude
688                    )
689                    continue
690
691            for name in columns:
692                if name in columns_to_exclude or name in coalesced_columns:
693                    continue
694                if name in using_column_tables and table in using_column_tables[name]:
695                    coalesced_columns.add(name)
696                    tables = using_column_tables[name]
697                    coalesce_args = [exp.column(name, table=table) for table in tables]
698
699                    new_selections.append(
700                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
701                    )
702                else:
703                    alias_ = renamed_columns.get(name, name)
704                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
705                    new_selections.append(
706                        alias(selection_expr, alias_, copy=False)
707                        if alias_ != name
708                        else selection_expr
709                    )
710
711    # Ensures we don't overwrite the initial selections with an empty list
712    if new_selections and isinstance(scope.expression, exp.Select):
713        scope.expression.set("expressions", new_selections)
714
715
716def _add_except_columns(
717    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
718) -> None:
719    except_ = expression.args.get("except")
720
721    if not except_:
722        return
723
724    columns = {e.name for e in except_}
725
726    for table in tables:
727        except_columns[id(table)] = columns
728
729
730def _add_rename_columns(
731    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
732) -> None:
733    rename = expression.args.get("rename")
734
735    if not rename:
736        return
737
738    columns = {e.this.name: e.alias for e in rename}
739
740    for table in tables:
741        rename_columns[id(table)] = columns
742
743
744def _add_replace_columns(
745    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
746) -> None:
747    replace = expression.args.get("replace")
748
749    if not replace:
750        return
751
752    columns = {e.alias: e for e in replace}
753
754    for table in tables:
755        replace_columns[id(table)] = columns
756
757
758def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
759    """Ensure all output columns are aliased"""
760    if isinstance(scope_or_expression, exp.Expression):
761        scope = build_scope(scope_or_expression)
762        if not isinstance(scope, Scope):
763            return
764    else:
765        scope = scope_or_expression
766
767    new_selections = []
768    for i, (selection, aliased_column) in enumerate(
769        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
770    ):
771        if selection is None:
772            break
773
774        if isinstance(selection, exp.Subquery):
775            if not selection.output_name:
776                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
777        elif not isinstance(selection, exp.Alias) and not selection.is_star:
778            selection = alias(
779                selection,
780                alias=selection.output_name or f"_col_{i}",
781                copy=False,
782            )
783        if aliased_column:
784            selection.set("alias", exp.to_identifier(aliased_column))
785
786        new_selections.append(selection)
787
788    if isinstance(scope.expression, exp.Select):
789        scope.expression.set("expressions", new_selections)
790
791
792def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
793    """Makes sure all identifiers that need to be quoted are quoted."""
794    return expression.transform(
795        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
796    )  # type: ignore
797
798
799def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
800    """
801    Pushes down the CTE alias columns into the projection,
802
803    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
804
805    Example:
806        >>> import sqlglot
807        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
808        >>> pushdown_cte_alias_columns(expression).sql()
809        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
810
811    Args:
812        expression: Expression to pushdown.
813
814    Returns:
815        The expression with the CTE aliases pushed down into the projection.
816    """
817    for cte in expression.find_all(exp.CTE):
818        if cte.alias_column_names:
819            new_expressions = []
820            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
821                if isinstance(projection, exp.Alias):
822                    projection.set("alias", _alias)
823                else:
824                    projection = alias(projection, alias=_alias)
825                new_expressions.append(projection)
826            cte.this.set("expressions", new_expressions)
827
828    return expression
829
830
831class Resolver:
832    """
833    Helper for resolving columns.
834
835    This is a class so we can lazily load some things and easily share them across functions.
836    """
837
838    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
839        self.scope = scope
840        self.schema = schema
841        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
842        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
843        self._all_columns: t.Optional[t.Set[str]] = None
844        self._infer_schema = infer_schema
845        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
846
847    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
848        """
849        Get the table for a column name.
850
851        Args:
852            column_name: The column name to find the table for.
853        Returns:
854            The table name if it can be found/inferred.
855        """
856        if self._unambiguous_columns is None:
857            self._unambiguous_columns = self._get_unambiguous_columns(
858                self._get_all_source_columns()
859            )
860
861        table_name = self._unambiguous_columns.get(column_name)
862
863        if not table_name and self._infer_schema:
864            sources_without_schema = tuple(
865                source
866                for source, columns in self._get_all_source_columns().items()
867                if not columns or "*" in columns
868            )
869            if len(sources_without_schema) == 1:
870                table_name = sources_without_schema[0]
871
872        if table_name not in self.scope.selected_sources:
873            return exp.to_identifier(table_name)
874
875        node, _ = self.scope.selected_sources.get(table_name)
876
877        if isinstance(node, exp.Query):
878            while node and node.alias != table_name:
879                node = node.parent
880
881        node_alias = node.args.get("alias")
882        if node_alias:
883            return exp.to_identifier(node_alias.this)
884
885        return exp.to_identifier(table_name)
886
887    @property
888    def all_columns(self) -> t.Set[str]:
889        """All available columns of all sources in this scope"""
890        if self._all_columns is None:
891            self._all_columns = {
892                column for columns in self._get_all_source_columns().values() for column in columns
893            }
894        return self._all_columns
895
896    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
897        """Resolve the source columns for a given source `name`."""
898        cache_key = (name, only_visible)
899        if cache_key not in self._get_source_columns_cache:
900            if name not in self.scope.sources:
901                raise OptimizeError(f"Unknown table: {name}")
902
903            source = self.scope.sources[name]
904
905            if isinstance(source, exp.Table):
906                columns = self.schema.column_names(source, only_visible)
907            elif isinstance(source, Scope) and isinstance(
908                source.expression, (exp.Values, exp.Unnest)
909            ):
910                columns = source.expression.named_selects
911
912                # in bigquery, unnest structs are automatically scoped as tables, so you can
913                # directly select a struct field in a query.
914                # this handles the case where the unnest is statically defined.
915                if self.schema.dialect == "bigquery":
916                    if source.expression.is_type(exp.DataType.Type.STRUCT):
917                        for k in source.expression.type.expressions:  # type: ignore
918                            columns.append(k.name)
919            else:
920                columns = source.expression.named_selects
921
922            node, _ = self.scope.selected_sources.get(name) or (None, None)
923            if isinstance(node, Scope):
924                column_aliases = node.expression.alias_column_names
925            elif isinstance(node, exp.Expression):
926                column_aliases = node.alias_column_names
927            else:
928                column_aliases = []
929
930            if column_aliases:
931                # If the source's columns are aliased, their aliases shadow the corresponding column names.
932                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
933                columns = [
934                    alias or name
935                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
936                ]
937
938            self._get_source_columns_cache[cache_key] = columns
939
940        return self._get_source_columns_cache[cache_key]
941
942    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
943        if self._source_columns is None:
944            self._source_columns = {
945                source_name: self.get_source_columns(source_name)
946                for source_name, source in itertools.chain(
947                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
948                )
949            }
950        return self._source_columns
951
952    def _get_unambiguous_columns(
953        self, source_columns: t.Dict[str, t.Sequence[str]]
954    ) -> t.Mapping[str, str]:
955        """
956        Find all the unambiguous columns in sources.
957
958        Args:
959            source_columns: Mapping of names to source columns.
960
961        Returns:
962            Mapping of column name to source name.
963        """
964        if not source_columns:
965            return {}
966
967        source_columns_pairs = list(source_columns.items())
968
969        first_table, first_columns = source_columns_pairs[0]
970
971        if len(source_columns_pairs) == 1:
972            # Performance optimization - avoid copying first_columns if there is only one table.
973            return SingleValuedMapping(first_columns, first_table)
974
975        unambiguous_columns = {col: first_table for col in first_columns}
976        all_columns = set(unambiguous_columns)
977
978        for table, columns in source_columns_pairs[1:]:
979            unique = set(columns)
980            ambiguous = all_columns.intersection(unique)
981            all_columns.update(columns)
982
983            for column in ambiguous:
984                unambiguous_columns.pop(column, None)
985            for column in unique.difference(ambiguous):
986                unambiguous_columns[column] = table
987
988        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27) -> exp.Expression:
 28    """
 29    Rewrite sqlglot AST to have fully qualified columns.
 30
 31    Example:
 32        >>> import sqlglot
 33        >>> schema = {"tbl": {"col": "INT"}}
 34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 35        >>> qualify_columns(expression, schema).sql()
 36        'SELECT tbl.col AS col FROM tbl'
 37
 38    Args:
 39        expression: Expression to qualify.
 40        schema: Database schema.
 41        expand_alias_refs: Whether to expand references to aliases.
 42        expand_stars: Whether to expand star queries. This is a necessary step
 43            for most of the optimizer's rules to work; do not set to False unless you
 44            know what you're doing!
 45        infer_schema: Whether to infer the schema if missing.
 46        allow_partial_qualification: Whether to allow partial qualification.
 47
 48    Returns:
 49        The qualified expression.
 50
 51    Notes:
 52        - Currently only handles a single PIVOT or UNPIVOT operator
 53    """
 54    schema = ensure_schema(schema)
 55    annotator = TypeAnnotator(schema)
 56    infer_schema = schema.empty if infer_schema is None else infer_schema
 57    dialect = Dialect.get_or_raise(schema.dialect)
 58    pseudocolumns = dialect.PSEUDOCOLUMNS
 59    bigquery = dialect == "bigquery"
 60
 61    for scope in traverse_scope(expression):
 62        scope_expression = scope.expression
 63        is_select = isinstance(scope_expression, exp.Select)
 64
 65        if is_select and scope_expression.args.get("connect"):
 66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 68            scope_expression.transform(
 69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 70                copy=False,
 71            )
 72            scope.clear_cache()
 73
 74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 75        _pop_table_column_aliases(scope.ctes)
 76        _pop_table_column_aliases(scope.derived_tables)
 77        using_column_tables = _expand_using(scope, resolver)
 78
 79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 80            _expand_alias_refs(
 81                scope,
 82                resolver,
 83                dialect,
 84                expand_only_groupby=bigquery,
 85            )
 86
 87        _convert_columns_to_dots(scope, resolver)
 88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 89
 90        if not schema.empty and expand_alias_refs:
 91            _expand_alias_refs(scope, resolver, dialect)
 92
 93        if is_select:
 94            if expand_stars:
 95                _expand_stars(
 96                    scope,
 97                    resolver,
 98                    using_column_tables,
 99                    pseudocolumns,
100                    annotator,
101                )
102            qualify_outputs(scope)
103
104        _expand_group_by(scope, dialect)
105
106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
108        _expand_order_by_and_distinct_on(scope, resolver)
109
110        if bigquery:
111            annotator.annotate_scope(scope)
112
113    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
116def validate_qualify_columns(expression: E) -> E:
117    """Raise an `OptimizeError` if any columns aren't qualified"""
118    all_unqualified_columns = []
119    for scope in traverse_scope(expression):
120        if isinstance(scope.expression, exp.Select):
121            unqualified_columns = scope.unqualified_columns
122
123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
124                column = scope.external_columns[0]
125                for_table = f" for table: '{column.table}'" if column.table else ""
126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
127
128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
131                # this list here to ensure those in the former category will be excluded.
132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
134
135            all_unqualified_columns.extend(unqualified_columns)
136
137    if all_unqualified_columns:
138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
139
140    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
759def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
760    """Ensure all output columns are aliased"""
761    if isinstance(scope_or_expression, exp.Expression):
762        scope = build_scope(scope_or_expression)
763        if not isinstance(scope, Scope):
764            return
765    else:
766        scope = scope_or_expression
767
768    new_selections = []
769    for i, (selection, aliased_column) in enumerate(
770        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
771    ):
772        if selection is None:
773            break
774
775        if isinstance(selection, exp.Subquery):
776            if not selection.output_name:
777                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
778        elif not isinstance(selection, exp.Alias) and not selection.is_star:
779            selection = alias(
780                selection,
781                alias=selection.output_name or f"_col_{i}",
782                copy=False,
783            )
784        if aliased_column:
785            selection.set("alias", exp.to_identifier(aliased_column))
786
787        new_selections.append(selection)
788
789    if isinstance(scope.expression, exp.Select):
790        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
793def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
794    """Makes sure all identifiers that need to be quoted are quoted."""
795    return expression.transform(
796        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
797    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
800def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
801    """
802    Pushes down the CTE alias columns into the projection,
803
804    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
805
806    Example:
807        >>> import sqlglot
808        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
809        >>> pushdown_cte_alias_columns(expression).sql()
810        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
811
812    Args:
813        expression: Expression to pushdown.
814
815    Returns:
816        The expression with the CTE aliases pushed down into the projection.
817    """
818    for cte in expression.find_all(exp.CTE):
819        if cte.alias_column_names:
820            new_expressions = []
821            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
822                if isinstance(projection, exp.Alias):
823                    projection.set("alias", _alias)
824                else:
825                    projection = alias(projection, alias=_alias)
826                new_expressions.append(projection)
827            cte.this.set("expressions", new_expressions)
828
829    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
832class Resolver:
833    """
834    Helper for resolving columns.
835
836    This is a class so we can lazily load some things and easily share them across functions.
837    """
838
839    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
840        self.scope = scope
841        self.schema = schema
842        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
843        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
844        self._all_columns: t.Optional[t.Set[str]] = None
845        self._infer_schema = infer_schema
846        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
847
848    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
849        """
850        Get the table for a column name.
851
852        Args:
853            column_name: The column name to find the table for.
854        Returns:
855            The table name if it can be found/inferred.
856        """
857        if self._unambiguous_columns is None:
858            self._unambiguous_columns = self._get_unambiguous_columns(
859                self._get_all_source_columns()
860            )
861
862        table_name = self._unambiguous_columns.get(column_name)
863
864        if not table_name and self._infer_schema:
865            sources_without_schema = tuple(
866                source
867                for source, columns in self._get_all_source_columns().items()
868                if not columns or "*" in columns
869            )
870            if len(sources_without_schema) == 1:
871                table_name = sources_without_schema[0]
872
873        if table_name not in self.scope.selected_sources:
874            return exp.to_identifier(table_name)
875
876        node, _ = self.scope.selected_sources.get(table_name)
877
878        if isinstance(node, exp.Query):
879            while node and node.alias != table_name:
880                node = node.parent
881
882        node_alias = node.args.get("alias")
883        if node_alias:
884            return exp.to_identifier(node_alias.this)
885
886        return exp.to_identifier(table_name)
887
888    @property
889    def all_columns(self) -> t.Set[str]:
890        """All available columns of all sources in this scope"""
891        if self._all_columns is None:
892            self._all_columns = {
893                column for columns in self._get_all_source_columns().values() for column in columns
894            }
895        return self._all_columns
896
897    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
898        """Resolve the source columns for a given source `name`."""
899        cache_key = (name, only_visible)
900        if cache_key not in self._get_source_columns_cache:
901            if name not in self.scope.sources:
902                raise OptimizeError(f"Unknown table: {name}")
903
904            source = self.scope.sources[name]
905
906            if isinstance(source, exp.Table):
907                columns = self.schema.column_names(source, only_visible)
908            elif isinstance(source, Scope) and isinstance(
909                source.expression, (exp.Values, exp.Unnest)
910            ):
911                columns = source.expression.named_selects
912
913                # in bigquery, unnest structs are automatically scoped as tables, so you can
914                # directly select a struct field in a query.
915                # this handles the case where the unnest is statically defined.
916                if self.schema.dialect == "bigquery":
917                    if source.expression.is_type(exp.DataType.Type.STRUCT):
918                        for k in source.expression.type.expressions:  # type: ignore
919                            columns.append(k.name)
920            else:
921                columns = source.expression.named_selects
922
923            node, _ = self.scope.selected_sources.get(name) or (None, None)
924            if isinstance(node, Scope):
925                column_aliases = node.expression.alias_column_names
926            elif isinstance(node, exp.Expression):
927                column_aliases = node.alias_column_names
928            else:
929                column_aliases = []
930
931            if column_aliases:
932                # If the source's columns are aliased, their aliases shadow the corresponding column names.
933                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
934                columns = [
935                    alias or name
936                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
937                ]
938
939            self._get_source_columns_cache[cache_key] = columns
940
941        return self._get_source_columns_cache[cache_key]
942
943    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
944        if self._source_columns is None:
945            self._source_columns = {
946                source_name: self.get_source_columns(source_name)
947                for source_name, source in itertools.chain(
948                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
949                )
950            }
951        return self._source_columns
952
953    def _get_unambiguous_columns(
954        self, source_columns: t.Dict[str, t.Sequence[str]]
955    ) -> t.Mapping[str, str]:
956        """
957        Find all the unambiguous columns in sources.
958
959        Args:
960            source_columns: Mapping of names to source columns.
961
962        Returns:
963            Mapping of column name to source name.
964        """
965        if not source_columns:
966            return {}
967
968        source_columns_pairs = list(source_columns.items())
969
970        first_table, first_columns = source_columns_pairs[0]
971
972        if len(source_columns_pairs) == 1:
973            # Performance optimization - avoid copying first_columns if there is only one table.
974            return SingleValuedMapping(first_columns, first_table)
975
976        unambiguous_columns = {col: first_table for col in first_columns}
977        all_columns = set(unambiguous_columns)
978
979        for table, columns in source_columns_pairs[1:]:
980            unique = set(columns)
981            ambiguous = all_columns.intersection(unique)
982            all_columns.update(columns)
983
984            for column in ambiguous:
985                unambiguous_columns.pop(column, None)
986            for column in unique.difference(ambiguous):
987                unambiguous_columns[column] = table
988
989        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
839    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
840        self.scope = scope
841        self.schema = schema
842        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
843        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
844        self._all_columns: t.Optional[t.Set[str]] = None
845        self._infer_schema = infer_schema
846        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
848    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
849        """
850        Get the table for a column name.
851
852        Args:
853            column_name: The column name to find the table for.
854        Returns:
855            The table name if it can be found/inferred.
856        """
857        if self._unambiguous_columns is None:
858            self._unambiguous_columns = self._get_unambiguous_columns(
859                self._get_all_source_columns()
860            )
861
862        table_name = self._unambiguous_columns.get(column_name)
863
864        if not table_name and self._infer_schema:
865            sources_without_schema = tuple(
866                source
867                for source, columns in self._get_all_source_columns().items()
868                if not columns or "*" in columns
869            )
870            if len(sources_without_schema) == 1:
871                table_name = sources_without_schema[0]
872
873        if table_name not in self.scope.selected_sources:
874            return exp.to_identifier(table_name)
875
876        node, _ = self.scope.selected_sources.get(table_name)
877
878        if isinstance(node, exp.Query):
879            while node and node.alias != table_name:
880                node = node.parent
881
882        node_alias = node.args.get("alias")
883        if node_alias:
884            return exp.to_identifier(node_alias.this)
885
886        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
888    @property
889    def all_columns(self) -> t.Set[str]:
890        """All available columns of all sources in this scope"""
891        if self._all_columns is None:
892            self._all_columns = {
893                column for columns in self._get_all_source_columns().values() for column in columns
894            }
895        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
897    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
898        """Resolve the source columns for a given source `name`."""
899        cache_key = (name, only_visible)
900        if cache_key not in self._get_source_columns_cache:
901            if name not in self.scope.sources:
902                raise OptimizeError(f"Unknown table: {name}")
903
904            source = self.scope.sources[name]
905
906            if isinstance(source, exp.Table):
907                columns = self.schema.column_names(source, only_visible)
908            elif isinstance(source, Scope) and isinstance(
909                source.expression, (exp.Values, exp.Unnest)
910            ):
911                columns = source.expression.named_selects
912
913                # in bigquery, unnest structs are automatically scoped as tables, so you can
914                # directly select a struct field in a query.
915                # this handles the case where the unnest is statically defined.
916                if self.schema.dialect == "bigquery":
917                    if source.expression.is_type(exp.DataType.Type.STRUCT):
918                        for k in source.expression.type.expressions:  # type: ignore
919                            columns.append(k.name)
920            else:
921                columns = source.expression.named_selects
922
923            node, _ = self.scope.selected_sources.get(name) or (None, None)
924            if isinstance(node, Scope):
925                column_aliases = node.expression.alias_column_names
926            elif isinstance(node, exp.Expression):
927                column_aliases = node.alias_column_names
928            else:
929                column_aliases = []
930
931            if column_aliases:
932                # If the source's columns are aliased, their aliases shadow the corresponding column names.
933                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
934                columns = [
935                    alias or name
936                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
937                ]
938
939            self._get_source_columns_cache[cache_key] = columns
940
941        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.