Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15
 16def qualify_columns(
 17    expression: exp.Expression,
 18    schema: t.Dict | Schema,
 19    expand_alias_refs: bool = True,
 20    expand_stars: bool = True,
 21    infer_schema: t.Optional[bool] = None,
 22) -> exp.Expression:
 23    """
 24    Rewrite sqlglot AST to have fully qualified columns.
 25
 26    Example:
 27        >>> import sqlglot
 28        >>> schema = {"tbl": {"col": "INT"}}
 29        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 30        >>> qualify_columns(expression, schema).sql()
 31        'SELECT tbl.col AS col FROM tbl'
 32
 33    Args:
 34        expression: Expression to qualify.
 35        schema: Database schema.
 36        expand_alias_refs: Whether or not to expand references to aliases.
 37        expand_stars: Whether or not to expand star queries. This is a necessary step
 38            for most of the optimizer's rules to work; do not set to False unless you
 39            know what you're doing!
 40        infer_schema: Whether or not to infer the schema if missing.
 41
 42    Returns:
 43        The qualified expression.
 44
 45    Notes:
 46        - Currently only handles a single PIVOT or UNPIVOT operator
 47    """
 48    schema = ensure_schema(schema)
 49    infer_schema = schema.empty if infer_schema is None else infer_schema
 50    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 51
 52    for scope in traverse_scope(expression):
 53        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 54        _pop_table_column_aliases(scope.ctes)
 55        _pop_table_column_aliases(scope.derived_tables)
 56        using_column_tables = _expand_using(scope, resolver)
 57
 58        if schema.empty and expand_alias_refs:
 59            _expand_alias_refs(scope, resolver)
 60
 61        _qualify_columns(scope, resolver)
 62
 63        if not schema.empty and expand_alias_refs:
 64            _expand_alias_refs(scope, resolver)
 65
 66        if not isinstance(scope.expression, exp.UDTF):
 67            if expand_stars:
 68                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 69            qualify_outputs(scope)
 70
 71        _expand_group_by(scope)
 72        _expand_order_by(scope, resolver)
 73
 74    return expression
 75
 76
 77def validate_qualify_columns(expression: E) -> E:
 78    """Raise an `OptimizeError` if any columns aren't qualified"""
 79    all_unqualified_columns = []
 80    for scope in traverse_scope(expression):
 81        if isinstance(scope.expression, exp.Select):
 82            unqualified_columns = scope.unqualified_columns
 83
 84            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 85                column = scope.external_columns[0]
 86                for_table = f" for table: '{column.table}'" if column.table else ""
 87                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 88
 89            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 90                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 91                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 92                # this list here to ensure those in the former category will be excluded.
 93                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 94                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 95
 96            all_unqualified_columns.extend(unqualified_columns)
 97
 98    if all_unqualified_columns:
 99        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
100
101    return expression
102
103
104def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
105    name_column = []
106    field = unpivot.args.get("field")
107    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
108        name_column.append(field.this)
109
110    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
111    return itertools.chain(name_column, value_columns)
112
113
114def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
115    """
116    Remove table column aliases.
117
118    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
119    """
120    for derived_table in derived_tables:
121        table_alias = derived_table.args.get("alias")
122        if table_alias:
123            table_alias.args.pop("columns", None)
124
125
126def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
127    joins = list(scope.find_all(exp.Join))
128    names = {join.alias_or_name for join in joins}
129    ordered = [key for key in scope.selected_sources if key not in names]
130
131    # Mapping of automatically joined column names to an ordered set of source names (dict).
132    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
133
134    for join in joins:
135        using = join.args.get("using")
136
137        if not using:
138            continue
139
140        join_table = join.alias_or_name
141
142        columns = {}
143
144        for source_name in scope.selected_sources:
145            if source_name in ordered:
146                for column_name in resolver.get_source_columns(source_name):
147                    if column_name not in columns:
148                        columns[column_name] = source_name
149
150        source_table = ordered[-1]
151        ordered.append(join_table)
152        join_columns = resolver.get_source_columns(join_table)
153        conditions = []
154
155        for identifier in using:
156            identifier = identifier.name
157            table = columns.get(identifier)
158
159            if not table or identifier not in join_columns:
160                if (columns and "*" not in columns) and join_columns:
161                    raise OptimizeError(f"Cannot automatically join: {identifier}")
162
163            table = table or source_table
164            conditions.append(
165                exp.condition(
166                    exp.EQ(
167                        this=exp.column(identifier, table=table),
168                        expression=exp.column(identifier, table=join_table),
169                    )
170                )
171            )
172
173            # Set all values in the dict to None, because we only care about the key ordering
174            tables = column_tables.setdefault(identifier, {})
175            if table not in tables:
176                tables[table] = None
177            if join_table not in tables:
178                tables[join_table] = None
179
180        join.args.pop("using")
181        join.set("on", exp.and_(*conditions, copy=False))
182
183    if column_tables:
184        for column in scope.columns:
185            if not column.table and column.name in column_tables:
186                tables = column_tables[column.name]
187                coalesce = [exp.column(column.name, table=table) for table in tables]
188                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
189
190                # Ensure selects keep their output name
191                if isinstance(column.parent, exp.Select):
192                    replacement = alias(replacement, alias=column.name, copy=False)
193
194                scope.replace(column, replacement)
195
196    return column_tables
197
198
199def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
200    expression = scope.expression
201
202    if not isinstance(expression, exp.Select):
203        return
204
205    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
206
207    def replace_columns(
208        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
209    ) -> None:
210        if not node:
211            return
212
213        for column, *_ in walk_in_scope(node):
214            if not isinstance(column, exp.Column):
215                continue
216
217            table = resolver.get_table(column.name) if resolve_table and not column.table else None
218            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
219            double_agg = (
220                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
221                if alias_expr
222                else False
223            )
224
225            if table and (not alias_expr or double_agg):
226                column.set("table", table)
227            elif not column.table and alias_expr and not double_agg:
228                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
229                    if literal_index:
230                        column.replace(exp.Literal.number(i))
231                else:
232                    column = column.replace(exp.paren(alias_expr))
233                    simplified = simplify_parens(column)
234                    if simplified is not column:
235                        column.replace(simplified)
236
237    for i, projection in enumerate(scope.expression.selects):
238        replace_columns(projection)
239
240        if isinstance(projection, exp.Alias):
241            alias_to_expression[projection.alias] = (projection.this, i + 1)
242
243    replace_columns(expression.args.get("where"))
244    replace_columns(expression.args.get("group"), literal_index=True)
245    replace_columns(expression.args.get("having"), resolve_table=True)
246    replace_columns(expression.args.get("qualify"), resolve_table=True)
247
248    scope.clear_cache()
249
250
251def _expand_group_by(scope: Scope) -> None:
252    expression = scope.expression
253    group = expression.args.get("group")
254    if not group:
255        return
256
257    group.set("expressions", _expand_positional_references(scope, group.expressions))
258    expression.set("group", group)
259
260
261def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
262    order = scope.expression.args.get("order")
263    if not order:
264        return
265
266    ordereds = order.expressions
267    for ordered, new_expression in zip(
268        ordereds,
269        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
270    ):
271        for agg in ordered.find_all(exp.AggFunc):
272            for col in agg.find_all(exp.Column):
273                if not col.table:
274                    col.set("table", resolver.get_table(col.name))
275
276        ordered.set("this", new_expression)
277
278    if scope.expression.args.get("group"):
279        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
280
281        for ordered in ordereds:
282            ordered = ordered.this
283
284            ordered.replace(
285                exp.to_identifier(_select_by_pos(scope, ordered).alias)
286                if ordered.is_int
287                else selects.get(ordered, ordered)
288            )
289
290
291def _expand_positional_references(
292    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
293) -> t.List[exp.Expression]:
294    new_nodes: t.List[exp.Expression] = []
295    for node in expressions:
296        if node.is_int:
297            select = _select_by_pos(scope, t.cast(exp.Literal, node))
298
299            if alias:
300                new_nodes.append(exp.column(select.args["alias"].copy()))
301            else:
302                select = select.this
303
304                if isinstance(select, exp.Literal):
305                    new_nodes.append(node)
306                else:
307                    new_nodes.append(select.copy())
308        else:
309            new_nodes.append(node)
310
311    return new_nodes
312
313
314def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
315    try:
316        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
317    except IndexError:
318        raise OptimizeError(f"Unknown output column: {node.name}")
319
320
321def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
322    """Disambiguate columns, ensuring each column specifies a source"""
323    for column in scope.columns:
324        column_table = column.table
325        column_name = column.name
326
327        if column_table and column_table in scope.sources:
328            source_columns = resolver.get_source_columns(column_table)
329            if source_columns and column_name not in source_columns and "*" not in source_columns:
330                raise OptimizeError(f"Unknown column: {column_name}")
331
332        if not column_table:
333            if scope.pivots and not column.find_ancestor(exp.Pivot):
334                # If the column is under the Pivot expression, we need to qualify it
335                # using the name of the pivoted source instead of the pivot's alias
336                column.set("table", exp.to_identifier(scope.pivots[0].alias))
337                continue
338
339            column_table = resolver.get_table(column_name)
340
341            # column_table can be a '' because bigquery unnest has no table alias
342            if column_table:
343                column.set("table", column_table)
344        elif column_table not in scope.sources and (
345            not scope.parent
346            or column_table not in scope.parent.sources
347            or not scope.is_correlated_subquery
348        ):
349            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
350            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
351
352            root, *parts = column.parts
353
354            if root.name in scope.sources:
355                # struct is already qualified, but we still need to change the AST representation
356                column_table = root
357                root, *parts = parts
358            else:
359                column_table = resolver.get_table(root.name)
360
361            if column_table:
362                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
363
364    for pivot in scope.pivots:
365        for column in pivot.find_all(exp.Column):
366            if not column.table and column.name in resolver.all_columns:
367                column_table = resolver.get_table(column.name)
368                if column_table:
369                    column.set("table", column_table)
370
371
372def _expand_stars(
373    scope: Scope,
374    resolver: Resolver,
375    using_column_tables: t.Dict[str, t.Any],
376    pseudocolumns: t.Set[str],
377) -> None:
378    """Expand stars to lists of column selections"""
379
380    new_selections = []
381    except_columns: t.Dict[int, t.Set[str]] = {}
382    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
383    coalesced_columns = set()
384
385    pivot_output_columns = None
386    pivot_exclude_columns = None
387
388    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
389    if isinstance(pivot, exp.Pivot):
390        if pivot.unpivot:
391            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
392
393            field = pivot.args.get("field")
394            if isinstance(field, exp.In):
395                pivot_exclude_columns = {
396                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
397                }
398        else:
399            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
400
401            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
402            if not pivot_output_columns:
403                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
404
405    for expression in scope.expression.selects:
406        if isinstance(expression, exp.Star):
407            tables = list(scope.selected_sources)
408            _add_except_columns(expression, tables, except_columns)
409            _add_replace_columns(expression, tables, replace_columns)
410        elif expression.is_star:
411            tables = [expression.table]
412            _add_except_columns(expression.this, tables, except_columns)
413            _add_replace_columns(expression.this, tables, replace_columns)
414        else:
415            new_selections.append(expression)
416            continue
417
418        for table in tables:
419            if table not in scope.sources:
420                raise OptimizeError(f"Unknown table: {table}")
421
422            columns = resolver.get_source_columns(table, only_visible=True)
423            columns = columns or scope.outer_column_list
424
425            if pseudocolumns:
426                columns = [name for name in columns if name.upper() not in pseudocolumns]
427
428            if not columns or "*" in columns:
429                return
430
431            table_id = id(table)
432            columns_to_exclude = except_columns.get(table_id) or set()
433
434            if pivot and pivot_output_columns and pivot_exclude_columns:
435                implicit_columns = [c for c in columns if c not in pivot_exclude_columns]
436                if implicit_columns:
437                    new_selections.extend(
438                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
439                        for name in implicit_columns + pivot_output_columns
440                        if name not in columns_to_exclude
441                    )
442                    continue
443
444            for name in columns:
445                if name in using_column_tables and table in using_column_tables[name]:
446                    if name in coalesced_columns:
447                        continue
448
449                    coalesced_columns.add(name)
450                    tables = using_column_tables[name]
451                    coalesce = [exp.column(name, table=table) for table in tables]
452
453                    new_selections.append(
454                        alias(
455                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
456                            alias=name,
457                            copy=False,
458                        )
459                    )
460                elif name not in columns_to_exclude:
461                    alias_ = replace_columns.get(table_id, {}).get(name, name)
462                    column = exp.column(name, table=table)
463                    new_selections.append(
464                        alias(column, alias_, copy=False) if alias_ != name else column
465                    )
466
467    # Ensures we don't overwrite the initial selections with an empty list
468    if new_selections:
469        scope.expression.set("expressions", new_selections)
470
471
472def _add_except_columns(
473    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
474) -> None:
475    except_ = expression.args.get("except")
476
477    if not except_:
478        return
479
480    columns = {e.name for e in except_}
481
482    for table in tables:
483        except_columns[id(table)] = columns
484
485
486def _add_replace_columns(
487    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
488) -> None:
489    replace = expression.args.get("replace")
490
491    if not replace:
492        return
493
494    columns = {e.this.name: e.alias for e in replace}
495
496    for table in tables:
497        replace_columns[id(table)] = columns
498
499
500def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
501    """Ensure all output columns are aliased"""
502    if isinstance(scope_or_expression, exp.Expression):
503        scope = build_scope(scope_or_expression)
504        if not isinstance(scope, Scope):
505            return
506    else:
507        scope = scope_or_expression
508
509    new_selections = []
510    for i, (selection, aliased_column) in enumerate(
511        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
512    ):
513        if selection is None:
514            break
515
516        if isinstance(selection, exp.Subquery):
517            if not selection.output_name:
518                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
519        elif not isinstance(selection, exp.Alias) and not selection.is_star:
520            selection = alias(
521                selection,
522                alias=selection.output_name or f"_col_{i}",
523            )
524        if aliased_column:
525            selection.set("alias", exp.to_identifier(aliased_column))
526
527        new_selections.append(selection)
528
529    scope.expression.set("expressions", new_selections)
530
531
532def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
533    """Makes sure all identifiers that need to be quoted are quoted."""
534    return expression.transform(
535        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
536    )
537
538
539def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
540    """
541    Pushes down the CTE alias columns into the projection,
542
543    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
544
545    Example:
546        >>> import sqlglot
547        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
548        >>> pushdown_cte_alias_columns(expression).sql()
549        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
550
551    Args:
552        expression: Expression to pushdown.
553
554    Returns:
555        The expression with the CTE aliases pushed down into the projection.
556    """
557    for cte in expression.find_all(exp.CTE):
558        if cte.alias_column_names:
559            new_expressions = []
560            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
561                if isinstance(projection, exp.Alias):
562                    projection.set("alias", _alias)
563                else:
564                    projection = alias(projection, alias=_alias)
565                new_expressions.append(projection)
566            cte.this.set("expressions", new_expressions)
567
568    return expression
569
570
571class Resolver:
572    """
573    Helper for resolving columns.
574
575    This is a class so we can lazily load some things and easily share them across functions.
576    """
577
578    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
579        self.scope = scope
580        self.schema = schema
581        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
582        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
583        self._all_columns: t.Optional[t.Set[str]] = None
584        self._infer_schema = infer_schema
585
586    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
587        """
588        Get the table for a column name.
589
590        Args:
591            column_name: The column name to find the table for.
592        Returns:
593            The table name if it can be found/inferred.
594        """
595        if self._unambiguous_columns is None:
596            self._unambiguous_columns = self._get_unambiguous_columns(
597                self._get_all_source_columns()
598            )
599
600        table_name = self._unambiguous_columns.get(column_name)
601
602        if not table_name and self._infer_schema:
603            sources_without_schema = tuple(
604                source
605                for source, columns in self._get_all_source_columns().items()
606                if not columns or "*" in columns
607            )
608            if len(sources_without_schema) == 1:
609                table_name = sources_without_schema[0]
610
611        if table_name not in self.scope.selected_sources:
612            return exp.to_identifier(table_name)
613
614        node, _ = self.scope.selected_sources.get(table_name)
615
616        if isinstance(node, exp.Subqueryable):
617            while node and node.alias != table_name:
618                node = node.parent
619
620        node_alias = node.args.get("alias")
621        if node_alias:
622            return exp.to_identifier(node_alias.this)
623
624        return exp.to_identifier(table_name)
625
626    @property
627    def all_columns(self) -> t.Set[str]:
628        """All available columns of all sources in this scope"""
629        if self._all_columns is None:
630            self._all_columns = {
631                column for columns in self._get_all_source_columns().values() for column in columns
632            }
633        return self._all_columns
634
635    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
636        """Resolve the source columns for a given source `name`."""
637        if name not in self.scope.sources:
638            raise OptimizeError(f"Unknown table: {name}")
639
640        source = self.scope.sources[name]
641
642        if isinstance(source, exp.Table):
643            columns = self.schema.column_names(source, only_visible)
644        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
645            columns = source.expression.alias_column_names
646        else:
647            columns = source.expression.named_selects
648
649        node, _ = self.scope.selected_sources.get(name) or (None, None)
650        if isinstance(node, Scope):
651            column_aliases = node.expression.alias_column_names
652        elif isinstance(node, exp.Expression):
653            column_aliases = node.alias_column_names
654        else:
655            column_aliases = []
656
657        # If the source's columns are aliased, their aliases shadow the corresponding column names
658        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
659
660    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
661        if self._source_columns is None:
662            self._source_columns = {
663                source_name: self.get_source_columns(source_name)
664                for source_name, source in itertools.chain(
665                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
666                )
667            }
668        return self._source_columns
669
670    def _get_unambiguous_columns(
671        self, source_columns: t.Dict[str, t.List[str]]
672    ) -> t.Dict[str, str]:
673        """
674        Find all the unambiguous columns in sources.
675
676        Args:
677            source_columns: Mapping of names to source columns.
678
679        Returns:
680            Mapping of column name to source name.
681        """
682        if not source_columns:
683            return {}
684
685        source_columns_pairs = list(source_columns.items())
686
687        first_table, first_columns = source_columns_pairs[0]
688        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
689        all_columns = set(unambiguous_columns)
690
691        for table, columns in source_columns_pairs[1:]:
692            unique = self._find_unique_columns(columns)
693            ambiguous = set(all_columns).intersection(unique)
694            all_columns.update(columns)
695
696            for column in ambiguous:
697                unambiguous_columns.pop(column, None)
698            for column in unique.difference(ambiguous):
699                unambiguous_columns[column] = table
700
701        return unambiguous_columns
702
703    @staticmethod
704    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
705        """
706        Find the unique columns in a list of columns.
707
708        Example:
709            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
710            ['a', 'c']
711
712        This is necessary because duplicate column names are ambiguous.
713        """
714        counts: t.Dict[str, int] = {}
715        for column in columns:
716            counts[column] = counts.get(column, 0) + 1
717        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns(
18    expression: exp.Expression,
19    schema: t.Dict | Schema,
20    expand_alias_refs: bool = True,
21    expand_stars: bool = True,
22    infer_schema: t.Optional[bool] = None,
23) -> exp.Expression:
24    """
25    Rewrite sqlglot AST to have fully qualified columns.
26
27    Example:
28        >>> import sqlglot
29        >>> schema = {"tbl": {"col": "INT"}}
30        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
31        >>> qualify_columns(expression, schema).sql()
32        'SELECT tbl.col AS col FROM tbl'
33
34    Args:
35        expression: Expression to qualify.
36        schema: Database schema.
37        expand_alias_refs: Whether or not to expand references to aliases.
38        expand_stars: Whether or not to expand star queries. This is a necessary step
39            for most of the optimizer's rules to work; do not set to False unless you
40            know what you're doing!
41        infer_schema: Whether or not to infer the schema if missing.
42
43    Returns:
44        The qualified expression.
45
46    Notes:
47        - Currently only handles a single PIVOT or UNPIVOT operator
48    """
49    schema = ensure_schema(schema)
50    infer_schema = schema.empty if infer_schema is None else infer_schema
51    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
52
53    for scope in traverse_scope(expression):
54        resolver = Resolver(scope, schema, infer_schema=infer_schema)
55        _pop_table_column_aliases(scope.ctes)
56        _pop_table_column_aliases(scope.derived_tables)
57        using_column_tables = _expand_using(scope, resolver)
58
59        if schema.empty and expand_alias_refs:
60            _expand_alias_refs(scope, resolver)
61
62        _qualify_columns(scope, resolver)
63
64        if not schema.empty and expand_alias_refs:
65            _expand_alias_refs(scope, resolver)
66
67        if not isinstance(scope.expression, exp.UDTF):
68            if expand_stars:
69                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
70            qualify_outputs(scope)
71
72        _expand_group_by(scope)
73        _expand_order_by(scope, resolver)
74
75    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 78def validate_qualify_columns(expression: E) -> E:
 79    """Raise an `OptimizeError` if any columns aren't qualified"""
 80    all_unqualified_columns = []
 81    for scope in traverse_scope(expression):
 82        if isinstance(scope.expression, exp.Select):
 83            unqualified_columns = scope.unqualified_columns
 84
 85            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 86                column = scope.external_columns[0]
 87                for_table = f" for table: '{column.table}'" if column.table else ""
 88                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 89
 90            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 91                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 92                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 93                # this list here to ensure those in the former category will be excluded.
 94                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 95                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 96
 97            all_unqualified_columns.extend(unqualified_columns)
 98
 99    if all_unqualified_columns:
100        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
101
102    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
501def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
502    """Ensure all output columns are aliased"""
503    if isinstance(scope_or_expression, exp.Expression):
504        scope = build_scope(scope_or_expression)
505        if not isinstance(scope, Scope):
506            return
507    else:
508        scope = scope_or_expression
509
510    new_selections = []
511    for i, (selection, aliased_column) in enumerate(
512        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
513    ):
514        if selection is None:
515            break
516
517        if isinstance(selection, exp.Subquery):
518            if not selection.output_name:
519                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
520        elif not isinstance(selection, exp.Alias) and not selection.is_star:
521            selection = alias(
522                selection,
523                alias=selection.output_name or f"_col_{i}",
524            )
525        if aliased_column:
526            selection.set("alias", exp.to_identifier(aliased_column))
527
528        new_selections.append(selection)
529
530    scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
533def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
534    """Makes sure all identifiers that need to be quoted are quoted."""
535    return expression.transform(
536        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
537    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
540def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
541    """
542    Pushes down the CTE alias columns into the projection,
543
544    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
545
546    Example:
547        >>> import sqlglot
548        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
549        >>> pushdown_cte_alias_columns(expression).sql()
550        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
551
552    Args:
553        expression: Expression to pushdown.
554
555    Returns:
556        The expression with the CTE aliases pushed down into the projection.
557    """
558    for cte in expression.find_all(exp.CTE):
559        if cte.alias_column_names:
560            new_expressions = []
561            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
562                if isinstance(projection, exp.Alias):
563                    projection.set("alias", _alias)
564                else:
565                    projection = alias(projection, alias=_alias)
566                new_expressions.append(projection)
567            cte.this.set("expressions", new_expressions)
568
569    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
572class Resolver:
573    """
574    Helper for resolving columns.
575
576    This is a class so we can lazily load some things and easily share them across functions.
577    """
578
579    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
580        self.scope = scope
581        self.schema = schema
582        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
583        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
584        self._all_columns: t.Optional[t.Set[str]] = None
585        self._infer_schema = infer_schema
586
587    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
588        """
589        Get the table for a column name.
590
591        Args:
592            column_name: The column name to find the table for.
593        Returns:
594            The table name if it can be found/inferred.
595        """
596        if self._unambiguous_columns is None:
597            self._unambiguous_columns = self._get_unambiguous_columns(
598                self._get_all_source_columns()
599            )
600
601        table_name = self._unambiguous_columns.get(column_name)
602
603        if not table_name and self._infer_schema:
604            sources_without_schema = tuple(
605                source
606                for source, columns in self._get_all_source_columns().items()
607                if not columns or "*" in columns
608            )
609            if len(sources_without_schema) == 1:
610                table_name = sources_without_schema[0]
611
612        if table_name not in self.scope.selected_sources:
613            return exp.to_identifier(table_name)
614
615        node, _ = self.scope.selected_sources.get(table_name)
616
617        if isinstance(node, exp.Subqueryable):
618            while node and node.alias != table_name:
619                node = node.parent
620
621        node_alias = node.args.get("alias")
622        if node_alias:
623            return exp.to_identifier(node_alias.this)
624
625        return exp.to_identifier(table_name)
626
627    @property
628    def all_columns(self) -> t.Set[str]:
629        """All available columns of all sources in this scope"""
630        if self._all_columns is None:
631            self._all_columns = {
632                column for columns in self._get_all_source_columns().values() for column in columns
633            }
634        return self._all_columns
635
636    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
637        """Resolve the source columns for a given source `name`."""
638        if name not in self.scope.sources:
639            raise OptimizeError(f"Unknown table: {name}")
640
641        source = self.scope.sources[name]
642
643        if isinstance(source, exp.Table):
644            columns = self.schema.column_names(source, only_visible)
645        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
646            columns = source.expression.alias_column_names
647        else:
648            columns = source.expression.named_selects
649
650        node, _ = self.scope.selected_sources.get(name) or (None, None)
651        if isinstance(node, Scope):
652            column_aliases = node.expression.alias_column_names
653        elif isinstance(node, exp.Expression):
654            column_aliases = node.alias_column_names
655        else:
656            column_aliases = []
657
658        # If the source's columns are aliased, their aliases shadow the corresponding column names
659        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
660
661    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
662        if self._source_columns is None:
663            self._source_columns = {
664                source_name: self.get_source_columns(source_name)
665                for source_name, source in itertools.chain(
666                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
667                )
668            }
669        return self._source_columns
670
671    def _get_unambiguous_columns(
672        self, source_columns: t.Dict[str, t.List[str]]
673    ) -> t.Dict[str, str]:
674        """
675        Find all the unambiguous columns in sources.
676
677        Args:
678            source_columns: Mapping of names to source columns.
679
680        Returns:
681            Mapping of column name to source name.
682        """
683        if not source_columns:
684            return {}
685
686        source_columns_pairs = list(source_columns.items())
687
688        first_table, first_columns = source_columns_pairs[0]
689        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
690        all_columns = set(unambiguous_columns)
691
692        for table, columns in source_columns_pairs[1:]:
693            unique = self._find_unique_columns(columns)
694            ambiguous = set(all_columns).intersection(unique)
695            all_columns.update(columns)
696
697            for column in ambiguous:
698                unambiguous_columns.pop(column, None)
699            for column in unique.difference(ambiguous):
700                unambiguous_columns[column] = table
701
702        return unambiguous_columns
703
704    @staticmethod
705    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
706        """
707        Find the unique columns in a list of columns.
708
709        Example:
710            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
711            ['a', 'c']
712
713        This is necessary because duplicate column names are ambiguous.
714        """
715        counts: t.Dict[str, int] = {}
716        for column in columns:
717            counts[column] = counts.get(column, 0) + 1
718        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
579    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
580        self.scope = scope
581        self.schema = schema
582        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
583        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
584        self._all_columns: t.Optional[t.Set[str]] = None
585        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
587    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
588        """
589        Get the table for a column name.
590
591        Args:
592            column_name: The column name to find the table for.
593        Returns:
594            The table name if it can be found/inferred.
595        """
596        if self._unambiguous_columns is None:
597            self._unambiguous_columns = self._get_unambiguous_columns(
598                self._get_all_source_columns()
599            )
600
601        table_name = self._unambiguous_columns.get(column_name)
602
603        if not table_name and self._infer_schema:
604            sources_without_schema = tuple(
605                source
606                for source, columns in self._get_all_source_columns().items()
607                if not columns or "*" in columns
608            )
609            if len(sources_without_schema) == 1:
610                table_name = sources_without_schema[0]
611
612        if table_name not in self.scope.selected_sources:
613            return exp.to_identifier(table_name)
614
615        node, _ = self.scope.selected_sources.get(table_name)
616
617        if isinstance(node, exp.Subqueryable):
618            while node and node.alias != table_name:
619                node = node.parent
620
621        node_alias = node.args.get("alias")
622        if node_alias:
623            return exp.to_identifier(node_alias.this)
624
625        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
627    @property
628    def all_columns(self) -> t.Set[str]:
629        """All available columns of all sources in this scope"""
630        if self._all_columns is None:
631            self._all_columns = {
632                column for columns in self._get_all_source_columns().values() for column in columns
633            }
634        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
636    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
637        """Resolve the source columns for a given source `name`."""
638        if name not in self.scope.sources:
639            raise OptimizeError(f"Unknown table: {name}")
640
641        source = self.scope.sources[name]
642
643        if isinstance(source, exp.Table):
644            columns = self.schema.column_names(source, only_visible)
645        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
646            columns = source.expression.alias_column_names
647        else:
648            columns = source.expression.named_selects
649
650        node, _ = self.scope.selected_sources.get(name) or (None, None)
651        if isinstance(node, Scope):
652            column_aliases = node.expression.alias_column_names
653        elif isinstance(node, exp.Expression):
654            column_aliases = node.alias_column_names
655        else:
656            column_aliases = []
657
658        # If the source's columns are aliased, their aliases shadow the corresponding column names
659        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

Resolve the source columns for a given source name.