400 lines
13 KiB
Python
400 lines
13 KiB
Python
import itertools
|
|
|
|
from sqlglot import alias, exp
|
|
from sqlglot.errors import OptimizeError
|
|
from sqlglot.optimizer.schema import ensure_schema
|
|
from sqlglot.optimizer.scope import traverse_scope
|
|
|
|
|
|
def qualify_columns(expression, schema):
|
|
"""
|
|
Rewrite sqlglot AST to have fully qualified columns.
|
|
|
|
Example:
|
|
>>> import sqlglot
|
|
>>> schema = {"tbl": {"col": "INT"}}
|
|
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
|
|
>>> qualify_columns(expression, schema).sql()
|
|
'SELECT tbl.col AS col FROM tbl'
|
|
|
|
Args:
|
|
expression (sqlglot.Expression): expression to qualify
|
|
schema (dict|sqlglot.optimizer.Schema): Database schema
|
|
Returns:
|
|
sqlglot.Expression: qualified expression
|
|
"""
|
|
schema = ensure_schema(schema)
|
|
|
|
for scope in traverse_scope(expression):
|
|
resolver = _Resolver(scope, schema)
|
|
_pop_table_column_aliases(scope.ctes)
|
|
_pop_table_column_aliases(scope.derived_tables)
|
|
_expand_using(scope, resolver)
|
|
_expand_group_by(scope, resolver)
|
|
_qualify_columns(scope, resolver)
|
|
_expand_order_by(scope)
|
|
if not isinstance(scope.expression, exp.UDTF):
|
|
_expand_stars(scope, resolver)
|
|
_qualify_outputs(scope)
|
|
_check_unknown_tables(scope)
|
|
|
|
return expression
|
|
|
|
|
|
def _pop_table_column_aliases(derived_tables):
|
|
"""
|
|
Remove table column aliases.
|
|
|
|
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
|
"""
|
|
for derived_table in derived_tables:
|
|
if isinstance(derived_table, exp.UDTF):
|
|
continue
|
|
table_alias = derived_table.args.get("alias")
|
|
if table_alias:
|
|
table_alias.args.pop("columns", None)
|
|
|
|
|
|
def _expand_using(scope, resolver):
|
|
joins = list(scope.expression.find_all(exp.Join))
|
|
names = {join.this.alias for join in joins}
|
|
ordered = [key for key in scope.selected_sources if key not in names]
|
|
|
|
# Mapping of automatically joined column names to source names
|
|
column_tables = {}
|
|
|
|
for join in joins:
|
|
using = join.args.get("using")
|
|
|
|
if not using:
|
|
continue
|
|
|
|
join_table = join.this.alias_or_name
|
|
|
|
columns = {}
|
|
|
|
for k in scope.selected_sources:
|
|
if k in ordered:
|
|
for column in resolver.get_source_columns(k):
|
|
if column not in columns:
|
|
columns[column] = k
|
|
|
|
ordered.append(join_table)
|
|
join_columns = resolver.get_source_columns(join_table)
|
|
conditions = []
|
|
|
|
for identifier in using:
|
|
identifier = identifier.name
|
|
table = columns.get(identifier)
|
|
|
|
if not table or identifier not in join_columns:
|
|
raise OptimizeError(f"Cannot automatically join: {identifier}")
|
|
|
|
conditions.append(
|
|
exp.condition(
|
|
exp.EQ(
|
|
this=exp.column(identifier, table=table),
|
|
expression=exp.column(identifier, table=join_table),
|
|
)
|
|
)
|
|
)
|
|
|
|
tables = column_tables.setdefault(identifier, [])
|
|
if table not in tables:
|
|
tables.append(table)
|
|
if join_table not in tables:
|
|
tables.append(join_table)
|
|
|
|
join.args.pop("using")
|
|
join.set("on", exp.and_(*conditions))
|
|
|
|
if column_tables:
|
|
for column in scope.columns:
|
|
if not column.table and column.name in column_tables:
|
|
tables = column_tables[column.name]
|
|
coalesce = [exp.column(column.name, table=table) for table in tables]
|
|
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
|
|
|
|
# Ensure selects keep their output name
|
|
if isinstance(column.parent, exp.Select):
|
|
replacement = exp.alias_(replacement, alias=column.name)
|
|
|
|
scope.replace(column, replacement)
|
|
|
|
|
|
def _expand_group_by(scope, resolver):
|
|
group = scope.expression.args.get("group")
|
|
if not group:
|
|
return
|
|
|
|
# Replace references to select aliases
|
|
def transform(node, *_):
|
|
if isinstance(node, exp.Column) and not node.table:
|
|
table = resolver.get_table(node.name)
|
|
|
|
# Source columns get priority over select aliases
|
|
if table:
|
|
node.set("table", exp.to_identifier(table))
|
|
return node
|
|
|
|
selects = {s.alias_or_name: s for s in scope.selects}
|
|
|
|
select = selects.get(node.name)
|
|
if select:
|
|
scope.clear_cache()
|
|
if isinstance(select, exp.Alias):
|
|
select = select.this
|
|
return select.copy()
|
|
|
|
return node
|
|
|
|
group.transform(transform, copy=False)
|
|
group.set("expressions", _expand_positional_references(scope, group.expressions))
|
|
scope.expression.set("group", group)
|
|
|
|
|
|
def _expand_order_by(scope):
|
|
order = scope.expression.args.get("order")
|
|
if not order:
|
|
return
|
|
|
|
ordereds = order.expressions
|
|
for ordered, new_expression in zip(
|
|
ordereds,
|
|
_expand_positional_references(scope, (o.this for o in ordereds)),
|
|
):
|
|
ordered.set("this", new_expression)
|
|
|
|
|
|
def _expand_positional_references(scope, expressions):
|
|
new_nodes = []
|
|
for node in expressions:
|
|
if node.is_int:
|
|
try:
|
|
select = scope.selects[int(node.name) - 1]
|
|
except IndexError:
|
|
raise OptimizeError(f"Unknown output column: {node.name}")
|
|
if isinstance(select, exp.Alias):
|
|
select = select.this
|
|
new_nodes.append(select.copy())
|
|
scope.clear_cache()
|
|
else:
|
|
new_nodes.append(node)
|
|
|
|
return new_nodes
|
|
|
|
|
|
def _qualify_columns(scope, resolver):
|
|
"""Disambiguate columns, ensuring each column specifies a source"""
|
|
for column in scope.columns:
|
|
column_table = column.table
|
|
column_name = column.name
|
|
|
|
if (
|
|
column_table
|
|
and column_table in scope.sources
|
|
and column_name not in resolver.get_source_columns(column_table)
|
|
):
|
|
raise OptimizeError(f"Unknown column: {column_name}")
|
|
|
|
if not column_table:
|
|
column_table = resolver.get_table(column_name)
|
|
|
|
if not scope.is_subquery and not scope.is_udtf:
|
|
if column_name not in resolver.all_columns:
|
|
raise OptimizeError(f"Unknown column: {column_name}")
|
|
|
|
if column_table is None:
|
|
raise OptimizeError(f"Ambiguous column: {column_name}")
|
|
|
|
# column_table can be a '' because bigquery unnest has no table alias
|
|
if column_table:
|
|
column.set("table", exp.to_identifier(column_table))
|
|
|
|
|
|
def _expand_stars(scope, resolver):
|
|
"""Expand stars to lists of column selections"""
|
|
|
|
new_selections = []
|
|
except_columns = {}
|
|
replace_columns = {}
|
|
|
|
for expression in scope.selects:
|
|
if isinstance(expression, exp.Star):
|
|
tables = list(scope.selected_sources)
|
|
_add_except_columns(expression, tables, except_columns)
|
|
_add_replace_columns(expression, tables, replace_columns)
|
|
elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
|
|
tables = [expression.table]
|
|
_add_except_columns(expression.this, tables, except_columns)
|
|
_add_replace_columns(expression.this, tables, replace_columns)
|
|
else:
|
|
new_selections.append(expression)
|
|
continue
|
|
|
|
for table in tables:
|
|
if table not in scope.sources:
|
|
raise OptimizeError(f"Unknown table: {table}")
|
|
columns = resolver.get_source_columns(table, only_visible=True)
|
|
table_id = id(table)
|
|
for name in columns:
|
|
if name not in except_columns.get(table_id, set()):
|
|
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
|
column = exp.column(name, table)
|
|
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
|
|
|
scope.expression.set("expressions", new_selections)
|
|
|
|
|
|
def _add_except_columns(expression, tables, except_columns):
|
|
except_ = expression.args.get("except")
|
|
|
|
if not except_:
|
|
return
|
|
|
|
columns = {e.name for e in except_}
|
|
|
|
for table in tables:
|
|
except_columns[id(table)] = columns
|
|
|
|
|
|
def _add_replace_columns(expression, tables, replace_columns):
|
|
replace = expression.args.get("replace")
|
|
|
|
if not replace:
|
|
return
|
|
|
|
columns = {e.this.name: e.alias for e in replace}
|
|
|
|
for table in tables:
|
|
replace_columns[id(table)] = columns
|
|
|
|
|
|
def _qualify_outputs(scope):
|
|
"""Ensure all output columns are aliased"""
|
|
new_selections = []
|
|
|
|
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
|
|
if isinstance(selection, exp.Column):
|
|
# convoluted setter because a simple selection.replace(alias) would require a copy
|
|
alias_ = alias(exp.column(""), alias=selection.name)
|
|
alias_.set("this", selection)
|
|
selection = alias_
|
|
elif not isinstance(selection, exp.Alias):
|
|
alias_ = alias(exp.column(""), f"_col_{i}")
|
|
alias_.set("this", selection)
|
|
selection = alias_
|
|
|
|
if aliased_column:
|
|
selection.set("alias", exp.to_identifier(aliased_column))
|
|
|
|
new_selections.append(selection)
|
|
|
|
scope.expression.set("expressions", new_selections)
|
|
|
|
|
|
def _check_unknown_tables(scope):
|
|
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
|
|
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
|
|
|
|
|
class _Resolver:
|
|
"""
|
|
Helper for resolving columns.
|
|
|
|
This is a class so we can lazily load some things and easily share them across functions.
|
|
"""
|
|
|
|
def __init__(self, scope, schema):
|
|
self.scope = scope
|
|
self.schema = schema
|
|
self._source_columns = None
|
|
self._unambiguous_columns = None
|
|
self._all_columns = None
|
|
|
|
def get_table(self, column_name):
|
|
"""
|
|
Get the table for a column name.
|
|
|
|
Args:
|
|
column_name (str)
|
|
Returns:
|
|
(str) table name
|
|
"""
|
|
if self._unambiguous_columns is None:
|
|
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
|
|
return self._unambiguous_columns.get(column_name)
|
|
|
|
@property
|
|
def all_columns(self):
|
|
"""All available columns of all sources in this scope"""
|
|
if self._all_columns is None:
|
|
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
|
return self._all_columns
|
|
|
|
def get_source_columns(self, name, only_visible=False):
|
|
"""Resolve the source columns for a given source `name`"""
|
|
if name not in self.scope.sources:
|
|
raise OptimizeError(f"Unknown table: {name}")
|
|
|
|
source = self.scope.sources[name]
|
|
|
|
# If referencing a table, return the columns from the schema
|
|
if isinstance(source, exp.Table):
|
|
try:
|
|
return self.schema.column_names(source, only_visible)
|
|
except Exception as e:
|
|
raise OptimizeError(str(e)) from e
|
|
|
|
# Otherwise, if referencing another scope, return that scope's named selects
|
|
return source.expression.named_selects
|
|
|
|
def _get_all_source_columns(self):
|
|
if self._source_columns is None:
|
|
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
|
|
return self._source_columns
|
|
|
|
def _get_unambiguous_columns(self, source_columns):
|
|
"""
|
|
Find all the unambiguous columns in sources.
|
|
|
|
Args:
|
|
source_columns (dict): Mapping of names to source columns
|
|
Returns:
|
|
dict: Mapping of column name to source name
|
|
"""
|
|
if not source_columns:
|
|
return {}
|
|
|
|
source_columns = list(source_columns.items())
|
|
|
|
first_table, first_columns = source_columns[0]
|
|
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
|
|
all_columns = set(unambiguous_columns)
|
|
|
|
for table, columns in source_columns[1:]:
|
|
unique = self._find_unique_columns(columns)
|
|
ambiguous = set(all_columns).intersection(unique)
|
|
all_columns.update(columns)
|
|
for column in ambiguous:
|
|
unambiguous_columns.pop(column, None)
|
|
for column in unique.difference(ambiguous):
|
|
unambiguous_columns[column] = table
|
|
|
|
return unambiguous_columns
|
|
|
|
@staticmethod
|
|
def _find_unique_columns(columns):
|
|
"""
|
|
Find the unique columns in a list of columns.
|
|
|
|
Example:
|
|
>>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
|
['a', 'c']
|
|
|
|
This is necessary because duplicate column names are ambiguous.
|
|
"""
|
|
counts = {}
|
|
for column in columns:
|
|
counts[column] = counts.get(column, 0) + 1
|
|
return {column for column, count in counts.items() if count == 1}
|