Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25) -> exp.Expression:
 26    """
 27    Rewrite sqlglot AST to have fully qualified columns.
 28
 29    Example:
 30        >>> import sqlglot
 31        >>> schema = {"tbl": {"col": "INT"}}
 32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 33        >>> qualify_columns(expression, schema).sql()
 34        'SELECT tbl.col AS col FROM tbl'
 35
 36    Args:
 37        expression: Expression to qualify.
 38        schema: Database schema.
 39        expand_alias_refs: Whether to expand references to aliases.
 40        expand_stars: Whether to expand star queries. This is a necessary step
 41            for most of the optimizer's rules to work; do not set to False unless you
 42            know what you're doing!
 43        infer_schema: Whether to infer the schema if missing.
 44
 45    Returns:
 46        The qualified expression.
 47
 48    Notes:
 49        - Currently only handles a single PIVOT or UNPIVOT operator
 50    """
 51    schema = ensure_schema(schema)
 52    annotator = TypeAnnotator(schema)
 53    infer_schema = schema.empty if infer_schema is None else infer_schema
 54    dialect = Dialect.get_or_raise(schema.dialect)
 55    pseudocolumns = dialect.PSEUDOCOLUMNS
 56
 57    for scope in traverse_scope(expression):
 58        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 59        _pop_table_column_aliases(scope.ctes)
 60        _pop_table_column_aliases(scope.derived_tables)
 61        using_column_tables = _expand_using(scope, resolver)
 62
 63        if schema.empty and expand_alias_refs:
 64            _expand_alias_refs(scope, resolver)
 65
 66        _convert_columns_to_dots(scope, resolver)
 67        _qualify_columns(scope, resolver)
 68
 69        if not schema.empty and expand_alias_refs:
 70            _expand_alias_refs(scope, resolver)
 71
 72        if not isinstance(scope.expression, exp.UDTF):
 73            if expand_stars:
 74                _expand_stars(
 75                    scope,
 76                    resolver,
 77                    using_column_tables,
 78                    pseudocolumns,
 79                    annotator,
 80                )
 81            qualify_outputs(scope)
 82
 83        _expand_group_by(scope, dialect)
 84        _expand_order_by(scope, resolver)
 85
 86        if dialect == "bigquery":
 87            annotator.annotate_scope(scope)
 88
 89    return expression
 90
 91
 92def validate_qualify_columns(expression: E) -> E:
 93    """Raise an `OptimizeError` if any columns aren't qualified"""
 94    all_unqualified_columns = []
 95    for scope in traverse_scope(expression):
 96        if isinstance(scope.expression, exp.Select):
 97            unqualified_columns = scope.unqualified_columns
 98
 99            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
100                column = scope.external_columns[0]
101                for_table = f" for table: '{column.table}'" if column.table else ""
102                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
103
104            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
105                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
106                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
107                # this list here to ensure those in the former category will be excluded.
108                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
109                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
110
111            all_unqualified_columns.extend(unqualified_columns)
112
113    if all_unqualified_columns:
114        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
115
116    return expression
117
118
119def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
120    name_column = []
121    field = unpivot.args.get("field")
122    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
123        name_column.append(field.this)
124
125    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
126    return itertools.chain(name_column, value_columns)
127
128
129def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
130    """
131    Remove table column aliases.
132
133    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
134    """
135    for derived_table in derived_tables:
136        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
137            continue
138        table_alias = derived_table.args.get("alias")
139        if table_alias:
140            table_alias.args.pop("columns", None)
141
142
143def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
144    joins = list(scope.find_all(exp.Join))
145    names = {join.alias_or_name for join in joins}
146    ordered = [key for key in scope.selected_sources if key not in names]
147
148    # Mapping of automatically joined column names to an ordered set of source names (dict).
149    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
150
151    for join in joins:
152        using = join.args.get("using")
153
154        if not using:
155            continue
156
157        join_table = join.alias_or_name
158
159        columns = {}
160
161        for source_name in scope.selected_sources:
162            if source_name in ordered:
163                for column_name in resolver.get_source_columns(source_name):
164                    if column_name not in columns:
165                        columns[column_name] = source_name
166
167        source_table = ordered[-1]
168        ordered.append(join_table)
169        join_columns = resolver.get_source_columns(join_table)
170        conditions = []
171
172        for identifier in using:
173            identifier = identifier.name
174            table = columns.get(identifier)
175
176            if not table or identifier not in join_columns:
177                if (columns and "*" not in columns) and join_columns:
178                    raise OptimizeError(f"Cannot automatically join: {identifier}")
179
180            table = table or source_table
181            conditions.append(
182                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
183            )
184
185            # Set all values in the dict to None, because we only care about the key ordering
186            tables = column_tables.setdefault(identifier, {})
187            if table not in tables:
188                tables[table] = None
189            if join_table not in tables:
190                tables[join_table] = None
191
192        join.args.pop("using")
193        join.set("on", exp.and_(*conditions, copy=False))
194
195    if column_tables:
196        for column in scope.columns:
197            if not column.table and column.name in column_tables:
198                tables = column_tables[column.name]
199                coalesce = [exp.column(column.name, table=table) for table in tables]
200                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
201
202                # Ensure selects keep their output name
203                if isinstance(column.parent, exp.Select):
204                    replacement = alias(replacement, alias=column.name, copy=False)
205
206                scope.replace(column, replacement)
207
208    return column_tables
209
210
211def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
212    expression = scope.expression
213
214    if not isinstance(expression, exp.Select):
215        return
216
217    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
218
219    def replace_columns(
220        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
221    ) -> None:
222        if not node:
223            return
224
225        for column in walk_in_scope(node, prune=lambda node: node.is_star):
226            if not isinstance(column, exp.Column):
227                continue
228
229            table = resolver.get_table(column.name) if resolve_table and not column.table else None
230            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
231            double_agg = (
232                (
233                    alias_expr.find(exp.AggFunc)
234                    and (
235                        column.find_ancestor(exp.AggFunc)
236                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
237                    )
238                )
239                if alias_expr
240                else False
241            )
242
243            if table and (not alias_expr or double_agg):
244                column.set("table", table)
245            elif not column.table and alias_expr and not double_agg:
246                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
247                    if literal_index:
248                        column.replace(exp.Literal.number(i))
249                else:
250                    column = column.replace(exp.paren(alias_expr))
251                    simplified = simplify_parens(column)
252                    if simplified is not column:
253                        column.replace(simplified)
254
255    for i, projection in enumerate(scope.expression.selects):
256        replace_columns(projection)
257
258        if isinstance(projection, exp.Alias):
259            alias_to_expression[projection.alias] = (projection.this, i + 1)
260
261    replace_columns(expression.args.get("where"))
262    replace_columns(expression.args.get("group"), literal_index=True)
263    replace_columns(expression.args.get("having"), resolve_table=True)
264    replace_columns(expression.args.get("qualify"), resolve_table=True)
265
266    scope.clear_cache()
267
268
269def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
270    expression = scope.expression
271    group = expression.args.get("group")
272    if not group:
273        return
274
275    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
276    expression.set("group", group)
277
278
279def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
280    order = scope.expression.args.get("order")
281    if not order:
282        return
283
284    ordereds = order.expressions
285    for ordered, new_expression in zip(
286        ordereds,
287        _expand_positional_references(
288            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
289        ),
290    ):
291        for agg in ordered.find_all(exp.AggFunc):
292            for col in agg.find_all(exp.Column):
293                if not col.table:
294                    col.set("table", resolver.get_table(col.name))
295
296        ordered.set("this", new_expression)
297
298    if scope.expression.args.get("group"):
299        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
300
301        for ordered in ordereds:
302            ordered = ordered.this
303
304            ordered.replace(
305                exp.to_identifier(_select_by_pos(scope, ordered).alias)
306                if ordered.is_int
307                else selects.get(ordered, ordered)
308            )
309
310
311def _expand_positional_references(
312    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
313) -> t.List[exp.Expression]:
314    new_nodes: t.List[exp.Expression] = []
315    ambiguous_projections = None
316
317    for node in expressions:
318        if node.is_int:
319            select = _select_by_pos(scope, t.cast(exp.Literal, node))
320
321            if alias:
322                new_nodes.append(exp.column(select.args["alias"].copy()))
323            else:
324                select = select.this
325
326                if dialect == "bigquery":
327                    if ambiguous_projections is None:
328                        # When a projection name is also a source name and it is referenced in the
329                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
330                        ambiguous_projections = {
331                            s.alias_or_name
332                            for s in scope.expression.selects
333                            if s.alias_or_name in scope.selected_sources
334                        }
335
336                    ambiguous = any(
337                        column.parts[0].name in ambiguous_projections
338                        for column in select.find_all(exp.Column)
339                    )
340                else:
341                    ambiguous = False
342
343                if (
344                    isinstance(select, exp.CONSTANTS)
345                    or select.find(exp.Explode, exp.Unnest)
346                    or ambiguous
347                ):
348                    new_nodes.append(node)
349                else:
350                    new_nodes.append(select.copy())
351        else:
352            new_nodes.append(node)
353
354    return new_nodes
355
356
357def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
358    try:
359        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
360    except IndexError:
361        raise OptimizeError(f"Unknown output column: {node.name}")
362
363
364def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
365    """
366    Converts `Column` instances that represent struct field lookup into chained `Dots`.
367
368    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
369    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
370    """
371    converted = False
372    for column in itertools.chain(scope.columns, scope.stars):
373        if isinstance(column, exp.Dot):
374            continue
375
376        column_table: t.Optional[str | exp.Identifier] = column.table
377        if (
378            column_table
379            and column_table not in scope.sources
380            and (
381                not scope.parent
382                or column_table not in scope.parent.sources
383                or not scope.is_correlated_subquery
384            )
385        ):
386            root, *parts = column.parts
387
388            if root.name in scope.sources:
389                # The struct is already qualified, but we still need to change the AST
390                column_table = root
391                root, *parts = parts
392            else:
393                column_table = resolver.get_table(root.name)
394
395            if column_table:
396                converted = True
397                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
398
399    if converted:
400        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
401        # a `for column in scope.columns` iteration, even though they shouldn't be
402        scope.clear_cache()
403
404
405def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
406    """Disambiguate columns, ensuring each column specifies a source"""
407    for column in scope.columns:
408        column_table = column.table
409        column_name = column.name
410
411        if column_table and column_table in scope.sources:
412            source_columns = resolver.get_source_columns(column_table)
413            if source_columns and column_name not in source_columns and "*" not in source_columns:
414                raise OptimizeError(f"Unknown column: {column_name}")
415
416        if not column_table:
417            if scope.pivots and not column.find_ancestor(exp.Pivot):
418                # If the column is under the Pivot expression, we need to qualify it
419                # using the name of the pivoted source instead of the pivot's alias
420                column.set("table", exp.to_identifier(scope.pivots[0].alias))
421                continue
422
423            # column_table can be a '' because bigquery unnest has no table alias
424            column_table = resolver.get_table(column_name)
425            if column_table:
426                column.set("table", column_table)
427
428    for pivot in scope.pivots:
429        for column in pivot.find_all(exp.Column):
430            if not column.table and column.name in resolver.all_columns:
431                column_table = resolver.get_table(column.name)
432                if column_table:
433                    column.set("table", column_table)
434
435
436def _expand_struct_stars(
437    expression: exp.Dot,
438) -> t.List[exp.Alias]:
439    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
440
441    dot_column = t.cast(exp.Column, expression.find(exp.Column))
442    if not dot_column.is_type(exp.DataType.Type.STRUCT):
443        return []
444
445    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
446    dot_column = dot_column.copy()
447    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
448
449    # First part is the table name and last part is the star so they can be dropped
450    dot_parts = expression.parts[1:-1]
451
452    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
453    for part in dot_parts[1:]:
454        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
455            # Unable to expand star unless all fields are named
456            if not isinstance(field.this, exp.Identifier):
457                return []
458
459            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
460                starting_struct = field
461                break
462        else:
463            # There is no matching field in the struct
464            return []
465
466    taken_names = set()
467    new_selections = []
468
469    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
470        name = field.name
471
472        # Ambiguous or anonymous fields can't be expanded
473        if name in taken_names or not isinstance(field.this, exp.Identifier):
474            return []
475
476        taken_names.add(name)
477
478        this = field.this.copy()
479        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
480        new_column = exp.column(
481            t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
482        )
483        new_selections.append(alias(new_column, this, copy=False))
484
485    return new_selections
486
487
488def _expand_stars(
489    scope: Scope,
490    resolver: Resolver,
491    using_column_tables: t.Dict[str, t.Any],
492    pseudocolumns: t.Set[str],
493    annotator: TypeAnnotator,
494) -> None:
495    """Expand stars to lists of column selections"""
496
497    new_selections = []
498    except_columns: t.Dict[int, t.Set[str]] = {}
499    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
500    coalesced_columns = set()
501    dialect = resolver.schema.dialect
502
503    pivot_output_columns = None
504    pivot_exclude_columns = None
505
506    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
507    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
508        if pivot.unpivot:
509            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
510
511            field = pivot.args.get("field")
512            if isinstance(field, exp.In):
513                pivot_exclude_columns = {
514                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
515                }
516        else:
517            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
518
519            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
520            if not pivot_output_columns:
521                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
522
523    is_bigquery = dialect == "bigquery"
524    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
525        # Found struct expansion, annotate scope ahead of time
526        annotator.annotate_scope(scope)
527
528    for expression in scope.expression.selects:
529        tables = []
530        if isinstance(expression, exp.Star):
531            tables.extend(scope.selected_sources)
532            _add_except_columns(expression, tables, except_columns)
533            _add_replace_columns(expression, tables, replace_columns)
534        elif expression.is_star:
535            if not isinstance(expression, exp.Dot):
536                tables.append(expression.table)
537                _add_except_columns(expression.this, tables, except_columns)
538                _add_replace_columns(expression.this, tables, replace_columns)
539            elif is_bigquery:
540                struct_fields = _expand_struct_stars(expression)
541                if struct_fields:
542                    new_selections.extend(struct_fields)
543                    continue
544
545        if not tables:
546            new_selections.append(expression)
547            continue
548
549        for table in tables:
550            if table not in scope.sources:
551                raise OptimizeError(f"Unknown table: {table}")
552
553            columns = resolver.get_source_columns(table, only_visible=True)
554            columns = columns or scope.outer_columns
555
556            if pseudocolumns:
557                columns = [name for name in columns if name.upper() not in pseudocolumns]
558
559            if not columns or "*" in columns:
560                return
561
562            table_id = id(table)
563            columns_to_exclude = except_columns.get(table_id) or set()
564
565            if pivot:
566                if pivot_output_columns and pivot_exclude_columns:
567                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
568                    pivot_columns.extend(pivot_output_columns)
569                else:
570                    pivot_columns = pivot.alias_column_names
571
572                if pivot_columns:
573                    new_selections.extend(
574                        alias(exp.column(name, table=pivot.alias), name, copy=False)
575                        for name in pivot_columns
576                        if name not in columns_to_exclude
577                    )
578                    continue
579
580            for name in columns:
581                if name in columns_to_exclude or name in coalesced_columns:
582                    continue
583                if name in using_column_tables and table in using_column_tables[name]:
584                    coalesced_columns.add(name)
585                    tables = using_column_tables[name]
586                    coalesce = [exp.column(name, table=table) for table in tables]
587
588                    new_selections.append(
589                        alias(
590                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
591                            alias=name,
592                            copy=False,
593                        )
594                    )
595                else:
596                    alias_ = replace_columns.get(table_id, {}).get(name, name)
597                    column = exp.column(name, table=table)
598                    new_selections.append(
599                        alias(column, alias_, copy=False) if alias_ != name else column
600                    )
601
602    # Ensures we don't overwrite the initial selections with an empty list
603    if new_selections and isinstance(scope.expression, exp.Select):
604        scope.expression.set("expressions", new_selections)
605
606
607def _add_except_columns(
608    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
609) -> None:
610    except_ = expression.args.get("except")
611
612    if not except_:
613        return
614
615    columns = {e.name for e in except_}
616
617    for table in tables:
618        except_columns[id(table)] = columns
619
620
621def _add_replace_columns(
622    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
623) -> None:
624    replace = expression.args.get("replace")
625
626    if not replace:
627        return
628
629    columns = {e.this.name: e.alias for e in replace}
630
631    for table in tables:
632        replace_columns[id(table)] = columns
633
634
635def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
636    """Ensure all output columns are aliased"""
637    if isinstance(scope_or_expression, exp.Expression):
638        scope = build_scope(scope_or_expression)
639        if not isinstance(scope, Scope):
640            return
641    else:
642        scope = scope_or_expression
643
644    new_selections = []
645    for i, (selection, aliased_column) in enumerate(
646        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
647    ):
648        if selection is None:
649            break
650
651        if isinstance(selection, exp.Subquery):
652            if not selection.output_name:
653                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
654        elif not isinstance(selection, exp.Alias) and not selection.is_star:
655            selection = alias(
656                selection,
657                alias=selection.output_name or f"_col_{i}",
658                copy=False,
659            )
660        if aliased_column:
661            selection.set("alias", exp.to_identifier(aliased_column))
662
663        new_selections.append(selection)
664
665    if isinstance(scope.expression, exp.Select):
666        scope.expression.set("expressions", new_selections)
667
668
669def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
670    """Makes sure all identifiers that need to be quoted are quoted."""
671    return expression.transform(
672        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
673    )  # type: ignore
674
675
676def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
677    """
678    Pushes down the CTE alias columns into the projection,
679
680    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
681
682    Example:
683        >>> import sqlglot
684        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
685        >>> pushdown_cte_alias_columns(expression).sql()
686        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
687
688    Args:
689        expression: Expression to pushdown.
690
691    Returns:
692        The expression with the CTE aliases pushed down into the projection.
693    """
694    for cte in expression.find_all(exp.CTE):
695        if cte.alias_column_names:
696            new_expressions = []
697            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
698                if isinstance(projection, exp.Alias):
699                    projection.set("alias", _alias)
700                else:
701                    projection = alias(projection, alias=_alias)
702                new_expressions.append(projection)
703            cte.this.set("expressions", new_expressions)
704
705    return expression
706
707
708class Resolver:
709    """
710    Helper for resolving columns.
711
712    This is a class so we can lazily load some things and easily share them across functions.
713    """
714
715    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
716        self.scope = scope
717        self.schema = schema
718        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
719        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
720        self._all_columns: t.Optional[t.Set[str]] = None
721        self._infer_schema = infer_schema
722
723    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
724        """
725        Get the table for a column name.
726
727        Args:
728            column_name: The column name to find the table for.
729        Returns:
730            The table name if it can be found/inferred.
731        """
732        if self._unambiguous_columns is None:
733            self._unambiguous_columns = self._get_unambiguous_columns(
734                self._get_all_source_columns()
735            )
736
737        table_name = self._unambiguous_columns.get(column_name)
738
739        if not table_name and self._infer_schema:
740            sources_without_schema = tuple(
741                source
742                for source, columns in self._get_all_source_columns().items()
743                if not columns or "*" in columns
744            )
745            if len(sources_without_schema) == 1:
746                table_name = sources_without_schema[0]
747
748        if table_name not in self.scope.selected_sources:
749            return exp.to_identifier(table_name)
750
751        node, _ = self.scope.selected_sources.get(table_name)
752
753        if isinstance(node, exp.Query):
754            while node and node.alias != table_name:
755                node = node.parent
756
757        node_alias = node.args.get("alias")
758        if node_alias:
759            return exp.to_identifier(node_alias.this)
760
761        return exp.to_identifier(table_name)
762
763    @property
764    def all_columns(self) -> t.Set[str]:
765        """All available columns of all sources in this scope"""
766        if self._all_columns is None:
767            self._all_columns = {
768                column for columns in self._get_all_source_columns().values() for column in columns
769            }
770        return self._all_columns
771
772    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
773        """Resolve the source columns for a given source `name`."""
774        if name not in self.scope.sources:
775            raise OptimizeError(f"Unknown table: {name}")
776
777        source = self.scope.sources[name]
778
779        if isinstance(source, exp.Table):
780            columns = self.schema.column_names(source, only_visible)
781        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
782            columns = source.expression.named_selects
783
784            # in bigquery, unnest structs are automatically scoped as tables, so you can
785            # directly select a struct field in a query.
786            # this handles the case where the unnest is statically defined.
787            if self.schema.dialect == "bigquery":
788                if source.expression.is_type(exp.DataType.Type.STRUCT):
789                    for k in source.expression.type.expressions:  # type: ignore
790                        columns.append(k.name)
791        else:
792            columns = source.expression.named_selects
793
794        node, _ = self.scope.selected_sources.get(name) or (None, None)
795        if isinstance(node, Scope):
796            column_aliases = node.expression.alias_column_names
797        elif isinstance(node, exp.Expression):
798            column_aliases = node.alias_column_names
799        else:
800            column_aliases = []
801
802        if column_aliases:
803            # If the source's columns are aliased, their aliases shadow the corresponding column names.
804            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
805            return [
806                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
807            ]
808        return columns
809
810    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
811        if self._source_columns is None:
812            self._source_columns = {
813                source_name: self.get_source_columns(source_name)
814                for source_name, source in itertools.chain(
815                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
816                )
817            }
818        return self._source_columns
819
820    def _get_unambiguous_columns(
821        self, source_columns: t.Dict[str, t.Sequence[str]]
822    ) -> t.Mapping[str, str]:
823        """
824        Find all the unambiguous columns in sources.
825
826        Args:
827            source_columns: Mapping of names to source columns.
828
829        Returns:
830            Mapping of column name to source name.
831        """
832        if not source_columns:
833            return {}
834
835        source_columns_pairs = list(source_columns.items())
836
837        first_table, first_columns = source_columns_pairs[0]
838
839        if len(source_columns_pairs) == 1:
840            # Performance optimization - avoid copying first_columns if there is only one table.
841            return SingleValuedMapping(first_columns, first_table)
842
843        unambiguous_columns = {col: first_table for col in first_columns}
844        all_columns = set(unambiguous_columns)
845
846        for table, columns in source_columns_pairs[1:]:
847            unique = set(columns)
848            ambiguous = all_columns.intersection(unique)
849            all_columns.update(columns)
850
851            for column in ambiguous:
852                unambiguous_columns.pop(column, None)
853            for column in unique.difference(ambiguous):
854                unambiguous_columns[column] = table
855
856        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26) -> exp.Expression:
27    """
28    Rewrite sqlglot AST to have fully qualified columns.
29
30    Example:
31        >>> import sqlglot
32        >>> schema = {"tbl": {"col": "INT"}}
33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
34        >>> qualify_columns(expression, schema).sql()
35        'SELECT tbl.col AS col FROM tbl'
36
37    Args:
38        expression: Expression to qualify.
39        schema: Database schema.
40        expand_alias_refs: Whether to expand references to aliases.
41        expand_stars: Whether to expand star queries. This is a necessary step
42            for most of the optimizer's rules to work; do not set to False unless you
43            know what you're doing!
44        infer_schema: Whether to infer the schema if missing.
45
46    Returns:
47        The qualified expression.
48
49    Notes:
50        - Currently only handles a single PIVOT or UNPIVOT operator
51    """
52    schema = ensure_schema(schema)
53    annotator = TypeAnnotator(schema)
54    infer_schema = schema.empty if infer_schema is None else infer_schema
55    dialect = Dialect.get_or_raise(schema.dialect)
56    pseudocolumns = dialect.PSEUDOCOLUMNS
57
58    for scope in traverse_scope(expression):
59        resolver = Resolver(scope, schema, infer_schema=infer_schema)
60        _pop_table_column_aliases(scope.ctes)
61        _pop_table_column_aliases(scope.derived_tables)
62        using_column_tables = _expand_using(scope, resolver)
63
64        if schema.empty and expand_alias_refs:
65            _expand_alias_refs(scope, resolver)
66
67        _convert_columns_to_dots(scope, resolver)
68        _qualify_columns(scope, resolver)
69
70        if not schema.empty and expand_alias_refs:
71            _expand_alias_refs(scope, resolver)
72
73        if not isinstance(scope.expression, exp.UDTF):
74            if expand_stars:
75                _expand_stars(
76                    scope,
77                    resolver,
78                    using_column_tables,
79                    pseudocolumns,
80                    annotator,
81                )
82            qualify_outputs(scope)
83
84        _expand_group_by(scope, dialect)
85        _expand_order_by(scope, resolver)
86
87        if dialect == "bigquery":
88            annotator.annotate_scope(scope)
89
90    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 93def validate_qualify_columns(expression: E) -> E:
 94    """Raise an `OptimizeError` if any columns aren't qualified"""
 95    all_unqualified_columns = []
 96    for scope in traverse_scope(expression):
 97        if isinstance(scope.expression, exp.Select):
 98            unqualified_columns = scope.unqualified_columns
 99
100            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
101                column = scope.external_columns[0]
102                for_table = f" for table: '{column.table}'" if column.table else ""
103                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
104
105            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
106                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
107                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
108                # this list here to ensure those in the former category will be excluded.
109                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
110                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
111
112            all_unqualified_columns.extend(unqualified_columns)
113
114    if all_unqualified_columns:
115        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
116
117    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
636def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
637    """Ensure all output columns are aliased"""
638    if isinstance(scope_or_expression, exp.Expression):
639        scope = build_scope(scope_or_expression)
640        if not isinstance(scope, Scope):
641            return
642    else:
643        scope = scope_or_expression
644
645    new_selections = []
646    for i, (selection, aliased_column) in enumerate(
647        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
648    ):
649        if selection is None:
650            break
651
652        if isinstance(selection, exp.Subquery):
653            if not selection.output_name:
654                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
655        elif not isinstance(selection, exp.Alias) and not selection.is_star:
656            selection = alias(
657                selection,
658                alias=selection.output_name or f"_col_{i}",
659                copy=False,
660            )
661        if aliased_column:
662            selection.set("alias", exp.to_identifier(aliased_column))
663
664        new_selections.append(selection)
665
666    if isinstance(scope.expression, exp.Select):
667        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
670def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
671    """Makes sure all identifiers that need to be quoted are quoted."""
672    return expression.transform(
673        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
674    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
677def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
678    """
679    Pushes down the CTE alias columns into the projection,
680
681    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
682
683    Example:
684        >>> import sqlglot
685        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
686        >>> pushdown_cte_alias_columns(expression).sql()
687        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
688
689    Args:
690        expression: Expression to pushdown.
691
692    Returns:
693        The expression with the CTE aliases pushed down into the projection.
694    """
695    for cte in expression.find_all(exp.CTE):
696        if cte.alias_column_names:
697            new_expressions = []
698            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
699                if isinstance(projection, exp.Alias):
700                    projection.set("alias", _alias)
701                else:
702                    projection = alias(projection, alias=_alias)
703                new_expressions.append(projection)
704            cte.this.set("expressions", new_expressions)
705
706    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
709class Resolver:
710    """
711    Helper for resolving columns.
712
713    This is a class so we can lazily load some things and easily share them across functions.
714    """
715
716    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
717        self.scope = scope
718        self.schema = schema
719        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
720        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
721        self._all_columns: t.Optional[t.Set[str]] = None
722        self._infer_schema = infer_schema
723
724    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
725        """
726        Get the table for a column name.
727
728        Args:
729            column_name: The column name to find the table for.
730        Returns:
731            The table name if it can be found/inferred.
732        """
733        if self._unambiguous_columns is None:
734            self._unambiguous_columns = self._get_unambiguous_columns(
735                self._get_all_source_columns()
736            )
737
738        table_name = self._unambiguous_columns.get(column_name)
739
740        if not table_name and self._infer_schema:
741            sources_without_schema = tuple(
742                source
743                for source, columns in self._get_all_source_columns().items()
744                if not columns or "*" in columns
745            )
746            if len(sources_without_schema) == 1:
747                table_name = sources_without_schema[0]
748
749        if table_name not in self.scope.selected_sources:
750            return exp.to_identifier(table_name)
751
752        node, _ = self.scope.selected_sources.get(table_name)
753
754        if isinstance(node, exp.Query):
755            while node and node.alias != table_name:
756                node = node.parent
757
758        node_alias = node.args.get("alias")
759        if node_alias:
760            return exp.to_identifier(node_alias.this)
761
762        return exp.to_identifier(table_name)
763
764    @property
765    def all_columns(self) -> t.Set[str]:
766        """All available columns of all sources in this scope"""
767        if self._all_columns is None:
768            self._all_columns = {
769                column for columns in self._get_all_source_columns().values() for column in columns
770            }
771        return self._all_columns
772
773    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
774        """Resolve the source columns for a given source `name`."""
775        if name not in self.scope.sources:
776            raise OptimizeError(f"Unknown table: {name}")
777
778        source = self.scope.sources[name]
779
780        if isinstance(source, exp.Table):
781            columns = self.schema.column_names(source, only_visible)
782        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
783            columns = source.expression.named_selects
784
785            # in bigquery, unnest structs are automatically scoped as tables, so you can
786            # directly select a struct field in a query.
787            # this handles the case where the unnest is statically defined.
788            if self.schema.dialect == "bigquery":
789                if source.expression.is_type(exp.DataType.Type.STRUCT):
790                    for k in source.expression.type.expressions:  # type: ignore
791                        columns.append(k.name)
792        else:
793            columns = source.expression.named_selects
794
795        node, _ = self.scope.selected_sources.get(name) or (None, None)
796        if isinstance(node, Scope):
797            column_aliases = node.expression.alias_column_names
798        elif isinstance(node, exp.Expression):
799            column_aliases = node.alias_column_names
800        else:
801            column_aliases = []
802
803        if column_aliases:
804            # If the source's columns are aliased, their aliases shadow the corresponding column names.
805            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
806            return [
807                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
808            ]
809        return columns
810
811    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
812        if self._source_columns is None:
813            self._source_columns = {
814                source_name: self.get_source_columns(source_name)
815                for source_name, source in itertools.chain(
816                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
817                )
818            }
819        return self._source_columns
820
821    def _get_unambiguous_columns(
822        self, source_columns: t.Dict[str, t.Sequence[str]]
823    ) -> t.Mapping[str, str]:
824        """
825        Find all the unambiguous columns in sources.
826
827        Args:
828            source_columns: Mapping of names to source columns.
829
830        Returns:
831            Mapping of column name to source name.
832        """
833        if not source_columns:
834            return {}
835
836        source_columns_pairs = list(source_columns.items())
837
838        first_table, first_columns = source_columns_pairs[0]
839
840        if len(source_columns_pairs) == 1:
841            # Performance optimization - avoid copying first_columns if there is only one table.
842            return SingleValuedMapping(first_columns, first_table)
843
844        unambiguous_columns = {col: first_table for col in first_columns}
845        all_columns = set(unambiguous_columns)
846
847        for table, columns in source_columns_pairs[1:]:
848            unique = set(columns)
849            ambiguous = all_columns.intersection(unique)
850            all_columns.update(columns)
851
852            for column in ambiguous:
853                unambiguous_columns.pop(column, None)
854            for column in unique.difference(ambiguous):
855                unambiguous_columns[column] = table
856
857        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
716    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
717        self.scope = scope
718        self.schema = schema
719        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
720        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
721        self._all_columns: t.Optional[t.Set[str]] = None
722        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
724    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
725        """
726        Get the table for a column name.
727
728        Args:
729            column_name: The column name to find the table for.
730        Returns:
731            The table name if it can be found/inferred.
732        """
733        if self._unambiguous_columns is None:
734            self._unambiguous_columns = self._get_unambiguous_columns(
735                self._get_all_source_columns()
736            )
737
738        table_name = self._unambiguous_columns.get(column_name)
739
740        if not table_name and self._infer_schema:
741            sources_without_schema = tuple(
742                source
743                for source, columns in self._get_all_source_columns().items()
744                if not columns or "*" in columns
745            )
746            if len(sources_without_schema) == 1:
747                table_name = sources_without_schema[0]
748
749        if table_name not in self.scope.selected_sources:
750            return exp.to_identifier(table_name)
751
752        node, _ = self.scope.selected_sources.get(table_name)
753
754        if isinstance(node, exp.Query):
755            while node and node.alias != table_name:
756                node = node.parent
757
758        node_alias = node.args.get("alias")
759        if node_alias:
760            return exp.to_identifier(node_alias.this)
761
762        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
764    @property
765    def all_columns(self) -> t.Set[str]:
766        """All available columns of all sources in this scope"""
767        if self._all_columns is None:
768            self._all_columns = {
769                column for columns in self._get_all_source_columns().values() for column in columns
770            }
771        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
773    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
774        """Resolve the source columns for a given source `name`."""
775        if name not in self.scope.sources:
776            raise OptimizeError(f"Unknown table: {name}")
777
778        source = self.scope.sources[name]
779
780        if isinstance(source, exp.Table):
781            columns = self.schema.column_names(source, only_visible)
782        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
783            columns = source.expression.named_selects
784
785            # in bigquery, unnest structs are automatically scoped as tables, so you can
786            # directly select a struct field in a query.
787            # this handles the case where the unnest is statically defined.
788            if self.schema.dialect == "bigquery":
789                if source.expression.is_type(exp.DataType.Type.STRUCT):
790                    for k in source.expression.type.expressions:  # type: ignore
791                        columns.append(k.name)
792        else:
793            columns = source.expression.named_selects
794
795        node, _ = self.scope.selected_sources.get(name) or (None, None)
796        if isinstance(node, Scope):
797            column_aliases = node.expression.alias_column_names
798        elif isinstance(node, exp.Expression):
799            column_aliases = node.alias_column_names
800        else:
801            column_aliases = []
802
803        if column_aliases:
804            # If the source's columns are aliased, their aliases shadow the corresponding column names.
805            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
806            return [
807                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
808            ]
809        return columns

Resolve the source columns for a given source name.