Adding upstream version 6.0.4.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d01130b3f1
commit
527597d2af
122 changed files with 23162 additions and 0 deletions
383
sqlglot/optimizer/simplify.py
Normal file
383
sqlglot/optimizer/simplify.py
Normal file
|
@ -0,0 +1,383 @@
|
|||
import datetime
|
||||
import functools
|
||||
import itertools
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.expressions import FALSE, NULL, TRUE
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import while_changing
|
||||
|
||||
GENERATOR = Generator(normalize=True, identify=True)
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
|
||||
>>> simplify(expression).sql()
|
||||
'TRUE'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to simplify
|
||||
Returns:
|
||||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
node = expression
|
||||
node = uniq_sort(node)
|
||||
node = absorb_and_eliminate(node)
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node)
|
||||
node = remove_compliments(node)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node)
|
||||
node = simplify_parens(node)
|
||||
if root:
|
||||
expression.replace(node)
|
||||
return node
|
||||
|
||||
expression = while_changing(expression, _simplify)
|
||||
remove_where_true(expression)
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_not(expression):
|
||||
"""
|
||||
Demorgan's Law
|
||||
NOT (x OR y) -> NOT x AND NOT y
|
||||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if always_true(expression.this):
|
||||
return FALSE
|
||||
if expression.this == FALSE:
|
||||
return TRUE
|
||||
if isinstance(expression.this, exp.Not):
|
||||
# double negation
|
||||
# NOT NOT x -> x
|
||||
return expression.this.this
|
||||
return expression
|
||||
|
||||
|
||||
def flatten(expression):
|
||||
"""
|
||||
A AND (B AND C) -> A AND B AND C
|
||||
A OR (B OR C) -> A OR B OR C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
for node in expression.args.values():
|
||||
child = node.unnest()
|
||||
if isinstance(child, expression.__class__):
|
||||
node.replace(child)
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_connectors(expression):
|
||||
if isinstance(expression, exp.Connector):
|
||||
left = expression.left
|
||||
right = expression.right
|
||||
|
||||
if left == right:
|
||||
return left
|
||||
|
||||
if isinstance(expression, exp.And):
|
||||
if NULL in (left, right):
|
||||
return NULL
|
||||
if FALSE in (left, right):
|
||||
return FALSE
|
||||
if always_true(left) and always_true(right):
|
||||
return TRUE
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
return left
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return TRUE
|
||||
if left == FALSE and right == FALSE:
|
||||
return FALSE
|
||||
if (
|
||||
(left == NULL and right == NULL)
|
||||
or (left == NULL and right == FALSE)
|
||||
or (left == FALSE and right == NULL)
|
||||
):
|
||||
return NULL
|
||||
if left == FALSE:
|
||||
return right
|
||||
if right == FALSE:
|
||||
return left
|
||||
return expression
|
||||
|
||||
|
||||
def remove_compliments(expression):
|
||||
"""
|
||||
Removing compliments.
|
||||
|
||||
A AND NOT A -> FALSE
|
||||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
compliment = FALSE if isinstance(expression, exp.And) else TRUE
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if is_complement(a, b):
|
||||
return compliment
|
||||
return expression
|
||||
|
||||
|
||||
def uniq_sort(expression):
|
||||
"""
|
||||
Uniq and sort a connector.
|
||||
|
||||
C AND A AND B AND B -> A AND B AND C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {GENERATOR.generate(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
# A AND C AND B -> A AND B AND C
|
||||
for i, (sql, e) in enumerate(arr[1:]):
|
||||
if sql < arr[i][0]:
|
||||
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
|
||||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
if len(deduped) < len(flattened):
|
||||
expression = result_func(*deduped.values())
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def absorb_and_eliminate(expression):
|
||||
"""
|
||||
absorption:
|
||||
A AND (A OR B) -> A
|
||||
A OR (A AND B) -> A
|
||||
A AND (NOT A OR B) -> A AND B
|
||||
A OR (NOT A AND B) -> A OR B
|
||||
elimination:
|
||||
(A AND B) OR (A AND NOT B) -> A
|
||||
(A OR B) AND (A OR NOT B) -> A
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
kind = exp.Or if isinstance(expression, exp.And) else exp.And
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if isinstance(a, kind):
|
||||
aa, ab = a.unnest_operands()
|
||||
|
||||
# absorb
|
||||
if is_complement(b, aa):
|
||||
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
elif is_complement(b, ab):
|
||||
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
|
||||
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
|
||||
a.flatten()
|
||||
):
|
||||
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
|
||||
elif isinstance(b, kind):
|
||||
# eliminate
|
||||
rhs = b.unnest_operands()
|
||||
ba, bb = rhs
|
||||
|
||||
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
|
||||
a.replace(aa)
|
||||
b.replace(aa)
|
||||
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
|
||||
a.replace(ab)
|
||||
b.replace(ab)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_literals(expression):
|
||||
if isinstance(expression, exp.Binary):
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
|
||||
for b in queue:
|
||||
result = _simplify_binary(expression, a, b)
|
||||
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
value = this.name
|
||||
if value[0] == "-":
|
||||
return exp.Literal.number(value[1:])
|
||||
return exp.Literal.number(f"-{value}")
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _simplify_binary(expression, a, b):
|
||||
if isinstance(expression, exp.Is):
|
||||
if isinstance(b, exp.Not):
|
||||
c = b.this
|
||||
not_ = True
|
||||
else:
|
||||
c = b
|
||||
not_ = False
|
||||
|
||||
if c == NULL:
|
||||
if isinstance(a, exp.Literal):
|
||||
return TRUE if not_ else FALSE
|
||||
if a == NULL:
|
||||
return FALSE if not_ else TRUE
|
||||
elif NULL in (a, b):
|
||||
return NULL
|
||||
|
||||
if isinstance(expression, exp.EQ) and a == b:
|
||||
return TRUE
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
|
||||
if isinstance(expression, exp.Add):
|
||||
return exp.Literal.number(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return exp.Literal.number(a - b)
|
||||
if isinstance(expression, exp.Mul):
|
||||
return exp.Literal.number(a * b)
|
||||
if isinstance(expression, exp.Div):
|
||||
if isinstance(a, int) and isinstance(b, int):
|
||||
return exp.Literal.number(a // b)
|
||||
return exp.Literal.number(a / b)
|
||||
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
elif a.is_string and b.is_string:
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if b:
|
||||
if isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return date_literal(a - b)
|
||||
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def simplify_parens(expression):
|
||||
if (
|
||||
isinstance(expression, exp.Paren)
|
||||
and not isinstance(expression.this, exp.Select)
|
||||
and (
|
||||
not isinstance(expression.parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(expression.this, (exp.Is, exp.Like))
|
||||
or not isinstance(expression.this, exp.Binary)
|
||||
)
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
||||
|
||||
def remove_where_true(expression):
|
||||
for where in expression.find_all(exp.Where):
|
||||
if always_true(where.this):
|
||||
where.parent.set("where", None)
|
||||
for join in expression.find_all(exp.Join):
|
||||
if always_true(join.args.get("on")):
|
||||
join.set("kind", "CROSS")
|
||||
join.set("on", None)
|
||||
|
||||
|
||||
def always_true(expression):
|
||||
return expression == TRUE or isinstance(expression, exp.Literal)
|
||||
|
||||
|
||||
def is_complement(a, b):
|
||||
return isinstance(b, exp.Not) and b.this == a
|
||||
|
||||
|
||||
def eval_boolean(expression, a, b):
|
||||
if isinstance(expression, (exp.EQ, exp.Is)):
|
||||
return boolean_literal(a == b)
|
||||
if isinstance(expression, exp.NEQ):
|
||||
return boolean_literal(a != b)
|
||||
if isinstance(expression, exp.GT):
|
||||
return boolean_literal(a > b)
|
||||
if isinstance(expression, exp.GTE):
|
||||
return boolean_literal(a >= b)
|
||||
if isinstance(expression, exp.LT):
|
||||
return boolean_literal(a < b)
|
||||
if isinstance(expression, exp.LTE):
|
||||
return boolean_literal(a <= b)
|
||||
return None
|
||||
|
||||
|
||||
def extract_date(cast):
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
return None
|
||||
|
||||
|
||||
def extract_interval(interval):
|
||||
try:
|
||||
from dateutil.relativedelta import relativedelta
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
n = int(interval.name)
|
||||
unit = interval.text("unit").lower()
|
||||
|
||||
if unit == "year":
|
||||
return relativedelta(years=n)
|
||||
if unit == "month":
|
||||
return relativedelta(months=n)
|
||||
if unit == "week":
|
||||
return relativedelta(weeks=n)
|
||||
if unit == "day":
|
||||
return relativedelta(days=n)
|
||||
return None
|
||||
|
||||
|
||||
def date_literal(date):
|
||||
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
|
||||
|
||||
|
||||
def boolean_literal(condition):
|
||||
return TRUE if condition else FALSE
|
Loading…
Add table
Add a link
Reference in a new issue