320 lines
12 KiB
Python
320 lines
12 KiB
Python
from collections import defaultdict
|
|
|
|
from sqlglot import expressions as exp
|
|
from sqlglot.helper import find_new_name
|
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
|
from sqlglot.optimizer.simplify import simplify
|
|
|
|
|
|
def merge_subqueries(expression, leave_tables_isolated=False):
|
|
"""
|
|
Rewrite sqlglot AST to merge derived tables into the outer query.
|
|
|
|
This also merges CTEs if they are selected from only once.
|
|
|
|
Example:
|
|
>>> import sqlglot
|
|
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
|
|
>>> merge_subqueries(expression).sql()
|
|
'SELECT x.a FROM x JOIN y'
|
|
|
|
If `leave_tables_isolated` is True, this will not merge inner queries into outer
|
|
queries if it would result in multiple table selects in a single query:
|
|
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
|
|
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
|
|
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
|
|
|
|
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
|
|
|
|
Args:
|
|
expression (sqlglot.Expression): expression to optimize
|
|
leave_tables_isolated (bool):
|
|
Returns:
|
|
sqlglot.Expression: optimized expression
|
|
"""
|
|
expression = merge_ctes(expression, leave_tables_isolated)
|
|
expression = merge_derived_tables(expression, leave_tables_isolated)
|
|
return expression
|
|
|
|
|
|
# If a derived table has these Select args, it can't be merged
|
|
UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
|
|
"expressions",
|
|
"from",
|
|
"joins",
|
|
"where",
|
|
"order",
|
|
"hint",
|
|
}
|
|
|
|
|
|
def merge_ctes(expression, leave_tables_isolated=False):
|
|
scopes = traverse_scope(expression)
|
|
|
|
# All places where we select from CTEs.
|
|
# We key on the CTE scope so we can detect CTES that are selected from multiple times.
|
|
cte_selections = defaultdict(list)
|
|
for outer_scope in scopes:
|
|
for table, inner_scope in outer_scope.selected_sources.values():
|
|
if isinstance(inner_scope, Scope) and inner_scope.is_cte:
|
|
cte_selections[id(inner_scope)].append(
|
|
(
|
|
outer_scope,
|
|
inner_scope,
|
|
table,
|
|
)
|
|
)
|
|
|
|
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
|
|
for outer_scope, inner_scope, table in singular_cte_selections:
|
|
inner_select = inner_scope.expression.unnest()
|
|
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
|
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|
node_to_replace = table
|
|
if isinstance(node_to_replace.parent, exp.Alias):
|
|
node_to_replace = node_to_replace.parent
|
|
alias = node_to_replace.alias
|
|
else:
|
|
alias = table.name
|
|
|
|
_rename_inner_sources(outer_scope, inner_scope, alias)
|
|
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
|
_merge_expressions(outer_scope, inner_scope, alias)
|
|
_merge_joins(outer_scope, inner_scope, from_or_join)
|
|
_merge_where(outer_scope, inner_scope, from_or_join)
|
|
_merge_order(outer_scope, inner_scope)
|
|
_merge_hints(outer_scope, inner_scope)
|
|
_pop_cte(inner_scope)
|
|
return expression
|
|
|
|
|
|
def merge_derived_tables(expression, leave_tables_isolated=False):
|
|
for outer_scope in traverse_scope(expression):
|
|
for subquery in outer_scope.derived_tables:
|
|
inner_select = subquery.unnest()
|
|
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
|
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|
alias = subquery.alias_or_name
|
|
inner_scope = outer_scope.sources[alias]
|
|
|
|
_rename_inner_sources(outer_scope, inner_scope, alias)
|
|
_merge_from(outer_scope, inner_scope, subquery, alias)
|
|
_merge_expressions(outer_scope, inner_scope, alias)
|
|
_merge_joins(outer_scope, inner_scope, from_or_join)
|
|
_merge_where(outer_scope, inner_scope, from_or_join)
|
|
_merge_order(outer_scope, inner_scope)
|
|
_merge_hints(outer_scope, inner_scope)
|
|
return expression
|
|
|
|
|
|
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|
"""
|
|
Return True if `inner_select` can be merged into outer query.
|
|
|
|
Args:
|
|
outer_scope (Scope)
|
|
inner_select (exp.Select)
|
|
leave_tables_isolated (bool)
|
|
from_or_join (exp.From|exp.Join)
|
|
Returns:
|
|
bool: True if can be merged
|
|
"""
|
|
return (
|
|
isinstance(outer_scope.expression, exp.Select)
|
|
and isinstance(inner_select, exp.Select)
|
|
and isinstance(inner_select, exp.Select)
|
|
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
|
and inner_select.args.get("from")
|
|
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
|
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
|
and not (
|
|
isinstance(from_or_join, exp.Join)
|
|
and inner_select.args.get("where")
|
|
and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
|
|
)
|
|
and not (
|
|
isinstance(from_or_join, exp.From)
|
|
and inner_select.args.get("where")
|
|
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
|
)
|
|
)
|
|
|
|
|
|
def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|
"""
|
|
Renames any sources in the inner query that conflict with names in the outer query.
|
|
|
|
Args:
|
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
|
alias (str)
|
|
"""
|
|
taken = set(outer_scope.selected_sources)
|
|
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
|
conflicts = conflicts - {alias}
|
|
|
|
for conflict in conflicts:
|
|
new_name = find_new_name(taken, conflict)
|
|
|
|
source, _ = inner_scope.selected_sources[conflict]
|
|
new_alias = exp.to_identifier(new_name)
|
|
|
|
if isinstance(source, exp.Subquery):
|
|
source.set("alias", exp.TableAlias(this=new_alias))
|
|
elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
|
|
source.parent.set("alias", new_alias)
|
|
elif isinstance(source, exp.Table):
|
|
source.replace(exp.alias_(source.copy(), new_alias))
|
|
|
|
for column in inner_scope.source_columns(conflict):
|
|
column.set("table", exp.to_identifier(new_name))
|
|
|
|
inner_scope.rename_source(conflict, new_name)
|
|
|
|
|
|
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|
"""
|
|
Merge FROM clause of inner query into outer query.
|
|
|
|
Args:
|
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
|
node_to_replace (exp.Subquery|exp.Table)
|
|
alias (str)
|
|
"""
|
|
new_subquery = inner_scope.expression.args.get("from").expressions[0]
|
|
node_to_replace.replace(new_subquery)
|
|
for join_hint in outer_scope.join_hints:
|
|
tables = join_hint.find_all(exp.Table)
|
|
for table in tables:
|
|
if table.alias_or_name == node_to_replace.alias_or_name:
|
|
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
|
|
table.set("this", exp.to_identifier(new_table.alias_or_name))
|
|
outer_scope.remove_source(alias)
|
|
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
|
|
|
|
|
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
|
"""
|
|
Merge JOIN clauses of inner query into outer query.
|
|
|
|
Args:
|
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
|
from_or_join (exp.From|exp.Join)
|
|
"""
|
|
|
|
new_joins = []
|
|
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
|
|
for subquery in comma_joins:
|
|
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
|
|
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
|
|
|
|
joins = inner_scope.expression.args.get("joins") or []
|
|
for join in joins:
|
|
new_joins.append(join)
|
|
outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
|
|
|
|
if new_joins:
|
|
outer_joins = outer_scope.expression.args.get("joins", [])
|
|
|
|
# Maintain the join order
|
|
if isinstance(from_or_join, exp.From):
|
|
position = 0
|
|
else:
|
|
position = outer_joins.index(from_or_join) + 1
|
|
outer_joins[position:position] = new_joins
|
|
|
|
outer_scope.expression.set("joins", outer_joins)
|
|
|
|
|
|
def _merge_expressions(outer_scope, inner_scope, alias):
|
|
"""
|
|
Merge projections of inner query into outer query.
|
|
|
|
Args:
|
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
|
alias (str)
|
|
"""
|
|
# Collect all columns that reference the alias of the inner query
|
|
outer_columns = defaultdict(list)
|
|
for column in outer_scope.columns:
|
|
if column.table == alias:
|
|
outer_columns[column.name].append(column)
|
|
|
|
# Replace columns with the projection expression in the inner query
|
|
for expression in inner_scope.expression.expressions:
|
|
projection_name = expression.alias_or_name
|
|
if not projection_name:
|
|
continue
|
|
columns_to_replace = outer_columns.get(projection_name, [])
|
|
for column in columns_to_replace:
|
|
column.replace(expression.unalias().copy())
|
|
|
|
|
|
def _merge_where(outer_scope, inner_scope, from_or_join):
|
|
"""
|
|
Merge WHERE clause of inner query into outer query.
|
|
|
|
Args:
|
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
|
from_or_join (exp.From|exp.Join)
|
|
"""
|
|
where = inner_scope.expression.args.get("where")
|
|
if not where or not where.this:
|
|
return
|
|
|
|
if isinstance(from_or_join, exp.Join):
|
|
# Merge predicates from an outer join to the ON clause
|
|
from_or_join.on(where.this, copy=False)
|
|
from_or_join.set("on", simplify(from_or_join.args.get("on")))
|
|
else:
|
|
outer_scope.expression.where(where.this, copy=False)
|
|
outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
|
|
|
|
|
|
def _merge_order(outer_scope, inner_scope):
|
|
"""
|
|
Merge ORDER clause of inner query into outer query.
|
|
|
|
Args:
|
|
outer_scope (sqlglot.optimizer.scope.Scope)
|
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
|
"""
|
|
if (
|
|
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
|
|
or len(outer_scope.selected_sources) != 1
|
|
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
|
|
):
|
|
return
|
|
|
|
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
|
|
|
|
|
def _merge_hints(outer_scope, inner_scope):
|
|
inner_scope_hint = inner_scope.expression.args.get("hint")
|
|
if not inner_scope_hint:
|
|
return
|
|
outer_scope_hint = outer_scope.expression.args.get("hint")
|
|
if outer_scope_hint:
|
|
for hint_expression in inner_scope_hint.expressions:
|
|
outer_scope_hint.append("expressions", hint_expression)
|
|
else:
|
|
outer_scope.expression.set("hint", inner_scope_hint)
|
|
|
|
|
|
def _pop_cte(inner_scope):
|
|
"""
|
|
Remove CTE from the AST.
|
|
|
|
Args:
|
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
|
"""
|
|
cte = inner_scope.expression.parent
|
|
with_ = cte.parent
|
|
if len(with_.expressions) == 1:
|
|
with_.pop()
|
|
else:
|
|
cte.pop()
|