Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25) -> exp.Expression:
 26    """
 27    Rewrite sqlglot AST to have fully qualified columns.
 28
 29    Example:
 30        >>> import sqlglot
 31        >>> schema = {"tbl": {"col": "INT"}}
 32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 33        >>> qualify_columns(expression, schema).sql()
 34        'SELECT tbl.col AS col FROM tbl'
 35
 36    Args:
 37        expression: Expression to qualify.
 38        schema: Database schema.
 39        expand_alias_refs: Whether to expand references to aliases.
 40        expand_stars: Whether to expand star queries. This is a necessary step
 41            for most of the optimizer's rules to work; do not set to False unless you
 42            know what you're doing!
 43        infer_schema: Whether to infer the schema if missing.
 44
 45    Returns:
 46        The qualified expression.
 47
 48    Notes:
 49        - Currently only handles a single PIVOT or UNPIVOT operator
 50    """
 51    schema = ensure_schema(schema)
 52    annotator = TypeAnnotator(schema)
 53    infer_schema = schema.empty if infer_schema is None else infer_schema
 54    dialect = Dialect.get_or_raise(schema.dialect)
 55    pseudocolumns = dialect.PSEUDOCOLUMNS
 56
 57    for scope in traverse_scope(expression):
 58        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 59        _pop_table_column_aliases(scope.ctes)
 60        _pop_table_column_aliases(scope.derived_tables)
 61        using_column_tables = _expand_using(scope, resolver)
 62
 63        if schema.empty and expand_alias_refs:
 64            _expand_alias_refs(scope, resolver)
 65
 66        _qualify_columns(scope, resolver)
 67
 68        if not schema.empty and expand_alias_refs:
 69            _expand_alias_refs(scope, resolver)
 70
 71        if not isinstance(scope.expression, exp.UDTF):
 72            if expand_stars:
 73                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 74            qualify_outputs(scope)
 75
 76        _expand_group_by(scope)
 77        _expand_order_by(scope, resolver)
 78
 79        if dialect == "bigquery":
 80            annotator.annotate_scope(scope)
 81
 82    return expression
 83
 84
 85def validate_qualify_columns(expression: E) -> E:
 86    """Raise an `OptimizeError` if any columns aren't qualified"""
 87    all_unqualified_columns = []
 88    for scope in traverse_scope(expression):
 89        if isinstance(scope.expression, exp.Select):
 90            unqualified_columns = scope.unqualified_columns
 91
 92            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 93                column = scope.external_columns[0]
 94                for_table = f" for table: '{column.table}'" if column.table else ""
 95                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 96
 97            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 98                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 99                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
100                # this list here to ensure those in the former category will be excluded.
101                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
102                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
103
104            all_unqualified_columns.extend(unqualified_columns)
105
106    if all_unqualified_columns:
107        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
108
109    return expression
110
111
112def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
113    name_column = []
114    field = unpivot.args.get("field")
115    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
116        name_column.append(field.this)
117
118    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
119    return itertools.chain(name_column, value_columns)
120
121
122def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
123    """
124    Remove table column aliases.
125
126    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
127    """
128    for derived_table in derived_tables:
129        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
130            continue
131        table_alias = derived_table.args.get("alias")
132        if table_alias:
133            table_alias.args.pop("columns", None)
134
135
136def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
137    joins = list(scope.find_all(exp.Join))
138    names = {join.alias_or_name for join in joins}
139    ordered = [key for key in scope.selected_sources if key not in names]
140
141    # Mapping of automatically joined column names to an ordered set of source names (dict).
142    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
143
144    for join in joins:
145        using = join.args.get("using")
146
147        if not using:
148            continue
149
150        join_table = join.alias_or_name
151
152        columns = {}
153
154        for source_name in scope.selected_sources:
155            if source_name in ordered:
156                for column_name in resolver.get_source_columns(source_name):
157                    if column_name not in columns:
158                        columns[column_name] = source_name
159
160        source_table = ordered[-1]
161        ordered.append(join_table)
162        join_columns = resolver.get_source_columns(join_table)
163        conditions = []
164
165        for identifier in using:
166            identifier = identifier.name
167            table = columns.get(identifier)
168
169            if not table or identifier not in join_columns:
170                if (columns and "*" not in columns) and join_columns:
171                    raise OptimizeError(f"Cannot automatically join: {identifier}")
172
173            table = table or source_table
174            conditions.append(
175                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
176            )
177
178            # Set all values in the dict to None, because we only care about the key ordering
179            tables = column_tables.setdefault(identifier, {})
180            if table not in tables:
181                tables[table] = None
182            if join_table not in tables:
183                tables[join_table] = None
184
185        join.args.pop("using")
186        join.set("on", exp.and_(*conditions, copy=False))
187
188    if column_tables:
189        for column in scope.columns:
190            if not column.table and column.name in column_tables:
191                tables = column_tables[column.name]
192                coalesce = [exp.column(column.name, table=table) for table in tables]
193                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
194
195                # Ensure selects keep their output name
196                if isinstance(column.parent, exp.Select):
197                    replacement = alias(replacement, alias=column.name, copy=False)
198
199                scope.replace(column, replacement)
200
201    return column_tables
202
203
204def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
205    expression = scope.expression
206
207    if not isinstance(expression, exp.Select):
208        return
209
210    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
211
212    def replace_columns(
213        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
214    ) -> None:
215        if not node:
216            return
217
218        for column in walk_in_scope(node, prune=lambda node: node.is_star):
219            if not isinstance(column, exp.Column):
220                continue
221
222            table = resolver.get_table(column.name) if resolve_table and not column.table else None
223            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
224            double_agg = (
225                (
226                    alias_expr.find(exp.AggFunc)
227                    and (
228                        column.find_ancestor(exp.AggFunc)
229                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
230                    )
231                )
232                if alias_expr
233                else False
234            )
235
236            if table and (not alias_expr or double_agg):
237                column.set("table", table)
238            elif not column.table and alias_expr and not double_agg:
239                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
240                    if literal_index:
241                        column.replace(exp.Literal.number(i))
242                else:
243                    column = column.replace(exp.paren(alias_expr))
244                    simplified = simplify_parens(column)
245                    if simplified is not column:
246                        column.replace(simplified)
247
248    for i, projection in enumerate(scope.expression.selects):
249        replace_columns(projection)
250
251        if isinstance(projection, exp.Alias):
252            alias_to_expression[projection.alias] = (projection.this, i + 1)
253
254    replace_columns(expression.args.get("where"))
255    replace_columns(expression.args.get("group"), literal_index=True)
256    replace_columns(expression.args.get("having"), resolve_table=True)
257    replace_columns(expression.args.get("qualify"), resolve_table=True)
258
259    scope.clear_cache()
260
261
262def _expand_group_by(scope: Scope) -> None:
263    expression = scope.expression
264    group = expression.args.get("group")
265    if not group:
266        return
267
268    group.set("expressions", _expand_positional_references(scope, group.expressions))
269    expression.set("group", group)
270
271
272def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
273    order = scope.expression.args.get("order")
274    if not order:
275        return
276
277    ordereds = order.expressions
278    for ordered, new_expression in zip(
279        ordereds,
280        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
281    ):
282        for agg in ordered.find_all(exp.AggFunc):
283            for col in agg.find_all(exp.Column):
284                if not col.table:
285                    col.set("table", resolver.get_table(col.name))
286
287        ordered.set("this", new_expression)
288
289    if scope.expression.args.get("group"):
290        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
291
292        for ordered in ordereds:
293            ordered = ordered.this
294
295            ordered.replace(
296                exp.to_identifier(_select_by_pos(scope, ordered).alias)
297                if ordered.is_int
298                else selects.get(ordered, ordered)
299            )
300
301
302def _expand_positional_references(
303    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
304) -> t.List[exp.Expression]:
305    new_nodes: t.List[exp.Expression] = []
306    for node in expressions:
307        if node.is_int:
308            select = _select_by_pos(scope, t.cast(exp.Literal, node))
309
310            if alias:
311                new_nodes.append(exp.column(select.args["alias"].copy()))
312            else:
313                select = select.this
314
315                if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
316                    new_nodes.append(node)
317                else:
318                    new_nodes.append(select.copy())
319        else:
320            new_nodes.append(node)
321
322    return new_nodes
323
324
325def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
326    try:
327        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
328    except IndexError:
329        raise OptimizeError(f"Unknown output column: {node.name}")
330
331
332def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
333    """Disambiguate columns, ensuring each column specifies a source"""
334    for column in scope.columns:
335        column_table = column.table
336        column_name = column.name
337
338        if column_table and column_table in scope.sources:
339            source_columns = resolver.get_source_columns(column_table)
340            if source_columns and column_name not in source_columns and "*" not in source_columns:
341                raise OptimizeError(f"Unknown column: {column_name}")
342
343        if not column_table:
344            if scope.pivots and not column.find_ancestor(exp.Pivot):
345                # If the column is under the Pivot expression, we need to qualify it
346                # using the name of the pivoted source instead of the pivot's alias
347                column.set("table", exp.to_identifier(scope.pivots[0].alias))
348                continue
349
350            column_table = resolver.get_table(column_name)
351
352            # column_table can be a '' because bigquery unnest has no table alias
353            if column_table:
354                column.set("table", column_table)
355        elif column_table not in scope.sources and (
356            not scope.parent
357            or column_table not in scope.parent.sources
358            or not scope.is_correlated_subquery
359        ):
360            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
361            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
362
363            root, *parts = column.parts
364
365            if root.name in scope.sources:
366                # struct is already qualified, but we still need to change the AST representation
367                column_table = root
368                root, *parts = parts
369            else:
370                column_table = resolver.get_table(root.name)
371
372            if column_table:
373                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
374
375    for pivot in scope.pivots:
376        for column in pivot.find_all(exp.Column):
377            if not column.table and column.name in resolver.all_columns:
378                column_table = resolver.get_table(column.name)
379                if column_table:
380                    column.set("table", column_table)
381
382
383def _expand_stars(
384    scope: Scope,
385    resolver: Resolver,
386    using_column_tables: t.Dict[str, t.Any],
387    pseudocolumns: t.Set[str],
388) -> None:
389    """Expand stars to lists of column selections"""
390
391    new_selections = []
392    except_columns: t.Dict[int, t.Set[str]] = {}
393    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
394    coalesced_columns = set()
395
396    pivot_output_columns = None
397    pivot_exclude_columns = None
398
399    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
400    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
401        if pivot.unpivot:
402            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
403
404            field = pivot.args.get("field")
405            if isinstance(field, exp.In):
406                pivot_exclude_columns = {
407                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
408                }
409        else:
410            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
411
412            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
413            if not pivot_output_columns:
414                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
415
416    for expression in scope.expression.selects:
417        if isinstance(expression, exp.Star):
418            tables = list(scope.selected_sources)
419            _add_except_columns(expression, tables, except_columns)
420            _add_replace_columns(expression, tables, replace_columns)
421        elif expression.is_star and not isinstance(expression, exp.Dot):
422            tables = [expression.table]
423            _add_except_columns(expression.this, tables, except_columns)
424            _add_replace_columns(expression.this, tables, replace_columns)
425        else:
426            new_selections.append(expression)
427            continue
428
429        for table in tables:
430            if table not in scope.sources:
431                raise OptimizeError(f"Unknown table: {table}")
432
433            columns = resolver.get_source_columns(table, only_visible=True)
434            columns = columns or scope.outer_columns
435
436            if pseudocolumns:
437                columns = [name for name in columns if name.upper() not in pseudocolumns]
438
439            if not columns or "*" in columns:
440                return
441
442            table_id = id(table)
443            columns_to_exclude = except_columns.get(table_id) or set()
444
445            if pivot:
446                if pivot_output_columns and pivot_exclude_columns:
447                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
448                    pivot_columns.extend(pivot_output_columns)
449                else:
450                    pivot_columns = pivot.alias_column_names
451
452                if pivot_columns:
453                    new_selections.extend(
454                        alias(exp.column(name, table=pivot.alias), name, copy=False)
455                        for name in pivot_columns
456                        if name not in columns_to_exclude
457                    )
458                    continue
459
460            for name in columns:
461                if name in columns_to_exclude or name in coalesced_columns:
462                    continue
463                if name in using_column_tables and table in using_column_tables[name]:
464                    coalesced_columns.add(name)
465                    tables = using_column_tables[name]
466                    coalesce = [exp.column(name, table=table) for table in tables]
467
468                    new_selections.append(
469                        alias(
470                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
471                            alias=name,
472                            copy=False,
473                        )
474                    )
475                else:
476                    alias_ = replace_columns.get(table_id, {}).get(name, name)
477                    column = exp.column(name, table=table)
478                    new_selections.append(
479                        alias(column, alias_, copy=False) if alias_ != name else column
480                    )
481
482    # Ensures we don't overwrite the initial selections with an empty list
483    if new_selections and isinstance(scope.expression, exp.Select):
484        scope.expression.set("expressions", new_selections)
485
486
487def _add_except_columns(
488    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
489) -> None:
490    except_ = expression.args.get("except")
491
492    if not except_:
493        return
494
495    columns = {e.name for e in except_}
496
497    for table in tables:
498        except_columns[id(table)] = columns
499
500
501def _add_replace_columns(
502    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
503) -> None:
504    replace = expression.args.get("replace")
505
506    if not replace:
507        return
508
509    columns = {e.this.name: e.alias for e in replace}
510
511    for table in tables:
512        replace_columns[id(table)] = columns
513
514
515def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
516    """Ensure all output columns are aliased"""
517    if isinstance(scope_or_expression, exp.Expression):
518        scope = build_scope(scope_or_expression)
519        if not isinstance(scope, Scope):
520            return
521    else:
522        scope = scope_or_expression
523
524    new_selections = []
525    for i, (selection, aliased_column) in enumerate(
526        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
527    ):
528        if selection is None:
529            break
530
531        if isinstance(selection, exp.Subquery):
532            if not selection.output_name:
533                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
534        elif not isinstance(selection, exp.Alias) and not selection.is_star:
535            selection = alias(
536                selection,
537                alias=selection.output_name or f"_col_{i}",
538                copy=False,
539            )
540        if aliased_column:
541            selection.set("alias", exp.to_identifier(aliased_column))
542
543        new_selections.append(selection)
544
545    if isinstance(scope.expression, exp.Select):
546        scope.expression.set("expressions", new_selections)
547
548
549def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
550    """Makes sure all identifiers that need to be quoted are quoted."""
551    return expression.transform(
552        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
553    )  # type: ignore
554
555
556def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
557    """
558    Pushes down the CTE alias columns into the projection,
559
560    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
561
562    Example:
563        >>> import sqlglot
564        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
565        >>> pushdown_cte_alias_columns(expression).sql()
566        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
567
568    Args:
569        expression: Expression to pushdown.
570
571    Returns:
572        The expression with the CTE aliases pushed down into the projection.
573    """
574    for cte in expression.find_all(exp.CTE):
575        if cte.alias_column_names:
576            new_expressions = []
577            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
578                if isinstance(projection, exp.Alias):
579                    projection.set("alias", _alias)
580                else:
581                    projection = alias(projection, alias=_alias)
582                new_expressions.append(projection)
583            cte.this.set("expressions", new_expressions)
584
585    return expression
586
587
588class Resolver:
589    """
590    Helper for resolving columns.
591
592    This is a class so we can lazily load some things and easily share them across functions.
593    """
594
595    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
596        self.scope = scope
597        self.schema = schema
598        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
599        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
600        self._all_columns: t.Optional[t.Set[str]] = None
601        self._infer_schema = infer_schema
602
603    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
604        """
605        Get the table for a column name.
606
607        Args:
608            column_name: The column name to find the table for.
609        Returns:
610            The table name if it can be found/inferred.
611        """
612        if self._unambiguous_columns is None:
613            self._unambiguous_columns = self._get_unambiguous_columns(
614                self._get_all_source_columns()
615            )
616
617        table_name = self._unambiguous_columns.get(column_name)
618
619        if not table_name and self._infer_schema:
620            sources_without_schema = tuple(
621                source
622                for source, columns in self._get_all_source_columns().items()
623                if not columns or "*" in columns
624            )
625            if len(sources_without_schema) == 1:
626                table_name = sources_without_schema[0]
627
628        if table_name not in self.scope.selected_sources:
629            return exp.to_identifier(table_name)
630
631        node, _ = self.scope.selected_sources.get(table_name)
632
633        if isinstance(node, exp.Query):
634            while node and node.alias != table_name:
635                node = node.parent
636
637        node_alias = node.args.get("alias")
638        if node_alias:
639            return exp.to_identifier(node_alias.this)
640
641        return exp.to_identifier(table_name)
642
643    @property
644    def all_columns(self) -> t.Set[str]:
645        """All available columns of all sources in this scope"""
646        if self._all_columns is None:
647            self._all_columns = {
648                column for columns in self._get_all_source_columns().values() for column in columns
649            }
650        return self._all_columns
651
652    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
653        """Resolve the source columns for a given source `name`."""
654        if name not in self.scope.sources:
655            raise OptimizeError(f"Unknown table: {name}")
656
657        source = self.scope.sources[name]
658
659        if isinstance(source, exp.Table):
660            columns = self.schema.column_names(source, only_visible)
661        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
662            columns = source.expression.named_selects
663
664            # in bigquery, unnest structs are automatically scoped as tables, so you can
665            # directly select a struct field in a query.
666            # this handles the case where the unnest is statically defined.
667            if self.schema.dialect == "bigquery":
668                if source.expression.is_type(exp.DataType.Type.STRUCT):
669                    for k in source.expression.type.expressions:  # type: ignore
670                        columns.append(k.name)
671        else:
672            columns = source.expression.named_selects
673
674        node, _ = self.scope.selected_sources.get(name) or (None, None)
675        if isinstance(node, Scope):
676            column_aliases = node.expression.alias_column_names
677        elif isinstance(node, exp.Expression):
678            column_aliases = node.alias_column_names
679        else:
680            column_aliases = []
681
682        if column_aliases:
683            # If the source's columns are aliased, their aliases shadow the corresponding column names.
684            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
685            return [
686                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
687            ]
688        return columns
689
690    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
691        if self._source_columns is None:
692            self._source_columns = {
693                source_name: self.get_source_columns(source_name)
694                for source_name, source in itertools.chain(
695                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
696                )
697            }
698        return self._source_columns
699
700    def _get_unambiguous_columns(
701        self, source_columns: t.Dict[str, t.Sequence[str]]
702    ) -> t.Mapping[str, str]:
703        """
704        Find all the unambiguous columns in sources.
705
706        Args:
707            source_columns: Mapping of names to source columns.
708
709        Returns:
710            Mapping of column name to source name.
711        """
712        if not source_columns:
713            return {}
714
715        source_columns_pairs = list(source_columns.items())
716
717        first_table, first_columns = source_columns_pairs[0]
718
719        if len(source_columns_pairs) == 1:
720            # Performance optimization - avoid copying first_columns if there is only one table.
721            return SingleValuedMapping(first_columns, first_table)
722
723        unambiguous_columns = {col: first_table for col in first_columns}
724        all_columns = set(unambiguous_columns)
725
726        for table, columns in source_columns_pairs[1:]:
727            unique = set(columns)
728            ambiguous = all_columns.intersection(unique)
729            all_columns.update(columns)
730
731            for column in ambiguous:
732                unambiguous_columns.pop(column, None)
733            for column in unique.difference(ambiguous):
734                unambiguous_columns[column] = table
735
736        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26) -> exp.Expression:
27    """
28    Rewrite sqlglot AST to have fully qualified columns.
29
30    Example:
31        >>> import sqlglot
32        >>> schema = {"tbl": {"col": "INT"}}
33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
34        >>> qualify_columns(expression, schema).sql()
35        'SELECT tbl.col AS col FROM tbl'
36
37    Args:
38        expression: Expression to qualify.
39        schema: Database schema.
40        expand_alias_refs: Whether to expand references to aliases.
41        expand_stars: Whether to expand star queries. This is a necessary step
42            for most of the optimizer's rules to work; do not set to False unless you
43            know what you're doing!
44        infer_schema: Whether to infer the schema if missing.
45
46    Returns:
47        The qualified expression.
48
49    Notes:
50        - Currently only handles a single PIVOT or UNPIVOT operator
51    """
52    schema = ensure_schema(schema)
53    annotator = TypeAnnotator(schema)
54    infer_schema = schema.empty if infer_schema is None else infer_schema
55    dialect = Dialect.get_or_raise(schema.dialect)
56    pseudocolumns = dialect.PSEUDOCOLUMNS
57
58    for scope in traverse_scope(expression):
59        resolver = Resolver(scope, schema, infer_schema=infer_schema)
60        _pop_table_column_aliases(scope.ctes)
61        _pop_table_column_aliases(scope.derived_tables)
62        using_column_tables = _expand_using(scope, resolver)
63
64        if schema.empty and expand_alias_refs:
65            _expand_alias_refs(scope, resolver)
66
67        _qualify_columns(scope, resolver)
68
69        if not schema.empty and expand_alias_refs:
70            _expand_alias_refs(scope, resolver)
71
72        if not isinstance(scope.expression, exp.UDTF):
73            if expand_stars:
74                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
75            qualify_outputs(scope)
76
77        _expand_group_by(scope)
78        _expand_order_by(scope, resolver)
79
80        if dialect == "bigquery":
81            annotator.annotate_scope(scope)
82
83    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 86def validate_qualify_columns(expression: E) -> E:
 87    """Raise an `OptimizeError` if any columns aren't qualified"""
 88    all_unqualified_columns = []
 89    for scope in traverse_scope(expression):
 90        if isinstance(scope.expression, exp.Select):
 91            unqualified_columns = scope.unqualified_columns
 92
 93            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 94                column = scope.external_columns[0]
 95                for_table = f" for table: '{column.table}'" if column.table else ""
 96                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 97
 98            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 99                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
100                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
101                # this list here to ensure those in the former category will be excluded.
102                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
103                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
104
105            all_unqualified_columns.extend(unqualified_columns)
106
107    if all_unqualified_columns:
108        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
109
110    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
516def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
517    """Ensure all output columns are aliased"""
518    if isinstance(scope_or_expression, exp.Expression):
519        scope = build_scope(scope_or_expression)
520        if not isinstance(scope, Scope):
521            return
522    else:
523        scope = scope_or_expression
524
525    new_selections = []
526    for i, (selection, aliased_column) in enumerate(
527        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
528    ):
529        if selection is None:
530            break
531
532        if isinstance(selection, exp.Subquery):
533            if not selection.output_name:
534                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
535        elif not isinstance(selection, exp.Alias) and not selection.is_star:
536            selection = alias(
537                selection,
538                alias=selection.output_name or f"_col_{i}",
539                copy=False,
540            )
541        if aliased_column:
542            selection.set("alias", exp.to_identifier(aliased_column))
543
544        new_selections.append(selection)
545
546    if isinstance(scope.expression, exp.Select):
547        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
550def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
551    """Makes sure all identifiers that need to be quoted are quoted."""
552    return expression.transform(
553        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
554    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
557def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
558    """
559    Pushes down the CTE alias columns into the projection,
560
561    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
562
563    Example:
564        >>> import sqlglot
565        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
566        >>> pushdown_cte_alias_columns(expression).sql()
567        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
568
569    Args:
570        expression: Expression to pushdown.
571
572    Returns:
573        The expression with the CTE aliases pushed down into the projection.
574    """
575    for cte in expression.find_all(exp.CTE):
576        if cte.alias_column_names:
577            new_expressions = []
578            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
579                if isinstance(projection, exp.Alias):
580                    projection.set("alias", _alias)
581                else:
582                    projection = alias(projection, alias=_alias)
583                new_expressions.append(projection)
584            cte.this.set("expressions", new_expressions)
585
586    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
589class Resolver:
590    """
591    Helper for resolving columns.
592
593    This is a class so we can lazily load some things and easily share them across functions.
594    """
595
596    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
597        self.scope = scope
598        self.schema = schema
599        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
600        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
601        self._all_columns: t.Optional[t.Set[str]] = None
602        self._infer_schema = infer_schema
603
604    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
605        """
606        Get the table for a column name.
607
608        Args:
609            column_name: The column name to find the table for.
610        Returns:
611            The table name if it can be found/inferred.
612        """
613        if self._unambiguous_columns is None:
614            self._unambiguous_columns = self._get_unambiguous_columns(
615                self._get_all_source_columns()
616            )
617
618        table_name = self._unambiguous_columns.get(column_name)
619
620        if not table_name and self._infer_schema:
621            sources_without_schema = tuple(
622                source
623                for source, columns in self._get_all_source_columns().items()
624                if not columns or "*" in columns
625            )
626            if len(sources_without_schema) == 1:
627                table_name = sources_without_schema[0]
628
629        if table_name not in self.scope.selected_sources:
630            return exp.to_identifier(table_name)
631
632        node, _ = self.scope.selected_sources.get(table_name)
633
634        if isinstance(node, exp.Query):
635            while node and node.alias != table_name:
636                node = node.parent
637
638        node_alias = node.args.get("alias")
639        if node_alias:
640            return exp.to_identifier(node_alias.this)
641
642        return exp.to_identifier(table_name)
643
644    @property
645    def all_columns(self) -> t.Set[str]:
646        """All available columns of all sources in this scope"""
647        if self._all_columns is None:
648            self._all_columns = {
649                column for columns in self._get_all_source_columns().values() for column in columns
650            }
651        return self._all_columns
652
653    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
654        """Resolve the source columns for a given source `name`."""
655        if name not in self.scope.sources:
656            raise OptimizeError(f"Unknown table: {name}")
657
658        source = self.scope.sources[name]
659
660        if isinstance(source, exp.Table):
661            columns = self.schema.column_names(source, only_visible)
662        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
663            columns = source.expression.named_selects
664
665            # in bigquery, unnest structs are automatically scoped as tables, so you can
666            # directly select a struct field in a query.
667            # this handles the case where the unnest is statically defined.
668            if self.schema.dialect == "bigquery":
669                if source.expression.is_type(exp.DataType.Type.STRUCT):
670                    for k in source.expression.type.expressions:  # type: ignore
671                        columns.append(k.name)
672        else:
673            columns = source.expression.named_selects
674
675        node, _ = self.scope.selected_sources.get(name) or (None, None)
676        if isinstance(node, Scope):
677            column_aliases = node.expression.alias_column_names
678        elif isinstance(node, exp.Expression):
679            column_aliases = node.alias_column_names
680        else:
681            column_aliases = []
682
683        if column_aliases:
684            # If the source's columns are aliased, their aliases shadow the corresponding column names.
685            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
686            return [
687                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
688            ]
689        return columns
690
691    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
692        if self._source_columns is None:
693            self._source_columns = {
694                source_name: self.get_source_columns(source_name)
695                for source_name, source in itertools.chain(
696                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
697                )
698            }
699        return self._source_columns
700
701    def _get_unambiguous_columns(
702        self, source_columns: t.Dict[str, t.Sequence[str]]
703    ) -> t.Mapping[str, str]:
704        """
705        Find all the unambiguous columns in sources.
706
707        Args:
708            source_columns: Mapping of names to source columns.
709
710        Returns:
711            Mapping of column name to source name.
712        """
713        if not source_columns:
714            return {}
715
716        source_columns_pairs = list(source_columns.items())
717
718        first_table, first_columns = source_columns_pairs[0]
719
720        if len(source_columns_pairs) == 1:
721            # Performance optimization - avoid copying first_columns if there is only one table.
722            return SingleValuedMapping(first_columns, first_table)
723
724        unambiguous_columns = {col: first_table for col in first_columns}
725        all_columns = set(unambiguous_columns)
726
727        for table, columns in source_columns_pairs[1:]:
728            unique = set(columns)
729            ambiguous = all_columns.intersection(unique)
730            all_columns.update(columns)
731
732            for column in ambiguous:
733                unambiguous_columns.pop(column, None)
734            for column in unique.difference(ambiguous):
735                unambiguous_columns[column] = table
736
737        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
596    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
597        self.scope = scope
598        self.schema = schema
599        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
600        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
601        self._all_columns: t.Optional[t.Set[str]] = None
602        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
604    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
605        """
606        Get the table for a column name.
607
608        Args:
609            column_name: The column name to find the table for.
610        Returns:
611            The table name if it can be found/inferred.
612        """
613        if self._unambiguous_columns is None:
614            self._unambiguous_columns = self._get_unambiguous_columns(
615                self._get_all_source_columns()
616            )
617
618        table_name = self._unambiguous_columns.get(column_name)
619
620        if not table_name and self._infer_schema:
621            sources_without_schema = tuple(
622                source
623                for source, columns in self._get_all_source_columns().items()
624                if not columns or "*" in columns
625            )
626            if len(sources_without_schema) == 1:
627                table_name = sources_without_schema[0]
628
629        if table_name not in self.scope.selected_sources:
630            return exp.to_identifier(table_name)
631
632        node, _ = self.scope.selected_sources.get(table_name)
633
634        if isinstance(node, exp.Query):
635            while node and node.alias != table_name:
636                node = node.parent
637
638        node_alias = node.args.get("alias")
639        if node_alias:
640            return exp.to_identifier(node_alias.this)
641
642        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
644    @property
645    def all_columns(self) -> t.Set[str]:
646        """All available columns of all sources in this scope"""
647        if self._all_columns is None:
648            self._all_columns = {
649                column for columns in self._get_all_source_columns().values() for column in columns
650            }
651        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
653    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
654        """Resolve the source columns for a given source `name`."""
655        if name not in self.scope.sources:
656            raise OptimizeError(f"Unknown table: {name}")
657
658        source = self.scope.sources[name]
659
660        if isinstance(source, exp.Table):
661            columns = self.schema.column_names(source, only_visible)
662        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
663            columns = source.expression.named_selects
664
665            # in bigquery, unnest structs are automatically scoped as tables, so you can
666            # directly select a struct field in a query.
667            # this handles the case where the unnest is statically defined.
668            if self.schema.dialect == "bigquery":
669                if source.expression.is_type(exp.DataType.Type.STRUCT):
670                    for k in source.expression.type.expressions:  # type: ignore
671                        columns.append(k.name)
672        else:
673            columns = source.expression.named_selects
674
675        node, _ = self.scope.selected_sources.get(name) or (None, None)
676        if isinstance(node, Scope):
677            column_aliases = node.expression.alias_column_names
678        elif isinstance(node, exp.Expression):
679            column_aliases = node.alias_column_names
680        else:
681            column_aliases = []
682
683        if column_aliases:
684            # If the source's columns are aliased, their aliases shadow the corresponding column names.
685            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
686            return [
687                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
688            ]
689        return columns

Resolve the source columns for a given source name.