1
0
Fork 0

Merging upstream version 25.0.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:37:40 +01:00
parent 03b67e2ec9
commit 021892b3ff
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
84 changed files with 33016 additions and 31040 deletions

View file

@ -36,7 +36,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
original = node.copy()
node.transform(rewrite_between, copy=False)
distance = normalization_distance(node, dnf=dnf)
distance = normalization_distance(node, dnf=dnf, max_=max_distance)
if distance > max_distance:
logger.info(
@ -85,7 +85,9 @@ def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
)
def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
def normalization_distance(
expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
) -> int:
"""
The difference in the number of predicates between a given expression and its normalized form.
@ -101,33 +103,47 @@ def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int
expression: The expression to compute the normalization distance for.
dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
max_: stop early if count exceeds this.
Returns:
The normalization distance.
"""
return sum(_predicate_lengths(expression, dnf)) - (
sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
for length in _predicate_lengths(expression, dnf, max_):
total += length
if total > max_:
return total
return total
def _predicate_lengths(expression, dnf):
def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
"""
Returns a list of predicate lengths when expanded to normalized form.
(A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
"""
if depth > max_:
yield depth
return
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
return (1,)
yield 1
return
depth += 1
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
)
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
for a in _predicate_lengths(left, dnf, max_, depth):
for b in _predicate_lengths(right, dnf, max_, depth):
yield a + b
else:
yield from _predicate_lengths(left, dnf, max_, depth)
yield from _predicate_lengths(right, dnf, max_, depth)
def distributive_law(expression, dnf, max_distance):
@ -138,7 +154,7 @@ def distributive_law(expression, dnf, max_distance):
if normalized(expression, dnf=dnf):
return expression
distance = normalization_distance(expression, dnf=dnf)
distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")

View file

@ -80,7 +80,7 @@ def qualify_columns(
)
qualify_outputs(scope)
_expand_group_by(scope)
_expand_group_by(scope, dialect)
_expand_order_by(scope, resolver)
if dialect == "bigquery":
@ -266,13 +266,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
scope.clear_cache()
def _expand_group_by(scope: Scope) -> None:
def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
expression = scope.expression
group = expression.args.get("group")
if not group:
return
group.set("expressions", _expand_positional_references(scope, group.expressions))
group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
expression.set("group", group)
@ -284,7 +284,9 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
ordereds = order.expressions
for ordered, new_expression in zip(
ordereds,
_expand_positional_references(scope, (o.this for o in ordereds), alias=True),
_expand_positional_references(
scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
),
):
for agg in ordered.find_all(exp.AggFunc):
for col in agg.find_all(exp.Column):
@ -307,9 +309,11 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
def _expand_positional_references(
scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
) -> t.List[exp.Expression]:
new_nodes: t.List[exp.Expression] = []
ambiguous_projections = None
for node in expressions:
if node.is_int:
select = _select_by_pos(scope, t.cast(exp.Literal, node))
@ -319,7 +323,28 @@ def _expand_positional_references(
else:
select = select.this
if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
if dialect == "bigquery":
if ambiguous_projections is None:
# When a projection name is also a source name and it is referenced in the
# GROUP BY clause, BQ can't understand what the identifier corresponds to
ambiguous_projections = {
s.alias_or_name
for s in scope.expression.selects
if s.alias_or_name in scope.selected_sources
}
ambiguous = any(
column.parts[0].name in ambiguous_projections
for column in select.find_all(exp.Column)
)
else:
ambiguous = False
if (
isinstance(select, exp.CONSTANTS)
or select.find(exp.Explode, exp.Unnest)
or ambiguous
):
new_nodes.append(node)
else:
new_nodes.append(select.copy())

View file

@ -1,10 +1,11 @@
from __future__ import annotations
import datetime
import logging
import functools
import itertools
import typing as t
from collections import deque
from collections import deque, defaultdict
from decimal import Decimal
from functools import reduce
@ -20,6 +21,8 @@ if t.TYPE_CHECKING:
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
]
logger = logging.getLogger("sqlglot")
# Final means that an expression should not be simplified
FINAL = "final"
@ -35,7 +38,10 @@ class UnsupportedUnit(Exception):
def simplify(
expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
expression: exp.Expression,
constant_propagation: bool = False,
dialect: DialectType = None,
max_depth: t.Optional[int] = None,
):
"""
Rewrite sqlglot AST to simplify expressions.
@ -47,9 +53,9 @@ def simplify(
'TRUE'
Args:
expression (sqlglot.Expression): expression to simplify
expression: expression to simplify
constant_propagation: whether the constant propagation rule should be used
max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
Returns:
sqlglot.Expression: simplified expression
"""
@ -57,6 +63,18 @@ def simplify(
dialect = Dialect.get_or_raise(dialect)
def _simplify(expression, root=True):
if (
max_depth
and isinstance(expression, exp.Connector)
and not isinstance(expression.parent, exp.Connector)
):
depth = connector_depth(expression)
if depth > max_depth:
logger.info(
f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
)
return expression
if expression.meta.get(FINAL):
return expression
@ -118,6 +136,33 @@ def simplify(
return expression
def connector_depth(expression: exp.Expression) -> int:
"""
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
"""
stack = deque([(expression, 0)])
max_depth = 0
while stack:
expression, depth = stack.pop()
if not isinstance(expression, exp.Connector):
continue
depth += 1
max_depth = max(depth, max_depth)
stack.append((expression.left, depth))
stack.append((expression.right, depth))
return max_depth
def catch(*exceptions):
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""
@ -280,6 +325,7 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
}
NONDETERMINISTIC = (exp.Rand, exp.Randn)
AND_OR = (exp.And, exp.Or)
def _simplify_comparison(expression, left, right, or_=False):
@ -351,12 +397,12 @@ def remove_complements(expression, root=True):
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
complement = exp.false() if isinstance(expression, exp.And) else exp.true()
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
ops = set(expression.flatten())
for op in ops:
if isinstance(op, exp.Not) and op.this in ops:
return exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
return complement
return expression
@ -404,31 +450,63 @@ def absorb_and_eliminate(expression, root=True):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
if isinstance(a, kind):
aa, ab = a.unnest_operands()
ops = tuple(expression.flatten())
# absorb
if is_complement(b, aa):
aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
ba, bb = rhs
# Initialize lookup tables:
# Set of all operands, used to find complements for absorption.
op_set = set()
# Sub-operands, used to find subsets for absorption.
subops = defaultdict(list)
# Pairs of complements, used for elimination.
pairs = defaultdict(list)
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
a.replace(aa)
b.replace(aa)
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
a.replace(ab)
b.replace(ab)
# Populate the lookup tables
for op in ops:
op_set.add(op)
if not isinstance(op, kind):
# In cases like: A OR (A AND B)
# Subop will be: ^
subops[op].append({op})
continue
# In cases like: (A AND B) OR (A AND B AND C)
# Subops will be: ^ ^
subset = set(op.flatten())
for i in subset:
subops[i].append(subset)
a, b = op.unnest_operands()
if isinstance(a, exp.Not):
pairs[frozenset((a.this, b))].append((op, b))
if isinstance(b, exp.Not):
pairs[frozenset((a, b.this))].append((op, a))
for op in ops:
if not isinstance(op, kind):
continue
a, b = op.unnest_operands()
# Absorb
if isinstance(a, exp.Not) and a.this in op_set:
a.replace(exp.true() if kind == exp.And else exp.false())
continue
if isinstance(b, exp.Not) and b.this in op_set:
b.replace(exp.true() if kind == exp.And else exp.false())
continue
superset = set(op.flatten())
if any(any(subset < superset for subset in subops[i]) for i in superset):
op.replace(exp.false() if kind == exp.And else exp.true())
continue
# Eliminate
for other, complement in pairs[frozenset((a, b))]:
op.replace(complement)
other.replace(complement)
return expression