1
0
Fork 0

Merging upstream version 18.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:03:38 +01:00
parent 77523b6777
commit d1b976f442
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
96 changed files with 59037 additions and 52828 deletions

View file

@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
exp.Timestamp,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
@ -177,6 +178,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Initcap,
exp.Lower,
exp.SafeConcat,
exp.SafeDPipe,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
@ -242,6 +244,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
self._visited: t.Set[int] = set()
def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
expression.type = target_type
self._visited.add(id(expression))
def annotate(self, expression: E) -> E:
for scope in traverse_scope(expression):
selects = {}
@ -279,9 +288,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
self._set_type(col, self.schema.get_column_type(source, col))
elif source and col.table in selects and col.name in selects[col.table]:
col.type = selects[col.table][col.name].type
self._set_type(col, selects[col.table][col.name].type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@ -289,7 +298,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression: E) -> E:
if expression.type:
if id(expression) in self._visited:
return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__)
@ -338,17 +347,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
self._set_type(expression, exp.DataType.Type.NULL)
elif exp.DataType.Type.NULL in (left_type, right_type):
expression.type = exp.DataType.build(
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
self._set_type(
expression,
exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
)
else:
expression.type = exp.DataType.Type.BOOLEAN
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif isinstance(expression, exp.Predicate):
expression.type = exp.DataType.Type.BOOLEAN
self._set_type(expression, exp.DataType.Type.BOOLEAN)
else:
expression.type = self._maybe_coerce(left_type, right_type)
self._set_type(expression, self._maybe_coerce(left_type, right_type))
return expression
@ -357,26 +367,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
expression.type = exp.DataType.Type.BOOLEAN
self._set_type(expression, exp.DataType.Type.BOOLEAN)
else:
expression.type = expression.this.type
self._set_type(expression, expression.this.type)
return expression
@t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
self._set_type(expression, exp.DataType.Type.VARCHAR)
elif expression.is_int:
expression.type = exp.DataType.Type.INT
self._set_type(expression, exp.DataType.Type.INT)
else:
expression.type = exp.DataType.Type.DOUBLE
self._set_type(expression, exp.DataType.Type.DOUBLE)
return expression
@t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
expression.type = target_type
self._set_type(expression, target_type)
return self._annotate_args(expression)
@t.no_type_check
@ -394,17 +404,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for expr in expressions:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
if promote:
if expression.type.this in exp.DataType.INTEGER_TYPES:
expression.type = exp.DataType.Type.BIGINT
self._set_type(expression, exp.DataType.Type.BIGINT)
elif expression.type.this in exp.DataType.FLOAT_TYPES:
expression.type = exp.DataType.Type.DOUBLE
self._set_type(expression, exp.DataType.Type.DOUBLE)
if array:
expression.type = exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
),
)
return expression

View file

@ -17,9 +17,11 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
expression = remove_ascending_order(expression)
return expression
@ -30,6 +32,14 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
return node
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
return node
def coerce_type(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Binary):
_coerce_date(node.left, node.right)
@ -63,6 +73,14 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
return expression
def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
# Convert ORDER BY a ASC to ORDER BY a
expression.set("desc", None)
return expression
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if (
@ -75,10 +93,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
def _replace_cast(node: exp.Expression, to: str) -> None:
data_type = exp.DataType.build(to)
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)
node.replace(exp.cast(node.copy(), to=to))
def _replace_int_predicate(expression: exp.Expression) -> None:

View file

@ -1,17 +1,22 @@
import datetime
import functools
import itertools
import typing as t
from collections import deque
from decimal import Decimal
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
from sqlglot.helper import first, merge_ranges, while_changing
# Final means that an expression should not be simplified
FINAL = "final"
class UnsupportedUnit(Exception):
pass
def simplify(expression):
"""
Rewrite sqlglot AST to simplify expressions.
@ -72,7 +77,9 @@ def simplify(expression):
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
node = simplify_datetrunc_predicate(node)
if root:
expression.replace(node)
@ -84,6 +91,21 @@ def simplify(expression):
return expression
def catch(*exceptions):
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""
def decorator(func):
def wrapped(expression, *args, **kwargs):
try:
return func(expression, *args, **kwargs)
except exceptions:
return expression
return wrapped
return decorator
def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""Rewrite x between y and z to x >= y AND x <= z.
@ -196,7 +218,7 @@ COMPARISONS = (
exp.Is,
)
INVERSE_COMPARISONS = {
INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.LT: exp.GT,
exp.GT: exp.LT,
exp.LTE: exp.GTE,
@ -347,6 +369,87 @@ def absorb_and_eliminate(expression, root=True):
return expression
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
exp.DatetimeAdd: exp.Sub,
exp.DatetimeSub: exp.Add,
}
INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
**INVERSE_DATE_OPS,
exp.Add: exp.Sub,
exp.Sub: exp.Add,
}
def _is_number(expression: exp.Expression) -> bool:
return expression.is_number
def _is_date(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Cast) and extract_date(expression) is not None
def _is_interval(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
@catch(ModuleNotFoundError, UnsupportedUnit)
def simplify_equality(expression: exp.Expression) -> exp.Expression:
"""
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and =
Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
"""
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
if l.__class__ in INVERSE_OPS:
pass
elif r.__class__ in INVERSE_OPS:
l, r = r, l
else:
return expression
if r.is_number:
a_predicate = _is_number
b_predicate = _is_number
elif _is_date(r):
a_predicate = _is_date
b_predicate = _is_interval
else:
return expression
if l.__class__ in INVERSE_DATE_OPS:
a = l.this
b = exp.Interval(
this=l.expression.copy(),
unit=l.unit.copy(),
)
else:
a, b = l.left, l.right
if not a_predicate(a) and b_predicate(b):
pass
elif not a_predicate(b) and b_predicate(a):
a, b = b, a
else:
return expression
return expression.__class__(
this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
)
return expression
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
@ -530,6 +633,123 @@ def simplify_concat(expression):
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
DateRange = t.Tuple[datetime.date, datetime.date]
def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
"""
Get the date range for a DATE_TRUNC equality comparison:
Example:
_datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
Returns:
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
"""
floor = date_floor(date, unit)
if date != floor:
# This will always be False, except for NULL values.
return None
return floor, floor + interval(unit)
def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
"""Get the logical expression for a date range"""
return exp.and_(
left >= date_literal(drange[0]),
left < date_literal(drange[1]),
copy=False,
)
def _datetrunc_eq(
left: exp.Expression, date: datetime.date, unit: str
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit)
if not drange:
return None
return _datetrunc_eq_expression(left, drange)
def _datetrunc_neq(
left: exp.Expression, date: datetime.date, unit: str
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit)
if not drange:
return None
return exp.and_(
left < date_literal(drange[0]),
left >= date_literal(drange[1]),
copy=False,
)
DateTruncBinaryTransform = t.Callable[
[exp.Expression, datetime.date, str], t.Optional[exp.Expression]
]
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
return (
isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
and isinstance(right, exp.Cast)
and right.is_type(*exp.DataType.TEMPORAL_TYPES)
)
@catch(ModuleNotFoundError, UnsupportedUnit)
def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
comparison = expression.__class__
if comparison not in DATETRUNC_COMPARISONS:
return expression
if isinstance(expression, exp.Binary):
l, r = expression.left, expression.right
if _is_datetrunc_predicate(l, r):
pass
elif _is_datetrunc_predicate(r, l):
comparison = INVERSE_COMPARISONS.get(comparison, comparison)
l, r = r, l
else:
return expression
unit = l.unit.name.lower()
date = extract_date(r)
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
if all(_is_datetrunc_predicate(l, r) for r in rs):
unit = l.unit.name.lower()
ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
if not ranges:
return expression
ranges = merge_ranges(ranges)
return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
return expression
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
@ -603,25 +823,15 @@ def extract_date(cast):
return None
def extract_interval(interval):
def extract_interval(expression):
n = int(expression.name)
unit = expression.text("unit").lower()
try:
from dateutil.relativedelta import relativedelta # type: ignore
except ModuleNotFoundError:
return interval(unit, n)
except (UnsupportedUnit, ModuleNotFoundError):
return None
n = int(interval.name)
unit = interval.text("unit").lower()
if unit == "year":
return relativedelta(years=n)
if unit == "month":
return relativedelta(months=n)
if unit == "week":
return relativedelta(weeks=n)
if unit == "day":
return relativedelta(days=n)
return None
def date_literal(date):
return exp.cast(
@ -630,6 +840,61 @@ def date_literal(date):
)
def interval(unit: str, n: int = 1):
from dateutil.relativedelta import relativedelta
if unit == "year":
return relativedelta(years=1 * n)
if unit == "quarter":
return relativedelta(months=3 * n)
if unit == "month":
return relativedelta(months=1 * n)
if unit == "week":
return relativedelta(weeks=1 * n)
if unit == "day":
return relativedelta(days=1 * n)
if unit == "hour":
return relativedelta(hours=1 * n)
if unit == "minute":
return relativedelta(minutes=1 * n)
if unit == "second":
return relativedelta(seconds=1 * n)
raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
if unit == "year":
return d.replace(month=1, day=1)
if unit == "quarter":
if d.month <= 3:
return d.replace(month=1, day=1)
elif d.month <= 6:
return d.replace(month=4, day=1)
elif d.month <= 9:
return d.replace(month=7, day=1)
else:
return d.replace(month=10, day=1)
if unit == "month":
return d.replace(month=d.month, day=1)
if unit == "week":
# Assuming week starts on Monday (0) and ends on Sunday (6)
return d - datetime.timedelta(days=d.weekday())
if unit == "day":
return d
raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
floor = date_floor(d, unit)
if floor == d:
return d
return floor + interval(unit)
def boolean_literal(condition):
return exp.true() if condition else exp.false()

View file

@ -43,7 +43,11 @@ def unnest(select, parent_select, next_alias_name):
predicate = select.find_ancestor(exp.Condition)
alias = next_alias_name()
if not predicate or parent_select is not predicate.parent_select:
if (
not predicate
or parent_select is not predicate.parent_select
or not parent_select.args.get("from")
):
return
# This subquery returns a scalar and can just be converted to a cross join