Merging upstream version 25.0.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
03b67e2ec9
commit
021892b3ff
84 changed files with 33016 additions and 31040 deletions
|
@ -36,7 +36,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
original = node.copy()
|
||||
|
||||
node.transform(rewrite_between, copy=False)
|
||||
distance = normalization_distance(node, dnf=dnf)
|
||||
distance = normalization_distance(node, dnf=dnf, max_=max_distance)
|
||||
|
||||
if distance > max_distance:
|
||||
logger.info(
|
||||
|
@ -85,7 +85,9 @@ def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
|
||||
def normalization_distance(
|
||||
expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
|
||||
) -> int:
|
||||
"""
|
||||
The difference in the number of predicates between a given expression and its normalized form.
|
||||
|
||||
|
@ -101,33 +103,47 @@ def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int
|
|||
expression: The expression to compute the normalization distance for.
|
||||
dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
|
||||
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
|
||||
max_: stop early if count exceeds this.
|
||||
|
||||
Returns:
|
||||
The normalization distance.
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
sum(1 for _ in expression.find_all(exp.Connector)) + 1
|
||||
)
|
||||
total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
|
||||
|
||||
for length in _predicate_lengths(expression, dnf, max_):
|
||||
total += length
|
||||
if total > max_:
|
||||
return total
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def _predicate_lengths(expression, dnf):
|
||||
def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
|
||||
"""
|
||||
Returns a list of predicate lengths when expanded to normalized form.
|
||||
|
||||
(A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
|
||||
"""
|
||||
if depth > max_:
|
||||
yield depth
|
||||
return
|
||||
|
||||
expression = expression.unnest()
|
||||
|
||||
if not isinstance(expression, exp.Connector):
|
||||
return (1,)
|
||||
yield 1
|
||||
return
|
||||
|
||||
depth += 1
|
||||
left, right = expression.args.values()
|
||||
|
||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||
return tuple(
|
||||
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
|
||||
)
|
||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
for a in _predicate_lengths(left, dnf, max_, depth):
|
||||
for b in _predicate_lengths(right, dnf, max_, depth):
|
||||
yield a + b
|
||||
else:
|
||||
yield from _predicate_lengths(left, dnf, max_, depth)
|
||||
yield from _predicate_lengths(right, dnf, max_, depth)
|
||||
|
||||
|
||||
def distributive_law(expression, dnf, max_distance):
|
||||
|
@ -138,7 +154,7 @@ def distributive_law(expression, dnf, max_distance):
|
|||
if normalized(expression, dnf=dnf):
|
||||
return expression
|
||||
|
||||
distance = normalization_distance(expression, dnf=dnf)
|
||||
distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
|
||||
|
||||
if distance > max_distance:
|
||||
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
|
||||
|
|
|
@ -80,7 +80,7 @@ def qualify_columns(
|
|||
)
|
||||
qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope)
|
||||
_expand_group_by(scope, dialect)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
||||
if dialect == "bigquery":
|
||||
|
@ -266,13 +266,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
scope.clear_cache()
|
||||
|
||||
|
||||
def _expand_group_by(scope: Scope) -> None:
|
||||
def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
|
||||
expression = scope.expression
|
||||
group = expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
group.set("expressions", _expand_positional_references(scope, group.expressions))
|
||||
group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
|
||||
expression.set("group", group)
|
||||
|
||||
|
||||
|
@ -284,7 +284,9 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
|||
ordereds = order.expressions
|
||||
for ordered, new_expression in zip(
|
||||
ordereds,
|
||||
_expand_positional_references(scope, (o.this for o in ordereds), alias=True),
|
||||
_expand_positional_references(
|
||||
scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
|
||||
),
|
||||
):
|
||||
for agg in ordered.find_all(exp.AggFunc):
|
||||
for col in agg.find_all(exp.Column):
|
||||
|
@ -307,9 +309,11 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
|
||||
def _expand_positional_references(
|
||||
scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
|
||||
scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
|
||||
) -> t.List[exp.Expression]:
|
||||
new_nodes: t.List[exp.Expression] = []
|
||||
ambiguous_projections = None
|
||||
|
||||
for node in expressions:
|
||||
if node.is_int:
|
||||
select = _select_by_pos(scope, t.cast(exp.Literal, node))
|
||||
|
@ -319,7 +323,28 @@ def _expand_positional_references(
|
|||
else:
|
||||
select = select.this
|
||||
|
||||
if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
|
||||
if dialect == "bigquery":
|
||||
if ambiguous_projections is None:
|
||||
# When a projection name is also a source name and it is referenced in the
|
||||
# GROUP BY clause, BQ can't understand what the identifier corresponds to
|
||||
ambiguous_projections = {
|
||||
s.alias_or_name
|
||||
for s in scope.expression.selects
|
||||
if s.alias_or_name in scope.selected_sources
|
||||
}
|
||||
|
||||
ambiguous = any(
|
||||
column.parts[0].name in ambiguous_projections
|
||||
for column in select.find_all(exp.Column)
|
||||
)
|
||||
else:
|
||||
ambiguous = False
|
||||
|
||||
if (
|
||||
isinstance(select, exp.CONSTANTS)
|
||||
or select.find(exp.Explode, exp.Unnest)
|
||||
or ambiguous
|
||||
):
|
||||
new_nodes.append(node)
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import functools
|
||||
import itertools
|
||||
import typing as t
|
||||
from collections import deque
|
||||
from collections import deque, defaultdict
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
|
||||
|
@ -20,6 +21,8 @@ if t.TYPE_CHECKING:
|
|||
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
|
||||
]
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
||||
|
@ -35,7 +38,10 @@ class UnsupportedUnit(Exception):
|
|||
|
||||
|
||||
def simplify(
|
||||
expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
|
||||
expression: exp.Expression,
|
||||
constant_propagation: bool = False,
|
||||
dialect: DialectType = None,
|
||||
max_depth: t.Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
@ -47,9 +53,9 @@ def simplify(
|
|||
'TRUE'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to simplify
|
||||
expression: expression to simplify
|
||||
constant_propagation: whether the constant propagation rule should be used
|
||||
|
||||
max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
|
||||
Returns:
|
||||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
@ -57,6 +63,18 @@ def simplify(
|
|||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
if (
|
||||
max_depth
|
||||
and isinstance(expression, exp.Connector)
|
||||
and not isinstance(expression.parent, exp.Connector)
|
||||
):
|
||||
depth = connector_depth(expression)
|
||||
if depth > max_depth:
|
||||
logger.info(
|
||||
f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
|
||||
)
|
||||
return expression
|
||||
|
||||
if expression.meta.get(FINAL):
|
||||
return expression
|
||||
|
||||
|
@ -118,6 +136,33 @@ def simplify(
|
|||
return expression
|
||||
|
||||
|
||||
def connector_depth(expression: exp.Expression) -> int:
|
||||
"""
|
||||
Determine the maximum depth of a tree of Connectors.
|
||||
|
||||
For example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> connector_depth(parse_one("a AND b AND c AND d"))
|
||||
3
|
||||
"""
|
||||
stack = deque([(expression, 0)])
|
||||
max_depth = 0
|
||||
|
||||
while stack:
|
||||
expression, depth = stack.pop()
|
||||
|
||||
if not isinstance(expression, exp.Connector):
|
||||
continue
|
||||
|
||||
depth += 1
|
||||
max_depth = max(depth, max_depth)
|
||||
|
||||
stack.append((expression.left, depth))
|
||||
stack.append((expression.right, depth))
|
||||
|
||||
return max_depth
|
||||
|
||||
|
||||
def catch(*exceptions):
|
||||
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""
|
||||
|
||||
|
@ -280,6 +325,7 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
|||
}
|
||||
|
||||
NONDETERMINISTIC = (exp.Rand, exp.Randn)
|
||||
AND_OR = (exp.And, exp.Or)
|
||||
|
||||
|
||||
def _simplify_comparison(expression, left, right, or_=False):
|
||||
|
@ -351,12 +397,12 @@ def remove_complements(expression, root=True):
|
|||
A AND NOT A -> FALSE
|
||||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
complement = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
|
||||
ops = set(expression.flatten())
|
||||
for op in ops:
|
||||
if isinstance(op, exp.Not) and op.this in ops:
|
||||
return exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if is_complement(a, b):
|
||||
return complement
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -404,31 +450,63 @@ def absorb_and_eliminate(expression, root=True):
|
|||
(A AND B) OR (A AND NOT B) -> A
|
||||
(A OR B) AND (A OR NOT B) -> A
|
||||
"""
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
|
||||
kind = exp.Or if isinstance(expression, exp.And) else exp.And
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if isinstance(a, kind):
|
||||
aa, ab = a.unnest_operands()
|
||||
ops = tuple(expression.flatten())
|
||||
|
||||
# absorb
|
||||
if is_complement(b, aa):
|
||||
aa.replace(exp.true() if kind == exp.And else exp.false())
|
||||
elif is_complement(b, ab):
|
||||
ab.replace(exp.true() if kind == exp.And else exp.false())
|
||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
|
||||
a.replace(exp.false() if kind == exp.And else exp.true())
|
||||
elif isinstance(b, kind):
|
||||
# eliminate
|
||||
rhs = b.unnest_operands()
|
||||
ba, bb = rhs
|
||||
# Initialize lookup tables:
|
||||
# Set of all operands, used to find complements for absorption.
|
||||
op_set = set()
|
||||
# Sub-operands, used to find subsets for absorption.
|
||||
subops = defaultdict(list)
|
||||
# Pairs of complements, used for elimination.
|
||||
pairs = defaultdict(list)
|
||||
|
||||
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
|
||||
a.replace(aa)
|
||||
b.replace(aa)
|
||||
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
|
||||
a.replace(ab)
|
||||
b.replace(ab)
|
||||
# Populate the lookup tables
|
||||
for op in ops:
|
||||
op_set.add(op)
|
||||
|
||||
if not isinstance(op, kind):
|
||||
# In cases like: A OR (A AND B)
|
||||
# Subop will be: ^
|
||||
subops[op].append({op})
|
||||
continue
|
||||
|
||||
# In cases like: (A AND B) OR (A AND B AND C)
|
||||
# Subops will be: ^ ^
|
||||
subset = set(op.flatten())
|
||||
for i in subset:
|
||||
subops[i].append(subset)
|
||||
|
||||
a, b = op.unnest_operands()
|
||||
if isinstance(a, exp.Not):
|
||||
pairs[frozenset((a.this, b))].append((op, b))
|
||||
if isinstance(b, exp.Not):
|
||||
pairs[frozenset((a, b.this))].append((op, a))
|
||||
|
||||
for op in ops:
|
||||
if not isinstance(op, kind):
|
||||
continue
|
||||
|
||||
a, b = op.unnest_operands()
|
||||
|
||||
# Absorb
|
||||
if isinstance(a, exp.Not) and a.this in op_set:
|
||||
a.replace(exp.true() if kind == exp.And else exp.false())
|
||||
continue
|
||||
if isinstance(b, exp.Not) and b.this in op_set:
|
||||
b.replace(exp.true() if kind == exp.And else exp.false())
|
||||
continue
|
||||
superset = set(op.flatten())
|
||||
if any(any(subset < superset for subset in subops[i]) for i in superset):
|
||||
op.replace(exp.false() if kind == exp.And else exp.true())
|
||||
continue
|
||||
|
||||
# Eliminate
|
||||
for other, complement in pairs[frozenset((a, b))]:
|
||||
op.replace(complement)
|
||||
other.replace(complement)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue