2025-02-13 21:56:02 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
import itertools
|
2025-02-13 21:56:02 +01:00
|
|
|
import typing as t
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
from sqlglot import expressions as exp
|
|
|
|
from sqlglot.helper import find_new_name
|
2025-02-13 21:56:02 +01:00
|
|
|
from sqlglot.optimizer.scope import Scope, build_scope
|
|
|
|
|
|
|
|
if t.TYPE_CHECKING:
|
|
|
|
ExistingCTEsMapping = t.Dict[exp.Expression, str]
|
|
|
|
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
|
2025-02-13 06:15:54 +01:00
|
|
|
"""
|
2025-02-13 14:46:58 +01:00
|
|
|
Rewrite derived tables as CTES, deduplicating if possible.
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> import sqlglot
|
2025-02-13 14:40:43 +01:00
|
|
|
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
|
2025-02-13 06:15:54 +01:00
|
|
|
>>> eliminate_subqueries(expression).sql()
|
2025-02-13 14:40:43 +01:00
|
|
|
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
|
|
|
|
|
|
|
|
This also deduplicates common subqueries:
|
2025-02-13 15:57:23 +01:00
|
|
|
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
|
2025-02-13 14:40:43 +01:00
|
|
|
>>> eliminate_subqueries(expression).sql()
|
2025-02-13 15:57:23 +01:00
|
|
|
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
Args:
|
2025-02-13 14:40:43 +01:00
|
|
|
expression (sqlglot.Expression): expression
|
2025-02-13 06:15:54 +01:00
|
|
|
Returns:
|
2025-02-13 14:40:43 +01:00
|
|
|
sqlglot.Expression: expression
|
2025-02-13 06:15:54 +01:00
|
|
|
"""
|
2025-02-13 14:40:43 +01:00
|
|
|
if isinstance(expression, exp.Subquery):
|
|
|
|
# It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
|
|
|
|
eliminate_subqueries(expression.this)
|
|
|
|
return expression
|
|
|
|
|
|
|
|
root = build_scope(expression)
|
|
|
|
|
2025-02-13 15:57:23 +01:00
|
|
|
if not root:
|
|
|
|
return expression
|
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
# Map of alias->Scope|Table
|
|
|
|
# These are all aliases that are already used in the expression.
|
|
|
|
# We don't want to create new CTEs that conflict with these names.
|
2025-02-13 21:56:02 +01:00
|
|
|
taken: TakenNameMapping = {}
|
2025-02-13 14:40:43 +01:00
|
|
|
|
|
|
|
# All CTE aliases in the root scope are taken
|
|
|
|
for scope in root.cte_scopes:
|
|
|
|
taken[scope.expression.parent.alias] = scope
|
|
|
|
|
|
|
|
# All table names are taken
|
|
|
|
for scope in root.traverse():
|
2025-02-13 14:53:05 +01:00
|
|
|
taken.update(
|
|
|
|
{
|
|
|
|
source.name: source
|
|
|
|
for _, source in scope.sources.items()
|
|
|
|
if isinstance(source, exp.Table)
|
|
|
|
}
|
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
# Map of Expression->alias
|
|
|
|
# Existing CTES in the root expression. We'll use this for deduplication.
|
2025-02-13 21:56:02 +01:00
|
|
|
existing_ctes: ExistingCTEsMapping = {}
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
with_ = root.expression.args.get("with")
|
2025-02-13 15:01:55 +01:00
|
|
|
recursive = False
|
2025-02-13 14:40:43 +01:00
|
|
|
if with_:
|
2025-02-13 15:01:55 +01:00
|
|
|
recursive = with_.args.get("recursive")
|
2025-02-13 14:40:43 +01:00
|
|
|
for cte in with_.expressions:
|
|
|
|
existing_ctes[cte.this] = cte.alias
|
|
|
|
new_ctes = []
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
# We're adding more CTEs, but we want to maintain the DAG order.
|
|
|
|
# Derived tables within an existing CTE need to come before the existing CTE.
|
|
|
|
for cte_scope in root.cte_scopes:
|
|
|
|
# Append all the new CTEs from this existing CTE
|
|
|
|
for scope in cte_scope.traverse():
|
2025-02-13 14:56:25 +01:00
|
|
|
if scope is cte_scope:
|
|
|
|
# Don't try to eliminate this CTE itself
|
|
|
|
continue
|
2025-02-13 14:40:43 +01:00
|
|
|
new_cte = _eliminate(scope, existing_ctes, taken)
|
|
|
|
if new_cte:
|
|
|
|
new_ctes.append(new_cte)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
# Append the existing CTE itself
|
|
|
|
new_ctes.append(cte_scope.expression.parent)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
# Now append the rest
|
2025-02-13 15:26:26 +01:00
|
|
|
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
|
2025-02-13 14:40:43 +01:00
|
|
|
for child_scope in scope.traverse():
|
|
|
|
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
|
|
|
if new_cte:
|
|
|
|
new_ctes.append(new_cte)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
if new_ctes:
|
2025-02-13 21:33:25 +01:00
|
|
|
query = expression.expression if isinstance(expression, exp.DDL) else expression
|
|
|
|
query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
return expression
|
2025-02-13 14:40:43 +01:00
|
|
|
|
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
def _eliminate(
|
|
|
|
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
|
|
) -> t.Optional[exp.Expression]:
|
2025-02-13 15:26:26 +01:00
|
|
|
if scope.is_derived_table:
|
2025-02-13 14:40:43 +01:00
|
|
|
return _eliminate_derived_table(scope, existing_ctes, taken)
|
|
|
|
|
2025-02-13 14:56:25 +01:00
|
|
|
if scope.is_cte:
|
|
|
|
return _eliminate_cte(scope, existing_ctes, taken)
|
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
return None
|
|
|
|
|
2025-02-13 14:40:43 +01:00
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
def _eliminate_derived_table(
|
|
|
|
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
|
|
) -> t.Optional[exp.Expression]:
|
2025-02-13 20:48:36 +01:00
|
|
|
# This makes sure that we don't:
|
|
|
|
# - drop the "pivot" arg from a pivoted subquery
|
|
|
|
# - eliminate a lateral correlated subquery
|
|
|
|
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
|
2025-02-13 15:57:23 +01:00
|
|
|
return None
|
|
|
|
|
2025-02-13 20:58:22 +01:00
|
|
|
# Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
|
|
|
|
to_replace = scope.expression.parent.unwrap()
|
2025-02-13 14:56:25 +01:00
|
|
|
name, cte = _new_cte(scope, existing_ctes, taken)
|
2025-02-13 20:58:22 +01:00
|
|
|
table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
|
|
|
|
table.set("joins", to_replace.args.get("joins"))
|
2025-02-13 14:56:25 +01:00
|
|
|
|
2025-02-13 20:58:22 +01:00
|
|
|
to_replace.replace(table)
|
2025-02-13 14:56:25 +01:00
|
|
|
|
|
|
|
return cte
|
|
|
|
|
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
def _eliminate_cte(
|
|
|
|
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
|
|
) -> t.Optional[exp.Expression]:
|
2025-02-13 14:56:25 +01:00
|
|
|
parent = scope.expression.parent
|
|
|
|
name, cte = _new_cte(scope, existing_ctes, taken)
|
|
|
|
|
|
|
|
with_ = parent.parent
|
|
|
|
parent.pop()
|
|
|
|
if not with_.expressions:
|
|
|
|
with_.pop()
|
|
|
|
|
|
|
|
# Rename references to this CTE
|
|
|
|
for child_scope in scope.parent.traverse():
|
|
|
|
for table, source in child_scope.selected_sources.values():
|
|
|
|
if source is scope:
|
2025-02-13 15:57:23 +01:00
|
|
|
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
|
2025-02-13 14:56:25 +01:00
|
|
|
table.replace(new_table)
|
|
|
|
|
|
|
|
return cte
|
|
|
|
|
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
def _new_cte(
|
|
|
|
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
|
|
|
) -> t.Tuple[str, t.Optional[exp.Expression]]:
|
2025-02-13 14:56:25 +01:00
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
tuple of (name, cte)
|
|
|
|
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
|
|
|
|
If this CTE duplicates an existing CTE, `cte` will be None.
|
|
|
|
"""
|
2025-02-13 14:40:43 +01:00
|
|
|
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
|
|
|
parent = scope.expression.parent
|
2025-02-13 14:56:25 +01:00
|
|
|
name = parent.alias
|
2025-02-13 14:40:43 +01:00
|
|
|
|
2025-02-13 14:56:25 +01:00
|
|
|
if not name:
|
|
|
|
name = find_new_name(taken=taken, base="cte")
|
2025-02-13 14:40:43 +01:00
|
|
|
|
|
|
|
if duplicate_cte_alias:
|
|
|
|
name = duplicate_cte_alias
|
2025-02-13 14:56:25 +01:00
|
|
|
elif taken.get(name):
|
|
|
|
name = find_new_name(taken=taken, base=name)
|
2025-02-13 14:40:43 +01:00
|
|
|
|
|
|
|
taken[name] = scope
|
|
|
|
|
|
|
|
if not duplicate_cte_alias:
|
|
|
|
existing_ctes[scope.expression] = name
|
2025-02-13 14:56:25 +01:00
|
|
|
cte = exp.CTE(
|
2025-02-13 14:40:43 +01:00
|
|
|
this=scope.expression,
|
|
|
|
alias=exp.TableAlias(this=exp.to_identifier(name)),
|
|
|
|
)
|
2025-02-13 14:56:25 +01:00
|
|
|
else:
|
|
|
|
cte = None
|
|
|
|
return name, cte
|