Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15
 16def qualify_columns(
 17    expression: exp.Expression,
 18    schema: t.Dict | Schema,
 19    expand_alias_refs: bool = True,
 20    infer_schema: t.Optional[bool] = None,
 21) -> exp.Expression:
 22    """
 23    Rewrite sqlglot AST to have fully qualified columns.
 24
 25    Example:
 26        >>> import sqlglot
 27        >>> schema = {"tbl": {"col": "INT"}}
 28        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 29        >>> qualify_columns(expression, schema).sql()
 30        'SELECT tbl.col AS col FROM tbl'
 31
 32    Args:
 33        expression: Expression to qualify.
 34        schema: Database schema.
 35        expand_alias_refs: Whether or not to expand references to aliases.
 36        infer_schema: Whether or not to infer the schema if missing.
 37
 38    Returns:
 39        The qualified expression.
 40    """
 41    schema = ensure_schema(schema)
 42    infer_schema = schema.empty if infer_schema is None else infer_schema
 43    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 44
 45    for scope in traverse_scope(expression):
 46        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 47        _pop_table_column_aliases(scope.ctes)
 48        _pop_table_column_aliases(scope.derived_tables)
 49        using_column_tables = _expand_using(scope, resolver)
 50
 51        if schema.empty and expand_alias_refs:
 52            _expand_alias_refs(scope, resolver)
 53
 54        _qualify_columns(scope, resolver)
 55
 56        if not schema.empty and expand_alias_refs:
 57            _expand_alias_refs(scope, resolver)
 58
 59        if not isinstance(scope.expression, exp.UDTF):
 60            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 61            qualify_outputs(scope)
 62
 63        _expand_group_by(scope)
 64        _expand_order_by(scope, resolver)
 65
 66    return expression
 67
 68
 69def validate_qualify_columns(expression: E) -> E:
 70    """Raise an `OptimizeError` if any columns aren't qualified"""
 71    unqualified_columns = []
 72    for scope in traverse_scope(expression):
 73        if isinstance(scope.expression, exp.Select):
 74            unqualified_columns.extend(scope.unqualified_columns)
 75            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 76                column = scope.external_columns[0]
 77                raise OptimizeError(
 78                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 79                )
 80
 81    if unqualified_columns:
 82        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 83    return expression
 84
 85
 86def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 87    """
 88    Remove table column aliases.
 89
 90    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 91    """
 92    for derived_table in derived_tables:
 93        table_alias = derived_table.args.get("alias")
 94        if table_alias:
 95            table_alias.args.pop("columns", None)
 96
 97
 98def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 99    joins = list(scope.find_all(exp.Join))
100    names = {join.alias_or_name for join in joins}
101    ordered = [key for key in scope.selected_sources if key not in names]
102
103    # Mapping of automatically joined column names to an ordered set of source names (dict).
104    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
105
106    for join in joins:
107        using = join.args.get("using")
108
109        if not using:
110            continue
111
112        join_table = join.alias_or_name
113
114        columns = {}
115
116        for source_name in scope.selected_sources:
117            if source_name in ordered:
118                for column_name in resolver.get_source_columns(source_name):
119                    if column_name not in columns:
120                        columns[column_name] = source_name
121
122        source_table = ordered[-1]
123        ordered.append(join_table)
124        join_columns = resolver.get_source_columns(join_table)
125        conditions = []
126
127        for identifier in using:
128            identifier = identifier.name
129            table = columns.get(identifier)
130
131            if not table or identifier not in join_columns:
132                if (columns and "*" not in columns) and join_columns:
133                    raise OptimizeError(f"Cannot automatically join: {identifier}")
134
135            table = table or source_table
136            conditions.append(
137                exp.condition(
138                    exp.EQ(
139                        this=exp.column(identifier, table=table),
140                        expression=exp.column(identifier, table=join_table),
141                    )
142                )
143            )
144
145            # Set all values in the dict to None, because we only care about the key ordering
146            tables = column_tables.setdefault(identifier, {})
147            if table not in tables:
148                tables[table] = None
149            if join_table not in tables:
150                tables[join_table] = None
151
152        join.args.pop("using")
153        join.set("on", exp.and_(*conditions, copy=False))
154
155    if column_tables:
156        for column in scope.columns:
157            if not column.table and column.name in column_tables:
158                tables = column_tables[column.name]
159                coalesce = [exp.column(column.name, table=table) for table in tables]
160                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
161
162                # Ensure selects keep their output name
163                if isinstance(column.parent, exp.Select):
164                    replacement = alias(replacement, alias=column.name, copy=False)
165
166                scope.replace(column, replacement)
167
168    return column_tables
169
170
171def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
172    expression = scope.expression
173
174    if not isinstance(expression, exp.Select):
175        return
176
177    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
178
179    def replace_columns(
180        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
181    ) -> None:
182        if not node:
183            return
184
185        for column, *_ in walk_in_scope(node):
186            if not isinstance(column, exp.Column):
187                continue
188
189            table = resolver.get_table(column.name) if resolve_table and not column.table else None
190            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
191            double_agg = (
192                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
193                if alias_expr
194                else False
195            )
196
197            if table and (not alias_expr or double_agg):
198                column.set("table", table)
199            elif not column.table and alias_expr and not double_agg:
200                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
201                    if literal_index:
202                        column.replace(exp.Literal.number(i))
203                else:
204                    column = column.replace(exp.paren(alias_expr))
205                    simplified = simplify_parens(column)
206                    if simplified is not column:
207                        column.replace(simplified)
208
209    for i, projection in enumerate(scope.expression.selects):
210        replace_columns(projection)
211
212        if isinstance(projection, exp.Alias):
213            alias_to_expression[projection.alias] = (projection.this, i + 1)
214
215    replace_columns(expression.args.get("where"))
216    replace_columns(expression.args.get("group"), literal_index=True)
217    replace_columns(expression.args.get("having"), resolve_table=True)
218    replace_columns(expression.args.get("qualify"), resolve_table=True)
219    scope.clear_cache()
220
221
222def _expand_group_by(scope: Scope) -> None:
223    expression = scope.expression
224    group = expression.args.get("group")
225    if not group:
226        return
227
228    group.set("expressions", _expand_positional_references(scope, group.expressions))
229    expression.set("group", group)
230
231
232def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
233    order = scope.expression.args.get("order")
234    if not order:
235        return
236
237    ordereds = order.expressions
238    for ordered, new_expression in zip(
239        ordereds,
240        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
241    ):
242        for agg in ordered.find_all(exp.AggFunc):
243            for col in agg.find_all(exp.Column):
244                if not col.table:
245                    col.set("table", resolver.get_table(col.name))
246
247        ordered.set("this", new_expression)
248
249    if scope.expression.args.get("group"):
250        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
251
252        for ordered in ordereds:
253            ordered = ordered.this
254
255            ordered.replace(
256                exp.to_identifier(_select_by_pos(scope, ordered).alias)
257                if ordered.is_int
258                else selects.get(ordered, ordered)
259            )
260
261
262def _expand_positional_references(
263    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
264) -> t.List[exp.Expression]:
265    new_nodes: t.List[exp.Expression] = []
266    for node in expressions:
267        if node.is_int:
268            select = _select_by_pos(scope, t.cast(exp.Literal, node))
269
270            if alias:
271                new_nodes.append(exp.column(select.args["alias"].copy()))
272            else:
273                select = select.this
274
275                if isinstance(select, exp.Literal):
276                    new_nodes.append(node)
277                else:
278                    new_nodes.append(select.copy())
279        else:
280            new_nodes.append(node)
281
282    return new_nodes
283
284
285def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
286    try:
287        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
288    except IndexError:
289        raise OptimizeError(f"Unknown output column: {node.name}")
290
291
292def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
293    """Disambiguate columns, ensuring each column specifies a source"""
294    for column in scope.columns:
295        column_table = column.table
296        column_name = column.name
297
298        if column_table and column_table in scope.sources:
299            source_columns = resolver.get_source_columns(column_table)
300            if source_columns and column_name not in source_columns and "*" not in source_columns:
301                raise OptimizeError(f"Unknown column: {column_name}")
302
303        if not column_table:
304            if scope.pivots and not column.find_ancestor(exp.Pivot):
305                # If the column is under the Pivot expression, we need to qualify it
306                # using the name of the pivoted source instead of the pivot's alias
307                column.set("table", exp.to_identifier(scope.pivots[0].alias))
308                continue
309
310            column_table = resolver.get_table(column_name)
311
312            # column_table can be a '' because bigquery unnest has no table alias
313            if column_table:
314                column.set("table", column_table)
315        elif column_table not in scope.sources and (
316            not scope.parent
317            or column_table not in scope.parent.sources
318            or not scope.is_correlated_subquery
319        ):
320            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
321            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
322
323            root, *parts = column.parts
324
325            if root.name in scope.sources:
326                # struct is already qualified, but we still need to change the AST representation
327                column_table = root
328                root, *parts = parts
329            else:
330                column_table = resolver.get_table(root.name)
331
332            if column_table:
333                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
334
335    for pivot in scope.pivots:
336        for column in pivot.find_all(exp.Column):
337            if not column.table and column.name in resolver.all_columns:
338                column_table = resolver.get_table(column.name)
339                if column_table:
340                    column.set("table", column_table)
341
342
343def _expand_stars(
344    scope: Scope,
345    resolver: Resolver,
346    using_column_tables: t.Dict[str, t.Any],
347    pseudocolumns: t.Set[str],
348) -> None:
349    """Expand stars to lists of column selections"""
350
351    new_selections = []
352    except_columns: t.Dict[int, t.Set[str]] = {}
353    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
354    coalesced_columns = set()
355
356    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
357    pivot_columns = None
358    pivot_output_columns = None
359    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
360
361    has_pivoted_source = pivot and not pivot.args.get("unpivot")
362    if pivot and has_pivoted_source:
363        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
364
365        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
366        if not pivot_output_columns:
367            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
368
369    for expression in scope.expression.selects:
370        if isinstance(expression, exp.Star):
371            tables = list(scope.selected_sources)
372            _add_except_columns(expression, tables, except_columns)
373            _add_replace_columns(expression, tables, replace_columns)
374        elif expression.is_star:
375            tables = [expression.table]
376            _add_except_columns(expression.this, tables, except_columns)
377            _add_replace_columns(expression.this, tables, replace_columns)
378        else:
379            new_selections.append(expression)
380            continue
381
382        for table in tables:
383            if table not in scope.sources:
384                raise OptimizeError(f"Unknown table: {table}")
385
386            columns = resolver.get_source_columns(table, only_visible=True)
387
388            if pseudocolumns:
389                columns = [name for name in columns if name.upper() not in pseudocolumns]
390
391            if columns and "*" not in columns:
392                table_id = id(table)
393                columns_to_exclude = except_columns.get(table_id) or set()
394
395                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
396                    implicit_columns = [col for col in columns if col not in pivot_columns]
397                    new_selections.extend(
398                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
399                        for name in implicit_columns + pivot_output_columns
400                        if name not in columns_to_exclude
401                    )
402                    continue
403
404                for name in columns:
405                    if name in using_column_tables and table in using_column_tables[name]:
406                        if name in coalesced_columns:
407                            continue
408
409                        coalesced_columns.add(name)
410                        tables = using_column_tables[name]
411                        coalesce = [exp.column(name, table=table) for table in tables]
412
413                        new_selections.append(
414                            alias(
415                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
416                                alias=name,
417                                copy=False,
418                            )
419                        )
420                    elif name not in columns_to_exclude:
421                        alias_ = replace_columns.get(table_id, {}).get(name, name)
422                        column = exp.column(name, table=table)
423                        new_selections.append(
424                            alias(column, alias_, copy=False) if alias_ != name else column
425                        )
426            else:
427                return
428
429    # Ensures we don't overwrite the initial selections with an empty list
430    if new_selections:
431        scope.expression.set("expressions", new_selections)
432
433
434def _add_except_columns(
435    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
436) -> None:
437    except_ = expression.args.get("except")
438
439    if not except_:
440        return
441
442    columns = {e.name for e in except_}
443
444    for table in tables:
445        except_columns[id(table)] = columns
446
447
448def _add_replace_columns(
449    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
450) -> None:
451    replace = expression.args.get("replace")
452
453    if not replace:
454        return
455
456    columns = {e.this.name: e.alias for e in replace}
457
458    for table in tables:
459        replace_columns[id(table)] = columns
460
461
462def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
463    """Ensure all output columns are aliased"""
464    if isinstance(scope_or_expression, exp.Expression):
465        scope = build_scope(scope_or_expression)
466        if not isinstance(scope, Scope):
467            return
468    else:
469        scope = scope_or_expression
470
471    new_selections = []
472    for i, (selection, aliased_column) in enumerate(
473        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
474    ):
475        if isinstance(selection, exp.Subquery):
476            if not selection.output_name:
477                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
478        elif not isinstance(selection, exp.Alias) and not selection.is_star:
479            selection = alias(
480                selection,
481                alias=selection.output_name or f"_col_{i}",
482            )
483        if aliased_column:
484            selection.set("alias", exp.to_identifier(aliased_column))
485
486        new_selections.append(selection)
487
488    scope.expression.set("expressions", new_selections)
489
490
491def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
492    """Makes sure all identifiers that need to be quoted are quoted."""
493    return expression.transform(
494        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
495    )
496
497
498class Resolver:
499    """
500    Helper for resolving columns.
501
502    This is a class so we can lazily load some things and easily share them across functions.
503    """
504
505    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
506        self.scope = scope
507        self.schema = schema
508        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
509        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
510        self._all_columns: t.Optional[t.Set[str]] = None
511        self._infer_schema = infer_schema
512
513    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
514        """
515        Get the table for a column name.
516
517        Args:
518            column_name: The column name to find the table for.
519        Returns:
520            The table name if it can be found/inferred.
521        """
522        if self._unambiguous_columns is None:
523            self._unambiguous_columns = self._get_unambiguous_columns(
524                self._get_all_source_columns()
525            )
526
527        table_name = self._unambiguous_columns.get(column_name)
528
529        if not table_name and self._infer_schema:
530            sources_without_schema = tuple(
531                source
532                for source, columns in self._get_all_source_columns().items()
533                if not columns or "*" in columns
534            )
535            if len(sources_without_schema) == 1:
536                table_name = sources_without_schema[0]
537
538        if table_name not in self.scope.selected_sources:
539            return exp.to_identifier(table_name)
540
541        node, _ = self.scope.selected_sources.get(table_name)
542
543        if isinstance(node, exp.Subqueryable):
544            while node and node.alias != table_name:
545                node = node.parent
546
547        node_alias = node.args.get("alias")
548        if node_alias:
549            return exp.to_identifier(node_alias.this)
550
551        return exp.to_identifier(table_name)
552
553    @property
554    def all_columns(self) -> t.Set[str]:
555        """All available columns of all sources in this scope"""
556        if self._all_columns is None:
557            self._all_columns = {
558                column for columns in self._get_all_source_columns().values() for column in columns
559            }
560        return self._all_columns
561
562    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
563        """Resolve the source columns for a given source `name`."""
564        if name not in self.scope.sources:
565            raise OptimizeError(f"Unknown table: {name}")
566
567        source = self.scope.sources[name]
568
569        if isinstance(source, exp.Table):
570            columns = self.schema.column_names(source, only_visible)
571        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
572            columns = source.expression.alias_column_names
573        else:
574            columns = source.expression.named_selects
575
576        node, _ = self.scope.selected_sources.get(name) or (None, None)
577        if isinstance(node, Scope):
578            column_aliases = node.expression.alias_column_names
579        elif isinstance(node, exp.Expression):
580            column_aliases = node.alias_column_names
581        else:
582            column_aliases = []
583
584        # If the source's columns are aliased, their aliases shadow the corresponding column names
585        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
586
587    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
588        if self._source_columns is None:
589            self._source_columns = {
590                source_name: self.get_source_columns(source_name)
591                for source_name, source in itertools.chain(
592                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
593                )
594            }
595        return self._source_columns
596
597    def _get_unambiguous_columns(
598        self, source_columns: t.Dict[str, t.List[str]]
599    ) -> t.Dict[str, str]:
600        """
601        Find all the unambiguous columns in sources.
602
603        Args:
604            source_columns: Mapping of names to source columns.
605
606        Returns:
607            Mapping of column name to source name.
608        """
609        if not source_columns:
610            return {}
611
612        source_columns_pairs = list(source_columns.items())
613
614        first_table, first_columns = source_columns_pairs[0]
615        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
616        all_columns = set(unambiguous_columns)
617
618        for table, columns in source_columns_pairs[1:]:
619            unique = self._find_unique_columns(columns)
620            ambiguous = set(all_columns).intersection(unique)
621            all_columns.update(columns)
622
623            for column in ambiguous:
624                unambiguous_columns.pop(column, None)
625            for column in unique.difference(ambiguous):
626                unambiguous_columns[column] = table
627
628        return unambiguous_columns
629
630    @staticmethod
631    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
632        """
633        Find the unique columns in a list of columns.
634
635        Example:
636            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
637            ['a', 'c']
638
639        This is necessary because duplicate column names are ambiguous.
640        """
641        counts: t.Dict[str, int] = {}
642        for column in columns:
643            counts[column] = counts.get(column, 0) + 1
644        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns(
18    expression: exp.Expression,
19    schema: t.Dict | Schema,
20    expand_alias_refs: bool = True,
21    infer_schema: t.Optional[bool] = None,
22) -> exp.Expression:
23    """
24    Rewrite sqlglot AST to have fully qualified columns.
25
26    Example:
27        >>> import sqlglot
28        >>> schema = {"tbl": {"col": "INT"}}
29        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
30        >>> qualify_columns(expression, schema).sql()
31        'SELECT tbl.col AS col FROM tbl'
32
33    Args:
34        expression: Expression to qualify.
35        schema: Database schema.
36        expand_alias_refs: Whether or not to expand references to aliases.
37        infer_schema: Whether or not to infer the schema if missing.
38
39    Returns:
40        The qualified expression.
41    """
42    schema = ensure_schema(schema)
43    infer_schema = schema.empty if infer_schema is None else infer_schema
44    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
45
46    for scope in traverse_scope(expression):
47        resolver = Resolver(scope, schema, infer_schema=infer_schema)
48        _pop_table_column_aliases(scope.ctes)
49        _pop_table_column_aliases(scope.derived_tables)
50        using_column_tables = _expand_using(scope, resolver)
51
52        if schema.empty and expand_alias_refs:
53            _expand_alias_refs(scope, resolver)
54
55        _qualify_columns(scope, resolver)
56
57        if not schema.empty and expand_alias_refs:
58            _expand_alias_refs(scope, resolver)
59
60        if not isinstance(scope.expression, exp.UDTF):
61            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
62            qualify_outputs(scope)
63
64        _expand_group_by(scope)
65        _expand_order_by(scope, resolver)
66
67    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

def validate_qualify_columns(expression: ~E) -> ~E:
70def validate_qualify_columns(expression: E) -> E:
71    """Raise an `OptimizeError` if any columns aren't qualified"""
72    unqualified_columns = []
73    for scope in traverse_scope(expression):
74        if isinstance(scope.expression, exp.Select):
75            unqualified_columns.extend(scope.unqualified_columns)
76            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
77                column = scope.external_columns[0]
78                raise OptimizeError(
79                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
80                )
81
82    if unqualified_columns:
83        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
84    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
463def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
464    """Ensure all output columns are aliased"""
465    if isinstance(scope_or_expression, exp.Expression):
466        scope = build_scope(scope_or_expression)
467        if not isinstance(scope, Scope):
468            return
469    else:
470        scope = scope_or_expression
471
472    new_selections = []
473    for i, (selection, aliased_column) in enumerate(
474        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
475    ):
476        if isinstance(selection, exp.Subquery):
477            if not selection.output_name:
478                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
479        elif not isinstance(selection, exp.Alias) and not selection.is_star:
480            selection = alias(
481                selection,
482                alias=selection.output_name or f"_col_{i}",
483            )
484        if aliased_column:
485            selection.set("alias", exp.to_identifier(aliased_column))
486
487        new_selections.append(selection)
488
489    scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
492def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
493    """Makes sure all identifiers that need to be quoted are quoted."""
494    return expression.transform(
495        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
496    )

Makes sure all identifiers that need to be quoted are quoted.

class Resolver:
499class Resolver:
500    """
501    Helper for resolving columns.
502
503    This is a class so we can lazily load some things and easily share them across functions.
504    """
505
506    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
507        self.scope = scope
508        self.schema = schema
509        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
510        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
511        self._all_columns: t.Optional[t.Set[str]] = None
512        self._infer_schema = infer_schema
513
514    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
515        """
516        Get the table for a column name.
517
518        Args:
519            column_name: The column name to find the table for.
520        Returns:
521            The table name if it can be found/inferred.
522        """
523        if self._unambiguous_columns is None:
524            self._unambiguous_columns = self._get_unambiguous_columns(
525                self._get_all_source_columns()
526            )
527
528        table_name = self._unambiguous_columns.get(column_name)
529
530        if not table_name and self._infer_schema:
531            sources_without_schema = tuple(
532                source
533                for source, columns in self._get_all_source_columns().items()
534                if not columns or "*" in columns
535            )
536            if len(sources_without_schema) == 1:
537                table_name = sources_without_schema[0]
538
539        if table_name not in self.scope.selected_sources:
540            return exp.to_identifier(table_name)
541
542        node, _ = self.scope.selected_sources.get(table_name)
543
544        if isinstance(node, exp.Subqueryable):
545            while node and node.alias != table_name:
546                node = node.parent
547
548        node_alias = node.args.get("alias")
549        if node_alias:
550            return exp.to_identifier(node_alias.this)
551
552        return exp.to_identifier(table_name)
553
554    @property
555    def all_columns(self) -> t.Set[str]:
556        """All available columns of all sources in this scope"""
557        if self._all_columns is None:
558            self._all_columns = {
559                column for columns in self._get_all_source_columns().values() for column in columns
560            }
561        return self._all_columns
562
563    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
564        """Resolve the source columns for a given source `name`."""
565        if name not in self.scope.sources:
566            raise OptimizeError(f"Unknown table: {name}")
567
568        source = self.scope.sources[name]
569
570        if isinstance(source, exp.Table):
571            columns = self.schema.column_names(source, only_visible)
572        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
573            columns = source.expression.alias_column_names
574        else:
575            columns = source.expression.named_selects
576
577        node, _ = self.scope.selected_sources.get(name) or (None, None)
578        if isinstance(node, Scope):
579            column_aliases = node.expression.alias_column_names
580        elif isinstance(node, exp.Expression):
581            column_aliases = node.alias_column_names
582        else:
583            column_aliases = []
584
585        # If the source's columns are aliased, their aliases shadow the corresponding column names
586        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
587
588    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
589        if self._source_columns is None:
590            self._source_columns = {
591                source_name: self.get_source_columns(source_name)
592                for source_name, source in itertools.chain(
593                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
594                )
595            }
596        return self._source_columns
597
598    def _get_unambiguous_columns(
599        self, source_columns: t.Dict[str, t.List[str]]
600    ) -> t.Dict[str, str]:
601        """
602        Find all the unambiguous columns in sources.
603
604        Args:
605            source_columns: Mapping of names to source columns.
606
607        Returns:
608            Mapping of column name to source name.
609        """
610        if not source_columns:
611            return {}
612
613        source_columns_pairs = list(source_columns.items())
614
615        first_table, first_columns = source_columns_pairs[0]
616        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
617        all_columns = set(unambiguous_columns)
618
619        for table, columns in source_columns_pairs[1:]:
620            unique = self._find_unique_columns(columns)
621            ambiguous = set(all_columns).intersection(unique)
622            all_columns.update(columns)
623
624            for column in ambiguous:
625                unambiguous_columns.pop(column, None)
626            for column in unique.difference(ambiguous):
627                unambiguous_columns[column] = table
628
629        return unambiguous_columns
630
631    @staticmethod
632    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
633        """
634        Find the unique columns in a list of columns.
635
636        Example:
637            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
638            ['a', 'c']
639
640        This is necessary because duplicate column names are ambiguous.
641        """
642        counts: t.Dict[str, int] = {}
643        for column in columns:
644            counts[column] = counts.get(column, 0) + 1
645        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
506    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
507        self.scope = scope
508        self.schema = schema
509        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
510        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
511        self._all_columns: t.Optional[t.Set[str]] = None
512        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
514    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
515        """
516        Get the table for a column name.
517
518        Args:
519            column_name: The column name to find the table for.
520        Returns:
521            The table name if it can be found/inferred.
522        """
523        if self._unambiguous_columns is None:
524            self._unambiguous_columns = self._get_unambiguous_columns(
525                self._get_all_source_columns()
526            )
527
528        table_name = self._unambiguous_columns.get(column_name)
529
530        if not table_name and self._infer_schema:
531            sources_without_schema = tuple(
532                source
533                for source, columns in self._get_all_source_columns().items()
534                if not columns or "*" in columns
535            )
536            if len(sources_without_schema) == 1:
537                table_name = sources_without_schema[0]
538
539        if table_name not in self.scope.selected_sources:
540            return exp.to_identifier(table_name)
541
542        node, _ = self.scope.selected_sources.get(table_name)
543
544        if isinstance(node, exp.Subqueryable):
545            while node and node.alias != table_name:
546                node = node.parent
547
548        node_alias = node.args.get("alias")
549        if node_alias:
550            return exp.to_identifier(node_alias.this)
551
552        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
563    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
564        """Resolve the source columns for a given source `name`."""
565        if name not in self.scope.sources:
566            raise OptimizeError(f"Unknown table: {name}")
567
568        source = self.scope.sources[name]
569
570        if isinstance(source, exp.Table):
571            columns = self.schema.column_names(source, only_visible)
572        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
573            columns = source.expression.alias_column_names
574        else:
575            columns = source.expression.named_selects
576
577        node, _ = self.scope.selected_sources.get(name) or (None, None)
578        if isinstance(node, Scope):
579            column_aliases = node.expression.alias_column_names
580        elif isinstance(node, exp.Expression):
581            column_aliases = node.alias_column_names
582        else:
583            column_aliases = []
584
585        # If the source's columns are aliased, their aliases shadow the corresponding column names
586        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

Resolve the source columns for a given source name.