Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get
 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 11from sqlglot.optimizer.simplify import simplify_parens
 12from sqlglot.schema import Schema, ensure_schema
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18def qualify_columns(
 19    expression: exp.Expression,
 20    schema: t.Dict | Schema,
 21    expand_alias_refs: bool = True,
 22    expand_stars: bool = True,
 23    infer_schema: t.Optional[bool] = None,
 24) -> exp.Expression:
 25    """
 26    Rewrite sqlglot AST to have fully qualified columns.
 27
 28    Example:
 29        >>> import sqlglot
 30        >>> schema = {"tbl": {"col": "INT"}}
 31        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 32        >>> qualify_columns(expression, schema).sql()
 33        'SELECT tbl.col AS col FROM tbl'
 34
 35    Args:
 36        expression: Expression to qualify.
 37        schema: Database schema.
 38        expand_alias_refs: Whether or not to expand references to aliases.
 39        expand_stars: Whether or not to expand star queries. This is a necessary step
 40            for most of the optimizer's rules to work; do not set to False unless you
 41            know what you're doing!
 42        infer_schema: Whether or not to infer the schema if missing.
 43
 44    Returns:
 45        The qualified expression.
 46
 47    Notes:
 48        - Currently only handles a single PIVOT or UNPIVOT operator
 49    """
 50    schema = ensure_schema(schema)
 51    infer_schema = schema.empty if infer_schema is None else infer_schema
 52    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 53
 54    for scope in traverse_scope(expression):
 55        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 56        _pop_table_column_aliases(scope.ctes)
 57        _pop_table_column_aliases(scope.derived_tables)
 58        using_column_tables = _expand_using(scope, resolver)
 59
 60        if schema.empty and expand_alias_refs:
 61            _expand_alias_refs(scope, resolver)
 62
 63        _qualify_columns(scope, resolver)
 64
 65        if not schema.empty and expand_alias_refs:
 66            _expand_alias_refs(scope, resolver)
 67
 68        if not isinstance(scope.expression, exp.UDTF):
 69            if expand_stars:
 70                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 71            qualify_outputs(scope)
 72
 73        _expand_group_by(scope)
 74        _expand_order_by(scope, resolver)
 75
 76    return expression
 77
 78
 79def validate_qualify_columns(expression: E) -> E:
 80    """Raise an `OptimizeError` if any columns aren't qualified"""
 81    all_unqualified_columns = []
 82    for scope in traverse_scope(expression):
 83        if isinstance(scope.expression, exp.Select):
 84            unqualified_columns = scope.unqualified_columns
 85
 86            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 87                column = scope.external_columns[0]
 88                for_table = f" for table: '{column.table}'" if column.table else ""
 89                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 90
 91            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 92                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 93                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 94                # this list here to ensure those in the former category will be excluded.
 95                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 96                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 97
 98            all_unqualified_columns.extend(unqualified_columns)
 99
100    if all_unqualified_columns:
101        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
102
103    return expression
104
105
106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
107    name_column = []
108    field = unpivot.args.get("field")
109    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
110        name_column.append(field.this)
111
112    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
113    return itertools.chain(name_column, value_columns)
114
115
116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
117    """
118    Remove table column aliases.
119
120    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
121    """
122    for derived_table in derived_tables:
123        table_alias = derived_table.args.get("alias")
124        if table_alias:
125            table_alias.args.pop("columns", None)
126
127
128def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
129    joins = list(scope.find_all(exp.Join))
130    names = {join.alias_or_name for join in joins}
131    ordered = [key for key in scope.selected_sources if key not in names]
132
133    # Mapping of automatically joined column names to an ordered set of source names (dict).
134    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
135
136    for join in joins:
137        using = join.args.get("using")
138
139        if not using:
140            continue
141
142        join_table = join.alias_or_name
143
144        columns = {}
145
146        for source_name in scope.selected_sources:
147            if source_name in ordered:
148                for column_name in resolver.get_source_columns(source_name):
149                    if column_name not in columns:
150                        columns[column_name] = source_name
151
152        source_table = ordered[-1]
153        ordered.append(join_table)
154        join_columns = resolver.get_source_columns(join_table)
155        conditions = []
156
157        for identifier in using:
158            identifier = identifier.name
159            table = columns.get(identifier)
160
161            if not table or identifier not in join_columns:
162                if (columns and "*" not in columns) and join_columns:
163                    raise OptimizeError(f"Cannot automatically join: {identifier}")
164
165            table = table or source_table
166            conditions.append(
167                exp.condition(
168                    exp.EQ(
169                        this=exp.column(identifier, table=table),
170                        expression=exp.column(identifier, table=join_table),
171                    )
172                )
173            )
174
175            # Set all values in the dict to None, because we only care about the key ordering
176            tables = column_tables.setdefault(identifier, {})
177            if table not in tables:
178                tables[table] = None
179            if join_table not in tables:
180                tables[join_table] = None
181
182        join.args.pop("using")
183        join.set("on", exp.and_(*conditions, copy=False))
184
185    if column_tables:
186        for column in scope.columns:
187            if not column.table and column.name in column_tables:
188                tables = column_tables[column.name]
189                coalesce = [exp.column(column.name, table=table) for table in tables]
190                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
191
192                # Ensure selects keep their output name
193                if isinstance(column.parent, exp.Select):
194                    replacement = alias(replacement, alias=column.name, copy=False)
195
196                scope.replace(column, replacement)
197
198    return column_tables
199
200
201def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
202    expression = scope.expression
203
204    if not isinstance(expression, exp.Select):
205        return
206
207    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
208
209    def replace_columns(
210        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
211    ) -> None:
212        if not node:
213            return
214
215        for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
216            if not isinstance(column, exp.Column):
217                continue
218
219            table = resolver.get_table(column.name) if resolve_table and not column.table else None
220            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
221            double_agg = (
222                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
223                if alias_expr
224                else False
225            )
226
227            if table and (not alias_expr or double_agg):
228                column.set("table", table)
229            elif not column.table and alias_expr and not double_agg:
230                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
231                    if literal_index:
232                        column.replace(exp.Literal.number(i))
233                else:
234                    column = column.replace(exp.paren(alias_expr))
235                    simplified = simplify_parens(column)
236                    if simplified is not column:
237                        column.replace(simplified)
238
239    for i, projection in enumerate(scope.expression.selects):
240        replace_columns(projection)
241
242        if isinstance(projection, exp.Alias):
243            alias_to_expression[projection.alias] = (projection.this, i + 1)
244
245    replace_columns(expression.args.get("where"))
246    replace_columns(expression.args.get("group"), literal_index=True)
247    replace_columns(expression.args.get("having"), resolve_table=True)
248    replace_columns(expression.args.get("qualify"), resolve_table=True)
249
250    scope.clear_cache()
251
252
253def _expand_group_by(scope: Scope) -> None:
254    expression = scope.expression
255    group = expression.args.get("group")
256    if not group:
257        return
258
259    group.set("expressions", _expand_positional_references(scope, group.expressions))
260    expression.set("group", group)
261
262
263def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
264    order = scope.expression.args.get("order")
265    if not order:
266        return
267
268    ordereds = order.expressions
269    for ordered, new_expression in zip(
270        ordereds,
271        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
272    ):
273        for agg in ordered.find_all(exp.AggFunc):
274            for col in agg.find_all(exp.Column):
275                if not col.table:
276                    col.set("table", resolver.get_table(col.name))
277
278        ordered.set("this", new_expression)
279
280    if scope.expression.args.get("group"):
281        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
282
283        for ordered in ordereds:
284            ordered = ordered.this
285
286            ordered.replace(
287                exp.to_identifier(_select_by_pos(scope, ordered).alias)
288                if ordered.is_int
289                else selects.get(ordered, ordered)
290            )
291
292
293def _expand_positional_references(
294    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
295) -> t.List[exp.Expression]:
296    new_nodes: t.List[exp.Expression] = []
297    for node in expressions:
298        if node.is_int:
299            select = _select_by_pos(scope, t.cast(exp.Literal, node))
300
301            if alias:
302                new_nodes.append(exp.column(select.args["alias"].copy()))
303            else:
304                select = select.this
305
306                if isinstance(select, exp.Literal):
307                    new_nodes.append(node)
308                else:
309                    new_nodes.append(select.copy())
310        else:
311            new_nodes.append(node)
312
313    return new_nodes
314
315
316def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
317    try:
318        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
319    except IndexError:
320        raise OptimizeError(f"Unknown output column: {node.name}")
321
322
323def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
324    """Disambiguate columns, ensuring each column specifies a source"""
325    for column in scope.columns:
326        column_table = column.table
327        column_name = column.name
328
329        if column_table and column_table in scope.sources:
330            source_columns = resolver.get_source_columns(column_table)
331            if source_columns and column_name not in source_columns and "*" not in source_columns:
332                raise OptimizeError(f"Unknown column: {column_name}")
333
334        if not column_table:
335            if scope.pivots and not column.find_ancestor(exp.Pivot):
336                # If the column is under the Pivot expression, we need to qualify it
337                # using the name of the pivoted source instead of the pivot's alias
338                column.set("table", exp.to_identifier(scope.pivots[0].alias))
339                continue
340
341            column_table = resolver.get_table(column_name)
342
343            # column_table can be a '' because bigquery unnest has no table alias
344            if column_table:
345                column.set("table", column_table)
346        elif column_table not in scope.sources and (
347            not scope.parent
348            or column_table not in scope.parent.sources
349            or not scope.is_correlated_subquery
350        ):
351            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
352            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
353
354            root, *parts = column.parts
355
356            if root.name in scope.sources:
357                # struct is already qualified, but we still need to change the AST representation
358                column_table = root
359                root, *parts = parts
360            else:
361                column_table = resolver.get_table(root.name)
362
363            if column_table:
364                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
365
366    for pivot in scope.pivots:
367        for column in pivot.find_all(exp.Column):
368            if not column.table and column.name in resolver.all_columns:
369                column_table = resolver.get_table(column.name)
370                if column_table:
371                    column.set("table", column_table)
372
373
374def _expand_stars(
375    scope: Scope,
376    resolver: Resolver,
377    using_column_tables: t.Dict[str, t.Any],
378    pseudocolumns: t.Set[str],
379) -> None:
380    """Expand stars to lists of column selections"""
381
382    new_selections = []
383    except_columns: t.Dict[int, t.Set[str]] = {}
384    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
385    coalesced_columns = set()
386
387    pivot_output_columns = None
388    pivot_exclude_columns = None
389
390    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
391    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
392        if pivot.unpivot:
393            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
394
395            field = pivot.args.get("field")
396            if isinstance(field, exp.In):
397                pivot_exclude_columns = {
398                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
399                }
400        else:
401            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
402
403            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
404            if not pivot_output_columns:
405                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
406
407    for expression in scope.expression.selects:
408        if isinstance(expression, exp.Star):
409            tables = list(scope.selected_sources)
410            _add_except_columns(expression, tables, except_columns)
411            _add_replace_columns(expression, tables, replace_columns)
412        elif expression.is_star:
413            tables = [expression.table]
414            _add_except_columns(expression.this, tables, except_columns)
415            _add_replace_columns(expression.this, tables, replace_columns)
416        else:
417            new_selections.append(expression)
418            continue
419
420        for table in tables:
421            if table not in scope.sources:
422                raise OptimizeError(f"Unknown table: {table}")
423
424            columns = resolver.get_source_columns(table, only_visible=True)
425            columns = columns or scope.outer_column_list
426
427            if pseudocolumns:
428                columns = [name for name in columns if name.upper() not in pseudocolumns]
429
430            if not columns or "*" in columns:
431                return
432
433            table_id = id(table)
434            columns_to_exclude = except_columns.get(table_id) or set()
435
436            if pivot:
437                if pivot_output_columns and pivot_exclude_columns:
438                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
439                    pivot_columns.extend(pivot_output_columns)
440                else:
441                    pivot_columns = pivot.alias_column_names
442
443                if pivot_columns:
444                    new_selections.extend(
445                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
446                        for name in pivot_columns
447                        if name not in columns_to_exclude
448                    )
449                    continue
450
451            for name in columns:
452                if name in using_column_tables and table in using_column_tables[name]:
453                    if name in coalesced_columns:
454                        continue
455
456                    coalesced_columns.add(name)
457                    tables = using_column_tables[name]
458                    coalesce = [exp.column(name, table=table) for table in tables]
459
460                    new_selections.append(
461                        alias(
462                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
463                            alias=name,
464                            copy=False,
465                        )
466                    )
467                elif name not in columns_to_exclude:
468                    alias_ = replace_columns.get(table_id, {}).get(name, name)
469                    column = exp.column(name, table=table)
470                    new_selections.append(
471                        alias(column, alias_, copy=False) if alias_ != name else column
472                    )
473
474    # Ensures we don't overwrite the initial selections with an empty list
475    if new_selections:
476        scope.expression.set("expressions", new_selections)
477
478
479def _add_except_columns(
480    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
481) -> None:
482    except_ = expression.args.get("except")
483
484    if not except_:
485        return
486
487    columns = {e.name for e in except_}
488
489    for table in tables:
490        except_columns[id(table)] = columns
491
492
493def _add_replace_columns(
494    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
495) -> None:
496    replace = expression.args.get("replace")
497
498    if not replace:
499        return
500
501    columns = {e.this.name: e.alias for e in replace}
502
503    for table in tables:
504        replace_columns[id(table)] = columns
505
506
507def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
508    """Ensure all output columns are aliased"""
509    if isinstance(scope_or_expression, exp.Expression):
510        scope = build_scope(scope_or_expression)
511        if not isinstance(scope, Scope):
512            return
513    else:
514        scope = scope_or_expression
515
516    new_selections = []
517    for i, (selection, aliased_column) in enumerate(
518        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
519    ):
520        if selection is None:
521            break
522
523        if isinstance(selection, exp.Subquery):
524            if not selection.output_name:
525                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
526        elif not isinstance(selection, exp.Alias) and not selection.is_star:
527            selection = alias(
528                selection,
529                alias=selection.output_name or f"_col_{i}",
530                copy=False,
531            )
532        if aliased_column:
533            selection.set("alias", exp.to_identifier(aliased_column))
534
535        new_selections.append(selection)
536
537    scope.expression.set("expressions", new_selections)
538
539
540def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
541    """Makes sure all identifiers that need to be quoted are quoted."""
542    return expression.transform(
543        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
544    )
545
546
547def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
548    """
549    Pushes down the CTE alias columns into the projection,
550
551    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
552
553    Example:
554        >>> import sqlglot
555        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
556        >>> pushdown_cte_alias_columns(expression).sql()
557        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
558
559    Args:
560        expression: Expression to pushdown.
561
562    Returns:
563        The expression with the CTE aliases pushed down into the projection.
564    """
565    for cte in expression.find_all(exp.CTE):
566        if cte.alias_column_names:
567            new_expressions = []
568            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
569                if isinstance(projection, exp.Alias):
570                    projection.set("alias", _alias)
571                else:
572                    projection = alias(projection, alias=_alias)
573                new_expressions.append(projection)
574            cte.this.set("expressions", new_expressions)
575
576    return expression
577
578
579class Resolver:
580    """
581    Helper for resolving columns.
582
583    This is a class so we can lazily load some things and easily share them across functions.
584    """
585
586    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
587        self.scope = scope
588        self.schema = schema
589        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
590        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
591        self._all_columns: t.Optional[t.Set[str]] = None
592        self._infer_schema = infer_schema
593
594    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
595        """
596        Get the table for a column name.
597
598        Args:
599            column_name: The column name to find the table for.
600        Returns:
601            The table name if it can be found/inferred.
602        """
603        if self._unambiguous_columns is None:
604            self._unambiguous_columns = self._get_unambiguous_columns(
605                self._get_all_source_columns()
606            )
607
608        table_name = self._unambiguous_columns.get(column_name)
609
610        if not table_name and self._infer_schema:
611            sources_without_schema = tuple(
612                source
613                for source, columns in self._get_all_source_columns().items()
614                if not columns or "*" in columns
615            )
616            if len(sources_without_schema) == 1:
617                table_name = sources_without_schema[0]
618
619        if table_name not in self.scope.selected_sources:
620            return exp.to_identifier(table_name)
621
622        node, _ = self.scope.selected_sources.get(table_name)
623
624        if isinstance(node, exp.Subqueryable):
625            while node and node.alias != table_name:
626                node = node.parent
627
628        node_alias = node.args.get("alias")
629        if node_alias:
630            return exp.to_identifier(node_alias.this)
631
632        return exp.to_identifier(table_name)
633
634    @property
635    def all_columns(self) -> t.Set[str]:
636        """All available columns of all sources in this scope"""
637        if self._all_columns is None:
638            self._all_columns = {
639                column for columns in self._get_all_source_columns().values() for column in columns
640            }
641        return self._all_columns
642
643    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
644        """Resolve the source columns for a given source `name`."""
645        if name not in self.scope.sources:
646            raise OptimizeError(f"Unknown table: {name}")
647
648        source = self.scope.sources[name]
649
650        if isinstance(source, exp.Table):
651            columns = self.schema.column_names(source, only_visible)
652        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
653            columns = source.expression.alias_column_names
654        else:
655            columns = source.expression.named_selects
656
657        node, _ = self.scope.selected_sources.get(name) or (None, None)
658        if isinstance(node, Scope):
659            column_aliases = node.expression.alias_column_names
660        elif isinstance(node, exp.Expression):
661            column_aliases = node.alias_column_names
662        else:
663            column_aliases = []
664
665        if column_aliases:
666            # If the source's columns are aliased, their aliases shadow the corresponding column names.
667            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
668            return [
669                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
670            ]
671        return columns
672
673    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
674        if self._source_columns is None:
675            self._source_columns = {
676                source_name: self.get_source_columns(source_name)
677                for source_name, source in itertools.chain(
678                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
679                )
680            }
681        return self._source_columns
682
683    def _get_unambiguous_columns(
684        self, source_columns: t.Dict[str, t.List[str]]
685    ) -> t.Dict[str, str]:
686        """
687        Find all the unambiguous columns in sources.
688
689        Args:
690            source_columns: Mapping of names to source columns.
691
692        Returns:
693            Mapping of column name to source name.
694        """
695        if not source_columns:
696            return {}
697
698        source_columns_pairs = list(source_columns.items())
699
700        first_table, first_columns = source_columns_pairs[0]
701        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
702        all_columns = set(unambiguous_columns)
703
704        for table, columns in source_columns_pairs[1:]:
705            unique = self._find_unique_columns(columns)
706            ambiguous = set(all_columns).intersection(unique)
707            all_columns.update(columns)
708
709            for column in ambiguous:
710                unambiguous_columns.pop(column, None)
711            for column in unique.difference(ambiguous):
712                unambiguous_columns[column] = table
713
714        return unambiguous_columns
715
716    @staticmethod
717    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
718        """
719        Find the unique columns in a list of columns.
720
721        Example:
722            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
723            ['a', 'c']
724
725        This is necessary because duplicate column names are ambiguous.
726        """
727        counts: t.Dict[str, int] = {}
728        for column in columns:
729            counts[column] = counts.get(column, 0) + 1
730        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns(
20    expression: exp.Expression,
21    schema: t.Dict | Schema,
22    expand_alias_refs: bool = True,
23    expand_stars: bool = True,
24    infer_schema: t.Optional[bool] = None,
25) -> exp.Expression:
26    """
27    Rewrite sqlglot AST to have fully qualified columns.
28
29    Example:
30        >>> import sqlglot
31        >>> schema = {"tbl": {"col": "INT"}}
32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
33        >>> qualify_columns(expression, schema).sql()
34        'SELECT tbl.col AS col FROM tbl'
35
36    Args:
37        expression: Expression to qualify.
38        schema: Database schema.
39        expand_alias_refs: Whether or not to expand references to aliases.
40        expand_stars: Whether or not to expand star queries. This is a necessary step
41            for most of the optimizer's rules to work; do not set to False unless you
42            know what you're doing!
43        infer_schema: Whether or not to infer the schema if missing.
44
45    Returns:
46        The qualified expression.
47
48    Notes:
49        - Currently only handles a single PIVOT or UNPIVOT operator
50    """
51    schema = ensure_schema(schema)
52    infer_schema = schema.empty if infer_schema is None else infer_schema
53    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
54
55    for scope in traverse_scope(expression):
56        resolver = Resolver(scope, schema, infer_schema=infer_schema)
57        _pop_table_column_aliases(scope.ctes)
58        _pop_table_column_aliases(scope.derived_tables)
59        using_column_tables = _expand_using(scope, resolver)
60
61        if schema.empty and expand_alias_refs:
62            _expand_alias_refs(scope, resolver)
63
64        _qualify_columns(scope, resolver)
65
66        if not schema.empty and expand_alias_refs:
67            _expand_alias_refs(scope, resolver)
68
69        if not isinstance(scope.expression, exp.UDTF):
70            if expand_stars:
71                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
72            qualify_outputs(scope)
73
74        _expand_group_by(scope)
75        _expand_order_by(scope, resolver)
76
77    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 80def validate_qualify_columns(expression: E) -> E:
 81    """Raise an `OptimizeError` if any columns aren't qualified"""
 82    all_unqualified_columns = []
 83    for scope in traverse_scope(expression):
 84        if isinstance(scope.expression, exp.Select):
 85            unqualified_columns = scope.unqualified_columns
 86
 87            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 88                column = scope.external_columns[0]
 89                for_table = f" for table: '{column.table}'" if column.table else ""
 90                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 91
 92            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 93                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 94                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 95                # this list here to ensure those in the former category will be excluded.
 96                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 97                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 98
 99            all_unqualified_columns.extend(unqualified_columns)
100
101    if all_unqualified_columns:
102        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
103
104    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
508def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
509    """Ensure all output columns are aliased"""
510    if isinstance(scope_or_expression, exp.Expression):
511        scope = build_scope(scope_or_expression)
512        if not isinstance(scope, Scope):
513            return
514    else:
515        scope = scope_or_expression
516
517    new_selections = []
518    for i, (selection, aliased_column) in enumerate(
519        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
520    ):
521        if selection is None:
522            break
523
524        if isinstance(selection, exp.Subquery):
525            if not selection.output_name:
526                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
527        elif not isinstance(selection, exp.Alias) and not selection.is_star:
528            selection = alias(
529                selection,
530                alias=selection.output_name or f"_col_{i}",
531                copy=False,
532            )
533        if aliased_column:
534            selection.set("alias", exp.to_identifier(aliased_column))
535
536        new_selections.append(selection)
537
538    scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
541def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
542    """Makes sure all identifiers that need to be quoted are quoted."""
543    return expression.transform(
544        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
545    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
548def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
549    """
550    Pushes down the CTE alias columns into the projection,
551
552    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
553
554    Example:
555        >>> import sqlglot
556        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
557        >>> pushdown_cte_alias_columns(expression).sql()
558        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
559
560    Args:
561        expression: Expression to pushdown.
562
563    Returns:
564        The expression with the CTE aliases pushed down into the projection.
565    """
566    for cte in expression.find_all(exp.CTE):
567        if cte.alias_column_names:
568            new_expressions = []
569            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
570                if isinstance(projection, exp.Alias):
571                    projection.set("alias", _alias)
572                else:
573                    projection = alias(projection, alias=_alias)
574                new_expressions.append(projection)
575            cte.this.set("expressions", new_expressions)
576
577    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
580class Resolver:
581    """
582    Helper for resolving columns.
583
584    This is a class so we can lazily load some things and easily share them across functions.
585    """
586
587    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
588        self.scope = scope
589        self.schema = schema
590        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
591        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
592        self._all_columns: t.Optional[t.Set[str]] = None
593        self._infer_schema = infer_schema
594
595    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
596        """
597        Get the table for a column name.
598
599        Args:
600            column_name: The column name to find the table for.
601        Returns:
602            The table name if it can be found/inferred.
603        """
604        if self._unambiguous_columns is None:
605            self._unambiguous_columns = self._get_unambiguous_columns(
606                self._get_all_source_columns()
607            )
608
609        table_name = self._unambiguous_columns.get(column_name)
610
611        if not table_name and self._infer_schema:
612            sources_without_schema = tuple(
613                source
614                for source, columns in self._get_all_source_columns().items()
615                if not columns or "*" in columns
616            )
617            if len(sources_without_schema) == 1:
618                table_name = sources_without_schema[0]
619
620        if table_name not in self.scope.selected_sources:
621            return exp.to_identifier(table_name)
622
623        node, _ = self.scope.selected_sources.get(table_name)
624
625        if isinstance(node, exp.Subqueryable):
626            while node and node.alias != table_name:
627                node = node.parent
628
629        node_alias = node.args.get("alias")
630        if node_alias:
631            return exp.to_identifier(node_alias.this)
632
633        return exp.to_identifier(table_name)
634
635    @property
636    def all_columns(self) -> t.Set[str]:
637        """All available columns of all sources in this scope"""
638        if self._all_columns is None:
639            self._all_columns = {
640                column for columns in self._get_all_source_columns().values() for column in columns
641            }
642        return self._all_columns
643
644    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
645        """Resolve the source columns for a given source `name`."""
646        if name not in self.scope.sources:
647            raise OptimizeError(f"Unknown table: {name}")
648
649        source = self.scope.sources[name]
650
651        if isinstance(source, exp.Table):
652            columns = self.schema.column_names(source, only_visible)
653        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
654            columns = source.expression.alias_column_names
655        else:
656            columns = source.expression.named_selects
657
658        node, _ = self.scope.selected_sources.get(name) or (None, None)
659        if isinstance(node, Scope):
660            column_aliases = node.expression.alias_column_names
661        elif isinstance(node, exp.Expression):
662            column_aliases = node.alias_column_names
663        else:
664            column_aliases = []
665
666        if column_aliases:
667            # If the source's columns are aliased, their aliases shadow the corresponding column names.
668            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
669            return [
670                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
671            ]
672        return columns
673
674    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
675        if self._source_columns is None:
676            self._source_columns = {
677                source_name: self.get_source_columns(source_name)
678                for source_name, source in itertools.chain(
679                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
680                )
681            }
682        return self._source_columns
683
684    def _get_unambiguous_columns(
685        self, source_columns: t.Dict[str, t.List[str]]
686    ) -> t.Dict[str, str]:
687        """
688        Find all the unambiguous columns in sources.
689
690        Args:
691            source_columns: Mapping of names to source columns.
692
693        Returns:
694            Mapping of column name to source name.
695        """
696        if not source_columns:
697            return {}
698
699        source_columns_pairs = list(source_columns.items())
700
701        first_table, first_columns = source_columns_pairs[0]
702        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
703        all_columns = set(unambiguous_columns)
704
705        for table, columns in source_columns_pairs[1:]:
706            unique = self._find_unique_columns(columns)
707            ambiguous = set(all_columns).intersection(unique)
708            all_columns.update(columns)
709
710            for column in ambiguous:
711                unambiguous_columns.pop(column, None)
712            for column in unique.difference(ambiguous):
713                unambiguous_columns[column] = table
714
715        return unambiguous_columns
716
717    @staticmethod
718    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
719        """
720        Find the unique columns in a list of columns.
721
722        Example:
723            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
724            ['a', 'c']
725
726        This is necessary because duplicate column names are ambiguous.
727        """
728        counts: t.Dict[str, int] = {}
729        for column in columns:
730            counts[column] = counts.get(column, 0) + 1
731        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
587    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
588        self.scope = scope
589        self.schema = schema
590        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
591        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
592        self._all_columns: t.Optional[t.Set[str]] = None
593        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
595    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
596        """
597        Get the table for a column name.
598
599        Args:
600            column_name: The column name to find the table for.
601        Returns:
602            The table name if it can be found/inferred.
603        """
604        if self._unambiguous_columns is None:
605            self._unambiguous_columns = self._get_unambiguous_columns(
606                self._get_all_source_columns()
607            )
608
609        table_name = self._unambiguous_columns.get(column_name)
610
611        if not table_name and self._infer_schema:
612            sources_without_schema = tuple(
613                source
614                for source, columns in self._get_all_source_columns().items()
615                if not columns or "*" in columns
616            )
617            if len(sources_without_schema) == 1:
618                table_name = sources_without_schema[0]
619
620        if table_name not in self.scope.selected_sources:
621            return exp.to_identifier(table_name)
622
623        node, _ = self.scope.selected_sources.get(table_name)
624
625        if isinstance(node, exp.Subqueryable):
626            while node and node.alias != table_name:
627                node = node.parent
628
629        node_alias = node.args.get("alias")
630        if node_alias:
631            return exp.to_identifier(node_alias.this)
632
633        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
635    @property
636    def all_columns(self) -> t.Set[str]:
637        """All available columns of all sources in this scope"""
638        if self._all_columns is None:
639            self._all_columns = {
640                column for columns in self._get_all_source_columns().values() for column in columns
641            }
642        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
644    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
645        """Resolve the source columns for a given source `name`."""
646        if name not in self.scope.sources:
647            raise OptimizeError(f"Unknown table: {name}")
648
649        source = self.scope.sources[name]
650
651        if isinstance(source, exp.Table):
652            columns = self.schema.column_names(source, only_visible)
653        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
654            columns = source.expression.alias_column_names
655        else:
656            columns = source.expression.named_selects
657
658        node, _ = self.scope.selected_sources.get(name) or (None, None)
659        if isinstance(node, Scope):
660            column_aliases = node.expression.alias_column_names
661        elif isinstance(node, exp.Expression):
662            column_aliases = node.alias_column_names
663        else:
664            column_aliases = []
665
666        if column_aliases:
667            # If the source's columns are aliased, their aliases shadow the corresponding column names.
668            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
669            return [
670                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
671            ]
672        return columns

Resolve the source columns for a given source name.