162 lines
5.2 KiB
Python
162 lines
5.2 KiB
Python
from sqlglot import expressions as exp
|
|
from sqlglot.optimizer.normalize import normalized
|
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
|
from sqlglot.optimizer.simplify import simplify
|
|
|
|
|
|
def eliminate_joins(expression):
|
|
"""
|
|
Remove unused joins from an expression.
|
|
|
|
This only removes joins when we know that the join condition doesn't produce duplicate rows.
|
|
|
|
Example:
|
|
>>> import sqlglot
|
|
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
|
|
>>> expression = sqlglot.parse_one(sql)
|
|
>>> eliminate_joins(expression).sql()
|
|
'SELECT x.a FROM x'
|
|
|
|
Args:
|
|
expression (sqlglot.Expression): expression to optimize
|
|
Returns:
|
|
sqlglot.Expression: optimized expression
|
|
"""
|
|
for scope in traverse_scope(expression):
|
|
# If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
|
|
# It's probably possible to infer this from the outputs of derived tables.
|
|
# But for now, let's just skip this rule.
|
|
if scope.unqualified_columns:
|
|
continue
|
|
|
|
joins = scope.expression.args.get("joins", [])
|
|
|
|
# Reverse the joins so we can remove chains of unused joins
|
|
for join in reversed(joins):
|
|
alias = join.this.alias_or_name
|
|
if _should_eliminate_join(scope, join, alias):
|
|
join.pop()
|
|
scope.remove_source(alias)
|
|
return expression
|
|
|
|
|
|
def _should_eliminate_join(scope, join, alias):
|
|
inner_source = scope.sources.get(alias)
|
|
return (
|
|
isinstance(inner_source, Scope)
|
|
and not _join_is_used(scope, join, alias)
|
|
and (
|
|
(join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
|
|
or (not join.args.get("on") and _has_single_output_row(inner_source))
|
|
)
|
|
)
|
|
|
|
|
|
def _join_is_used(scope, join, alias):
|
|
# We need to find all columns that reference this join.
|
|
# But columns in the ON clause shouldn't count.
|
|
on = join.args.get("on")
|
|
if on:
|
|
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
|
else:
|
|
on_clause_columns = set()
|
|
return any(
|
|
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
|
|
)
|
|
|
|
|
|
def _is_joined_on_all_unique_outputs(scope, join):
|
|
unique_outputs = _unique_outputs(scope)
|
|
if not unique_outputs:
|
|
return False
|
|
|
|
_, join_keys, _ = join_condition(join)
|
|
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
|
|
return not remaining_unique_outputs
|
|
|
|
|
|
def _unique_outputs(scope):
|
|
"""Determine output columns of `scope` that must have a unique combination per row"""
|
|
if scope.expression.args.get("distinct"):
|
|
return set(scope.expression.named_selects)
|
|
|
|
group = scope.expression.args.get("group")
|
|
if group:
|
|
grouped_expressions = set(group.expressions)
|
|
grouped_outputs = set()
|
|
|
|
unique_outputs = set()
|
|
for select in scope.selects:
|
|
output = select.unalias()
|
|
if output in grouped_expressions:
|
|
grouped_outputs.add(output)
|
|
unique_outputs.add(select.alias_or_name)
|
|
|
|
# All the grouped expressions must be in the output
|
|
if not grouped_expressions.difference(grouped_outputs):
|
|
return unique_outputs
|
|
else:
|
|
return set()
|
|
|
|
if _has_single_output_row(scope):
|
|
return set(scope.expression.named_selects)
|
|
|
|
return set()
|
|
|
|
|
|
def _has_single_output_row(scope):
|
|
return isinstance(scope.expression, exp.Select) and (
|
|
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
|
|
or _is_limit_1(scope)
|
|
or not scope.expression.args.get("from")
|
|
)
|
|
|
|
|
|
def _is_limit_1(scope):
|
|
limit = scope.expression.args.get("limit")
|
|
return limit and limit.expression.this == "1"
|
|
|
|
|
|
def join_condition(join):
|
|
"""
|
|
Extract the join condition from a join expression.
|
|
|
|
Args:
|
|
join (exp.Join)
|
|
Returns:
|
|
tuple[list[str], list[str], exp.Expression]:
|
|
Tuple of (source key, join key, remaining predicate)
|
|
"""
|
|
name = join.this.alias_or_name
|
|
on = join.args.get("on") or exp.TRUE
|
|
on = on.copy()
|
|
source_key = []
|
|
join_key = []
|
|
|
|
# find the join keys
|
|
# SELECT
|
|
# FROM x
|
|
# JOIN y
|
|
# ON x.a = y.b AND y.b > 1
|
|
#
|
|
# should pull y.b as the join key and x.a as the source key
|
|
if normalized(on):
|
|
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
|
|
if isinstance(condition, exp.EQ):
|
|
left, right = condition.unnest_operands()
|
|
left_tables = exp.column_table_names(left)
|
|
right_tables = exp.column_table_names(right)
|
|
|
|
if name in left_tables and name not in right_tables:
|
|
join_key.append(left)
|
|
source_key.append(right)
|
|
condition.replace(exp.TRUE)
|
|
elif name in right_tables and name not in left_tables:
|
|
join_key.append(right)
|
|
source_key.append(left)
|
|
condition.replace(exp.TRUE)
|
|
|
|
on = simplify(on)
|
|
remaining_condition = None if on == exp.TRUE else on
|
|
|
|
return source_key, join_key, remaining_condition
|