1
0
Fork 0
sqlglot/sqlglot/planner.py
Daniel Baumann 768d386bf5
Adding upstream version 7.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 14:46:14 +01:00

301 lines
8.3 KiB
Python

import itertools
import math
from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError
from sqlglot.optimizer.eliminate_joins import join_condition
class Plan:
def __init__(self, expression):
self.expression = expression
self.root = Step.from_expression(self.expression)
self._dag = {}
@property
def dag(self):
if not self._dag:
dag = {}
nodes = {self.root}
while nodes:
node = nodes.pop()
dag[node] = set()
for dep in node.dependencies:
dag[node].add(dep)
nodes.add(dep)
self._dag = dag
return self._dag
@property
def leaves(self):
return (node for node, deps in self.dag.items() if not deps)
class Step:
@classmethod
def from_expression(cls, expression, ctes=None):
"""
Build a DAG of Steps from a SQL expression.
Giving an expression like:
SELECT x.a, SUM(x.b)
FROM x
JOIN y
ON x.a = y.a
GROUP BY x.a
Transform it into a DAG of the form:
Aggregate(x.a, SUM(x.b))
Join(y)
Scan(x)
Scan(y)
This can then more easily be executed on by an engine.
"""
ctes = ctes or {}
with_ = expression.args.get("with")
# CTEs break the mold of scope and introduce themselves to all in the context.
if with_:
ctes = ctes.copy()
for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes)
step.name = cte.alias
ctes[step.name] = step
from_ = expression.args.get("from")
if from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
step = Scan.from_expression(from_[0], ctes)
else:
raise UnsupportedError("Static selects are unsupported.")
joins = expression.args.get("joins")
if joins:
join = Join.from_joins(joins, ctes)
join.name = step.name
join.add_dependency(step)
step = join
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = []
sequence = itertools.count()
for e in expression.expressions:
aggregation = e.find(exp.AggFunc)
if aggregation:
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
for operand in aggregation.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], step.name, quoted=True))
else:
projections.append(e)
where = expression.args.get("where")
if where:
step.condition = where.this
group = expression.args.get("group")
if group:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
aggregate.aggregations = aggregations
aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
aggregate.add_dependency(step)
step = aggregate
having = expression.args.get("having")
if having:
step.condition = having.this
order = expression.args.get("order")
if order:
sort = Sort()
sort.name = step.name
sort.key = order.expressions
sort.add_dependency(step)
step = sort
for k in sort.key + projections:
for column in k.find_all(exp.Column):
column.set("table", exp.to_identifier(step.name, quoted=True))
step.projections = projections
limit = expression.args.get("limit")
if limit:
step.limit = int(limit.text("expression"))
return step
def __init__(self):
self.name = None
self.dependencies = set()
self.dependents = set()
self.projections = []
self.limit = math.inf
self.condition = None
def add_dependency(self, dependency):
self.dependencies.add(dependency)
dependency.dependents.add(self)
def __repr__(self):
return self.to_s()
def to_s(self, level=0):
indent = " " * level
nested = f"{indent} "
context = self._to_s(f"{nested} ")
if context:
context = [f"{nested}Context:"] + context
lines = [
f"{indent}- {self.__class__.__name__}: {self.name}",
*context,
f"{nested}Projections:",
]
for expression in self.projections:
lines.append(f"{nested} - {expression.sql()}")
if self.condition:
lines.append(f"{nested}Condition: {self.condition.sql()}")
if self.dependencies:
lines.append(f"{nested}Dependencies:")
for dependency in self.dependencies:
lines.append(" " + dependency.to_s(level + 1))
return "\n".join(lines)
def _to_s(self, _indent):
return []
class Scan(Step):
@classmethod
def from_expression(cls, expression, ctes=None):
table = expression.this
alias_ = expression.alias
if not alias_:
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
if isinstance(expression, exp.Subquery):
step = Step.from_expression(table, ctes)
step.name = alias_
return step
step = Scan()
step.name = alias_
step.source = expression
if table.name in ctes:
step.add_dependency(ctes[table.name])
return step
def __init__(self):
super().__init__()
self.source = None
def _to_s(self, indent):
return [f"{indent}Source: {self.source.sql()}"]
class Write(Step):
pass
class Join(Step):
@classmethod
def from_joins(cls, joins, ctes=None):
step = Join()
for join in joins:
source_key, join_key, condition = join_condition(join)
step.joins[join.this.alias_or_name] = {
"side": join.side,
"join_key": join_key,
"source_key": source_key,
"condition": condition,
}
step.add_dependency(Scan.from_expression(join.this, ctes))
return step
def __init__(self):
super().__init__()
self.joins = {}
def _to_s(self, indent):
lines = []
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side']}")
if join.get("condition"):
lines.append(f"{indent}On: {join['condition'].sql()}")
return lines
class Aggregate(Step):
def __init__(self):
super().__init__()
self.aggregations = []
self.operands = []
self.group = []
self.source = None
def _to_s(self, indent):
lines = [f"{indent}Aggregations:"]
for expression in self.aggregations:
lines.append(f"{indent} - {expression.sql()}")
if self.group:
lines.append(f"{indent}Group:")
for expression in self.group:
lines.append(f"{indent} - {expression.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
for expression in self.operands:
lines.append(f"{indent} - {expression.sql()}")
return lines
class Sort(Step):
def __init__(self):
super().__init__()
self.key = None
def _to_s(self, indent):
lines = [f"{indent}Key:"]
for expression in self.key:
lines.append(f"{indent} - {expression.sql()}")
return lines