Edit on GitHub

sqlglot.optimizer.qualify_columns

  1import itertools
  2import typing as t
  3
  4from sqlglot import alias, exp
  5from sqlglot.errors import OptimizeError
  6from sqlglot.optimizer.scope import Scope, traverse_scope
  7from sqlglot.schema import ensure_schema
  8
  9
 10def qualify_columns(expression, schema):
 11    """
 12    Rewrite sqlglot AST to have fully qualified columns.
 13
 14    Example:
 15        >>> import sqlglot
 16        >>> schema = {"tbl": {"col": "INT"}}
 17        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 18        >>> qualify_columns(expression, schema).sql()
 19        'SELECT tbl.col AS col FROM tbl'
 20
 21    Args:
 22        expression (sqlglot.Expression): expression to qualify
 23        schema (dict|sqlglot.optimizer.Schema): Database schema
 24    Returns:
 25        sqlglot.Expression: qualified expression
 26    """
 27    schema = ensure_schema(schema)
 28
 29    for scope in traverse_scope(expression):
 30        resolver = Resolver(scope, schema)
 31        _pop_table_column_aliases(scope.ctes)
 32        _pop_table_column_aliases(scope.derived_tables)
 33        _expand_using(scope, resolver)
 34        _qualify_columns(scope, resolver)
 35        if not isinstance(scope.expression, exp.UDTF):
 36            _expand_stars(scope, resolver)
 37            _qualify_outputs(scope)
 38        _expand_group_by(scope, resolver)
 39        _expand_order_by(scope)
 40
 41    return expression
 42
 43
 44def validate_qualify_columns(expression):
 45    """Raise an `OptimizeError` if any columns aren't qualified"""
 46    unqualified_columns = []
 47    for scope in traverse_scope(expression):
 48        if isinstance(scope.expression, exp.Select):
 49            unqualified_columns.extend(scope.unqualified_columns)
 50            if scope.external_columns and not scope.is_correlated_subquery:
 51                column = scope.external_columns[0]
 52                raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
 53
 54    if unqualified_columns:
 55        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 56    return expression
 57
 58
 59def _pop_table_column_aliases(derived_tables):
 60    """
 61    Remove table column aliases.
 62
 63    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 64    """
 65    for derived_table in derived_tables:
 66        table_alias = derived_table.args.get("alias")
 67        if table_alias:
 68            table_alias.args.pop("columns", None)
 69
 70
 71def _expand_using(scope, resolver):
 72    joins = list(scope.expression.find_all(exp.Join))
 73    names = {join.this.alias for join in joins}
 74    ordered = [key for key in scope.selected_sources if key not in names]
 75
 76    # Mapping of automatically joined column names to source names
 77    column_tables = {}
 78
 79    for join in joins:
 80        using = join.args.get("using")
 81
 82        if not using:
 83            continue
 84
 85        join_table = join.this.alias_or_name
 86
 87        columns = {}
 88
 89        for k in scope.selected_sources:
 90            if k in ordered:
 91                for column in resolver.get_source_columns(k):
 92                    if column not in columns:
 93                        columns[column] = k
 94
 95        ordered.append(join_table)
 96        join_columns = resolver.get_source_columns(join_table)
 97        conditions = []
 98
 99        for identifier in using:
100            identifier = identifier.name
101            table = columns.get(identifier)
102
103            if not table or identifier not in join_columns:
104                raise OptimizeError(f"Cannot automatically join: {identifier}")
105
106            conditions.append(
107                exp.condition(
108                    exp.EQ(
109                        this=exp.column(identifier, table=table),
110                        expression=exp.column(identifier, table=join_table),
111                    )
112                )
113            )
114
115            tables = column_tables.setdefault(identifier, [])
116            if table not in tables:
117                tables.append(table)
118            if join_table not in tables:
119                tables.append(join_table)
120
121        join.args.pop("using")
122        join.set("on", exp.and_(*conditions))
123
124    if column_tables:
125        for column in scope.columns:
126            if not column.table and column.name in column_tables:
127                tables = column_tables[column.name]
128                coalesce = [exp.column(column.name, table=table) for table in tables]
129                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
130
131                # Ensure selects keep their output name
132                if isinstance(column.parent, exp.Select):
133                    replacement = exp.alias_(replacement, alias=column.name)
134
135                scope.replace(column, replacement)
136
137
138def _expand_group_by(scope, resolver):
139    group = scope.expression.args.get("group")
140    if not group:
141        return
142
143    # Replace references to select aliases
144    def transform(node, *_):
145        if isinstance(node, exp.Column) and not node.table:
146            table = resolver.get_table(node.name)
147
148            # Source columns get priority over select aliases
149            if table:
150                node.set("table", table)
151                return node
152
153            selects = {s.alias_or_name: s for s in scope.selects}
154
155            select = selects.get(node.name)
156            if select:
157                scope.clear_cache()
158                if isinstance(select, exp.Alias):
159                    select = select.this
160                return select.copy()
161
162        return node
163
164    group.transform(transform, copy=False)
165    group.set("expressions", _expand_positional_references(scope, group.expressions))
166    scope.expression.set("group", group)
167
168
169def _expand_order_by(scope):
170    order = scope.expression.args.get("order")
171    if not order:
172        return
173
174    ordereds = order.expressions
175    for ordered, new_expression in zip(
176        ordereds,
177        _expand_positional_references(scope, (o.this for o in ordereds)),
178    ):
179        ordered.set("this", new_expression)
180
181
182def _expand_positional_references(scope, expressions):
183    new_nodes = []
184    for node in expressions:
185        if node.is_int:
186            try:
187                select = scope.selects[int(node.name) - 1]
188            except IndexError:
189                raise OptimizeError(f"Unknown output column: {node.name}")
190            if isinstance(select, exp.Alias):
191                select = select.this
192            new_nodes.append(select.copy())
193            scope.clear_cache()
194        else:
195            new_nodes.append(node)
196
197    return new_nodes
198
199
200def _qualify_columns(scope, resolver):
201    """Disambiguate columns, ensuring each column specifies a source"""
202    for column in scope.columns:
203        column_table = column.table
204        column_name = column.name
205
206        if column_table and column_table in scope.sources:
207            source_columns = resolver.get_source_columns(column_table)
208            if source_columns and column_name not in source_columns and "*" not in source_columns:
209                raise OptimizeError(f"Unknown column: {column_name}")
210
211        if not column_table:
212            column_table = resolver.get_table(column_name)
213
214            # column_table can be a '' because bigquery unnest has no table alias
215            if column_table:
216                column.set("table", column_table)
217        elif column_table not in scope.sources:
218            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
219            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
220
221            struct_root, *struct_fields = [
222                val for val in reversed(list(column.args.values())) if val is not None
223            ]
224
225            if struct_root.name in scope.sources:
226                # struct is already qualified, but we still need to change the AST representation
227                struct_table = struct_root
228                struct_root, *struct_fields = struct_fields
229            else:
230                struct_table = resolver.get_table(struct_root.name)
231
232            if struct_table:
233                while column.parent and isinstance(column.parent, exp.Dot):
234                    column = column.parent
235                    struct_fields.append(column.expression)
236
237                new_column = exp.column(struct_root, table=struct_table)
238                for field in struct_fields:
239                    new_column = exp.Dot(this=new_column, expression=field)
240
241                column.replace(new_column)
242
243    columns_missing_from_scope = []
244    # Determine whether each reference in the order by clause is to a column or an alias.
245    for ordered in scope.find_all(exp.Ordered):
246        for column in ordered.find_all(exp.Column):
247            if (
248                not column.table
249                and column.parent is not ordered
250                and column.name in resolver.all_columns
251            ):
252                columns_missing_from_scope.append(column)
253
254    # Determine whether each reference in the having clause is to a column or an alias.
255    for having in scope.find_all(exp.Having):
256        for column in having.find_all(exp.Column):
257            if (
258                not column.table
259                and column.find_ancestor(exp.AggFunc)
260                and column.name in resolver.all_columns
261            ):
262                columns_missing_from_scope.append(column)
263
264    for column in columns_missing_from_scope:
265        column_table = resolver.get_table(column.name)
266
267        if column_table:
268            column.set("table", column_table)
269
270
271def _expand_stars(scope, resolver):
272    """Expand stars to lists of column selections"""
273
274    new_selections = []
275    except_columns = {}
276    replace_columns = {}
277
278    for expression in scope.selects:
279        if isinstance(expression, exp.Star):
280            tables = list(scope.selected_sources)
281            _add_except_columns(expression, tables, except_columns)
282            _add_replace_columns(expression, tables, replace_columns)
283        elif expression.is_star:
284            tables = [expression.table]
285            _add_except_columns(expression.this, tables, except_columns)
286            _add_replace_columns(expression.this, tables, replace_columns)
287        else:
288            new_selections.append(expression)
289            continue
290
291        for table in tables:
292            if table not in scope.sources:
293                raise OptimizeError(f"Unknown table: {table}")
294            columns = resolver.get_source_columns(table, only_visible=True)
295
296            if columns and "*" not in columns:
297                table_id = id(table)
298                for name in columns:
299                    if name not in except_columns.get(table_id, set()):
300                        alias_ = replace_columns.get(table_id, {}).get(name, name)
301                        column = exp.column(name, table)
302                        new_selections.append(alias(column, alias_) if alias_ != name else column)
303            else:
304                return
305    scope.expression.set("expressions", new_selections)
306
307
308def _add_except_columns(expression, tables, except_columns):
309    except_ = expression.args.get("except")
310
311    if not except_:
312        return
313
314    columns = {e.name for e in except_}
315
316    for table in tables:
317        except_columns[id(table)] = columns
318
319
320def _add_replace_columns(expression, tables, replace_columns):
321    replace = expression.args.get("replace")
322
323    if not replace:
324        return
325
326    columns = {e.this.name: e.alias for e in replace}
327
328    for table in tables:
329        replace_columns[id(table)] = columns
330
331
332def _qualify_outputs(scope):
333    """Ensure all output columns are aliased"""
334    new_selections = []
335
336    for i, (selection, aliased_column) in enumerate(
337        itertools.zip_longest(scope.selects, scope.outer_column_list)
338    ):
339        if isinstance(selection, exp.Subquery):
340            if not selection.output_name:
341                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
342        elif not isinstance(selection, exp.Alias) and not selection.is_star:
343            alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
344            alias_.set("this", selection)
345            selection = alias_
346
347        if aliased_column:
348            selection.set("alias", exp.to_identifier(aliased_column))
349
350        new_selections.append(selection)
351
352    scope.expression.set("expressions", new_selections)
353
354
355class Resolver:
356    """
357    Helper for resolving columns.
358
359    This is a class so we can lazily load some things and easily share them across functions.
360    """
361
362    def __init__(self, scope, schema):
363        self.scope = scope
364        self.schema = schema
365        self._source_columns = None
366        self._unambiguous_columns = None
367        self._all_columns = None
368
369    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
370        """
371        Get the table for a column name.
372
373        Args:
374            column_name: The column name to find the table for.
375        Returns:
376            The table name if it can be found/inferred.
377        """
378        if self._unambiguous_columns is None:
379            self._unambiguous_columns = self._get_unambiguous_columns(
380                self._get_all_source_columns()
381            )
382
383        table_name = self._unambiguous_columns.get(column_name)
384
385        if not table_name:
386            sources_without_schema = tuple(
387                source
388                for source, columns in self._get_all_source_columns().items()
389                if not columns or "*" in columns
390            )
391            if len(sources_without_schema) == 1:
392                table_name = sources_without_schema[0]
393
394        if table_name not in self.scope.selected_sources:
395            return exp.to_identifier(table_name)
396
397        node, _ = self.scope.selected_sources.get(table_name)
398
399        if isinstance(node, exp.Subqueryable):
400            while node and node.alias != table_name:
401                node = node.parent
402
403        node_alias = node.args.get("alias")
404        if node_alias:
405            return node_alias.this
406
407        return exp.to_identifier(
408            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
409        )
410
411    @property
412    def all_columns(self):
413        """All available columns of all sources in this scope"""
414        if self._all_columns is None:
415            self._all_columns = {
416                column for columns in self._get_all_source_columns().values() for column in columns
417            }
418        return self._all_columns
419
420    def get_source_columns(self, name, only_visible=False):
421        """Resolve the source columns for a given source `name`"""
422        if name not in self.scope.sources:
423            raise OptimizeError(f"Unknown table: {name}")
424
425        source = self.scope.sources[name]
426
427        # If referencing a table, return the columns from the schema
428        if isinstance(source, exp.Table):
429            return self.schema.column_names(source, only_visible)
430
431        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
432            return source.expression.alias_column_names
433
434        # Otherwise, if referencing another scope, return that scope's named selects
435        return source.expression.named_selects
436
437    def _get_all_source_columns(self):
438        if self._source_columns is None:
439            self._source_columns = {
440                k: self.get_source_columns(k)
441                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
442            }
443        return self._source_columns
444
445    def _get_unambiguous_columns(self, source_columns):
446        """
447        Find all the unambiguous columns in sources.
448
449        Args:
450            source_columns (dict): Mapping of names to source columns
451        Returns:
452            dict: Mapping of column name to source name
453        """
454        if not source_columns:
455            return {}
456
457        source_columns = list(source_columns.items())
458
459        first_table, first_columns = source_columns[0]
460        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
461        all_columns = set(unambiguous_columns)
462
463        for table, columns in source_columns[1:]:
464            unique = self._find_unique_columns(columns)
465            ambiguous = set(all_columns).intersection(unique)
466            all_columns.update(columns)
467            for column in ambiguous:
468                unambiguous_columns.pop(column, None)
469            for column in unique.difference(ambiguous):
470                unambiguous_columns[column] = table
471
472        return unambiguous_columns
473
474    @staticmethod
475    def _find_unique_columns(columns):
476        """
477        Find the unique columns in a list of columns.
478
479        Example:
480            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
481            ['a', 'c']
482
483        This is necessary because duplicate column names are ambiguous.
484        """
485        counts = {}
486        for column in columns:
487            counts[column] = counts.get(column, 0) + 1
488        return {column for column, count in counts.items() if count == 1}
def qualify_columns(expression, schema):
11def qualify_columns(expression, schema):
12    """
13    Rewrite sqlglot AST to have fully qualified columns.
14
15    Example:
16        >>> import sqlglot
17        >>> schema = {"tbl": {"col": "INT"}}
18        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
19        >>> qualify_columns(expression, schema).sql()
20        'SELECT tbl.col AS col FROM tbl'
21
22    Args:
23        expression (sqlglot.Expression): expression to qualify
24        schema (dict|sqlglot.optimizer.Schema): Database schema
25    Returns:
26        sqlglot.Expression: qualified expression
27    """
28    schema = ensure_schema(schema)
29
30    for scope in traverse_scope(expression):
31        resolver = Resolver(scope, schema)
32        _pop_table_column_aliases(scope.ctes)
33        _pop_table_column_aliases(scope.derived_tables)
34        _expand_using(scope, resolver)
35        _qualify_columns(scope, resolver)
36        if not isinstance(scope.expression, exp.UDTF):
37            _expand_stars(scope, resolver)
38            _qualify_outputs(scope)
39        _expand_group_by(scope, resolver)
40        _expand_order_by(scope)
41
42    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression (sqlglot.Expression): expression to qualify
  • schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:

sqlglot.Expression: qualified expression

def validate_qualify_columns(expression):
45def validate_qualify_columns(expression):
46    """Raise an `OptimizeError` if any columns aren't qualified"""
47    unqualified_columns = []
48    for scope in traverse_scope(expression):
49        if isinstance(scope.expression, exp.Select):
50            unqualified_columns.extend(scope.unqualified_columns)
51            if scope.external_columns and not scope.is_correlated_subquery:
52                column = scope.external_columns[0]
53                raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
54
55    if unqualified_columns:
56        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
57    return expression

Raise an OptimizeError if any columns aren't qualified

class Resolver:
356class Resolver:
357    """
358    Helper for resolving columns.
359
360    This is a class so we can lazily load some things and easily share them across functions.
361    """
362
363    def __init__(self, scope, schema):
364        self.scope = scope
365        self.schema = schema
366        self._source_columns = None
367        self._unambiguous_columns = None
368        self._all_columns = None
369
370    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
371        """
372        Get the table for a column name.
373
374        Args:
375            column_name: The column name to find the table for.
376        Returns:
377            The table name if it can be found/inferred.
378        """
379        if self._unambiguous_columns is None:
380            self._unambiguous_columns = self._get_unambiguous_columns(
381                self._get_all_source_columns()
382            )
383
384        table_name = self._unambiguous_columns.get(column_name)
385
386        if not table_name:
387            sources_without_schema = tuple(
388                source
389                for source, columns in self._get_all_source_columns().items()
390                if not columns or "*" in columns
391            )
392            if len(sources_without_schema) == 1:
393                table_name = sources_without_schema[0]
394
395        if table_name not in self.scope.selected_sources:
396            return exp.to_identifier(table_name)
397
398        node, _ = self.scope.selected_sources.get(table_name)
399
400        if isinstance(node, exp.Subqueryable):
401            while node and node.alias != table_name:
402                node = node.parent
403
404        node_alias = node.args.get("alias")
405        if node_alias:
406            return node_alias.this
407
408        return exp.to_identifier(
409            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
410        )
411
412    @property
413    def all_columns(self):
414        """All available columns of all sources in this scope"""
415        if self._all_columns is None:
416            self._all_columns = {
417                column for columns in self._get_all_source_columns().values() for column in columns
418            }
419        return self._all_columns
420
421    def get_source_columns(self, name, only_visible=False):
422        """Resolve the source columns for a given source `name`"""
423        if name not in self.scope.sources:
424            raise OptimizeError(f"Unknown table: {name}")
425
426        source = self.scope.sources[name]
427
428        # If referencing a table, return the columns from the schema
429        if isinstance(source, exp.Table):
430            return self.schema.column_names(source, only_visible)
431
432        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
433            return source.expression.alias_column_names
434
435        # Otherwise, if referencing another scope, return that scope's named selects
436        return source.expression.named_selects
437
438    def _get_all_source_columns(self):
439        if self._source_columns is None:
440            self._source_columns = {
441                k: self.get_source_columns(k)
442                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
443            }
444        return self._source_columns
445
446    def _get_unambiguous_columns(self, source_columns):
447        """
448        Find all the unambiguous columns in sources.
449
450        Args:
451            source_columns (dict): Mapping of names to source columns
452        Returns:
453            dict: Mapping of column name to source name
454        """
455        if not source_columns:
456            return {}
457
458        source_columns = list(source_columns.items())
459
460        first_table, first_columns = source_columns[0]
461        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
462        all_columns = set(unambiguous_columns)
463
464        for table, columns in source_columns[1:]:
465            unique = self._find_unique_columns(columns)
466            ambiguous = set(all_columns).intersection(unique)
467            all_columns.update(columns)
468            for column in ambiguous:
469                unambiguous_columns.pop(column, None)
470            for column in unique.difference(ambiguous):
471                unambiguous_columns[column] = table
472
473        return unambiguous_columns
474
475    @staticmethod
476    def _find_unique_columns(columns):
477        """
478        Find the unique columns in a list of columns.
479
480        Example:
481            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
482            ['a', 'c']
483
484        This is necessary because duplicate column names are ambiguous.
485        """
486        counts = {}
487        for column in columns:
488            counts[column] = counts.get(column, 0) + 1
489        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver(scope, schema)
363    def __init__(self, scope, schema):
364        self.scope = scope
365        self.schema = schema
366        self._source_columns = None
367        self._unambiguous_columns = None
368        self._all_columns = None
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
370    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
371        """
372        Get the table for a column name.
373
374        Args:
375            column_name: The column name to find the table for.
376        Returns:
377            The table name if it can be found/inferred.
378        """
379        if self._unambiguous_columns is None:
380            self._unambiguous_columns = self._get_unambiguous_columns(
381                self._get_all_source_columns()
382            )
383
384        table_name = self._unambiguous_columns.get(column_name)
385
386        if not table_name:
387            sources_without_schema = tuple(
388                source
389                for source, columns in self._get_all_source_columns().items()
390                if not columns or "*" in columns
391            )
392            if len(sources_without_schema) == 1:
393                table_name = sources_without_schema[0]
394
395        if table_name not in self.scope.selected_sources:
396            return exp.to_identifier(table_name)
397
398        node, _ = self.scope.selected_sources.get(table_name)
399
400        if isinstance(node, exp.Subqueryable):
401            while node and node.alias != table_name:
402                node = node.parent
403
404        node_alias = node.args.get("alias")
405        if node_alias:
406            return node_alias.this
407
408        return exp.to_identifier(
409            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
410        )

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns

All available columns of all sources in this scope

def get_source_columns(self, name, only_visible=False):
421    def get_source_columns(self, name, only_visible=False):
422        """Resolve the source columns for a given source `name`"""
423        if name not in self.scope.sources:
424            raise OptimizeError(f"Unknown table: {name}")
425
426        source = self.scope.sources[name]
427
428        # If referencing a table, return the columns from the schema
429        if isinstance(source, exp.Table):
430            return self.schema.column_names(source, only_visible)
431
432        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
433            return source.expression.alias_column_names
434
435        # Otherwise, if referencing another scope, return that scope's named selects
436        return source.expression.named_selects

Resolve the source columns for a given source name