1
0
Fork 0
sqlglot/sqlglot/optimizer/pushdown_projections.py
Daniel Baumann 91f2cef5f0
Adding upstream version 21.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 21:20:19 +01:00

137 lines
5 KiB
Python

from collections import defaultdict
from sqlglot import alias, exp
from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
# Selection to use if selection list is empty
def default_selection(is_agg: bool) -> exp.Alias:
return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
"""
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Args:
expression (sqlglot.Expression): expression to optimize
remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression
"""
# Map of Scope to all columns being selected by outer queries.
schema = ensure_schema(schema)
source_column_alias_count = {}
referenced_columns = defaultdict(set)
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
alias_count = source_column_alias_count.get(scope, 0)
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
if scope.expression.args.get("distinct"):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
left, right = scope.union_scopes
referenced_columns[left] = parent_selections
if any(select.is_star for select in right.expression.selects):
referenced_columns[right] = parent_selections
elif not any(select.is_star for select in left.expression.selects):
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
]
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:
_remove_unused_selections(scope, parent_selections, schema, alias_count)
if scope.expression.is_star:
continue
# Group columns by source name
selects = defaultdict(set)
for col in scope.columns:
table_name = col.table
col_name = col.name
selects[table_name].add(col_name)
# Push the selected columns down to the next scope
for name, (node, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
referenced_columns[source].update(columns)
column_aliases = node.alias_column_names
if column_aliases:
source_column_alias_count[source] = len(column_aliases)
return expression
def _remove_unused_selections(scope, parent_selections, schema, alias_count):
order = scope.expression.args.get("order")
if order:
# Assume columns without a qualified table are references to output columns
order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
else:
order_refs = set()
new_selections = []
removed = False
star = False
is_agg = False
select_all = SELECT_ALL in parent_selections
for selection in scope.expression.selects:
name = selection.alias_or_name
if select_all or name in parent_selections or name in order_refs or alias_count > 0:
new_selections.append(selection)
alias_count -= 1
else:
if selection.is_star:
star = True
removed = True
if not is_agg and selection.find(exp.AggFunc):
is_agg = True
if star:
resolver = Resolver(scope, schema)
names = {s.alias_or_name for s in new_selections}
for name in sorted(parent_selections):
if name not in names:
new_selections.append(
alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
)
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(default_selection(is_agg))
scope.expression.select(*new_selections, append=False, copy=False)
if removed:
scope.clear_cache()