Merging upstream version 7.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
964bd62de9
commit
e6b3d2fe54
42 changed files with 1430 additions and 253 deletions
42
sqlglot/optimizer/eliminate_ctes.py
Normal file
42
sqlglot/optimizer/eliminate_ctes.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from sqlglot.optimizer.scope import Scope, build_scope
|
||||
|
||||
|
||||
def eliminate_ctes(expression):
|
||||
"""
|
||||
Remove unused CTEs from an expression.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> eliminate_ctes(expression).sql()
|
||||
'SELECT a FROM z'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
root = build_scope(expression)
|
||||
|
||||
ref_count = root.ref_count()
|
||||
|
||||
# Traverse the scope tree in reverse so we can remove chains of unused CTEs
|
||||
for scope in reversed(list(root.traverse())):
|
||||
if scope.is_cte:
|
||||
count = ref_count[id(scope)]
|
||||
if count <= 0:
|
||||
cte_node = scope.expression.parent
|
||||
with_node = cte_node.parent
|
||||
cte_node.pop()
|
||||
|
||||
# Pop the entire WITH clause if this is the last CTE
|
||||
if len(with_node.expressions) <= 0:
|
||||
with_node.pop()
|
||||
|
||||
# Decrement the ref count for all sources this CTE selects from
|
||||
for _, source in scope.selected_sources.values():
|
||||
if isinstance(source, Scope):
|
||||
ref_count[id(source)] -= 1
|
||||
|
||||
return expression
|
160
sqlglot/optimizer/eliminate_joins.py
Normal file
160
sqlglot/optimizer/eliminate_joins.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_joins(expression):
|
||||
"""
|
||||
Remove unused joins from an expression.
|
||||
|
||||
This only removes joins when we know that the join condition doesn't produce duplicate rows.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> eliminate_joins(expression).sql()
|
||||
'SELECT x.a FROM x'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
for scope in traverse_scope(expression):
|
||||
# If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
|
||||
# It's probably possible to infer this from the outputs of derived tables.
|
||||
# But for now, let's just skip this rule.
|
||||
if scope.unqualified_columns:
|
||||
continue
|
||||
|
||||
joins = scope.expression.args.get("joins", [])
|
||||
|
||||
# Reverse the joins so we can remove chains of unused joins
|
||||
for join in reversed(joins):
|
||||
alias = join.this.alias_or_name
|
||||
if _should_eliminate_join(scope, join, alias):
|
||||
join.pop()
|
||||
scope.remove_source(alias)
|
||||
return expression
|
||||
|
||||
|
||||
def _should_eliminate_join(scope, join, alias):
|
||||
inner_source = scope.sources.get(alias)
|
||||
return (
|
||||
isinstance(inner_source, Scope)
|
||||
and not _join_is_used(scope, join, alias)
|
||||
and (
|
||||
(join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
|
||||
or (not join.args.get("on") and _has_single_output_row(inner_source))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _join_is_used(scope, join, alias):
|
||||
# We need to find all columns that reference this join.
|
||||
# But columns in the ON clause shouldn't count.
|
||||
on = join.args.get("on")
|
||||
if on:
|
||||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
||||
else:
|
||||
on_clause_columns = set()
|
||||
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
|
||||
|
||||
|
||||
def _is_joined_on_all_unique_outputs(scope, join):
|
||||
unique_outputs = _unique_outputs(scope)
|
||||
if not unique_outputs:
|
||||
return False
|
||||
|
||||
_, join_keys, _ = join_condition(join)
|
||||
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
|
||||
return not remaining_unique_outputs
|
||||
|
||||
|
||||
def _unique_outputs(scope):
|
||||
"""Determine output columns of `scope` that must have a unique combination per row"""
|
||||
if scope.expression.args.get("distinct"):
|
||||
return set(scope.expression.named_selects)
|
||||
|
||||
group = scope.expression.args.get("group")
|
||||
if group:
|
||||
grouped_expressions = set(group.expressions)
|
||||
grouped_outputs = set()
|
||||
|
||||
unique_outputs = set()
|
||||
for select in scope.selects:
|
||||
output = select.unalias()
|
||||
if output in grouped_expressions:
|
||||
grouped_outputs.add(output)
|
||||
unique_outputs.add(select.alias_or_name)
|
||||
|
||||
# All the grouped expressions must be in the output
|
||||
if not grouped_expressions.difference(grouped_outputs):
|
||||
return unique_outputs
|
||||
else:
|
||||
return set()
|
||||
|
||||
if _has_single_output_row(scope):
|
||||
return set(scope.expression.named_selects)
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
def _has_single_output_row(scope):
|
||||
return isinstance(scope.expression, exp.Select) and (
|
||||
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
|
||||
or _is_limit_1(scope)
|
||||
or not scope.expression.args.get("from")
|
||||
)
|
||||
|
||||
|
||||
def _is_limit_1(scope):
|
||||
limit = scope.expression.args.get("limit")
|
||||
return limit and limit.expression.this == "1"
|
||||
|
||||
|
||||
def join_condition(join):
|
||||
"""
|
||||
Extract the join condition from a join expression.
|
||||
|
||||
Args:
|
||||
join (exp.Join)
|
||||
Returns:
|
||||
tuple[list[str], list[str], exp.Expression]:
|
||||
Tuple of (source key, join key, remaining predicate)
|
||||
"""
|
||||
name = join.this.alias_or_name
|
||||
on = join.args.get("on") or exp.TRUE
|
||||
on = on.copy()
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
||||
# find the join keys
|
||||
# SELECT
|
||||
# FROM x
|
||||
# JOIN y
|
||||
# ON x.a = y.b AND y.b > 1
|
||||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
if normalized(on):
|
||||
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
|
||||
if isinstance(condition, exp.EQ):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
right_tables = exp.column_table_names(right)
|
||||
|
||||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.TRUE)
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.TRUE)
|
||||
|
||||
on = simplify(on)
|
||||
remaining_condition = None if on == exp.TRUE else on
|
||||
|
||||
return source_key, join_key, remaining_condition
|
|
@ -8,7 +8,7 @@ from sqlglot.optimizer.simplify import simplify
|
|||
|
||||
def eliminate_subqueries(expression):
|
||||
"""
|
||||
Rewrite subqueries as CTES, deduplicating if possible.
|
||||
Rewrite derived tables as CTES, deduplicating if possible.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
|
|
|
@ -119,6 +119,23 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
|
||||
def _is_a_window_expression_in_unmergable_operation():
|
||||
window_expressions = inner_select.find_all(exp.Window)
|
||||
window_alias_names = {window.parent.alias_or_name for window in window_expressions}
|
||||
inner_select_name = inner_select.parent.alias_or_name
|
||||
unmergable_window_columns = [
|
||||
column
|
||||
for column in outer_scope.columns
|
||||
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
|
||||
]
|
||||
window_expressions_in_unmergable = [
|
||||
column
|
||||
for column in unmergable_window_columns
|
||||
if column.table == inner_select_name and column.name in window_alias_names
|
||||
]
|
||||
return any(window_expressions_in_unmergable)
|
||||
|
||||
return (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
|
@ -137,6 +154,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
and inner_select.args.get("where")
|
||||
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
||||
)
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
|
@ -23,6 +25,8 @@ RULES = (
|
|||
optimize_joins,
|
||||
eliminate_subqueries,
|
||||
merge_subqueries,
|
||||
eliminate_joins,
|
||||
eliminate_ctes,
|
||||
quote_identities,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
|
@ -22,15 +20,10 @@ def pushdown_predicates(expression):
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
scope_ref_count = defaultdict(lambda: 0)
|
||||
scopes = traverse_scope(expression)
|
||||
scopes.reverse()
|
||||
root = build_scope(expression)
|
||||
scope_ref_count = root.ref_count()
|
||||
|
||||
for scope in scopes:
|
||||
for _, source in scope.selected_sources.values():
|
||||
scope_ref_count[id(source)] += 1
|
||||
|
||||
for scope in scopes:
|
||||
for scope in reversed(list(root.traverse())):
|
||||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
|
@ -152,9 +145,11 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
return {}
|
||||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
# We can't push down window expressions
|
||||
has_window_expression = any(select for select in node.selects if select.find(exp.Window))
|
||||
# we can't push down predicates to select statements if they are referenced in
|
||||
# multiple places.
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2:
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
|
||||
nodes[table] = node
|
||||
return nodes
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import itertools
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlglot import exp
|
||||
|
@ -314,6 +315,16 @@ class Scope:
|
|||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
def unqualified_columns(self):
|
||||
"""
|
||||
Unqualified columns in the current scope.
|
||||
|
||||
Returns:
|
||||
list[exp.Column]: Unqualified columns
|
||||
"""
|
||||
return [c for c in self.columns if not c.table]
|
||||
|
||||
@property
|
||||
def join_hints(self):
|
||||
"""
|
||||
|
@ -403,6 +414,21 @@ class Scope:
|
|||
yield from child_scope.traverse()
|
||||
yield self
|
||||
|
||||
def ref_count(self):
|
||||
"""
|
||||
Count the number of times each scope in this tree is referenced.
|
||||
|
||||
Returns:
|
||||
dict[int, int]: Mapping of Scope instance ID to reference count
|
||||
"""
|
||||
scope_ref_count = defaultdict(lambda: 0)
|
||||
|
||||
for scope in self.traverse():
|
||||
for _, source in scope.selected_sources.values():
|
||||
scope_ref_count[id(source)] += 1
|
||||
|
||||
return scope_ref_count
|
||||
|
||||
|
||||
def traverse_scope(expression):
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue