1
0
Fork 0
sqlglot/sqlglot/optimizer/simplify.py
Daniel Baumann e47608846f
Adding upstream version 23.13.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 21:33:03 +01:00

1497 lines
46 KiB
Python

from __future__ import annotations
import datetime
import functools
import itertools
import typing as t
from collections import deque
from decimal import Decimal
from functools import reduce
import sqlglot
from sqlglot import Dialect, exp
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
DateTruncBinaryTransform = t.Callable[
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
]
# Final means that an expression should not be simplified
FINAL = "final"
# Value ranges for byte-sized signed/unsigned integers
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(Exception):
pass
def simplify(
expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
):
"""
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Args:
expression (sqlglot.Expression): expression to simplify
constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
"""
dialect = Dialect.get_or_raise(dialect)
def _simplify(expression, root=True):
if expression.meta.get(FINAL):
return expression
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
group = expression.args.get("group")
if group and hasattr(expression, "selects"):
groups = set(group.expressions)
group.meta[FINAL] = True
for e in expression.selects:
for node in e.walk():
if node in groups:
e.meta[FINAL] = True
break
having = expression.args.get("having")
if having:
for node in having.walk():
if node in groups:
having.meta[FINAL] = True
break
# Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
if constant_propagation:
node = propagate_constants(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
# Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
node = remove_complements(node, root)
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
node = simplify_datetrunc(node, dialect)
node = sort_comparison(node)
node = simplify_startswith(node)
if root:
expression.replace(node)
return node
expression = while_changing(expression, _simplify)
remove_where_true(expression)
return expression
def catch(*exceptions):
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""
def decorator(func):
def wrapped(expression, *args, **kwargs):
try:
return func(expression, *args, **kwargs)
except exceptions:
return expression
return wrapped
return decorator
def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
negate = isinstance(expression.parent, exp.Not)
expression = exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
if negate:
expression = exp.paren(expression, copy=False)
return expression
COMPLEMENT_COMPARISONS = {
exp.LT: exp.GTE,
exp.GT: exp.LTE,
exp.LTE: exp.GT,
exp.GTE: exp.LT,
exp.EQ: exp.NEQ,
exp.NEQ: exp.EQ,
}
def simplify_not(expression):
"""
Demorgan's Law
NOT (x OR y) -> NOT x AND NOT y
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
this = expression.this
if is_null(this):
return exp.null()
if this.__class__ in COMPLEMENT_COMPARISONS:
return COMPLEMENT_COMPARISONS[this.__class__](
this=this.this, expression=this.expression
)
if isinstance(this, exp.Paren):
condition = this.unnest()
if isinstance(condition, exp.And):
return exp.paren(
exp.or_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
)
if isinstance(condition, exp.Or):
return exp.paren(
exp.and_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
)
if is_null(condition):
return exp.null()
if always_true(this):
return exp.false()
if is_false(this):
return exp.true()
if isinstance(this, exp.Not):
# double negation
# NOT NOT x -> x
return this.this
return expression
def flatten(expression):
"""
A AND (B AND C) -> A AND B AND C
A OR (B OR C) -> A OR B OR C
"""
if isinstance(expression, exp.Connector):
for node in expression.args.values():
child = node.unnest()
if isinstance(child, expression.__class__):
node.replace(child)
return expression
def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
if left == right:
if isinstance(expression, exp.Xor):
return exp.false()
return left
if isinstance(expression, exp.And):
if is_false(left) or is_false(right):
return exp.false()
if is_null(left) or is_null(right):
return exp.null()
if always_true(left) and always_true(right):
return exp.true()
if always_true(left):
return right
if always_true(right):
return left
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if is_false(left) and is_false(right):
return exp.false()
if (
(is_null(left) and is_null(right))
or (is_null(left) and is_false(right))
or (is_false(left) and is_null(right))
):
return exp.null()
if is_false(left):
return right
if is_false(right):
return left
return _simplify_comparison(expression, left, right, or_=True)
if isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_connectors, root)
return expression
LT_LTE = (exp.LT, exp.LTE)
GT_GTE = (exp.GT, exp.GTE)
COMPARISONS = (
*LT_LTE,
*GT_GTE,
exp.EQ,
exp.NEQ,
exp.Is,
)
INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.LT: exp.GT,
exp.GT: exp.LT,
exp.LTE: exp.GTE,
exp.GTE: exp.LTE,
}
NONDETERMINISTIC = (exp.Rand, exp.Randn)
def _simplify_comparison(expression, left, right, or_=False):
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
ll, lr = left.args.values()
rl, rr = right.args.values()
largs = {ll, lr}
rargs = {rl, rr}
matching = largs & rargs
columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
if matching and columns:
try:
l = first(largs - columns)
r = first(rargs - columns)
except StopIteration:
return expression
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
elif l.is_string and r.is_string:
l = l.name
r = r.name
else:
l = extract_date(l)
if not l:
return None
r = extract_date(r)
if not r:
return None
# python won't compare date and datetime, but many engines will upcast
l, r = cast_as_datetime(l), cast_as_datetime(r)
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
return left if (av > bv if or_ else av <= bv) else right
if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
return left if (av < bv if or_ else av >= bv) else right
# we can't ever shortcut to true because the column could be null
if not or_:
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
if av <= bv:
return exp.false()
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
if av >= bv:
return exp.false()
elif isinstance(a, exp.EQ):
if isinstance(b, exp.LT):
return exp.false() if av >= bv else a
if isinstance(b, exp.LTE):
return exp.false() if av > bv else a
if isinstance(b, exp.GT):
return exp.false() if av <= bv else a
if isinstance(b, exp.GTE):
return exp.false() if av < bv else a
if isinstance(b, exp.NEQ):
return exp.false() if av == bv else a
return None
def remove_complements(expression, root=True):
"""
Removing complements.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
complement = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
return complement
return expression
def uniq_sort(expression, root=True):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
flattened = tuple(expression.flatten())
if isinstance(expression, exp.Xor):
result_func = exp.xor
# Do not deduplicate XOR as A XOR A != A if A == True
deduped = None
arr = tuple((gen(e), e) for e in flattened)
else:
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
expression = result_func(*(e for _, e in sorted(arr)), copy=False)
break
else:
# we didn't have to sort but maybe we need to dedup
if deduped and len(deduped) < len(flattened):
expression = result_func(*deduped.values(), copy=False)
return expression
def absorb_and_eliminate(expression, root=True):
"""
absorption:
A AND (A OR B) -> A
A OR (A AND B) -> A
A AND (NOT A OR B) -> A AND B
A OR (NOT A AND B) -> A OR B
elimination:
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
if isinstance(a, kind):
aa, ab = a.unnest_operands()
# absorb
if is_complement(b, aa):
aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
ba, bb = rhs
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
a.replace(aa)
b.replace(aa)
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
a.replace(ab)
b.replace(ab)
return expression
def propagate_constants(expression, root=True):
"""
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes
SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
"""
if (
isinstance(expression, exp.And)
and (root or not expression.same_parent)
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
if isinstance(expr, exp.EQ):
l, r = expr.left, expr.right
# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
constant_mapping[l] = (id(l), r)
if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
parent = column.parent
column_id, constant = constant_mapping.get(column) or (None, None)
if (
column_id is not None
and id(column) != column_id
and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
):
column.replace(constant.copy())
return expression
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
exp.DatetimeAdd: exp.Sub,
exp.DatetimeSub: exp.Add,
}
INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
**INVERSE_DATE_OPS,
exp.Add: exp.Sub,
exp.Sub: exp.Add,
}
def _is_number(expression: exp.Expression) -> bool:
return expression.is_number
def _is_interval(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
@catch(ModuleNotFoundError, UnsupportedUnit)
def simplify_equality(expression: exp.Expression) -> exp.Expression:
"""
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and =
Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
"""
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
if l.__class__ not in INVERSE_OPS:
return expression
if r.is_number:
a_predicate = _is_number
b_predicate = _is_number
elif _is_date_literal(r):
a_predicate = _is_date_literal
b_predicate = _is_interval
else:
return expression
if l.__class__ in INVERSE_DATE_OPS:
l = t.cast(exp.IntervalOp, l)
a = l.this
b = l.interval()
else:
l = t.cast(exp.Binary, l)
a, b = l.left, l.right
if not a_predicate(a) and b_predicate(b):
pass
elif not a_predicate(b) and b_predicate(a):
a, b = b, a
else:
return expression
return expression.__class__(
this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
)
return expression
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
if isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
if value[0] == "-":
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
if type(expression) in INVERSE_DATE_OPS:
return _simplify_binary(expression, expression.this, expression.interval()) or expression
return expression
NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
this = _simplify_integer_cast(expr.this)
else:
this = expr.this
if isinstance(expr, exp.Cast) and this.is_int:
num = int(this.name)
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
# engine-dependent
if (
TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
) or (
UTINYINT_MIN <= num <= UTINYINT_MAX
and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
):
return this
return expr
def _simplify_binary(expression, a, b):
if isinstance(expression, COMPARISONS):
a = _simplify_integer_cast(a)
b = _simplify_integer_cast(b)
if isinstance(expression, exp.Is):
if isinstance(b, exp.Not):
c = b.this
not_ = True
else:
c = b
not_ = False
if is_null(c):
if isinstance(a, exp.Literal):
return exp.true() if not_ else exp.false()
if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, NULL_OK):
return None
elif is_null(a) or is_null(b):
return exp.null()
if a.is_number and b.is_number:
num_a = int(a.name) if a.is_int else Decimal(a.name)
num_b = int(b.name) if b.is_int else Decimal(b.name)
if isinstance(expression, exp.Add):
return exp.Literal.number(num_a + num_b)
if isinstance(expression, exp.Mul):
return exp.Literal.number(num_a * num_b)
# We only simplify Sub, Div if a and b have the same parent because they're not associative
if isinstance(expression, exp.Sub):
return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
if isinstance(expression, exp.Div):
# engines have differing int div behavior so intdiv is not safe
if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
return None
return exp.Literal.number(num_a / num_b)
boolean = eval_boolean(expression, num_a, num_b)
if boolean:
return boolean
elif a.is_string and b.is_string:
boolean = eval_boolean(expression, a.this, b.this)
if boolean:
return boolean
elif _is_date_literal(a) and isinstance(b, exp.Interval):
date, b = extract_date(a), extract_interval(b)
if date and b:
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
return date_literal(date + b, extract_type(a))
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
return date_literal(date - b, extract_type(a))
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, date = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
return date_literal(a + date, extract_type(b))
elif _is_date_literal(a) and _is_date_literal(b):
if isinstance(expression, exp.Predicate):
a, b = extract_date(a), extract_date(b)
boolean = eval_boolean(expression, a, b)
if boolean:
return boolean
return None
def simplify_parens(expression):
if not isinstance(expression, exp.Paren):
return expression
this = expression.this
parent = expression.parent
parent_is_predicate = isinstance(parent, exp.Predicate)
if (
not isinstance(this, exp.Select)
and not isinstance(parent, exp.SubqueryPredicate)
and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
or (
not isinstance(this, exp.Binary)
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
)
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
)
):
return this
return expression
def _is_nonnull_constant(expression: exp.Expression) -> bool:
return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
def _is_constant(expression: exp.Expression) -> bool:
return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
def simplify_coalesce(expression):
# COALESCE(x) -> x
if (
isinstance(expression, exp.Coalesce)
and (not expression.expressions or _is_nonnull_constant(expression.this))
# COALESCE is also used as a Spark partitioning hint
and not isinstance(expression.parent, exp.Hint)
):
return expression.this
if not isinstance(expression, COMPARISONS):
return expression
if isinstance(expression.left, exp.Coalesce):
coalesce = expression.left
other = expression.right
elif isinstance(expression.right, exp.Coalesce):
coalesce = expression.right
other = expression.left
else:
return expression
# This transformation is valid for non-constants,
# but it really only does anything if they are both constants.
if not _is_constant(other):
return expression
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
if _is_constant(arg):
break
else:
return expression
coalesce.set("expressions", coalesce.expressions[:arg_index])
# Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
# since we already remove COALESCE at the top of this function.
coalesce = coalesce if coalesce.expressions else coalesce.this
# This expression is more complex than when we started, but it will get simplified further
return exp.paren(
exp.or_(
exp.and_(
coalesce.is_(exp.null()).not_(copy=False),
expression.copy(),
copy=False,
),
exp.and_(
coalesce.is_(exp.null()),
type(expression)(this=arg.copy(), expression=other.copy()),
copy=False,
),
copy=False,
)
)
CONCATS = (exp.Concat, exp.DPipe)
def simplify_concat(expression):
"""Reduces all groups that contain string literals by concatenating them."""
if not isinstance(expression, CONCATS) or (
# We can't reduce a CONCAT_WS call if we don't statically know the separator
isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
):
return expression
if isinstance(expression, exp.ConcatWs):
sep_expr, *expressions = expression.expressions
sep = sep_expr.name
concat_type = exp.ConcatWs
args = {}
else:
expressions = expression.expressions
sep = ""
concat_type = exp.Concat
args = {
"safe": expression.args.get("safe"),
"coalesce": expression.args.get("coalesce"),
}
new_args = []
for is_string_group, group in itertools.groupby(
expressions or expression.flatten(), lambda e: e.is_string
):
if is_string_group:
new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
else:
new_args.extend(group)
if len(new_args) == 1 and new_args[0].is_string:
return new_args[0]
if concat_type is exp.ConcatWs:
new_args = [sep_expr] + new_args
elif isinstance(expression, exp.DPipe):
return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
return concat_type(expressions=new_args, **args)
def simplify_conditionals(expression):
"""Simplifies expressions like IF, CASE if their condition is statically known."""
if isinstance(expression, exp.Case):
this = expression.this
for case in expression.args["ifs"]:
cond = case.this
if this:
# Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
cond = cond.replace(this.pop().eq(cond))
if always_true(cond):
return case.args["true"]
if always_false(cond):
case.pop()
if not expression.args["ifs"]:
return expression.args.get("default") or exp.null()
elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
if always_true(expression.this):
return expression.args["true"]
if always_false(expression.this):
return expression.args.get("false") or exp.null()
return expression
def simplify_startswith(expression: exp.Expression) -> exp.Expression:
"""
Reduces a prefix check to either TRUE or FALSE if both the string and the
prefix are statically known.
Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
"""
if (
isinstance(expression, exp.StartsWith)
and expression.this.is_string
and expression.expression.is_string
):
return exp.convert(expression.name.startswith(expression.expression.name))
return expression
DateRange = t.Tuple[datetime.date, datetime.date]
def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
"""
Get the date range for a DATE_TRUNC equality comparison:
Example:
_datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
Returns:
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
"""
floor = date_floor(date, unit, dialect)
if date != floor:
# This will always be False, except for NULL values.
return None
return floor, floor + interval(unit)
def _datetrunc_eq_expression(
left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
) -> exp.Expression:
"""Get the logical expression for a date range"""
return exp.and_(
left >= date_literal(drange[0], target_type),
left < date_literal(drange[1], target_type),
copy=False,
)
def _datetrunc_eq(
left: exp.Expression,
date: datetime.date,
unit: str,
dialect: Dialect,
target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
return _datetrunc_eq_expression(left, drange, target_type)
def _datetrunc_neq(
left: exp.Expression,
date: datetime.date,
unit: str,
dialect: Dialect,
target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
return exp.and_(
left < date_literal(drange[0], target_type),
left >= date_literal(drange[1], target_type),
copy=False,
)
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
exp.LT: lambda l, dt, u, d, t: l
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
return isinstance(left, DATETRUNCS) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
comparison = expression.__class__
if isinstance(expression, DATETRUNCS):
this = expression.this
trunc_type = extract_type(this)
date = extract_date(this)
if date and expression.unit:
return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
elif comparison not in DATETRUNC_COMPARISONS:
return expression
if isinstance(expression, exp.Binary):
l, r = expression.left, expression.right
if not _is_datetrunc_predicate(l, r):
return expression
l = t.cast(exp.DateTrunc, l)
trunc_arg = l.this
unit = l.unit.name.lower()
date = extract_date(r)
if not date:
return expression
return (
DATETRUNC_BINARY_COMPARISONS[comparison](
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
)
or expression
)
if isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
ranges = []
for r in rs:
date = extract_date(r)
if not date:
return expression
drange = _datetrunc_range(date, unit, dialect)
if drange:
ranges.append(drange)
if not ranges:
return expression
ranges = merge_ranges(ranges)
target_type = extract_type(l, *rs)
return exp.or_(
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
)
return expression
def sort_comparison(expression: exp.Expression) -> exp.Expression:
if expression.__class__ in COMPLEMENT_COMPARISONS:
l, r = expression.this, expression.expression
l_column = isinstance(l, exp.Column)
r_column = isinstance(r, exp.Column)
l_const = _is_constant(l)
r_const = _is_constant(r)
if (l_column and not r_column) or (r_const and not l_const):
return expression
if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
this=r, expression=l
)
return expression
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
JOINS = {
("", ""),
("", "INNER"),
("RIGHT", ""),
("RIGHT", "OUTER"),
}
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
where.pop()
for join in expression.find_all(exp.Join):
if (
always_true(join.args.get("on"))
and not join.args.get("using")
and not join.args.get("method")
and (join.side, join.kind) in JOINS
):
join.args["on"].pop()
join.set("side", None)
join.set("kind", "CROSS")
def always_true(expression):
return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
expression, exp.Literal
)
def always_false(expression):
return is_false(expression) or is_null(expression)
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
def is_false(a: exp.Expression) -> bool:
return type(a) is exp.Boolean and not a.this
def is_null(a: exp.Expression) -> bool:
return type(a) is exp.Null
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
if isinstance(expression, exp.NEQ):
return boolean_literal(a != b)
if isinstance(expression, exp.GT):
return boolean_literal(a > b)
if isinstance(expression, exp.GTE):
return boolean_literal(a >= b)
if isinstance(expression, exp.LT):
return boolean_literal(a < b)
if isinstance(expression, exp.LTE):
return boolean_literal(a <= b)
return None
def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
if isinstance(value, datetime.datetime):
return value.date()
if isinstance(value, datetime.date):
return value
try:
return datetime.datetime.fromisoformat(value).date()
except ValueError:
return None
def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
return datetime.datetime(year=value.year, month=value.month, day=value.day)
try:
return datetime.datetime.fromisoformat(value)
except ValueError:
return None
def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if not value:
return None
if to.is_type(exp.DataType.Type.DATE):
return cast_as_date(value)
if to.is_type(*exp.DataType.TEMPORAL_TYPES):
return cast_as_datetime(value)
return None
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if isinstance(cast, exp.Cast):
to = cast.to
elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
to = exp.DataType.build(exp.DataType.Type.DATE)
else:
return None
if isinstance(cast.this, exp.Literal):
value: t.Any = cast.this.name
elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
value = extract_date(cast.this)
else:
return None
return cast_value(value, to)
def _is_date_literal(expression: exp.Expression) -> bool:
return extract_date(expression) is not None
def extract_interval(expression):
try:
n = int(expression.name)
unit = expression.text("unit").lower()
return interval(unit, n)
except (UnsupportedUnit, ModuleNotFoundError, ValueError):
return None
def extract_type(*expressions):
target_type = None
for expression in expressions:
target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
if target_type:
break
return target_type
def date_literal(date, target_type=None):
if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
target_type = (
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE
)
return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
from dateutil.relativedelta import relativedelta
if unit == "year":
return relativedelta(years=1 * n)
if unit == "quarter":
return relativedelta(months=3 * n)
if unit == "month":
return relativedelta(months=1 * n)
if unit == "week":
return relativedelta(weeks=1 * n)
if unit == "day":
return relativedelta(days=1 * n)
if unit == "hour":
return relativedelta(hours=1 * n)
if unit == "minute":
return relativedelta(minutes=1 * n)
if unit == "second":
return relativedelta(seconds=1 * n)
raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
if unit == "year":
return d.replace(month=1, day=1)
if unit == "quarter":
if d.month <= 3:
return d.replace(month=1, day=1)
elif d.month <= 6:
return d.replace(month=4, day=1)
elif d.month <= 9:
return d.replace(month=7, day=1)
else:
return d.replace(month=10, day=1)
if unit == "month":
return d.replace(month=d.month, day=1)
if unit == "week":
# Assuming week starts on Monday (0) and ends on Sunday (6)
return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
if unit == "day":
return d
raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
floor = date_floor(d, unit, dialect)
if floor == d:
return d
return floor + interval(unit)
def boolean_literal(condition):
return exp.true() if condition else exp.false()
def _flat_simplify(expression, simplifier, root=True):
if root or not expression.same_parent:
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
for b in queue:
result = simplifier(expression, a, b)
if result and result is not expression:
queue.remove(b)
queue.appendleft(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression
def gen(expression: t.Any) -> str:
"""Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
"""
return Gen().gen(expression)
class Gen:
def __init__(self):
self.stack = []
self.sqls = []
def gen(self, expression: exp.Expression) -> str:
self.stack = [expression]
self.sqls.clear()
while self.stack:
node = self.stack.pop()
if isinstance(node, exp.Expression):
exp_handler_name = f"{node.key}_sql"
if hasattr(self, exp_handler_name):
getattr(self, exp_handler_name)(node)
elif isinstance(node, exp.Func):
self._function(node)
else:
key = node.key.upper()
self.stack.append(f"{key} " if self._args(node) else key)
elif type(node) is list:
for n in reversed(node):
if n is not None:
self.stack.extend((n, ","))
if node:
self.stack.pop()
else:
if node is not None:
self.sqls.append(str(node))
return "".join(self.sqls)
def add_sql(self, e: exp.Add) -> None:
self._binary(e, " + ")
def alias_sql(self, e: exp.Alias) -> None:
self.stack.extend(
(
e.args.get("alias"),
" AS ",
e.args.get("this"),
)
)
def and_sql(self, e: exp.And) -> None:
self._binary(e, " AND ")
def anonymous_sql(self, e: exp.Anonymous) -> None:
this = e.this
if isinstance(this, str):
name = this.upper()
elif isinstance(this, exp.Identifier):
name = this.this
name = f'"{name}"' if this.quoted else name.upper()
else:
raise ValueError(
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
)
self.stack.extend(
(
")",
e.expressions,
"(",
name,
)
)
def between_sql(self, e: exp.Between) -> None:
self.stack.extend(
(
e.args.get("high"),
" AND ",
e.args.get("low"),
" BETWEEN ",
e.this,
)
)
def boolean_sql(self, e: exp.Boolean) -> None:
self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: exp.Bracket) -> None:
self.stack.extend(
(
"]",
e.expressions,
"[",
e.this,
)
)
def column_sql(self, e: exp.Column) -> None:
for p in reversed(e.parts):
self.stack.extend((p, "."))
self.stack.pop()
def datatype_sql(self, e: exp.DataType) -> None:
self._args(e, 1)
self.stack.append(f"{e.this.name} ")
def div_sql(self, e: exp.Div) -> None:
self._binary(e, " / ")
def dot_sql(self, e: exp.Dot) -> None:
self._binary(e, ".")
def eq_sql(self, e: exp.EQ) -> None:
self._binary(e, " = ")
def from_sql(self, e: exp.From) -> None:
self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: exp.GT) -> None:
self._binary(e, " > ")
def gte_sql(self, e: exp.GTE) -> None:
self._binary(e, " >= ")
def identifier_sql(self, e: exp.Identifier) -> None:
self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: exp.ILike) -> None:
self._binary(e, " ILIKE ")
def in_sql(self, e: exp.In) -> None:
self.stack.append(")")
self._args(e, 1)
self.stack.extend(
(
"(",
" IN ",
e.this,
)
)
def intdiv_sql(self, e: exp.IntDiv) -> None:
self._binary(e, " DIV ")
def is_sql(self, e: exp.Is) -> None:
self._binary(e, " IS ")
def like_sql(self, e: exp.Like) -> None:
self._binary(e, " Like ")
def literal_sql(self, e: exp.Literal) -> None:
self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: exp.LT) -> None:
self._binary(e, " < ")
def lte_sql(self, e: exp.LTE) -> None:
self._binary(e, " <= ")
def mod_sql(self, e: exp.Mod) -> None:
self._binary(e, " % ")
def mul_sql(self, e: exp.Mul) -> None:
self._binary(e, " * ")
def neg_sql(self, e: exp.Neg) -> None:
self._unary(e, "-")
def neq_sql(self, e: exp.NEQ) -> None:
self._binary(e, " <> ")
def not_sql(self, e: exp.Not) -> None:
self._unary(e, "NOT ")
def null_sql(self, e: exp.Null) -> None:
self.stack.append("NULL")
def or_sql(self, e: exp.Or) -> None:
self._binary(e, " OR ")
def paren_sql(self, e: exp.Paren) -> None:
self.stack.extend(
(
")",
e.this,
"(",
)
)
def sub_sql(self, e: exp.Sub) -> None:
self._binary(e, " - ")
def subquery_sql(self, e: exp.Subquery) -> None:
self._args(e, 2)
alias = e.args.get("alias")
if alias:
self.stack.append(alias)
self.stack.extend((")", e.this, "("))
def table_sql(self, e: exp.Table) -> None:
self._args(e, 4)
alias = e.args.get("alias")
if alias:
self.stack.append(alias)
for p in reversed(e.parts):
self.stack.extend((p, "."))
self.stack.pop()
def tablealias_sql(self, e: exp.TableAlias) -> None:
columns = e.columns
if columns:
self.stack.extend((")", columns, "("))
self.stack.extend((e.this, " AS "))
def var_sql(self, e: exp.Var) -> None:
self.stack.append(e.this)
def _binary(self, e: exp.Binary, op: str) -> None:
self.stack.extend((e.expression, op, e.this))
def _unary(self, e: exp.Unary, op: str) -> None:
self.stack.extend((e.this, op))
def _function(self, e: exp.Func) -> None:
self.stack.extend(
(
")",
list(e.args.values()),
"(",
e.sql_name(),
)
)
def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
kvs = []
arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
for k in arg_types or arg_types:
v = node.args.get(k)
if v is not None:
kvs.append([f":{k}", v])
if kvs:
self.stack.append(kvs)
return True
return False