Merging upstream version 10.2.6.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
40155883c5
commit
17f6b2c749
36 changed files with 1281 additions and 493 deletions
|
@ -7,7 +7,7 @@ from decimal import Decimal
|
|||
from sqlglot import exp
|
||||
from sqlglot.expressions import FALSE, NULL, TRUE
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.helper import first, while_changing
|
||||
|
||||
GENERATOR = Generator(normalize=True, identify=True)
|
||||
|
||||
|
@ -30,6 +30,7 @@ def simplify(expression):
|
|||
|
||||
def _simplify(expression, root=True):
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node)
|
||||
node = absorb_and_eliminate(node)
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
|
@ -49,6 +50,19 @@ def simplify(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def rewrite_between(expression: exp.Expression) -> exp.Expression:
|
||||
"""Rewrite x between y and z to x >= y AND x <= z.
|
||||
|
||||
This is done because comparison simplification is only done on lt/lte/gt/gte.
|
||||
"""
|
||||
if isinstance(expression, exp.Between):
|
||||
return exp.and_(
|
||||
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
||||
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_not(expression):
|
||||
"""
|
||||
Demorgan's Law
|
||||
|
@ -57,7 +71,7 @@ def simplify_not(expression):
|
|||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if isinstance(expression.this, exp.Null):
|
||||
return NULL
|
||||
return exp.null()
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
|
@ -65,11 +79,11 @@ def simplify_not(expression):
|
|||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Null):
|
||||
return NULL
|
||||
return exp.null()
|
||||
if always_true(expression.this):
|
||||
return FALSE
|
||||
return exp.false()
|
||||
if expression.this == FALSE:
|
||||
return TRUE
|
||||
return exp.true()
|
||||
if isinstance(expression.this, exp.Not):
|
||||
# double negation
|
||||
# NOT NOT x -> x
|
||||
|
@ -91,40 +105,119 @@ def flatten(expression):
|
|||
|
||||
|
||||
def simplify_connectors(expression):
|
||||
if isinstance(expression, exp.Connector):
|
||||
left = expression.left
|
||||
right = expression.right
|
||||
|
||||
if left == right:
|
||||
return left
|
||||
|
||||
if isinstance(expression, exp.And):
|
||||
if FALSE in (left, right):
|
||||
return FALSE
|
||||
if NULL in (left, right):
|
||||
return NULL
|
||||
if always_true(left) and always_true(right):
|
||||
return TRUE
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
def _simplify_connectors(expression, left, right):
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left == right:
|
||||
return left
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return TRUE
|
||||
if left == FALSE and right == FALSE:
|
||||
return FALSE
|
||||
if (
|
||||
(left == NULL and right == NULL)
|
||||
or (left == NULL and right == FALSE)
|
||||
or (left == FALSE and right == NULL)
|
||||
):
|
||||
return NULL
|
||||
if left == FALSE:
|
||||
return right
|
||||
if right == FALSE:
|
||||
return left
|
||||
return expression
|
||||
if isinstance(expression, exp.And):
|
||||
if FALSE in (left, right):
|
||||
return exp.false()
|
||||
if NULL in (left, right):
|
||||
return exp.null()
|
||||
if always_true(left) and always_true(right):
|
||||
return exp.true()
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right)
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return exp.true()
|
||||
if left == FALSE and right == FALSE:
|
||||
return exp.false()
|
||||
if (
|
||||
(left == NULL and right == NULL)
|
||||
or (left == NULL and right == FALSE)
|
||||
or (left == FALSE and right == NULL)
|
||||
):
|
||||
return exp.null()
|
||||
if left == FALSE:
|
||||
return right
|
||||
if right == FALSE:
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right, or_=True)
|
||||
return None
|
||||
|
||||
return _flat_simplify(expression, _simplify_connectors)
|
||||
|
||||
|
||||
LT_LTE = (exp.LT, exp.LTE)
|
||||
GT_GTE = (exp.GT, exp.GTE)
|
||||
|
||||
COMPARISONS = (
|
||||
*LT_LTE,
|
||||
*GT_GTE,
|
||||
exp.EQ,
|
||||
exp.NEQ,
|
||||
)
|
||||
|
||||
INVERSE_COMPARISONS = {
|
||||
exp.LT: exp.GT,
|
||||
exp.GT: exp.LT,
|
||||
exp.LTE: exp.GTE,
|
||||
exp.GTE: exp.LTE,
|
||||
}
|
||||
|
||||
|
||||
def _simplify_comparison(expression, left, right, or_=False):
|
||||
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
|
||||
ll, lr = left.args.values()
|
||||
rl, rr = right.args.values()
|
||||
|
||||
largs = {ll, lr}
|
||||
rargs = {rl, rr}
|
||||
|
||||
matching = largs & rargs
|
||||
columns = {m for m in matching if isinstance(m, exp.Column)}
|
||||
|
||||
if matching and columns:
|
||||
try:
|
||||
l = first(largs - columns)
|
||||
r = first(rargs - columns)
|
||||
except StopIteration:
|
||||
return expression
|
||||
|
||||
# make sure the comparison is always of the form x > 1 instead of 1 < x
|
||||
if left.__class__ in INVERSE_COMPARISONS and l == ll:
|
||||
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
|
||||
if right.__class__ in INVERSE_COMPARISONS and r == rl:
|
||||
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
|
||||
|
||||
if l.is_number and r.is_number:
|
||||
l = float(l.name)
|
||||
r = float(r.name)
|
||||
elif l.is_string and r.is_string:
|
||||
l = l.name
|
||||
r = r.name
|
||||
else:
|
||||
return None
|
||||
|
||||
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
|
||||
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
|
||||
return left if (av > bv if or_ else av <= bv) else right
|
||||
if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
|
||||
return left if (av < bv if or_ else av >= bv) else right
|
||||
|
||||
# we can't ever shortcut to true because the column could be null
|
||||
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
|
||||
if not or_ and av <= bv:
|
||||
return exp.false()
|
||||
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
|
||||
if not or_ and av >= bv:
|
||||
return exp.false()
|
||||
elif isinstance(a, exp.EQ):
|
||||
if isinstance(b, exp.LT):
|
||||
return exp.false() if av >= bv else a
|
||||
if isinstance(b, exp.LTE):
|
||||
return exp.false() if av > bv else a
|
||||
if isinstance(b, exp.GT):
|
||||
return exp.false() if av <= bv else a
|
||||
if isinstance(b, exp.GTE):
|
||||
return exp.false() if av < bv else a
|
||||
if isinstance(b, exp.NEQ):
|
||||
return exp.false() if av == bv else a
|
||||
return None
|
||||
|
||||
|
||||
def remove_compliments(expression):
|
||||
|
@ -135,7 +228,7 @@ def remove_compliments(expression):
|
|||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
compliment = FALSE if isinstance(expression, exp.And) else TRUE
|
||||
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if is_complement(a, b):
|
||||
|
@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):
|
|||
|
||||
def simplify_literals(expression):
|
||||
if isinstance(expression, exp.Binary):
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
|
||||
for b in queue:
|
||||
result = _simplify_binary(expression, a, b)
|
||||
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
return _flat_simplify(expression, _simplify_binary)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
|
@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):
|
|||
|
||||
if c == NULL:
|
||||
if isinstance(a, exp.Literal):
|
||||
return TRUE if not_ else FALSE
|
||||
return exp.true() if not_ else exp.false()
|
||||
if a == NULL:
|
||||
return FALSE if not_ else TRUE
|
||||
elif isinstance(expression, exp.NullSafeEQ):
|
||||
if a == b:
|
||||
return TRUE
|
||||
elif isinstance(expression, exp.NullSafeNEQ):
|
||||
if a == b:
|
||||
return FALSE
|
||||
return exp.false() if not_ else exp.true()
|
||||
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
|
||||
return None
|
||||
elif NULL in (a, b):
|
||||
return NULL
|
||||
|
||||
if isinstance(expression, exp.EQ) and a == b:
|
||||
return TRUE
|
||||
return exp.null()
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
|
@ -388,4 +454,27 @@ def date_literal(date):
|
|||
|
||||
|
||||
def boolean_literal(condition):
|
||||
return TRUE if condition else FALSE
|
||||
return exp.true() if condition else exp.false()
|
||||
|
||||
|
||||
def _flat_simplify(expression, simplifier):
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
|
||||
for b in queue:
|
||||
result = simplifier(expression, a, b)
|
||||
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue