Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import annotate_types
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25) -> exp.Expression:
 26    """
 27    Rewrite sqlglot AST to have fully qualified columns.
 28
 29    Example:
 30        >>> import sqlglot
 31        >>> schema = {"tbl": {"col": "INT"}}
 32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 33        >>> qualify_columns(expression, schema).sql()
 34        'SELECT tbl.col AS col FROM tbl'
 35
 36    Args:
 37        expression: Expression to qualify.
 38        schema: Database schema.
 39        expand_alias_refs: Whether to expand references to aliases.
 40        expand_stars: Whether to expand star queries. This is a necessary step
 41            for most of the optimizer's rules to work; do not set to False unless you
 42            know what you're doing!
 43        infer_schema: Whether to infer the schema if missing.
 44
 45    Returns:
 46        The qualified expression.
 47
 48    Notes:
 49        - Currently only handles a single PIVOT or UNPIVOT operator
 50    """
 51    schema = ensure_schema(schema)
 52    infer_schema = schema.empty if infer_schema is None else infer_schema
 53    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 54
 55    for scope in traverse_scope(expression):
 56        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 57        _pop_table_column_aliases(scope.ctes)
 58        _pop_table_column_aliases(scope.derived_tables)
 59        using_column_tables = _expand_using(scope, resolver)
 60
 61        if schema.empty and expand_alias_refs:
 62            _expand_alias_refs(scope, resolver)
 63
 64        _qualify_columns(scope, resolver)
 65
 66        if not schema.empty and expand_alias_refs:
 67            _expand_alias_refs(scope, resolver)
 68
 69        if not isinstance(scope.expression, exp.UDTF):
 70            if expand_stars:
 71                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 72            qualify_outputs(scope)
 73
 74        _expand_group_by(scope)
 75        _expand_order_by(scope, resolver)
 76
 77    return expression
 78
 79
 80def validate_qualify_columns(expression: E) -> E:
 81    """Raise an `OptimizeError` if any columns aren't qualified"""
 82    all_unqualified_columns = []
 83    for scope in traverse_scope(expression):
 84        if isinstance(scope.expression, exp.Select):
 85            unqualified_columns = scope.unqualified_columns
 86
 87            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 88                column = scope.external_columns[0]
 89                for_table = f" for table: '{column.table}'" if column.table else ""
 90                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 91
 92            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 93                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 94                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 95                # this list here to ensure those in the former category will be excluded.
 96                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 97                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 98
 99            all_unqualified_columns.extend(unqualified_columns)
100
101    if all_unqualified_columns:
102        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
103
104    return expression
105
106
107def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
108    name_column = []
109    field = unpivot.args.get("field")
110    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
111        name_column.append(field.this)
112
113    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
114    return itertools.chain(name_column, value_columns)
115
116
117def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
118    """
119    Remove table column aliases.
120
121    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
122    """
123    for derived_table in derived_tables:
124        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
125            continue
126        table_alias = derived_table.args.get("alias")
127        if table_alias:
128            table_alias.args.pop("columns", None)
129
130
131def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
132    joins = list(scope.find_all(exp.Join))
133    names = {join.alias_or_name for join in joins}
134    ordered = [key for key in scope.selected_sources if key not in names]
135
136    # Mapping of automatically joined column names to an ordered set of source names (dict).
137    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
138
139    for join in joins:
140        using = join.args.get("using")
141
142        if not using:
143            continue
144
145        join_table = join.alias_or_name
146
147        columns = {}
148
149        for source_name in scope.selected_sources:
150            if source_name in ordered:
151                for column_name in resolver.get_source_columns(source_name):
152                    if column_name not in columns:
153                        columns[column_name] = source_name
154
155        source_table = ordered[-1]
156        ordered.append(join_table)
157        join_columns = resolver.get_source_columns(join_table)
158        conditions = []
159
160        for identifier in using:
161            identifier = identifier.name
162            table = columns.get(identifier)
163
164            if not table or identifier not in join_columns:
165                if (columns and "*" not in columns) and join_columns:
166                    raise OptimizeError(f"Cannot automatically join: {identifier}")
167
168            table = table or source_table
169            conditions.append(
170                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
171            )
172
173            # Set all values in the dict to None, because we only care about the key ordering
174            tables = column_tables.setdefault(identifier, {})
175            if table not in tables:
176                tables[table] = None
177            if join_table not in tables:
178                tables[join_table] = None
179
180        join.args.pop("using")
181        join.set("on", exp.and_(*conditions, copy=False))
182
183    if column_tables:
184        for column in scope.columns:
185            if not column.table and column.name in column_tables:
186                tables = column_tables[column.name]
187                coalesce = [exp.column(column.name, table=table) for table in tables]
188                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
189
190                # Ensure selects keep their output name
191                if isinstance(column.parent, exp.Select):
192                    replacement = alias(replacement, alias=column.name, copy=False)
193
194                scope.replace(column, replacement)
195
196    return column_tables
197
198
199def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
200    expression = scope.expression
201
202    if not isinstance(expression, exp.Select):
203        return
204
205    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
206
207    def replace_columns(
208        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
209    ) -> None:
210        if not node:
211            return
212
213        for column in walk_in_scope(node, prune=lambda node: node.is_star):
214            if not isinstance(column, exp.Column):
215                continue
216
217            table = resolver.get_table(column.name) if resolve_table and not column.table else None
218            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
219            double_agg = (
220                (
221                    alias_expr.find(exp.AggFunc)
222                    and (
223                        column.find_ancestor(exp.AggFunc)
224                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
225                    )
226                )
227                if alias_expr
228                else False
229            )
230
231            if table and (not alias_expr or double_agg):
232                column.set("table", table)
233            elif not column.table and alias_expr and not double_agg:
234                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
235                    if literal_index:
236                        column.replace(exp.Literal.number(i))
237                else:
238                    column = column.replace(exp.paren(alias_expr))
239                    simplified = simplify_parens(column)
240                    if simplified is not column:
241                        column.replace(simplified)
242
243    for i, projection in enumerate(scope.expression.selects):
244        replace_columns(projection)
245
246        if isinstance(projection, exp.Alias):
247            alias_to_expression[projection.alias] = (projection.this, i + 1)
248
249    replace_columns(expression.args.get("where"))
250    replace_columns(expression.args.get("group"), literal_index=True)
251    replace_columns(expression.args.get("having"), resolve_table=True)
252    replace_columns(expression.args.get("qualify"), resolve_table=True)
253
254    scope.clear_cache()
255
256
257def _expand_group_by(scope: Scope) -> None:
258    expression = scope.expression
259    group = expression.args.get("group")
260    if not group:
261        return
262
263    group.set("expressions", _expand_positional_references(scope, group.expressions))
264    expression.set("group", group)
265
266
267def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
268    order = scope.expression.args.get("order")
269    if not order:
270        return
271
272    ordereds = order.expressions
273    for ordered, new_expression in zip(
274        ordereds,
275        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
276    ):
277        for agg in ordered.find_all(exp.AggFunc):
278            for col in agg.find_all(exp.Column):
279                if not col.table:
280                    col.set("table", resolver.get_table(col.name))
281
282        ordered.set("this", new_expression)
283
284    if scope.expression.args.get("group"):
285        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
286
287        for ordered in ordereds:
288            ordered = ordered.this
289
290            ordered.replace(
291                exp.to_identifier(_select_by_pos(scope, ordered).alias)
292                if ordered.is_int
293                else selects.get(ordered, ordered)
294            )
295
296
297def _expand_positional_references(
298    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
299) -> t.List[exp.Expression]:
300    new_nodes: t.List[exp.Expression] = []
301    for node in expressions:
302        if node.is_int:
303            select = _select_by_pos(scope, t.cast(exp.Literal, node))
304
305            if alias:
306                new_nodes.append(exp.column(select.args["alias"].copy()))
307            else:
308                select = select.this
309
310                if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
311                    new_nodes.append(node)
312                else:
313                    new_nodes.append(select.copy())
314        else:
315            new_nodes.append(node)
316
317    return new_nodes
318
319
320def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
321    try:
322        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
323    except IndexError:
324        raise OptimizeError(f"Unknown output column: {node.name}")
325
326
327def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
328    """Disambiguate columns, ensuring each column specifies a source"""
329    for column in scope.columns:
330        column_table = column.table
331        column_name = column.name
332
333        if column_table and column_table in scope.sources:
334            source_columns = resolver.get_source_columns(column_table)
335            if source_columns and column_name not in source_columns and "*" not in source_columns:
336                raise OptimizeError(f"Unknown column: {column_name}")
337
338        if not column_table:
339            if scope.pivots and not column.find_ancestor(exp.Pivot):
340                # If the column is under the Pivot expression, we need to qualify it
341                # using the name of the pivoted source instead of the pivot's alias
342                column.set("table", exp.to_identifier(scope.pivots[0].alias))
343                continue
344
345            column_table = resolver.get_table(column_name)
346
347            # column_table can be a '' because bigquery unnest has no table alias
348            if column_table:
349                column.set("table", column_table)
350        elif column_table not in scope.sources and (
351            not scope.parent
352            or column_table not in scope.parent.sources
353            or not scope.is_correlated_subquery
354        ):
355            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
356            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
357
358            root, *parts = column.parts
359
360            if root.name in scope.sources:
361                # struct is already qualified, but we still need to change the AST representation
362                column_table = root
363                root, *parts = parts
364            else:
365                column_table = resolver.get_table(root.name)
366
367            if column_table:
368                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
369
370    for pivot in scope.pivots:
371        for column in pivot.find_all(exp.Column):
372            if not column.table and column.name in resolver.all_columns:
373                column_table = resolver.get_table(column.name)
374                if column_table:
375                    column.set("table", column_table)
376
377
378def _expand_stars(
379    scope: Scope,
380    resolver: Resolver,
381    using_column_tables: t.Dict[str, t.Any],
382    pseudocolumns: t.Set[str],
383) -> None:
384    """Expand stars to lists of column selections"""
385
386    new_selections = []
387    except_columns: t.Dict[int, t.Set[str]] = {}
388    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
389    coalesced_columns = set()
390
391    pivot_output_columns = None
392    pivot_exclude_columns = None
393
394    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
395    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
396        if pivot.unpivot:
397            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
398
399            field = pivot.args.get("field")
400            if isinstance(field, exp.In):
401                pivot_exclude_columns = {
402                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
403                }
404        else:
405            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
406
407            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
408            if not pivot_output_columns:
409                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
410
411    for expression in scope.expression.selects:
412        if isinstance(expression, exp.Star):
413            tables = list(scope.selected_sources)
414            _add_except_columns(expression, tables, except_columns)
415            _add_replace_columns(expression, tables, replace_columns)
416        elif expression.is_star and not isinstance(expression, exp.Dot):
417            tables = [expression.table]
418            _add_except_columns(expression.this, tables, except_columns)
419            _add_replace_columns(expression.this, tables, replace_columns)
420        else:
421            new_selections.append(expression)
422            continue
423
424        for table in tables:
425            if table not in scope.sources:
426                raise OptimizeError(f"Unknown table: {table}")
427
428            columns = resolver.get_source_columns(table, only_visible=True)
429            columns = columns or scope.outer_columns
430
431            if pseudocolumns:
432                columns = [name for name in columns if name.upper() not in pseudocolumns]
433
434            if not columns or "*" in columns:
435                return
436
437            table_id = id(table)
438            columns_to_exclude = except_columns.get(table_id) or set()
439
440            if pivot:
441                if pivot_output_columns and pivot_exclude_columns:
442                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
443                    pivot_columns.extend(pivot_output_columns)
444                else:
445                    pivot_columns = pivot.alias_column_names
446
447                if pivot_columns:
448                    new_selections.extend(
449                        alias(exp.column(name, table=pivot.alias), name, copy=False)
450                        for name in pivot_columns
451                        if name not in columns_to_exclude
452                    )
453                    continue
454
455            for name in columns:
456                if name in columns_to_exclude or name in coalesced_columns:
457                    continue
458                if name in using_column_tables and table in using_column_tables[name]:
459                    coalesced_columns.add(name)
460                    tables = using_column_tables[name]
461                    coalesce = [exp.column(name, table=table) for table in tables]
462
463                    new_selections.append(
464                        alias(
465                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
466                            alias=name,
467                            copy=False,
468                        )
469                    )
470                else:
471                    alias_ = replace_columns.get(table_id, {}).get(name, name)
472                    column = exp.column(name, table=table)
473                    new_selections.append(
474                        alias(column, alias_, copy=False) if alias_ != name else column
475                    )
476
477    # Ensures we don't overwrite the initial selections with an empty list
478    if new_selections and isinstance(scope.expression, exp.Select):
479        scope.expression.set("expressions", new_selections)
480
481
482def _add_except_columns(
483    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
484) -> None:
485    except_ = expression.args.get("except")
486
487    if not except_:
488        return
489
490    columns = {e.name for e in except_}
491
492    for table in tables:
493        except_columns[id(table)] = columns
494
495
496def _add_replace_columns(
497    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
498) -> None:
499    replace = expression.args.get("replace")
500
501    if not replace:
502        return
503
504    columns = {e.this.name: e.alias for e in replace}
505
506    for table in tables:
507        replace_columns[id(table)] = columns
508
509
510def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
511    """Ensure all output columns are aliased"""
512    if isinstance(scope_or_expression, exp.Expression):
513        scope = build_scope(scope_or_expression)
514        if not isinstance(scope, Scope):
515            return
516    else:
517        scope = scope_or_expression
518
519    new_selections = []
520    for i, (selection, aliased_column) in enumerate(
521        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
522    ):
523        if selection is None:
524            break
525
526        if isinstance(selection, exp.Subquery):
527            if not selection.output_name:
528                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
529        elif not isinstance(selection, exp.Alias) and not selection.is_star:
530            selection = alias(
531                selection,
532                alias=selection.output_name or f"_col_{i}",
533                copy=False,
534            )
535        if aliased_column:
536            selection.set("alias", exp.to_identifier(aliased_column))
537
538        new_selections.append(selection)
539
540    if isinstance(scope.expression, exp.Select):
541        scope.expression.set("expressions", new_selections)
542
543
544def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
545    """Makes sure all identifiers that need to be quoted are quoted."""
546    return expression.transform(
547        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
548    )  # type: ignore
549
550
551def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
552    """
553    Pushes down the CTE alias columns into the projection,
554
555    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
556
557    Example:
558        >>> import sqlglot
559        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
560        >>> pushdown_cte_alias_columns(expression).sql()
561        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
562
563    Args:
564        expression: Expression to pushdown.
565
566    Returns:
567        The expression with the CTE aliases pushed down into the projection.
568    """
569    for cte in expression.find_all(exp.CTE):
570        if cte.alias_column_names:
571            new_expressions = []
572            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
573                if isinstance(projection, exp.Alias):
574                    projection.set("alias", _alias)
575                else:
576                    projection = alias(projection, alias=_alias)
577                new_expressions.append(projection)
578            cte.this.set("expressions", new_expressions)
579
580    return expression
581
582
583class Resolver:
584    """
585    Helper for resolving columns.
586
587    This is a class so we can lazily load some things and easily share them across functions.
588    """
589
590    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
591        self.scope = scope
592        self.schema = schema
593        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
594        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
595        self._all_columns: t.Optional[t.Set[str]] = None
596        self._infer_schema = infer_schema
597
598    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
599        """
600        Get the table for a column name.
601
602        Args:
603            column_name: The column name to find the table for.
604        Returns:
605            The table name if it can be found/inferred.
606        """
607        if self._unambiguous_columns is None:
608            self._unambiguous_columns = self._get_unambiguous_columns(
609                self._get_all_source_columns()
610            )
611
612        table_name = self._unambiguous_columns.get(column_name)
613
614        if not table_name and self._infer_schema:
615            sources_without_schema = tuple(
616                source
617                for source, columns in self._get_all_source_columns().items()
618                if not columns or "*" in columns
619            )
620            if len(sources_without_schema) == 1:
621                table_name = sources_without_schema[0]
622
623        if table_name not in self.scope.selected_sources:
624            return exp.to_identifier(table_name)
625
626        node, _ = self.scope.selected_sources.get(table_name)
627
628        if isinstance(node, exp.Query):
629            while node and node.alias != table_name:
630                node = node.parent
631
632        node_alias = node.args.get("alias")
633        if node_alias:
634            return exp.to_identifier(node_alias.this)
635
636        return exp.to_identifier(table_name)
637
638    @property
639    def all_columns(self) -> t.Set[str]:
640        """All available columns of all sources in this scope"""
641        if self._all_columns is None:
642            self._all_columns = {
643                column for columns in self._get_all_source_columns().values() for column in columns
644            }
645        return self._all_columns
646
647    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
648        """Resolve the source columns for a given source `name`."""
649        if name not in self.scope.sources:
650            raise OptimizeError(f"Unknown table: {name}")
651
652        source = self.scope.sources[name]
653
654        if isinstance(source, exp.Table):
655            columns = self.schema.column_names(source, only_visible)
656        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
657            columns = source.expression.named_selects
658
659            # in bigquery, unnest structs are automatically scoped as tables, so you can
660            # directly select a struct field in a query.
661            # this handles the case where the unnest is statically defined.
662            if self.schema.dialect == "bigquery":
663                expression = source.expression
664                annotate_types(expression)
665
666                if expression.is_type(exp.DataType.Type.STRUCT):
667                    for k in expression.type.expressions:  # type: ignore
668                        columns.append(k.name)
669        else:
670            columns = source.expression.named_selects
671
672        node, _ = self.scope.selected_sources.get(name) or (None, None)
673        if isinstance(node, Scope):
674            column_aliases = node.expression.alias_column_names
675        elif isinstance(node, exp.Expression):
676            column_aliases = node.alias_column_names
677        else:
678            column_aliases = []
679
680        if column_aliases:
681            # If the source's columns are aliased, their aliases shadow the corresponding column names.
682            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
683            return [
684                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
685            ]
686        return columns
687
688    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
689        if self._source_columns is None:
690            self._source_columns = {
691                source_name: self.get_source_columns(source_name)
692                for source_name, source in itertools.chain(
693                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
694                )
695            }
696        return self._source_columns
697
698    def _get_unambiguous_columns(
699        self, source_columns: t.Dict[str, t.Sequence[str]]
700    ) -> t.Mapping[str, str]:
701        """
702        Find all the unambiguous columns in sources.
703
704        Args:
705            source_columns: Mapping of names to source columns.
706
707        Returns:
708            Mapping of column name to source name.
709        """
710        if not source_columns:
711            return {}
712
713        source_columns_pairs = list(source_columns.items())
714
715        first_table, first_columns = source_columns_pairs[0]
716
717        if len(source_columns_pairs) == 1:
718            # Performance optimization - avoid copying first_columns if there is only one table.
719            return SingleValuedMapping(first_columns, first_table)
720
721        unambiguous_columns = {col: first_table for col in first_columns}
722        all_columns = set(unambiguous_columns)
723
724        for table, columns in source_columns_pairs[1:]:
725            unique = set(columns)
726            ambiguous = all_columns.intersection(unique)
727            all_columns.update(columns)
728
729            for column in ambiguous:
730                unambiguous_columns.pop(column, None)
731            for column in unique.difference(ambiguous):
732                unambiguous_columns[column] = table
733
734        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26) -> exp.Expression:
27    """
28    Rewrite sqlglot AST to have fully qualified columns.
29
30    Example:
31        >>> import sqlglot
32        >>> schema = {"tbl": {"col": "INT"}}
33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
34        >>> qualify_columns(expression, schema).sql()
35        'SELECT tbl.col AS col FROM tbl'
36
37    Args:
38        expression: Expression to qualify.
39        schema: Database schema.
40        expand_alias_refs: Whether to expand references to aliases.
41        expand_stars: Whether to expand star queries. This is a necessary step
42            for most of the optimizer's rules to work; do not set to False unless you
43            know what you're doing!
44        infer_schema: Whether to infer the schema if missing.
45
46    Returns:
47        The qualified expression.
48
49    Notes:
50        - Currently only handles a single PIVOT or UNPIVOT operator
51    """
52    schema = ensure_schema(schema)
53    infer_schema = schema.empty if infer_schema is None else infer_schema
54    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
55
56    for scope in traverse_scope(expression):
57        resolver = Resolver(scope, schema, infer_schema=infer_schema)
58        _pop_table_column_aliases(scope.ctes)
59        _pop_table_column_aliases(scope.derived_tables)
60        using_column_tables = _expand_using(scope, resolver)
61
62        if schema.empty and expand_alias_refs:
63            _expand_alias_refs(scope, resolver)
64
65        _qualify_columns(scope, resolver)
66
67        if not schema.empty and expand_alias_refs:
68            _expand_alias_refs(scope, resolver)
69
70        if not isinstance(scope.expression, exp.UDTF):
71            if expand_stars:
72                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
73            qualify_outputs(scope)
74
75        _expand_group_by(scope)
76        _expand_order_by(scope, resolver)
77
78    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 81def validate_qualify_columns(expression: E) -> E:
 82    """Raise an `OptimizeError` if any columns aren't qualified"""
 83    all_unqualified_columns = []
 84    for scope in traverse_scope(expression):
 85        if isinstance(scope.expression, exp.Select):
 86            unqualified_columns = scope.unqualified_columns
 87
 88            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 89                column = scope.external_columns[0]
 90                for_table = f" for table: '{column.table}'" if column.table else ""
 91                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 92
 93            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 94                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 95                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 96                # this list here to ensure those in the former category will be excluded.
 97                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 98                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 99
100            all_unqualified_columns.extend(unqualified_columns)
101
102    if all_unqualified_columns:
103        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
104
105    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
511def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
512    """Ensure all output columns are aliased"""
513    if isinstance(scope_or_expression, exp.Expression):
514        scope = build_scope(scope_or_expression)
515        if not isinstance(scope, Scope):
516            return
517    else:
518        scope = scope_or_expression
519
520    new_selections = []
521    for i, (selection, aliased_column) in enumerate(
522        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
523    ):
524        if selection is None:
525            break
526
527        if isinstance(selection, exp.Subquery):
528            if not selection.output_name:
529                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
530        elif not isinstance(selection, exp.Alias) and not selection.is_star:
531            selection = alias(
532                selection,
533                alias=selection.output_name or f"_col_{i}",
534                copy=False,
535            )
536        if aliased_column:
537            selection.set("alias", exp.to_identifier(aliased_column))
538
539        new_selections.append(selection)
540
541    if isinstance(scope.expression, exp.Select):
542        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
545def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
546    """Makes sure all identifiers that need to be quoted are quoted."""
547    return expression.transform(
548        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
549    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
552def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
553    """
554    Pushes down the CTE alias columns into the projection,
555
556    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
557
558    Example:
559        >>> import sqlglot
560        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
561        >>> pushdown_cte_alias_columns(expression).sql()
562        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
563
564    Args:
565        expression: Expression to pushdown.
566
567    Returns:
568        The expression with the CTE aliases pushed down into the projection.
569    """
570    for cte in expression.find_all(exp.CTE):
571        if cte.alias_column_names:
572            new_expressions = []
573            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
574                if isinstance(projection, exp.Alias):
575                    projection.set("alias", _alias)
576                else:
577                    projection = alias(projection, alias=_alias)
578                new_expressions.append(projection)
579            cte.this.set("expressions", new_expressions)
580
581    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
584class Resolver:
585    """
586    Helper for resolving columns.
587
588    This is a class so we can lazily load some things and easily share them across functions.
589    """
590
591    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
592        self.scope = scope
593        self.schema = schema
594        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
595        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
596        self._all_columns: t.Optional[t.Set[str]] = None
597        self._infer_schema = infer_schema
598
599    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
600        """
601        Get the table for a column name.
602
603        Args:
604            column_name: The column name to find the table for.
605        Returns:
606            The table name if it can be found/inferred.
607        """
608        if self._unambiguous_columns is None:
609            self._unambiguous_columns = self._get_unambiguous_columns(
610                self._get_all_source_columns()
611            )
612
613        table_name = self._unambiguous_columns.get(column_name)
614
615        if not table_name and self._infer_schema:
616            sources_without_schema = tuple(
617                source
618                for source, columns in self._get_all_source_columns().items()
619                if not columns or "*" in columns
620            )
621            if len(sources_without_schema) == 1:
622                table_name = sources_without_schema[0]
623
624        if table_name not in self.scope.selected_sources:
625            return exp.to_identifier(table_name)
626
627        node, _ = self.scope.selected_sources.get(table_name)
628
629        if isinstance(node, exp.Query):
630            while node and node.alias != table_name:
631                node = node.parent
632
633        node_alias = node.args.get("alias")
634        if node_alias:
635            return exp.to_identifier(node_alias.this)
636
637        return exp.to_identifier(table_name)
638
639    @property
640    def all_columns(self) -> t.Set[str]:
641        """All available columns of all sources in this scope"""
642        if self._all_columns is None:
643            self._all_columns = {
644                column for columns in self._get_all_source_columns().values() for column in columns
645            }
646        return self._all_columns
647
648    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
649        """Resolve the source columns for a given source `name`."""
650        if name not in self.scope.sources:
651            raise OptimizeError(f"Unknown table: {name}")
652
653        source = self.scope.sources[name]
654
655        if isinstance(source, exp.Table):
656            columns = self.schema.column_names(source, only_visible)
657        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
658            columns = source.expression.named_selects
659
660            # in bigquery, unnest structs are automatically scoped as tables, so you can
661            # directly select a struct field in a query.
662            # this handles the case where the unnest is statically defined.
663            if self.schema.dialect == "bigquery":
664                expression = source.expression
665                annotate_types(expression)
666
667                if expression.is_type(exp.DataType.Type.STRUCT):
668                    for k in expression.type.expressions:  # type: ignore
669                        columns.append(k.name)
670        else:
671            columns = source.expression.named_selects
672
673        node, _ = self.scope.selected_sources.get(name) or (None, None)
674        if isinstance(node, Scope):
675            column_aliases = node.expression.alias_column_names
676        elif isinstance(node, exp.Expression):
677            column_aliases = node.alias_column_names
678        else:
679            column_aliases = []
680
681        if column_aliases:
682            # If the source's columns are aliased, their aliases shadow the corresponding column names.
683            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
684            return [
685                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
686            ]
687        return columns
688
689    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
690        if self._source_columns is None:
691            self._source_columns = {
692                source_name: self.get_source_columns(source_name)
693                for source_name, source in itertools.chain(
694                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
695                )
696            }
697        return self._source_columns
698
699    def _get_unambiguous_columns(
700        self, source_columns: t.Dict[str, t.Sequence[str]]
701    ) -> t.Mapping[str, str]:
702        """
703        Find all the unambiguous columns in sources.
704
705        Args:
706            source_columns: Mapping of names to source columns.
707
708        Returns:
709            Mapping of column name to source name.
710        """
711        if not source_columns:
712            return {}
713
714        source_columns_pairs = list(source_columns.items())
715
716        first_table, first_columns = source_columns_pairs[0]
717
718        if len(source_columns_pairs) == 1:
719            # Performance optimization - avoid copying first_columns if there is only one table.
720            return SingleValuedMapping(first_columns, first_table)
721
722        unambiguous_columns = {col: first_table for col in first_columns}
723        all_columns = set(unambiguous_columns)
724
725        for table, columns in source_columns_pairs[1:]:
726            unique = set(columns)
727            ambiguous = all_columns.intersection(unique)
728            all_columns.update(columns)
729
730            for column in ambiguous:
731                unambiguous_columns.pop(column, None)
732            for column in unique.difference(ambiguous):
733                unambiguous_columns[column] = table
734
735        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
591    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
592        self.scope = scope
593        self.schema = schema
594        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
595        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
596        self._all_columns: t.Optional[t.Set[str]] = None
597        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
599    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
600        """
601        Get the table for a column name.
602
603        Args:
604            column_name: The column name to find the table for.
605        Returns:
606            The table name if it can be found/inferred.
607        """
608        if self._unambiguous_columns is None:
609            self._unambiguous_columns = self._get_unambiguous_columns(
610                self._get_all_source_columns()
611            )
612
613        table_name = self._unambiguous_columns.get(column_name)
614
615        if not table_name and self._infer_schema:
616            sources_without_schema = tuple(
617                source
618                for source, columns in self._get_all_source_columns().items()
619                if not columns or "*" in columns
620            )
621            if len(sources_without_schema) == 1:
622                table_name = sources_without_schema[0]
623
624        if table_name not in self.scope.selected_sources:
625            return exp.to_identifier(table_name)
626
627        node, _ = self.scope.selected_sources.get(table_name)
628
629        if isinstance(node, exp.Query):
630            while node and node.alias != table_name:
631                node = node.parent
632
633        node_alias = node.args.get("alias")
634        if node_alias:
635            return exp.to_identifier(node_alias.this)
636
637        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
639    @property
640    def all_columns(self) -> t.Set[str]:
641        """All available columns of all sources in this scope"""
642        if self._all_columns is None:
643            self._all_columns = {
644                column for columns in self._get_all_source_columns().values() for column in columns
645            }
646        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
648    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
649        """Resolve the source columns for a given source `name`."""
650        if name not in self.scope.sources:
651            raise OptimizeError(f"Unknown table: {name}")
652
653        source = self.scope.sources[name]
654
655        if isinstance(source, exp.Table):
656            columns = self.schema.column_names(source, only_visible)
657        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
658            columns = source.expression.named_selects
659
660            # in bigquery, unnest structs are automatically scoped as tables, so you can
661            # directly select a struct field in a query.
662            # this handles the case where the unnest is statically defined.
663            if self.schema.dialect == "bigquery":
664                expression = source.expression
665                annotate_types(expression)
666
667                if expression.is_type(exp.DataType.Type.STRUCT):
668                    for k in expression.type.expressions:  # type: ignore
669                        columns.append(k.name)
670        else:
671            columns = source.expression.named_selects
672
673        node, _ = self.scope.selected_sources.get(name) or (None, None)
674        if isinstance(node, Scope):
675            column_aliases = node.expression.alias_column_names
676        elif isinstance(node, exp.Expression):
677            column_aliases = node.alias_column_names
678        else:
679            column_aliases = []
680
681        if column_aliases:
682            # If the source's columns are aliased, their aliases shadow the corresponding column names.
683            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
684            return [
685                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
686            ]
687        return columns

Resolve the source columns for a given source name.