Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15
 16def qualify_columns(
 17    expression: exp.Expression,
 18    schema: t.Dict | Schema,
 19    expand_alias_refs: bool = True,
 20    expand_stars: bool = True,
 21    infer_schema: t.Optional[bool] = None,
 22) -> exp.Expression:
 23    """
 24    Rewrite sqlglot AST to have fully qualified columns.
 25
 26    Example:
 27        >>> import sqlglot
 28        >>> schema = {"tbl": {"col": "INT"}}
 29        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 30        >>> qualify_columns(expression, schema).sql()
 31        'SELECT tbl.col AS col FROM tbl'
 32
 33    Args:
 34        expression: Expression to qualify.
 35        schema: Database schema.
 36        expand_alias_refs: Whether or not to expand references to aliases.
 37        expand_stars: Whether or not to expand star queries. This is a necessary step
 38            for most of the optimizer's rules to work; do not set to False unless you
 39            know what you're doing!
 40        infer_schema: Whether or not to infer the schema if missing.
 41
 42    Returns:
 43        The qualified expression.
 44
 45    Notes:
 46        - Currently only handles a single PIVOT or UNPIVOT operator
 47    """
 48    schema = ensure_schema(schema)
 49    infer_schema = schema.empty if infer_schema is None else infer_schema
 50    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 51
 52    for scope in traverse_scope(expression):
 53        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 54        _pop_table_column_aliases(scope.ctes)
 55        _pop_table_column_aliases(scope.derived_tables)
 56        using_column_tables = _expand_using(scope, resolver)
 57
 58        if schema.empty and expand_alias_refs:
 59            _expand_alias_refs(scope, resolver)
 60
 61        _qualify_columns(scope, resolver)
 62
 63        if not schema.empty and expand_alias_refs:
 64            _expand_alias_refs(scope, resolver)
 65
 66        if not isinstance(scope.expression, exp.UDTF):
 67            if expand_stars:
 68                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 69            qualify_outputs(scope)
 70
 71        _expand_group_by(scope)
 72        _expand_order_by(scope, resolver)
 73
 74    return expression
 75
 76
 77def validate_qualify_columns(expression: E) -> E:
 78    """Raise an `OptimizeError` if any columns aren't qualified"""
 79    all_unqualified_columns = []
 80    for scope in traverse_scope(expression):
 81        if isinstance(scope.expression, exp.Select):
 82            unqualified_columns = scope.unqualified_columns
 83
 84            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 85                column = scope.external_columns[0]
 86                for_table = f" for table: '{column.table}'" if column.table else ""
 87                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 88
 89            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 90                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 91                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 92                # this list here to ensure those in the former category will be excluded.
 93                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 94                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 95
 96            all_unqualified_columns.extend(unqualified_columns)
 97
 98    if all_unqualified_columns:
 99        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
100
101    return expression
102
103
104def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
105    name_column = []
106    field = unpivot.args.get("field")
107    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
108        name_column.append(field.this)
109
110    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
111    return itertools.chain(name_column, value_columns)
112
113
114def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
115    """
116    Remove table column aliases.
117
118    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
119    """
120    for derived_table in derived_tables:
121        table_alias = derived_table.args.get("alias")
122        if table_alias:
123            table_alias.args.pop("columns", None)
124
125
126def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
127    joins = list(scope.find_all(exp.Join))
128    names = {join.alias_or_name for join in joins}
129    ordered = [key for key in scope.selected_sources if key not in names]
130
131    # Mapping of automatically joined column names to an ordered set of source names (dict).
132    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
133
134    for join in joins:
135        using = join.args.get("using")
136
137        if not using:
138            continue
139
140        join_table = join.alias_or_name
141
142        columns = {}
143
144        for source_name in scope.selected_sources:
145            if source_name in ordered:
146                for column_name in resolver.get_source_columns(source_name):
147                    if column_name not in columns:
148                        columns[column_name] = source_name
149
150        source_table = ordered[-1]
151        ordered.append(join_table)
152        join_columns = resolver.get_source_columns(join_table)
153        conditions = []
154
155        for identifier in using:
156            identifier = identifier.name
157            table = columns.get(identifier)
158
159            if not table or identifier not in join_columns:
160                if (columns and "*" not in columns) and join_columns:
161                    raise OptimizeError(f"Cannot automatically join: {identifier}")
162
163            table = table or source_table
164            conditions.append(
165                exp.condition(
166                    exp.EQ(
167                        this=exp.column(identifier, table=table),
168                        expression=exp.column(identifier, table=join_table),
169                    )
170                )
171            )
172
173            # Set all values in the dict to None, because we only care about the key ordering
174            tables = column_tables.setdefault(identifier, {})
175            if table not in tables:
176                tables[table] = None
177            if join_table not in tables:
178                tables[join_table] = None
179
180        join.args.pop("using")
181        join.set("on", exp.and_(*conditions, copy=False))
182
183    if column_tables:
184        for column in scope.columns:
185            if not column.table and column.name in column_tables:
186                tables = column_tables[column.name]
187                coalesce = [exp.column(column.name, table=table) for table in tables]
188                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
189
190                # Ensure selects keep their output name
191                if isinstance(column.parent, exp.Select):
192                    replacement = alias(replacement, alias=column.name, copy=False)
193
194                scope.replace(column, replacement)
195
196    return column_tables
197
198
199def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
200    expression = scope.expression
201
202    if not isinstance(expression, exp.Select):
203        return
204
205    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
206
207    def replace_columns(
208        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
209    ) -> None:
210        if not node:
211            return
212
213        for column, *_ in walk_in_scope(node):
214            if not isinstance(column, exp.Column):
215                continue
216
217            table = resolver.get_table(column.name) if resolve_table and not column.table else None
218            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
219            double_agg = (
220                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
221                if alias_expr
222                else False
223            )
224
225            if table and (not alias_expr or double_agg):
226                column.set("table", table)
227            elif not column.table and alias_expr and not double_agg:
228                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
229                    if literal_index:
230                        column.replace(exp.Literal.number(i))
231                else:
232                    column = column.replace(exp.paren(alias_expr))
233                    simplified = simplify_parens(column)
234                    if simplified is not column:
235                        column.replace(simplified)
236
237    for i, projection in enumerate(scope.expression.selects):
238        replace_columns(projection)
239
240        if isinstance(projection, exp.Alias):
241            alias_to_expression[projection.alias] = (projection.this, i + 1)
242
243    replace_columns(expression.args.get("where"))
244    replace_columns(expression.args.get("group"), literal_index=True)
245    replace_columns(expression.args.get("having"), resolve_table=True)
246    replace_columns(expression.args.get("qualify"), resolve_table=True)
247
248    scope.clear_cache()
249
250
251def _expand_group_by(scope: Scope) -> None:
252    expression = scope.expression
253    group = expression.args.get("group")
254    if not group:
255        return
256
257    group.set("expressions", _expand_positional_references(scope, group.expressions))
258    expression.set("group", group)
259
260
261def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
262    order = scope.expression.args.get("order")
263    if not order:
264        return
265
266    ordereds = order.expressions
267    for ordered, new_expression in zip(
268        ordereds,
269        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
270    ):
271        for agg in ordered.find_all(exp.AggFunc):
272            for col in agg.find_all(exp.Column):
273                if not col.table:
274                    col.set("table", resolver.get_table(col.name))
275
276        ordered.set("this", new_expression)
277
278    if scope.expression.args.get("group"):
279        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
280
281        for ordered in ordereds:
282            ordered = ordered.this
283
284            ordered.replace(
285                exp.to_identifier(_select_by_pos(scope, ordered).alias)
286                if ordered.is_int
287                else selects.get(ordered, ordered)
288            )
289
290
291def _expand_positional_references(
292    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
293) -> t.List[exp.Expression]:
294    new_nodes: t.List[exp.Expression] = []
295    for node in expressions:
296        if node.is_int:
297            select = _select_by_pos(scope, t.cast(exp.Literal, node))
298
299            if alias:
300                new_nodes.append(exp.column(select.args["alias"].copy()))
301            else:
302                select = select.this
303
304                if isinstance(select, exp.Literal):
305                    new_nodes.append(node)
306                else:
307                    new_nodes.append(select.copy())
308        else:
309            new_nodes.append(node)
310
311    return new_nodes
312
313
314def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
315    try:
316        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
317    except IndexError:
318        raise OptimizeError(f"Unknown output column: {node.name}")
319
320
321def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
322    """Disambiguate columns, ensuring each column specifies a source"""
323    for column in scope.columns:
324        column_table = column.table
325        column_name = column.name
326
327        if column_table and column_table in scope.sources:
328            source_columns = resolver.get_source_columns(column_table)
329            if source_columns and column_name not in source_columns and "*" not in source_columns:
330                raise OptimizeError(f"Unknown column: {column_name}")
331
332        if not column_table:
333            if scope.pivots and not column.find_ancestor(exp.Pivot):
334                # If the column is under the Pivot expression, we need to qualify it
335                # using the name of the pivoted source instead of the pivot's alias
336                column.set("table", exp.to_identifier(scope.pivots[0].alias))
337                continue
338
339            column_table = resolver.get_table(column_name)
340
341            # column_table can be a '' because bigquery unnest has no table alias
342            if column_table:
343                column.set("table", column_table)
344        elif column_table not in scope.sources and (
345            not scope.parent
346            or column_table not in scope.parent.sources
347            or not scope.is_correlated_subquery
348        ):
349            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
350            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
351
352            root, *parts = column.parts
353
354            if root.name in scope.sources:
355                # struct is already qualified, but we still need to change the AST representation
356                column_table = root
357                root, *parts = parts
358            else:
359                column_table = resolver.get_table(root.name)
360
361            if column_table:
362                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
363
364    for pivot in scope.pivots:
365        for column in pivot.find_all(exp.Column):
366            if not column.table and column.name in resolver.all_columns:
367                column_table = resolver.get_table(column.name)
368                if column_table:
369                    column.set("table", column_table)
370
371
372def _expand_stars(
373    scope: Scope,
374    resolver: Resolver,
375    using_column_tables: t.Dict[str, t.Any],
376    pseudocolumns: t.Set[str],
377) -> None:
378    """Expand stars to lists of column selections"""
379
380    new_selections = []
381    except_columns: t.Dict[int, t.Set[str]] = {}
382    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
383    coalesced_columns = set()
384
385    pivot_output_columns = None
386    pivot_exclude_columns = None
387
388    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
389    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
390        if pivot.unpivot:
391            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
392
393            field = pivot.args.get("field")
394            if isinstance(field, exp.In):
395                pivot_exclude_columns = {
396                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
397                }
398        else:
399            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
400
401            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
402            if not pivot_output_columns:
403                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
404
405    for expression in scope.expression.selects:
406        if isinstance(expression, exp.Star):
407            tables = list(scope.selected_sources)
408            _add_except_columns(expression, tables, except_columns)
409            _add_replace_columns(expression, tables, replace_columns)
410        elif expression.is_star:
411            tables = [expression.table]
412            _add_except_columns(expression.this, tables, except_columns)
413            _add_replace_columns(expression.this, tables, replace_columns)
414        else:
415            new_selections.append(expression)
416            continue
417
418        for table in tables:
419            if table not in scope.sources:
420                raise OptimizeError(f"Unknown table: {table}")
421
422            columns = resolver.get_source_columns(table, only_visible=True)
423            columns = columns or scope.outer_column_list
424
425            if pseudocolumns:
426                columns = [name for name in columns if name.upper() not in pseudocolumns]
427
428            if not columns or "*" in columns:
429                return
430
431            table_id = id(table)
432            columns_to_exclude = except_columns.get(table_id) or set()
433
434            if pivot:
435                if pivot_output_columns and pivot_exclude_columns:
436                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
437                    pivot_columns.extend(pivot_output_columns)
438                else:
439                    pivot_columns = pivot.alias_column_names
440
441                if pivot_columns:
442                    new_selections.extend(
443                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
444                        for name in pivot_columns
445                        if name not in columns_to_exclude
446                    )
447                    continue
448
449            for name in columns:
450                if name in using_column_tables and table in using_column_tables[name]:
451                    if name in coalesced_columns:
452                        continue
453
454                    coalesced_columns.add(name)
455                    tables = using_column_tables[name]
456                    coalesce = [exp.column(name, table=table) for table in tables]
457
458                    new_selections.append(
459                        alias(
460                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
461                            alias=name,
462                            copy=False,
463                        )
464                    )
465                elif name not in columns_to_exclude:
466                    alias_ = replace_columns.get(table_id, {}).get(name, name)
467                    column = exp.column(name, table=table)
468                    new_selections.append(
469                        alias(column, alias_, copy=False) if alias_ != name else column
470                    )
471
472    # Ensures we don't overwrite the initial selections with an empty list
473    if new_selections:
474        scope.expression.set("expressions", new_selections)
475
476
477def _add_except_columns(
478    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
479) -> None:
480    except_ = expression.args.get("except")
481
482    if not except_:
483        return
484
485    columns = {e.name for e in except_}
486
487    for table in tables:
488        except_columns[id(table)] = columns
489
490
491def _add_replace_columns(
492    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
493) -> None:
494    replace = expression.args.get("replace")
495
496    if not replace:
497        return
498
499    columns = {e.this.name: e.alias for e in replace}
500
501    for table in tables:
502        replace_columns[id(table)] = columns
503
504
505def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
506    """Ensure all output columns are aliased"""
507    if isinstance(scope_or_expression, exp.Expression):
508        scope = build_scope(scope_or_expression)
509        if not isinstance(scope, Scope):
510            return
511    else:
512        scope = scope_or_expression
513
514    new_selections = []
515    for i, (selection, aliased_column) in enumerate(
516        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
517    ):
518        if selection is None:
519            break
520
521        if isinstance(selection, exp.Subquery):
522            if not selection.output_name:
523                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
524        elif not isinstance(selection, exp.Alias) and not selection.is_star:
525            selection = alias(
526                selection,
527                alias=selection.output_name or f"_col_{i}",
528                copy=False,
529            )
530        if aliased_column:
531            selection.set("alias", exp.to_identifier(aliased_column))
532
533        new_selections.append(selection)
534
535    scope.expression.set("expressions", new_selections)
536
537
538def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
539    """Makes sure all identifiers that need to be quoted are quoted."""
540    return expression.transform(
541        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
542    )
543
544
545def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
546    """
547    Pushes down the CTE alias columns into the projection,
548
549    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
550
551    Example:
552        >>> import sqlglot
553        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
554        >>> pushdown_cte_alias_columns(expression).sql()
555        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
556
557    Args:
558        expression: Expression to pushdown.
559
560    Returns:
561        The expression with the CTE aliases pushed down into the projection.
562    """
563    for cte in expression.find_all(exp.CTE):
564        if cte.alias_column_names:
565            new_expressions = []
566            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
567                if isinstance(projection, exp.Alias):
568                    projection.set("alias", _alias)
569                else:
570                    projection = alias(projection, alias=_alias)
571                new_expressions.append(projection)
572            cte.this.set("expressions", new_expressions)
573
574    return expression
575
576
577class Resolver:
578    """
579    Helper for resolving columns.
580
581    This is a class so we can lazily load some things and easily share them across functions.
582    """
583
584    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
585        self.scope = scope
586        self.schema = schema
587        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
588        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
589        self._all_columns: t.Optional[t.Set[str]] = None
590        self._infer_schema = infer_schema
591
592    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
593        """
594        Get the table for a column name.
595
596        Args:
597            column_name: The column name to find the table for.
598        Returns:
599            The table name if it can be found/inferred.
600        """
601        if self._unambiguous_columns is None:
602            self._unambiguous_columns = self._get_unambiguous_columns(
603                self._get_all_source_columns()
604            )
605
606        table_name = self._unambiguous_columns.get(column_name)
607
608        if not table_name and self._infer_schema:
609            sources_without_schema = tuple(
610                source
611                for source, columns in self._get_all_source_columns().items()
612                if not columns or "*" in columns
613            )
614            if len(sources_without_schema) == 1:
615                table_name = sources_without_schema[0]
616
617        if table_name not in self.scope.selected_sources:
618            return exp.to_identifier(table_name)
619
620        node, _ = self.scope.selected_sources.get(table_name)
621
622        if isinstance(node, exp.Subqueryable):
623            while node and node.alias != table_name:
624                node = node.parent
625
626        node_alias = node.args.get("alias")
627        if node_alias:
628            return exp.to_identifier(node_alias.this)
629
630        return exp.to_identifier(table_name)
631
632    @property
633    def all_columns(self) -> t.Set[str]:
634        """All available columns of all sources in this scope"""
635        if self._all_columns is None:
636            self._all_columns = {
637                column for columns in self._get_all_source_columns().values() for column in columns
638            }
639        return self._all_columns
640
641    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
642        """Resolve the source columns for a given source `name`."""
643        if name not in self.scope.sources:
644            raise OptimizeError(f"Unknown table: {name}")
645
646        source = self.scope.sources[name]
647
648        if isinstance(source, exp.Table):
649            columns = self.schema.column_names(source, only_visible)
650        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
651            columns = source.expression.alias_column_names
652        else:
653            columns = source.expression.named_selects
654
655        node, _ = self.scope.selected_sources.get(name) or (None, None)
656        if isinstance(node, Scope):
657            column_aliases = node.expression.alias_column_names
658        elif isinstance(node, exp.Expression):
659            column_aliases = node.alias_column_names
660        else:
661            column_aliases = []
662
663        # If the source's columns are aliased, their aliases shadow the corresponding column names
664        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
665
666    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
667        if self._source_columns is None:
668            self._source_columns = {
669                source_name: self.get_source_columns(source_name)
670                for source_name, source in itertools.chain(
671                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
672                )
673            }
674        return self._source_columns
675
676    def _get_unambiguous_columns(
677        self, source_columns: t.Dict[str, t.List[str]]
678    ) -> t.Dict[str, str]:
679        """
680        Find all the unambiguous columns in sources.
681
682        Args:
683            source_columns: Mapping of names to source columns.
684
685        Returns:
686            Mapping of column name to source name.
687        """
688        if not source_columns:
689            return {}
690
691        source_columns_pairs = list(source_columns.items())
692
693        first_table, first_columns = source_columns_pairs[0]
694        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
695        all_columns = set(unambiguous_columns)
696
697        for table, columns in source_columns_pairs[1:]:
698            unique = self._find_unique_columns(columns)
699            ambiguous = set(all_columns).intersection(unique)
700            all_columns.update(columns)
701
702            for column in ambiguous:
703                unambiguous_columns.pop(column, None)
704            for column in unique.difference(ambiguous):
705                unambiguous_columns[column] = table
706
707        return unambiguous_columns
708
709    @staticmethod
710    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
711        """
712        Find the unique columns in a list of columns.
713
714        Example:
715            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
716            ['a', 'c']
717
718        This is necessary because duplicate column names are ambiguous.
719        """
720        counts: t.Dict[str, int] = {}
721        for column in columns:
722            counts[column] = counts.get(column, 0) + 1
723        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns(
18    expression: exp.Expression,
19    schema: t.Dict | Schema,
20    expand_alias_refs: bool = True,
21    expand_stars: bool = True,
22    infer_schema: t.Optional[bool] = None,
23) -> exp.Expression:
24    """
25    Rewrite sqlglot AST to have fully qualified columns.
26
27    Example:
28        >>> import sqlglot
29        >>> schema = {"tbl": {"col": "INT"}}
30        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
31        >>> qualify_columns(expression, schema).sql()
32        'SELECT tbl.col AS col FROM tbl'
33
34    Args:
35        expression: Expression to qualify.
36        schema: Database schema.
37        expand_alias_refs: Whether or not to expand references to aliases.
38        expand_stars: Whether or not to expand star queries. This is a necessary step
39            for most of the optimizer's rules to work; do not set to False unless you
40            know what you're doing!
41        infer_schema: Whether or not to infer the schema if missing.
42
43    Returns:
44        The qualified expression.
45
46    Notes:
47        - Currently only handles a single PIVOT or UNPIVOT operator
48    """
49    schema = ensure_schema(schema)
50    infer_schema = schema.empty if infer_schema is None else infer_schema
51    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
52
53    for scope in traverse_scope(expression):
54        resolver = Resolver(scope, schema, infer_schema=infer_schema)
55        _pop_table_column_aliases(scope.ctes)
56        _pop_table_column_aliases(scope.derived_tables)
57        using_column_tables = _expand_using(scope, resolver)
58
59        if schema.empty and expand_alias_refs:
60            _expand_alias_refs(scope, resolver)
61
62        _qualify_columns(scope, resolver)
63
64        if not schema.empty and expand_alias_refs:
65            _expand_alias_refs(scope, resolver)
66
67        if not isinstance(scope.expression, exp.UDTF):
68            if expand_stars:
69                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
70            qualify_outputs(scope)
71
72        _expand_group_by(scope)
73        _expand_order_by(scope, resolver)
74
75    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 78def validate_qualify_columns(expression: E) -> E:
 79    """Raise an `OptimizeError` if any columns aren't qualified"""
 80    all_unqualified_columns = []
 81    for scope in traverse_scope(expression):
 82        if isinstance(scope.expression, exp.Select):
 83            unqualified_columns = scope.unqualified_columns
 84
 85            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 86                column = scope.external_columns[0]
 87                for_table = f" for table: '{column.table}'" if column.table else ""
 88                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 89
 90            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 91                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 92                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 93                # this list here to ensure those in the former category will be excluded.
 94                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 95                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 96
 97            all_unqualified_columns.extend(unqualified_columns)
 98
 99    if all_unqualified_columns:
100        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
101
102    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
506def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
507    """Ensure all output columns are aliased"""
508    if isinstance(scope_or_expression, exp.Expression):
509        scope = build_scope(scope_or_expression)
510        if not isinstance(scope, Scope):
511            return
512    else:
513        scope = scope_or_expression
514
515    new_selections = []
516    for i, (selection, aliased_column) in enumerate(
517        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
518    ):
519        if selection is None:
520            break
521
522        if isinstance(selection, exp.Subquery):
523            if not selection.output_name:
524                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
525        elif not isinstance(selection, exp.Alias) and not selection.is_star:
526            selection = alias(
527                selection,
528                alias=selection.output_name or f"_col_{i}",
529                copy=False,
530            )
531        if aliased_column:
532            selection.set("alias", exp.to_identifier(aliased_column))
533
534        new_selections.append(selection)
535
536    scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
539def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
540    """Makes sure all identifiers that need to be quoted are quoted."""
541    return expression.transform(
542        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
543    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
546def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
547    """
548    Pushes down the CTE alias columns into the projection,
549
550    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
551
552    Example:
553        >>> import sqlglot
554        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
555        >>> pushdown_cte_alias_columns(expression).sql()
556        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
557
558    Args:
559        expression: Expression to pushdown.
560
561    Returns:
562        The expression with the CTE aliases pushed down into the projection.
563    """
564    for cte in expression.find_all(exp.CTE):
565        if cte.alias_column_names:
566            new_expressions = []
567            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
568                if isinstance(projection, exp.Alias):
569                    projection.set("alias", _alias)
570                else:
571                    projection = alias(projection, alias=_alias)
572                new_expressions.append(projection)
573            cte.this.set("expressions", new_expressions)
574
575    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
578class Resolver:
579    """
580    Helper for resolving columns.
581
582    This is a class so we can lazily load some things and easily share them across functions.
583    """
584
585    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
586        self.scope = scope
587        self.schema = schema
588        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
589        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
590        self._all_columns: t.Optional[t.Set[str]] = None
591        self._infer_schema = infer_schema
592
593    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
594        """
595        Get the table for a column name.
596
597        Args:
598            column_name: The column name to find the table for.
599        Returns:
600            The table name if it can be found/inferred.
601        """
602        if self._unambiguous_columns is None:
603            self._unambiguous_columns = self._get_unambiguous_columns(
604                self._get_all_source_columns()
605            )
606
607        table_name = self._unambiguous_columns.get(column_name)
608
609        if not table_name and self._infer_schema:
610            sources_without_schema = tuple(
611                source
612                for source, columns in self._get_all_source_columns().items()
613                if not columns or "*" in columns
614            )
615            if len(sources_without_schema) == 1:
616                table_name = sources_without_schema[0]
617
618        if table_name not in self.scope.selected_sources:
619            return exp.to_identifier(table_name)
620
621        node, _ = self.scope.selected_sources.get(table_name)
622
623        if isinstance(node, exp.Subqueryable):
624            while node and node.alias != table_name:
625                node = node.parent
626
627        node_alias = node.args.get("alias")
628        if node_alias:
629            return exp.to_identifier(node_alias.this)
630
631        return exp.to_identifier(table_name)
632
633    @property
634    def all_columns(self) -> t.Set[str]:
635        """All available columns of all sources in this scope"""
636        if self._all_columns is None:
637            self._all_columns = {
638                column for columns in self._get_all_source_columns().values() for column in columns
639            }
640        return self._all_columns
641
642    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
643        """Resolve the source columns for a given source `name`."""
644        if name not in self.scope.sources:
645            raise OptimizeError(f"Unknown table: {name}")
646
647        source = self.scope.sources[name]
648
649        if isinstance(source, exp.Table):
650            columns = self.schema.column_names(source, only_visible)
651        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
652            columns = source.expression.alias_column_names
653        else:
654            columns = source.expression.named_selects
655
656        node, _ = self.scope.selected_sources.get(name) or (None, None)
657        if isinstance(node, Scope):
658            column_aliases = node.expression.alias_column_names
659        elif isinstance(node, exp.Expression):
660            column_aliases = node.alias_column_names
661        else:
662            column_aliases = []
663
664        # If the source's columns are aliased, their aliases shadow the corresponding column names
665        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
666
667    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
668        if self._source_columns is None:
669            self._source_columns = {
670                source_name: self.get_source_columns(source_name)
671                for source_name, source in itertools.chain(
672                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
673                )
674            }
675        return self._source_columns
676
677    def _get_unambiguous_columns(
678        self, source_columns: t.Dict[str, t.List[str]]
679    ) -> t.Dict[str, str]:
680        """
681        Find all the unambiguous columns in sources.
682
683        Args:
684            source_columns: Mapping of names to source columns.
685
686        Returns:
687            Mapping of column name to source name.
688        """
689        if not source_columns:
690            return {}
691
692        source_columns_pairs = list(source_columns.items())
693
694        first_table, first_columns = source_columns_pairs[0]
695        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
696        all_columns = set(unambiguous_columns)
697
698        for table, columns in source_columns_pairs[1:]:
699            unique = self._find_unique_columns(columns)
700            ambiguous = set(all_columns).intersection(unique)
701            all_columns.update(columns)
702
703            for column in ambiguous:
704                unambiguous_columns.pop(column, None)
705            for column in unique.difference(ambiguous):
706                unambiguous_columns[column] = table
707
708        return unambiguous_columns
709
710    @staticmethod
711    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
712        """
713        Find the unique columns in a list of columns.
714
715        Example:
716            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
717            ['a', 'c']
718
719        This is necessary because duplicate column names are ambiguous.
720        """
721        counts: t.Dict[str, int] = {}
722        for column in columns:
723            counts[column] = counts.get(column, 0) + 1
724        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
585    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
586        self.scope = scope
587        self.schema = schema
588        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
589        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
590        self._all_columns: t.Optional[t.Set[str]] = None
591        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
593    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
594        """
595        Get the table for a column name.
596
597        Args:
598            column_name: The column name to find the table for.
599        Returns:
600            The table name if it can be found/inferred.
601        """
602        if self._unambiguous_columns is None:
603            self._unambiguous_columns = self._get_unambiguous_columns(
604                self._get_all_source_columns()
605            )
606
607        table_name = self._unambiguous_columns.get(column_name)
608
609        if not table_name and self._infer_schema:
610            sources_without_schema = tuple(
611                source
612                for source, columns in self._get_all_source_columns().items()
613                if not columns or "*" in columns
614            )
615            if len(sources_without_schema) == 1:
616                table_name = sources_without_schema[0]
617
618        if table_name not in self.scope.selected_sources:
619            return exp.to_identifier(table_name)
620
621        node, _ = self.scope.selected_sources.get(table_name)
622
623        if isinstance(node, exp.Subqueryable):
624            while node and node.alias != table_name:
625                node = node.parent
626
627        node_alias = node.args.get("alias")
628        if node_alias:
629            return exp.to_identifier(node_alias.this)
630
631        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
633    @property
634    def all_columns(self) -> t.Set[str]:
635        """All available columns of all sources in this scope"""
636        if self._all_columns is None:
637            self._all_columns = {
638                column for columns in self._get_all_source_columns().values() for column in columns
639            }
640        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
642    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
643        """Resolve the source columns for a given source `name`."""
644        if name not in self.scope.sources:
645            raise OptimizeError(f"Unknown table: {name}")
646
647        source = self.scope.sources[name]
648
649        if isinstance(source, exp.Table):
650            columns = self.schema.column_names(source, only_visible)
651        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
652            columns = source.expression.alias_column_names
653        else:
654            columns = source.expression.named_selects
655
656        node, _ = self.scope.selected_sources.get(name) or (None, None)
657        if isinstance(node, Scope):
658            column_aliases = node.expression.alias_column_names
659        elif isinstance(node, exp.Expression):
660            column_aliases = node.alias_column_names
661        else:
662            column_aliases = []
663
664        # If the source's columns are aliased, their aliases shadow the corresponding column names
665        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

Resolve the source columns for a given source name.