1
0
Fork 0
sqlglot/sqlglot/optimizer/pushdown_predicates.py
Daniel Baumann d26905e4af
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 21:30:28 +01:00

209 lines
7.6 KiB
Python

from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
def pushdown_predicates(expression, dialect=None):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
root = build_scope(expression)
if root:
scope_ref_count = root.ref_count()
for scope in reversed(list(root.traverse())):
select = scope.expression
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
join_index = {
join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
}
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.alias_or_name
if name in scope.selected_sources:
pushdown(
join.args.get("on"),
{name: scope.selected_sources[name]},
scope_ref_count,
dialect,
)
return expression
def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
if not condition:
return
condition = condition.replace(simplify(condition, dialect=dialect))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
else:
pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
if isinstance(node, exp.Join):
name = node.alias_or_name
predicate_tables = exp.column_table_names(predicate, name)
# Don't push the predicate if it references tables that appear in later joins
this_index = join_index[name]
if all(join_index.get(table, -1) < this_index for table in predicate_tables):
predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def pushdown_dnf(predicates, sources, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
"""
# find all the tables that can be pushdown too
# these are tables that are referenced in all blocks of a DNF
# (a.x AND b.x) OR (a.y AND c.y)
# only table a can be push down
pushdown_tables = set()
for a in predicates:
a_tables = exp.column_table_names(a)
for b in predicates:
a_tables &= exp.column_table_names(b)
pushdown_tables.update(a_tables)
conditions = {}
# pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
if table not in nodes:
continue
conditions[table] = (
exp.or_(conditions[table], predicate) if table in conditions else predicate
)
for name, node in nodes.items():
if name not in conditions:
continue
predicate = conditions[name]
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes = {}
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
for table in sorted(tables):
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down
# we want to find the root join or from statement
if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From)
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
with_ = source.parent.expression.args.get("with")
if with_ and with_.recursive:
return {}
node = source.expression
if isinstance(node, exp.Join):
if node.side and node.side != "RIGHT":
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions
has_window_expression = any(
select for select in node.selects if select.find(exp.Window)
)
# we can't push down predicates to select statements if they are referenced in
# multiple places.
if (
not node.args.get("group")
and scope_ref_count[id(source)] < 2
and not has_window_expression
):
nodes[table] = node
return nodes
def replace_aliases(source, predicate):
aliases = {}
for select in source.selects:
if isinstance(select, exp.Alias):
aliases[select.alias] = select.this
else:
aliases[select.name] = select
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name].copy()
return column
return predicate.transform(_replace_alias)