Adding upstream version 6.0.4.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d01130b3f1
commit
527597d2af
122 changed files with 23162 additions and 0 deletions
2
sqlglot/optimizer/__init__.py
Normal file
2
sqlglot/optimizer/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from sqlglot.optimizer.optimizer import optimize
|
||||
from sqlglot.optimizer.schema import Schema
|
48
sqlglot/optimizer/eliminate_subqueries.py
Normal file
48
sqlglot/optimizer/eliminate_subqueries.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import alias, exp, select, table
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_subqueries(expression):
|
||||
"""
|
||||
Rewrite duplicate subqueries from sqlglot AST.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
|
||||
>>> eliminate_subqueries(expression).sql()
|
||||
'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
schema (dict|sqlglot.optimizer.Schema): Database schema
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
"""
|
||||
expression = simplify(expression)
|
||||
queries = {}
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
query = scope.expression
|
||||
queries[query] = queries.get(query, []) + [query]
|
||||
|
||||
sequence = itertools.count()
|
||||
|
||||
for query, duplicates in queries.items():
|
||||
if len(duplicates) == 1:
|
||||
continue
|
||||
|
||||
alias_ = f"_e_{next(sequence)}"
|
||||
|
||||
for dup in duplicates:
|
||||
parent = dup.parent
|
||||
if isinstance(parent, exp.Subquery):
|
||||
parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
|
||||
elif isinstance(parent, exp.Union):
|
||||
dup.replace(select("*").from_(alias_))
|
||||
|
||||
expression.with_(alias_, as_=query, copy=False)
|
||||
|
||||
return expression
|
16
sqlglot/optimizer/expand_multi_table_selects.py
Normal file
16
sqlglot/optimizer/expand_multi_table_selects.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from sqlglot import exp
|
||||
|
||||
|
||||
def expand_multi_table_selects(expression):
|
||||
for from_ in expression.find_all(exp.From):
|
||||
parent = from_.parent
|
||||
|
||||
for query in from_.expressions[1:]:
|
||||
parent.join(
|
||||
query,
|
||||
join_type="CROSS",
|
||||
copy=False,
|
||||
)
|
||||
from_.expressions.remove(query)
|
||||
|
||||
return expression
|
31
sqlglot/optimizer/isolate_table_selects.py
Normal file
31
sqlglot/optimizer/isolate_table_selects.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
|
||||
def isolate_table_selects(expression):
|
||||
for scope in traverse_scope(expression):
|
||||
if len(scope.selected_sources) == 1:
|
||||
continue
|
||||
|
||||
for (_, source) in scope.selected_sources.values():
|
||||
if not isinstance(source, exp.Table):
|
||||
continue
|
||||
|
||||
if not isinstance(source.parent, exp.Alias):
|
||||
raise OptimizeError(
|
||||
"Tables require an alias. Run qualify_tables optimization."
|
||||
)
|
||||
|
||||
parent = source.parent
|
||||
|
||||
parent.replace(
|
||||
exp.select("*")
|
||||
.from_(
|
||||
alias(source, source.name or parent.alias, table=True),
|
||||
copy=False,
|
||||
)
|
||||
.subquery(parent.alias, copy=False)
|
||||
)
|
||||
|
||||
return expression
|
136
sqlglot/optimizer/normalize.py
Normal file
136
sqlglot/optimizer/normalize.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
|
||||
|
||||
|
||||
def normalize(expression, dnf=False, max_distance=128):
|
||||
"""
|
||||
Rewrite sqlglot AST into conjunctive normal form.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("(x AND y) OR z")
|
||||
>>> normalize(expression).sql()
|
||||
'(x OR z) AND (y OR z)'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to normalize
|
||||
dnf (bool): rewrite in disjunctive normal form instead
|
||||
max_distance (int): the maximal estimated distance from cnf to attempt conversion
|
||||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
expression = simplify(expression)
|
||||
|
||||
expression = while_changing(
|
||||
expression, lambda e: distributive_law(e, dnf, max_distance)
|
||||
)
|
||||
return simplify(expression)
|
||||
|
||||
|
||||
def normalized(expression, dnf=False):
|
||||
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
||||
|
||||
return not any(
|
||||
connector.find_ancestor(ancestor) for connector in expression.find_all(root)
|
||||
)
|
||||
|
||||
|
||||
def normalization_distance(expression, dnf=False):
|
||||
"""
|
||||
The difference in the number of predicates between the current expression and the normalized form.
|
||||
|
||||
This is used as an estimate of the cost of the conversion which is exponential in complexity.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
|
||||
>>> normalization_distance(expression)
|
||||
4
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to compute distance
|
||||
dnf (bool): compute to dnf distance instead
|
||||
Returns:
|
||||
int: difference
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
len(list(expression.find_all(exp.Connector))) + 1
|
||||
)
|
||||
|
||||
|
||||
def _predicate_lengths(expression, dnf):
|
||||
"""
|
||||
Returns a list of predicate lengths when expanded to normalized form.
|
||||
|
||||
(A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
|
||||
"""
|
||||
expression = expression.unnest()
|
||||
|
||||
if not isinstance(expression, exp.Connector):
|
||||
return [1]
|
||||
|
||||
left, right = expression.args.values()
|
||||
|
||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||
x = [
|
||||
a + b
|
||||
for a in _predicate_lengths(left, dnf)
|
||||
for b in _predicate_lengths(right, dnf)
|
||||
]
|
||||
return x
|
||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
def distributive_law(expression, dnf, max_distance):
|
||||
"""
|
||||
x OR (y AND z) -> (x OR y) AND (x OR z)
|
||||
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
|
||||
"""
|
||||
if isinstance(expression.unnest(), exp.Connector):
|
||||
if normalization_distance(expression, dnf) > max_distance:
|
||||
return expression
|
||||
|
||||
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
|
||||
if isinstance(expression, from_exp):
|
||||
a, b = expression.unnest_operands()
|
||||
|
||||
from_func = exp.and_ if from_exp == exp.And else exp.or_
|
||||
to_func = exp.and_ if to_exp == exp.And else exp.or_
|
||||
|
||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(
|
||||
tuple(b.find_all(exp.Connector))
|
||||
):
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
if isinstance(a, to_exp):
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
if isinstance(b, to_exp):
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _distribute(a, b, from_func, to_func):
|
||||
if isinstance(a, exp.Connector):
|
||||
exp.replace_children(
|
||||
a,
|
||||
lambda c: to_func(
|
||||
exp.paren(from_func(c, b.left)),
|
||||
exp.paren(from_func(c, b.right)),
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(from_func(a, b.left), from_func(a, b.right))
|
||||
|
||||
return _simplify(a)
|
||||
|
||||
|
||||
def _simplify(node):
|
||||
node = uniq_sort(flatten(node))
|
||||
exp.replace_children(node, _simplify)
|
||||
return node
|
75
sqlglot/optimizer/optimize_joins.py
Normal file
75
sqlglot/optimizer/optimize_joins.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import tsort
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def optimize_joins(expression):
|
||||
"""
|
||||
Removes cross joins if possible and reorder joins based on predicate dependencies.
|
||||
"""
|
||||
for select in expression.find_all(exp.Select):
|
||||
references = {}
|
||||
cross_joins = []
|
||||
|
||||
for join in select.args.get("joins", []):
|
||||
name = join.this.alias_or_name
|
||||
tables = other_table_names(join, name)
|
||||
|
||||
if tables:
|
||||
for table in tables:
|
||||
references[table] = references.get(table, []) + [join]
|
||||
else:
|
||||
cross_joins.append((name, join))
|
||||
|
||||
for name, join in cross_joins:
|
||||
for dep in references.get(name, []):
|
||||
on = dep.args["on"]
|
||||
on = on.replace(simplify(on))
|
||||
|
||||
if isinstance(on, exp.Connector):
|
||||
for predicate in on.flatten():
|
||||
if name in exp.column_table_names(predicate):
|
||||
predicate.replace(exp.TRUE)
|
||||
join.on(predicate, copy=False)
|
||||
|
||||
expression = reorder_joins(expression)
|
||||
expression = normalize(expression)
|
||||
return expression
|
||||
|
||||
|
||||
def reorder_joins(expression):
|
||||
"""
|
||||
Reorder joins by topological sort order based on predicate references.
|
||||
"""
|
||||
for from_ in expression.find_all(exp.From):
|
||||
head = from_.expressions[0]
|
||||
parent = from_.parent
|
||||
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
|
||||
dag = {head.alias_or_name: []}
|
||||
|
||||
for name, join in joins.items():
|
||||
dag[name] = other_table_names(join, name)
|
||||
|
||||
parent.set(
|
||||
"joins",
|
||||
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def normalize(expression):
|
||||
"""
|
||||
Remove INNER and OUTER from joins as they are optional.
|
||||
"""
|
||||
for join in expression.find_all(exp.Join):
|
||||
if join.kind != "CROSS":
|
||||
join.set("kind", None)
|
||||
return expression
|
||||
|
||||
|
||||
def other_table_names(join, exclude):
|
||||
return [
|
||||
name
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
||||
if name != exclude
|
||||
]
|
43
sqlglot/optimizer/optimizer.py
Normal file
43
sqlglot/optimizer/optimizer.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.normalize import normalize
|
||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.quote_identities import quote_identities
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
|
||||
|
||||
def optimize(expression, schema=None, db=None, catalog=None):
|
||||
"""
|
||||
Rewrite a sqlglot AST into an optimized form.
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
schema (dict|sqlglot.optimizer.Schema): database schema.
|
||||
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
|
||||
the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
expression = expression.copy()
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog)
|
||||
expression = isolate_table_selects(expression)
|
||||
expression = qualify_columns(expression, schema)
|
||||
expression = pushdown_projections(expression)
|
||||
expression = normalize(expression)
|
||||
expression = unnest_subqueries(expression)
|
||||
expression = expand_multi_table_selects(expression)
|
||||
expression = pushdown_predicates(expression)
|
||||
expression = optimize_joins(expression)
|
||||
expression = eliminate_subqueries(expression)
|
||||
expression = quote_identities(expression)
|
||||
return expression
|
176
sqlglot/optimizer/pushdown_predicates.py
Normal file
176
sqlglot/optimizer/pushdown_predicates.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def pushdown_predicates(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> pushdown_predicates(expression).sql()
|
||||
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
for scope in reversed(traverse_scope(expression)):
|
||||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
pushdown(where.this, scope.selected_sources)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.this.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def pushdown(condition, sources):
|
||||
if not condition:
|
||||
return
|
||||
|
||||
condition = condition.replace(simplify(condition))
|
||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||
|
||||
predicates = list(
|
||||
condition.flatten()
|
||||
if isinstance(condition, exp.And if cnf_like else exp.Or)
|
||||
else [condition]
|
||||
)
|
||||
|
||||
if cnf_like:
|
||||
pushdown_cnf(predicates, sources)
|
||||
else:
|
||||
pushdown_dnf(predicates, sources)
|
||||
|
||||
|
||||
def pushdown_cnf(predicates, scope):
|
||||
"""
|
||||
If the predicates are in CNF like form, we can simply replace each block in the parent.
|
||||
"""
|
||||
for predicate in predicates:
|
||||
for node in nodes_for_predicate(predicate, scope).values():
|
||||
if isinstance(node, exp.Join):
|
||||
predicate.replace(exp.TRUE)
|
||||
node.on(predicate, copy=False)
|
||||
break
|
||||
if isinstance(node, exp.Select):
|
||||
predicate.replace(exp.TRUE)
|
||||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
|
||||
|
||||
def pushdown_dnf(predicates, scope):
|
||||
"""
|
||||
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
|
||||
Additionally, we can't remove predicates from their original form.
|
||||
"""
|
||||
# find all the tables that can be pushdown too
|
||||
# these are tables that are referenced in all blocks of a DNF
|
||||
# (a.x AND b.x) OR (a.y AND c.y)
|
||||
# only table a can be push down
|
||||
pushdown_tables = set()
|
||||
|
||||
for a in predicates:
|
||||
a_tables = set(exp.column_table_names(a))
|
||||
|
||||
for b in predicates:
|
||||
a_tables &= set(exp.column_table_names(b))
|
||||
|
||||
pushdown_tables.update(a_tables)
|
||||
|
||||
conditions = {}
|
||||
|
||||
# for every pushdown table, find all related conditions in all predicates
|
||||
# combine them with ORS
|
||||
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
|
||||
for table in sorted(pushdown_tables):
|
||||
for predicate in predicates:
|
||||
nodes = nodes_for_predicate(predicate, scope)
|
||||
|
||||
if table not in nodes:
|
||||
continue
|
||||
|
||||
predicate_condition = None
|
||||
|
||||
for column in predicate.find_all(exp.Column):
|
||||
if column.table == table:
|
||||
condition = column.find_ancestor(exp.Condition)
|
||||
predicate_condition = (
|
||||
exp.and_(predicate_condition, condition)
|
||||
if predicate_condition
|
||||
else condition
|
||||
)
|
||||
|
||||
if predicate_condition:
|
||||
conditions[table] = (
|
||||
exp.or_(conditions[table], predicate_condition)
|
||||
if table in conditions
|
||||
else predicate_condition
|
||||
)
|
||||
|
||||
for name, node in nodes.items():
|
||||
if name not in conditions:
|
||||
continue
|
||||
|
||||
predicate = conditions[name]
|
||||
|
||||
if isinstance(node, exp.Join):
|
||||
node.on(predicate, copy=False)
|
||||
elif isinstance(node, exp.Select):
|
||||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
|
||||
|
||||
def nodes_for_predicate(predicate, sources):
|
||||
nodes = {}
|
||||
tables = exp.column_table_names(predicate)
|
||||
where_condition = isinstance(
|
||||
predicate.find_ancestor(exp.Join, exp.Where), exp.Where
|
||||
)
|
||||
|
||||
for table in tables:
|
||||
node, source = sources.get(table) or (None, None)
|
||||
|
||||
# if the predicate is in a where statement we can try to push it down
|
||||
# we want to find the root join or from statement
|
||||
if node and where_condition:
|
||||
node = node.find_ancestor(exp.Join, exp.From)
|
||||
|
||||
# a node can reference a CTE which should be push down
|
||||
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
|
||||
node = source.expression
|
||||
|
||||
if isinstance(node, exp.Join):
|
||||
if node.side:
|
||||
return {}
|
||||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
if not node.args.get("group"):
|
||||
nodes[table] = node
|
||||
return nodes
|
||||
|
||||
|
||||
def replace_aliases(source, predicate):
|
||||
aliases = {}
|
||||
|
||||
for select in source.selects:
|
||||
if isinstance(select, exp.Alias):
|
||||
aliases[select.alias] = select.this
|
||||
else:
|
||||
aliases[select.name] = select
|
||||
|
||||
def _replace_alias(column):
|
||||
if isinstance(column, exp.Column) and column.name in aliases:
|
||||
return aliases[column.name]
|
||||
return column
|
||||
|
||||
return predicate.transform(_replace_alias)
|
85
sqlglot/optimizer/pushdown_projections.py
Normal file
85
sqlglot/optimizer/pushdown_projections.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
||||
|
||||
def pushdown_projections(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to remove unused columns projections.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> pushdown_projections(expression).sql()
|
||||
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
# Map of Scope to all columns being selected by outer queries.
|
||||
referenced_columns = defaultdict(set)
|
||||
|
||||
# We build the scope tree (which is traversed in DFS postorder), then iterate
|
||||
# over the result in reverse order. This should ensure that the set of selected
|
||||
# columns for a particular scope are completely build by the time we get to it.
|
||||
for scope in reversed(traverse_scope(expression)):
|
||||
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
|
||||
|
||||
if scope.expression.args.get("distinct"):
|
||||
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT
|
||||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
left, right = scope.union
|
||||
referenced_columns[left] = parent_selections
|
||||
referenced_columns[right] = parent_selections
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
_remove_unused_selections(scope, parent_selections)
|
||||
|
||||
# Group columns by source name
|
||||
selects = defaultdict(set)
|
||||
for col in scope.columns:
|
||||
table_name = col.table
|
||||
col_name = col.name
|
||||
selects[table_name].add(col_name)
|
||||
|
||||
# Push the selected columns down to the next scope
|
||||
for name, (_, source) in scope.selected_sources.items():
|
||||
if isinstance(source, Scope):
|
||||
columns = selects.get(name) or set()
|
||||
referenced_columns[source].update(columns)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _remove_unused_selections(scope, parent_selections):
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
# Assume columns without a qualified table are references to output columns
|
||||
order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
|
||||
else:
|
||||
order_refs = set()
|
||||
|
||||
new_selections = []
|
||||
for selection in scope.selects:
|
||||
if (
|
||||
SELECT_ALL in parent_selections
|
||||
or selection.alias_or_name in parent_selections
|
||||
or selection.alias_or_name in order_refs
|
||||
):
|
||||
new_selections.append(selection)
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(alias("1", "_"))
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
422
sqlglot/optimizer/qualify_columns.py
Normal file
422
sqlglot/optimizer/qualify_columns.py
Normal file
|
@ -0,0 +1,422 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
|
||||
|
||||
|
||||
def qualify_columns(expression, schema):
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified columns.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> schema = {"tbl": {"col": "INT"}}
|
||||
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
|
||||
>>> qualify_columns(expression, schema).sql()
|
||||
'SELECT tbl.col AS col FROM tbl'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
schema (dict|sqlglot.optimizer.Schema): Database schema
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
"""
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = _Resolver(scope, schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
_pop_table_column_aliases(scope.derived_tables)
|
||||
_expand_using(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
_qualify_columns(scope, resolver)
|
||||
if not isinstance(scope.expression, SKIP_QUALIFY):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
_check_unknown_tables(scope)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _pop_table_column_aliases(derived_tables):
|
||||
"""
|
||||
Remove table column aliases.
|
||||
|
||||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table, SKIP_QUALIFY):
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
table_alias.args.pop("columns", None)
|
||||
|
||||
|
||||
def _expand_using(scope, resolver):
|
||||
joins = list(scope.expression.find_all(exp.Join))
|
||||
names = {join.this.alias for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
# Mapping of automatically joined column names to source names
|
||||
column_tables = {}
|
||||
|
||||
for join in joins:
|
||||
using = join.args.get("using")
|
||||
|
||||
if not using:
|
||||
continue
|
||||
|
||||
join_table = join.this.alias_or_name
|
||||
|
||||
columns = {}
|
||||
|
||||
for k in scope.selected_sources:
|
||||
if k in ordered:
|
||||
for column in resolver.get_source_columns(k):
|
||||
if column not in columns:
|
||||
columns[column] = k
|
||||
|
||||
ordered.append(join_table)
|
||||
join_columns = resolver.get_source_columns(join_table)
|
||||
conditions = []
|
||||
|
||||
for identifier in using:
|
||||
identifier = identifier.name
|
||||
table = columns.get(identifier)
|
||||
|
||||
if not table or identifier not in join_columns:
|
||||
raise OptimizeError(f"Cannot automatically join: {identifier}")
|
||||
|
||||
conditions.append(
|
||||
exp.condition(
|
||||
exp.EQ(
|
||||
this=exp.column(identifier, table=table),
|
||||
expression=exp.column(identifier, table=join_table),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
tables = column_tables.setdefault(identifier, [])
|
||||
if table not in tables:
|
||||
tables.append(table)
|
||||
if join_table not in tables:
|
||||
tables.append(join_table)
|
||||
|
||||
join.args.pop("using")
|
||||
join.set("on", exp.and_(*conditions))
|
||||
|
||||
if column_tables:
|
||||
for column in scope.columns:
|
||||
if not column.table and column.name in column_tables:
|
||||
tables = column_tables[column.name]
|
||||
coalesce = [exp.column(column.name, table=table) for table in tables]
|
||||
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
|
||||
|
||||
# Ensure selects keep their output name
|
||||
if isinstance(column.parent, exp.Select):
|
||||
replacement = exp.alias_(replacement, alias=column.name)
|
||||
|
||||
scope.replace(column, replacement)
|
||||
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
group = scope.expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
# Replace references to select aliases
|
||||
def transform(node, *_):
|
||||
if isinstance(node, exp.Column) and not node.table:
|
||||
table = resolver.get_table(node.name)
|
||||
|
||||
# Source columns get priority over select aliases
|
||||
if table:
|
||||
node.set("table", exp.to_identifier(table))
|
||||
return node
|
||||
|
||||
selects = {s.alias_or_name: s for s in scope.selects}
|
||||
|
||||
select = selects.get(node.name)
|
||||
if select:
|
||||
scope.clear_cache()
|
||||
if isinstance(select, exp.Alias):
|
||||
select = select.this
|
||||
return select.copy()
|
||||
|
||||
return node
|
||||
|
||||
group.transform(transform, copy=False)
|
||||
group.set("expressions", _expand_positional_references(scope, group.expressions))
|
||||
scope.expression.set("group", group)
|
||||
|
||||
|
||||
def _expand_order_by(scope):
|
||||
order = scope.expression.args.get("order")
|
||||
if not order:
|
||||
return
|
||||
|
||||
ordereds = order.expressions
|
||||
for ordered, new_expression in zip(
|
||||
ordereds,
|
||||
_expand_positional_references(scope, (o.this for o in ordereds)),
|
||||
):
|
||||
ordered.set("this", new_expression)
|
||||
|
||||
|
||||
def _expand_positional_references(scope, expressions):
|
||||
new_nodes = []
|
||||
for node in expressions:
|
||||
if node.is_int:
|
||||
try:
|
||||
select = scope.selects[int(node.name) - 1]
|
||||
except IndexError:
|
||||
raise OptimizeError(f"Unknown output column: {node.name}")
|
||||
if isinstance(select, exp.Alias):
|
||||
select = select.this
|
||||
new_nodes.append(select.copy())
|
||||
scope.clear_cache()
|
||||
else:
|
||||
new_nodes.append(node)
|
||||
|
||||
return new_nodes
|
||||
|
||||
|
||||
def _qualify_columns(scope, resolver):
|
||||
"""Disambiguate columns, ensuring each column specifies a source"""
|
||||
for column in scope.columns:
|
||||
column_table = column.table
|
||||
column_name = column.name
|
||||
|
||||
if (
|
||||
column_table
|
||||
and column_table in scope.sources
|
||||
and column_name not in resolver.get_source_columns(column_table)
|
||||
):
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
if not column_table:
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
if not scope.is_subquery and not scope.is_unnest:
|
||||
if column_name not in resolver.all_columns:
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
if column_table is None:
|
||||
raise OptimizeError(f"Ambiguous column: {column_name}")
|
||||
|
||||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
if column_table:
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver):
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
new_selections = []
|
||||
except_columns = {}
|
||||
replace_columns = {}
|
||||
|
||||
for expression in scope.selects:
|
||||
if isinstance(expression, exp.Star):
|
||||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
_add_replace_columns(expression, tables, replace_columns)
|
||||
elif isinstance(expression, exp.Column) and isinstance(
|
||||
expression.this, exp.Star
|
||||
):
|
||||
tables = [expression.table]
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
else:
|
||||
new_selections.append(expression)
|
||||
continue
|
||||
|
||||
for table in tables:
|
||||
if table not in scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {table}")
|
||||
columns = resolver.get_source_columns(table)
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(
|
||||
alias(column, alias_) if alias_ != name else column
|
||||
)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def _add_except_columns(expression, tables, except_columns):
|
||||
except_ = expression.args.get("except")
|
||||
|
||||
if not except_:
|
||||
return
|
||||
|
||||
columns = {e.name for e in except_}
|
||||
|
||||
for table in tables:
|
||||
except_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _add_replace_columns(expression, tables, replace_columns):
|
||||
replace = expression.args.get("replace")
|
||||
|
||||
if not replace:
|
||||
return
|
||||
|
||||
columns = {e.this.name: e.alias for e in replace}
|
||||
|
||||
for table in tables:
|
||||
replace_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _qualify_outputs(scope):
|
||||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||
):
|
||||
if isinstance(selection, exp.Column):
|
||||
# convoluted setter because a simple selection.replace(alias) would require a copy
|
||||
alias_ = alias(exp.column(""), alias=selection.name)
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
elif not isinstance(selection, exp.Alias):
|
||||
alias_ = alias(exp.column(""), f"_col_{i}")
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
|
||||
if aliased_column:
|
||||
selection.set("alias", exp.to_identifier(aliased_column))
|
||||
|
||||
new_selections.append(selection)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def _check_unknown_tables(scope):
|
||||
if (
|
||||
scope.external_columns
|
||||
and not scope.is_unnest
|
||||
and not scope.is_correlated_subquery
|
||||
):
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
||||
|
||||
|
||||
class _Resolver:
|
||||
"""
|
||||
Helper for resolving columns.
|
||||
|
||||
This is a class so we can lazily load some things and easily share them across functions.
|
||||
"""
|
||||
|
||||
def __init__(self, scope, schema):
|
||||
self.scope = scope
|
||||
self.schema = schema
|
||||
self._source_columns = None
|
||||
self._unambiguous_columns = None
|
||||
self._all_columns = None
|
||||
|
||||
def get_table(self, column_name):
|
||||
"""
|
||||
Get the table for a column name.
|
||||
|
||||
Args:
|
||||
column_name (str)
|
||||
Returns:
|
||||
(str) table name
|
||||
"""
|
||||
if self._unambiguous_columns is None:
|
||||
self._unambiguous_columns = self._get_unambiguous_columns(
|
||||
self._get_all_source_columns()
|
||||
)
|
||||
return self._unambiguous_columns.get(column_name)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
"""All available columns of all sources in this scope"""
|
||||
if self._all_columns is None:
|
||||
self._all_columns = set(
|
||||
column
|
||||
for columns in self._get_all_source_columns().values()
|
||||
for column in columns
|
||||
)
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name):
|
||||
"""Resolve the source columns for a given source `name`"""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
||||
source = self.scope.sources[name]
|
||||
|
||||
# If referencing a table, return the columns from the schema
|
||||
if isinstance(source, exp.Table):
|
||||
try:
|
||||
return self.schema.column_names(source)
|
||||
except Exception as e:
|
||||
raise OptimizeError(str(e)) from e
|
||||
|
||||
# Otherwise, if referencing another scope, return that scope's named selects
|
||||
return source.expression.named_selects
|
||||
|
||||
def _get_all_source_columns(self):
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
||||
}
|
||||
return self._source_columns
|
||||
|
||||
def _get_unambiguous_columns(self, source_columns):
|
||||
"""
|
||||
Find all the unambiguous columns in sources.
|
||||
|
||||
Args:
|
||||
source_columns (dict): Mapping of names to source columns
|
||||
Returns:
|
||||
dict: Mapping of column name to source name
|
||||
"""
|
||||
if not source_columns:
|
||||
return {}
|
||||
|
||||
source_columns = list(source_columns.items())
|
||||
|
||||
first_table, first_columns = source_columns[0]
|
||||
unambiguous_columns = {
|
||||
col: first_table for col in self._find_unique_columns(first_columns)
|
||||
}
|
||||
all_columns = set(unambiguous_columns)
|
||||
|
||||
for table, columns in source_columns[1:]:
|
||||
unique = self._find_unique_columns(columns)
|
||||
ambiguous = set(all_columns).intersection(unique)
|
||||
all_columns.update(columns)
|
||||
for column in ambiguous:
|
||||
unambiguous_columns.pop(column, None)
|
||||
for column in unique.difference(ambiguous):
|
||||
unambiguous_columns[column] = table
|
||||
|
||||
return unambiguous_columns
|
||||
|
||||
@staticmethod
|
||||
def _find_unique_columns(columns):
|
||||
"""
|
||||
Find the unique columns in a list of columns.
|
||||
|
||||
Example:
|
||||
>>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
||||
['a', 'c']
|
||||
|
||||
This is necessary because duplicate column names are ambiguous.
|
||||
"""
|
||||
counts = {}
|
||||
for column in columns:
|
||||
counts[column] = counts.get(column, 0) + 1
|
||||
return {column for column, count in counts.items() if count == 1}
|
54
sqlglot/optimizer/qualify_tables.py
Normal file
54
sqlglot/optimizer/qualify_tables.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
|
||||
def qualify_tables(expression, db=None, catalog=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified tables.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
|
||||
>>> qualify_tables(expression, db="db").sql()
|
||||
'SELECT 1 FROM db.tbl AS tbl'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
db (str): Database name
|
||||
catalog (str): Catalog name
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
"""
|
||||
sequence = itertools.count()
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in scope.ctes + scope.derived_tables:
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = f"_q_{next(sequence)}"
|
||||
derived_table.set(
|
||||
"alias", exp.TableAlias(this=exp.to_identifier(alias_))
|
||||
)
|
||||
scope.rename_source(None, alias_)
|
||||
|
||||
for source in scope.sources.values():
|
||||
if isinstance(source, exp.Table):
|
||||
identifier = isinstance(source.this, exp.Identifier)
|
||||
|
||||
if identifier:
|
||||
if not source.args.get("db"):
|
||||
source.set("db", exp.to_identifier(db))
|
||||
if not source.args.get("catalog"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
if not isinstance(source.parent, exp.Alias):
|
||||
source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
source.this if identifier else f"_q_{next(sequence)}",
|
||||
table=True,
|
||||
)
|
||||
)
|
||||
|
||||
return expression
|
25
sqlglot/optimizer/quote_identities.py
Normal file
25
sqlglot/optimizer/quote_identities.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from sqlglot import exp
|
||||
|
||||
|
||||
def quote_identities(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to ensure all identities are quoted.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
|
||||
>>> quote_identities(expression).sql()
|
||||
'SELECT "x"."a" AS "a" FROM "db"."x"'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to quote
|
||||
Returns:
|
||||
sqlglot.Expression: quoted expression
|
||||
"""
|
||||
|
||||
def qualify(node):
|
||||
if isinstance(node, exp.Identifier):
|
||||
node.set("quoted", True)
|
||||
return node
|
||||
|
||||
return expression.transform(qualify, copy=False)
|
129
sqlglot/optimizer/schema.py
Normal file
129
sqlglot/optimizer/schema.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import abc
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import csv_reader
|
||||
|
||||
|
||||
class Schema(abc.ABC):
|
||||
"""Abstract base class for database schemas"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table):
|
||||
"""
|
||||
Get the column names for a table.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): Table expression instance
|
||||
Returns:
|
||||
list[str]: list of column names
|
||||
"""
|
||||
|
||||
|
||||
class MappingSchema(Schema):
|
||||
"""
|
||||
Schema based on a nested mapping.
|
||||
|
||||
Args:
|
||||
schema (dict): Mapping in one of the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
"""
|
||||
|
||||
def __init__(self, schema):
|
||||
self.schema = schema
|
||||
|
||||
depth = _dict_depth(schema)
|
||||
|
||||
if not depth: # {}
|
||||
self.supported_table_args = []
|
||||
elif depth == 2: # {table: {col: type}}
|
||||
self.supported_table_args = ("this",)
|
||||
elif depth == 3: # {db: {table: {col: type}}}
|
||||
self.supported_table_args = ("db", "this")
|
||||
elif depth == 4: # {catalog: {db: {table: {col: type}}}}
|
||||
self.supported_table_args = ("catalog", "db", "this")
|
||||
else:
|
||||
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
|
||||
|
||||
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
||||
|
||||
def column_names(self, table):
|
||||
if not isinstance(table.this, exp.Identifier):
|
||||
return fs_get(table)
|
||||
|
||||
args = tuple(table.text(p) for p in self.supported_table_args)
|
||||
|
||||
for forbidden in self.forbidden_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(
|
||||
f"Schema doesn't support {forbidden}. Received: {table.sql()}"
|
||||
)
|
||||
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
|
||||
|
||||
def ensure_schema(schema):
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
||||
return MappingSchema(schema)
|
||||
|
||||
|
||||
def fs_get(table):
|
||||
name = table.this.name.upper()
|
||||
|
||||
if name.upper() == "READ_CSV":
|
||||
with csv_reader(table) as reader:
|
||||
return next(reader)
|
||||
|
||||
raise ValueError(f"Cannot read schema for {table}")
|
||||
|
||||
|
||||
def _nested_get(d, *path):
|
||||
"""
|
||||
Get a value for a nested dictionary.
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
*path (tuple[str, str]): tuples of (name, key)
|
||||
`key` is the key in the dictionary to get.
|
||||
`name` is a string to use in the error if `key` isn't found.
|
||||
"""
|
||||
for name, key in path:
|
||||
d = d.get(key)
|
||||
if d is None:
|
||||
name = "table" if name == "this" else name
|
||||
raise ValueError(f"Unknown {name}")
|
||||
return d
|
||||
|
||||
|
||||
def _dict_depth(d):
|
||||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
||||
For example:
|
||||
>>> _dict_depth(None)
|
||||
0
|
||||
>>> _dict_depth({})
|
||||
1
|
||||
>>> _dict_depth({"a": "b"})
|
||||
1
|
||||
>>> _dict_depth({"a": {}})
|
||||
2
|
||||
>>> _dict_depth({"a": {"b": {}}})
|
||||
3
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
Returns:
|
||||
int: depth
|
||||
"""
|
||||
try:
|
||||
return 1 + _dict_depth(next(iter(d.values())))
|
||||
except AttributeError:
|
||||
# d doesn't have attribute "values"
|
||||
return 0
|
||||
except StopIteration:
|
||||
# d.values() returns an empty sequence
|
||||
return 1
|
438
sqlglot/optimizer/scope.py
Normal file
438
sqlglot/optimizer/scope.py
Normal file
|
@ -0,0 +1,438 @@
|
|||
from copy import copy
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
|
||||
|
||||
class ScopeType(Enum):
|
||||
ROOT = auto()
|
||||
SUBQUERY = auto()
|
||||
DERIVED_TABLE = auto()
|
||||
CTE = auto()
|
||||
UNION = auto()
|
||||
UNNEST = auto()
|
||||
|
||||
|
||||
class Scope:
|
||||
"""
|
||||
Selection scope.
|
||||
|
||||
Attributes:
|
||||
expression (exp.Select|exp.Union): Root expression of this scope
|
||||
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
|
||||
a Table expression or another Scope instance. For example:
|
||||
SELECT * FROM x {"x": Table(this="x")}
|
||||
SELECT * FROM x AS y {"y": Table(this="x")}
|
||||
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
SELECT * FROM (SELECT ...) AS y(col1, col2)
|
||||
The inner query would have `["col1", "col2"]` for its `outer_column_list`
|
||||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
|
||||
This does not include derived tables or CTEs.
|
||||
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
|
||||
a tuple of the left and right child scopes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
sources=None,
|
||||
outer_column_list=None,
|
||||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
self.union = None
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
self._collected = False
|
||||
self._raw_columns = None
|
||||
self._derived_tables = None
|
||||
self._tables = None
|
||||
self._ctes = None
|
||||
self._subqueries = None
|
||||
self._selected_sources = None
|
||||
self._columns = None
|
||||
self._external_columns = None
|
||||
|
||||
def branch(self, expression, scope_type, add_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
sources = copy(self.sources)
|
||||
if add_sources:
|
||||
sources.update(add_sources)
|
||||
return Scope(
|
||||
expression=expression.unnest(),
|
||||
sources=sources,
|
||||
parent=self,
|
||||
scope_type=scope_type,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _collect(self):
|
||||
self._tables = []
|
||||
self._ctes = []
|
||||
self._subqueries = []
|
||||
self._derived_tables = []
|
||||
self._raw_columns = []
|
||||
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
|
||||
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
|
||||
prune = False
|
||||
|
||||
if node is self.expression:
|
||||
continue
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table):
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, (exp.Unnest, exp.Lateral)):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(
|
||||
parent, (exp.From, exp.Join)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
self._subqueries.append(node)
|
||||
prune = True
|
||||
|
||||
self._collected = True
|
||||
|
||||
def _ensure_collected(self):
|
||||
if not self._collected:
|
||||
self._collect()
|
||||
|
||||
def replace(self, old, new):
|
||||
"""
|
||||
Replace `old` with `new`.
|
||||
|
||||
This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
|
||||
|
||||
Args:
|
||||
old (exp.Expression): old node
|
||||
new (exp.Expression): new node
|
||||
"""
|
||||
old.replace(new)
|
||||
self.clear_cache()
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
"""
|
||||
List of tables in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.Table]: tables
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._tables
|
||||
|
||||
@property
|
||||
def ctes(self):
|
||||
"""
|
||||
List of CTEs in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.CTE]: ctes
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._ctes
|
||||
|
||||
@property
|
||||
def derived_tables(self):
|
||||
"""
|
||||
List of derived tables in this scope.
|
||||
|
||||
For example:
|
||||
SELECT * FROM (SELECT ...) <- that's a derived table
|
||||
|
||||
Returns:
|
||||
list[exp.Subquery]: derived tables
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._derived_tables
|
||||
|
||||
@property
|
||||
def subqueries(self):
|
||||
"""
|
||||
List of subqueries in this scope.
|
||||
|
||||
For example:
|
||||
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
|
||||
|
||||
Returns:
|
||||
list[exp.Subqueryable]: subqueries
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._subqueries
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
"""
|
||||
List of columns in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.Column]: Column instances in this scope, plus any
|
||||
Columns that reference this scope from correlated subqueries.
|
||||
"""
|
||||
if self._columns is None:
|
||||
self._ensure_collected()
|
||||
columns = self._raw_columns
|
||||
|
||||
external_columns = [
|
||||
column
|
||||
for scope in self.subquery_scopes
|
||||
for column in scope.external_columns
|
||||
]
|
||||
|
||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||
|
||||
self._columns = [
|
||||
c
|
||||
for c in columns + external_columns
|
||||
if not (
|
||||
c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
|
||||
)
|
||||
]
|
||||
return self._columns
|
||||
|
||||
@property
|
||||
def selected_sources(self):
|
||||
"""
|
||||
Mapping of nodes and sources that are actually selected from in this scope.
|
||||
|
||||
That is, all tables in a schema are selectable at any point. But a
|
||||
table only becomes a selected source if it's included in a FROM or JOIN clause.
|
||||
|
||||
Returns:
|
||||
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
|
||||
"""
|
||||
if self._selected_sources is None:
|
||||
referenced_names = []
|
||||
|
||||
for table in self.tables:
|
||||
referenced_names.append(
|
||||
(
|
||||
table.parent.alias
|
||||
if isinstance(table.parent, exp.Alias)
|
||||
else table.name,
|
||||
table,
|
||||
)
|
||||
)
|
||||
for derived_table in self.derived_tables:
|
||||
referenced_names.append((derived_table.alias, derived_table.unnest()))
|
||||
|
||||
result = {}
|
||||
|
||||
for name, node in referenced_names:
|
||||
if name in self.sources:
|
||||
result[name] = (node, self.sources[name])
|
||||
|
||||
self._selected_sources = result
|
||||
return self._selected_sources
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
"""
|
||||
Select expressions of this scope.
|
||||
|
||||
For example, for the following expression:
|
||||
SELECT 1 as a, 2 as b FROM x
|
||||
|
||||
The outputs are the "1 as a" and "2 as b" expressions.
|
||||
|
||||
Returns:
|
||||
list[exp.Expression]: expressions
|
||||
"""
|
||||
if isinstance(self.expression, exp.Union):
|
||||
return []
|
||||
return self.expression.selects
|
||||
|
||||
@property
|
||||
def external_columns(self):
|
||||
"""
|
||||
Columns that appear to reference sources in outer scopes.
|
||||
|
||||
Returns:
|
||||
list[exp.Column]: Column instances that don't reference
|
||||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
return self._external_columns
|
||||
|
||||
def source_columns(self, source_name):
|
||||
"""
|
||||
Get all columns in the current scope for a particular source.
|
||||
|
||||
Args:
|
||||
source_name (str): Name of the source
|
||||
Returns:
|
||||
list[exp.Column]: Column instances that reference `source_name`
|
||||
"""
|
||||
return [column for column in self.columns if column.table == source_name]
|
||||
|
||||
@property
|
||||
def is_subquery(self):
|
||||
"""Determine if this scope is a subquery"""
|
||||
return self.scope_type == ScopeType.SUBQUERY
|
||||
|
||||
@property
|
||||
def is_unnest(self):
|
||||
"""Determine if this scope is an unnest"""
|
||||
return self.scope_type == ScopeType.UNNEST
|
||||
|
||||
@property
|
||||
def is_correlated_subquery(self):
|
||||
"""Determine if this scope is a correlated subquery"""
|
||||
return bool(self.is_subquery and self.external_columns)
|
||||
|
||||
def rename_source(self, old_name, new_name):
|
||||
"""Rename a source in this scope"""
|
||||
columns = self.sources.pop(old_name or "", [])
|
||||
self.sources[new_name] = columns
|
||||
|
||||
|
||||
def traverse_scope(expression):
|
||||
"""
|
||||
Traverse an expression by it's "scopes".
|
||||
|
||||
"Scope" represents the current context of a Select statement.
|
||||
|
||||
This is helpful for optimizing queries, where we need more information than
|
||||
the expression tree itself. For example, we might care about the source
|
||||
names within a subquery. Returns a list because a generator could result in
|
||||
incomplete properties which is confusing.
|
||||
|
||||
Examples:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
|
||||
>>> scopes = traverse_scope(expression)
|
||||
>>> scopes[0].expression.sql(), list(scopes[0].sources)
|
||||
('SELECT a FROM x', ['x'])
|
||||
>>> scopes[1].expression.sql(), list(scopes[1].sources)
|
||||
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
Returns:
|
||||
List[Scope]: scope instances
|
||||
"""
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
|
||||
pass
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
else:
|
||||
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
|
||||
yield scope
|
||||
|
||||
|
||||
def _traverse_select(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||
yield from _traverse_subqueries(scope)
|
||||
yield from _traverse_derived_tables(
|
||||
scope.derived_tables, scope, ScopeType.DERIVED_TABLE
|
||||
)
|
||||
_add_table_sources(scope)
|
||||
|
||||
|
||||
def _traverse_union(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
|
||||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
for left in _traverse_scope(
|
||||
scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
|
||||
):
|
||||
yield left
|
||||
|
||||
right = None
|
||||
for right in _traverse_scope(
|
||||
scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
|
||||
):
|
||||
yield right
|
||||
|
||||
scope.union = (left, right)
|
||||
|
||||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
sources = {}
|
||||
|
||||
for derived_table in derived_tables:
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table
|
||||
if isinstance(derived_table, (exp.Unnest, exp.Lateral))
|
||||
else derived_table.this,
|
||||
add_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
outer_column_list=derived_table.alias_column_names,
|
||||
scope_type=ScopeType.UNNEST
|
||||
if isinstance(derived_table, exp.Unnest)
|
||||
else scope_type,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
# Tables without aliases will be set as ""
|
||||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
sources[derived_table.alias] = child_scope
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _add_table_sources(scope):
|
||||
sources = {}
|
||||
for table in scope.tables:
|
||||
table_name = table.name
|
||||
|
||||
if isinstance(table.parent, exp.Alias):
|
||||
source_name = table.parent.alias
|
||||
else:
|
||||
source_name = table_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
scope.sources[source_name] = scope.sources[table_name]
|
||||
elif source_name in scope.sources:
|
||||
raise OptimizeError(f"Duplicate table name: {source_name}")
|
||||
else:
|
||||
sources[source_name] = table
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_subqueries(scope):
|
||||
for subquery in scope.subqueries:
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
|
||||
):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
scope.subquery_scopes.append(top)
|
383
sqlglot/optimizer/simplify.py
Normal file
383
sqlglot/optimizer/simplify.py
Normal file
|
@ -0,0 +1,383 @@
|
|||
import datetime
|
||||
import functools
|
||||
import itertools
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.expressions import FALSE, NULL, TRUE
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import while_changing
|
||||
|
||||
GENERATOR = Generator(normalize=True, identify=True)
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
|
||||
>>> simplify(expression).sql()
|
||||
'TRUE'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to simplify
|
||||
Returns:
|
||||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
node = expression
|
||||
node = uniq_sort(node)
|
||||
node = absorb_and_eliminate(node)
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node)
|
||||
node = remove_compliments(node)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node)
|
||||
node = simplify_parens(node)
|
||||
if root:
|
||||
expression.replace(node)
|
||||
return node
|
||||
|
||||
expression = while_changing(expression, _simplify)
|
||||
remove_where_true(expression)
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_not(expression):
|
||||
"""
|
||||
Demorgan's Law
|
||||
NOT (x OR y) -> NOT x AND NOT y
|
||||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if always_true(expression.this):
|
||||
return FALSE
|
||||
if expression.this == FALSE:
|
||||
return TRUE
|
||||
if isinstance(expression.this, exp.Not):
|
||||
# double negation
|
||||
# NOT NOT x -> x
|
||||
return expression.this.this
|
||||
return expression
|
||||
|
||||
|
||||
def flatten(expression):
|
||||
"""
|
||||
A AND (B AND C) -> A AND B AND C
|
||||
A OR (B OR C) -> A OR B OR C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
for node in expression.args.values():
|
||||
child = node.unnest()
|
||||
if isinstance(child, expression.__class__):
|
||||
node.replace(child)
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_connectors(expression):
|
||||
if isinstance(expression, exp.Connector):
|
||||
left = expression.left
|
||||
right = expression.right
|
||||
|
||||
if left == right:
|
||||
return left
|
||||
|
||||
if isinstance(expression, exp.And):
|
||||
if NULL in (left, right):
|
||||
return NULL
|
||||
if FALSE in (left, right):
|
||||
return FALSE
|
||||
if always_true(left) and always_true(right):
|
||||
return TRUE
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
return left
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return TRUE
|
||||
if left == FALSE and right == FALSE:
|
||||
return FALSE
|
||||
if (
|
||||
(left == NULL and right == NULL)
|
||||
or (left == NULL and right == FALSE)
|
||||
or (left == FALSE and right == NULL)
|
||||
):
|
||||
return NULL
|
||||
if left == FALSE:
|
||||
return right
|
||||
if right == FALSE:
|
||||
return left
|
||||
return expression
|
||||
|
||||
|
||||
def remove_compliments(expression):
|
||||
"""
|
||||
Removing compliments.
|
||||
|
||||
A AND NOT A -> FALSE
|
||||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
compliment = FALSE if isinstance(expression, exp.And) else TRUE
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if is_complement(a, b):
|
||||
return compliment
|
||||
return expression
|
||||
|
||||
|
||||
def uniq_sort(expression):
|
||||
"""
|
||||
Uniq and sort a connector.
|
||||
|
||||
C AND A AND B AND B -> A AND B AND C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {GENERATOR.generate(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
# A AND C AND B -> A AND B AND C
|
||||
for i, (sql, e) in enumerate(arr[1:]):
|
||||
if sql < arr[i][0]:
|
||||
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
|
||||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
if len(deduped) < len(flattened):
|
||||
expression = result_func(*deduped.values())
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def absorb_and_eliminate(expression):
|
||||
"""
|
||||
absorption:
|
||||
A AND (A OR B) -> A
|
||||
A OR (A AND B) -> A
|
||||
A AND (NOT A OR B) -> A AND B
|
||||
A OR (NOT A AND B) -> A OR B
|
||||
elimination:
|
||||
(A AND B) OR (A AND NOT B) -> A
|
||||
(A OR B) AND (A OR NOT B) -> A
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
kind = exp.Or if isinstance(expression, exp.And) else exp.And
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if isinstance(a, kind):
|
||||
aa, ab = a.unnest_operands()
|
||||
|
||||
# absorb
|
||||
if is_complement(b, aa):
|
||||
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
elif is_complement(b, ab):
|
||||
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
|
||||
a.flatten()
|
||||
):
|
||||
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
|
||||
elif isinstance(b, kind):
|
||||
# eliminate
|
||||
rhs = b.unnest_operands()
|
||||
ba, bb = rhs
|
||||
|
||||
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
|
||||
a.replace(aa)
|
||||
b.replace(aa)
|
||||
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
|
||||
a.replace(ab)
|
||||
b.replace(ab)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_literals(expression):
|
||||
if isinstance(expression, exp.Binary):
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
|
||||
for b in queue:
|
||||
result = _simplify_binary(expression, a, b)
|
||||
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
value = this.name
|
||||
if value[0] == "-":
|
||||
return exp.Literal.number(value[1:])
|
||||
return exp.Literal.number(f"-{value}")
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _simplify_binary(expression, a, b):
|
||||
if isinstance(expression, exp.Is):
|
||||
if isinstance(b, exp.Not):
|
||||
c = b.this
|
||||
not_ = True
|
||||
else:
|
||||
c = b
|
||||
not_ = False
|
||||
|
||||
if c == NULL:
|
||||
if isinstance(a, exp.Literal):
|
||||
return TRUE if not_ else FALSE
|
||||
if a == NULL:
|
||||
return FALSE if not_ else TRUE
|
||||
elif NULL in (a, b):
|
||||
return NULL
|
||||
|
||||
if isinstance(expression, exp.EQ) and a == b:
|
||||
return TRUE
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
|
||||
if isinstance(expression, exp.Add):
|
||||
return exp.Literal.number(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return exp.Literal.number(a - b)
|
||||
if isinstance(expression, exp.Mul):
|
||||
return exp.Literal.number(a * b)
|
||||
if isinstance(expression, exp.Div):
|
||||
if isinstance(a, int) and isinstance(b, int):
|
||||
return exp.Literal.number(a // b)
|
||||
return exp.Literal.number(a / b)
|
||||
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
elif a.is_string and b.is_string:
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if b:
|
||||
if isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return date_literal(a - b)
|
||||
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def simplify_parens(expression):
|
||||
if (
|
||||
isinstance(expression, exp.Paren)
|
||||
and not isinstance(expression.this, exp.Select)
|
||||
and (
|
||||
not isinstance(expression.parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(expression.this, (exp.Is, exp.Like))
|
||||
or not isinstance(expression.this, exp.Binary)
|
||||
)
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
||||
|
||||
def remove_where_true(expression):
|
||||
for where in expression.find_all(exp.Where):
|
||||
if always_true(where.this):
|
||||
where.parent.set("where", None)
|
||||
for join in expression.find_all(exp.Join):
|
||||
if always_true(join.args.get("on")):
|
||||
join.set("kind", "CROSS")
|
||||
join.set("on", None)
|
||||
|
||||
|
||||
def always_true(expression):
|
||||
return expression == TRUE or isinstance(expression, exp.Literal)
|
||||
|
||||
|
||||
def is_complement(a, b):
|
||||
return isinstance(b, exp.Not) and b.this == a
|
||||
|
||||
|
||||
def eval_boolean(expression, a, b):
|
||||
if isinstance(expression, (exp.EQ, exp.Is)):
|
||||
return boolean_literal(a == b)
|
||||
if isinstance(expression, exp.NEQ):
|
||||
return boolean_literal(a != b)
|
||||
if isinstance(expression, exp.GT):
|
||||
return boolean_literal(a > b)
|
||||
if isinstance(expression, exp.GTE):
|
||||
return boolean_literal(a >= b)
|
||||
if isinstance(expression, exp.LT):
|
||||
return boolean_literal(a < b)
|
||||
if isinstance(expression, exp.LTE):
|
||||
return boolean_literal(a <= b)
|
||||
return None
|
||||
|
||||
|
||||
def extract_date(cast):
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
return None
|
||||
|
||||
|
||||
def extract_interval(interval):
|
||||
try:
|
||||
from dateutil.relativedelta import relativedelta
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
n = int(interval.name)
|
||||
unit = interval.text("unit").lower()
|
||||
|
||||
if unit == "year":
|
||||
return relativedelta(years=n)
|
||||
if unit == "month":
|
||||
return relativedelta(months=n)
|
||||
if unit == "week":
|
||||
return relativedelta(weeks=n)
|
||||
if unit == "day":
|
||||
return relativedelta(days=n)
|
||||
return None
|
||||
|
||||
|
||||
def date_literal(date):
|
||||
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
|
||||
|
||||
|
||||
def boolean_literal(condition):
|
||||
return TRUE if condition else FALSE
|
220
sqlglot/optimizer/unnest_subqueries.py
Normal file
220
sqlglot/optimizer/unnest_subqueries.py
Normal file
|
@ -0,0 +1,220 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
|
||||
def unnest_subqueries(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
|
||||
|
||||
Convert the subquery into a group by so it is not a many to many left join.
|
||||
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
|
||||
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
|
||||
>>> unnest_subqueries(expression).sql()
|
||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
|
||||
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to unnest
|
||||
Returns:
|
||||
sqlglot.Expression: unnested expression
|
||||
"""
|
||||
sequence = itertools.count()
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
select = scope.expression
|
||||
parent = select.parent_select
|
||||
if scope.external_columns:
|
||||
decorrelate(select, parent, scope.external_columns, sequence)
|
||||
else:
|
||||
unnest(select, parent, sequence)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unnest(select, parent_select, sequence):
|
||||
predicate = select.find_ancestor(exp.In, exp.Any)
|
||||
|
||||
if not predicate or parent_select is not predicate.parent_select:
|
||||
return
|
||||
|
||||
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
|
||||
return
|
||||
|
||||
if isinstance(predicate, exp.Any):
|
||||
predicate = predicate.find_ancestor(exp.EQ)
|
||||
|
||||
if not predicate or parent_select is not predicate.parent_select:
|
||||
return
|
||||
|
||||
column = _other_operand(predicate)
|
||||
value = select.selects[0]
|
||||
alias = _alias(sequence)
|
||||
|
||||
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
|
||||
_replace(predicate, f"NOT {on.right} IS NULL")
|
||||
|
||||
parent_select.join(
|
||||
select.group_by(value.this, copy=False),
|
||||
on=on,
|
||||
join_type="LEFT",
|
||||
join_alias=alias,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
def decorrelate(select, parent_select, external_columns, sequence):
|
||||
where = select.args.get("where")
|
||||
|
||||
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
|
||||
return
|
||||
|
||||
table_alias = _alias(sequence)
|
||||
keys = []
|
||||
|
||||
# for all external columns in the where statement,
|
||||
# split out the relevant data to convert it into a join
|
||||
for column in external_columns:
|
||||
if column.find_ancestor(exp.Where) is not where:
|
||||
return
|
||||
|
||||
predicate = column.find_ancestor(exp.Predicate)
|
||||
|
||||
if not predicate or predicate.find_ancestor(exp.Where) is not where:
|
||||
return
|
||||
|
||||
if isinstance(predicate, exp.Binary):
|
||||
key = (
|
||||
predicate.right
|
||||
if any(node is column for node, *_ in predicate.left.walk())
|
||||
else predicate.left
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
keys.append((key, column, predicate))
|
||||
|
||||
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
|
||||
return
|
||||
|
||||
value = select.selects[0]
|
||||
key_aliases = {}
|
||||
group_by = []
|
||||
|
||||
for key, _, predicate in keys:
|
||||
# if we filter on the value of the subquery, it needs to be unique
|
||||
if key == value.this:
|
||||
key_aliases[key] = value.alias
|
||||
group_by.append(key)
|
||||
else:
|
||||
if key not in key_aliases:
|
||||
key_aliases[key] = _alias(sequence)
|
||||
# all predicates that are equalities must also be in the unique
|
||||
# so that we don't do a many to many join
|
||||
if isinstance(predicate, exp.EQ) and key not in group_by:
|
||||
group_by.append(key)
|
||||
|
||||
parent_predicate = select.find_ancestor(exp.Predicate)
|
||||
|
||||
# if the value of the subquery is not an agg or a key, we need to collect it into an array
|
||||
# so that it can be grouped
|
||||
if not value.find(exp.AggFunc) and value.this not in group_by:
|
||||
select.select(
|
||||
f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
|
||||
)
|
||||
|
||||
# exists queries should not have any selects as it only checks if there are any rows
|
||||
# all selects will be added by the optimizer and only used for join keys
|
||||
if isinstance(parent_predicate, exp.Exists):
|
||||
select.args["expressions"] = []
|
||||
|
||||
for key, alias in key_aliases.items():
|
||||
if key in group_by:
|
||||
# add all keys to the projections of the subquery
|
||||
# so that we can use it as a join key
|
||||
if isinstance(parent_predicate, exp.Exists) or key != value.this:
|
||||
select.select(f"{key} AS {alias}", copy=False)
|
||||
else:
|
||||
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
|
||||
|
||||
alias = exp.column(value.alias, table_alias)
|
||||
other = _other_operand(parent_predicate)
|
||||
|
||||
if isinstance(parent_predicate, exp.Exists):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
|
||||
else:
|
||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||
elif isinstance(parent_predicate, exp.All):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||
)
|
||||
elif isinstance(parent_predicate, exp.Any):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
||||
else:
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
|
||||
)
|
||||
elif isinstance(parent_predicate, exp.In):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
|
||||
else:
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
|
||||
)
|
||||
else:
|
||||
select.parent.replace(alias)
|
||||
|
||||
for key, column, predicate in keys:
|
||||
predicate.replace(exp.TRUE)
|
||||
nested = exp.column(key_aliases[key], table_alias)
|
||||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
||||
)
|
||||
elif isinstance(predicate, exp.EQ):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
|
||||
)
|
||||
else:
|
||||
key.replace(exp.to_identifier("_x"))
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
|
||||
)
|
||||
|
||||
parent_select.join(
|
||||
select.group_by(*group_by, copy=False),
|
||||
on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
|
||||
join_type="LEFT",
|
||||
join_alias=table_alias,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
def _alias(sequence):
|
||||
return f"_u_{next(sequence)}"
|
||||
|
||||
|
||||
def _replace(expression, condition):
|
||||
return expression.replace(exp.condition(condition))
|
||||
|
||||
|
||||
def _other_operand(expression):
|
||||
if isinstance(expression, exp.In):
|
||||
return expression.this
|
||||
|
||||
if isinstance(expression, exp.Binary):
|
||||
return expression.right if expression.arg_key == "this" else expression.left
|
||||
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue