1
0
Fork 0

Merging upstream version 18.11.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:04:58 +01:00
parent 15b8b39545
commit c37998973e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
88 changed files with 52059 additions and 46960 deletions

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import datetime
import functools
import typing as t
from sqlglot import exp
@ -11,6 +13,16 @@ from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
BinaryCoercions = t.Dict[
t.Tuple[exp.DataType.Type, exp.DataType.Type],
BinaryCoercionFunc,
]
# Interval units that operate on date components
DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
def annotate_types(
expression: E,
@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
return lambda self, e: self._annotate_with_type(e, data_type)
def _is_iso_date(text: str) -> bool:
try:
datetime.date.fromisoformat(text)
return True
except ValueError:
return False
def _is_iso_datetime(text: str) -> bool:
try:
datetime.datetime.fromisoformat(text)
return True
except ValueError:
return False
def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
date_text = l.name
unit = r.text("unit").lower()
is_iso_date = _is_iso_date(date_text)
if is_iso_date and unit in DATE_UNITS:
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
return exp.DataType.Type.DATE
# An ISO date is also an ISO datetime, but not vice versa
if is_iso_date or _is_iso_datetime(date_text):
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
return exp.DataType.Type.DATETIME
return exp.DataType.Type.UNKNOWN
def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
unit = r.text("unit").lower()
if unit not in DATE_UNITS:
return exp.DataType.Type.DATETIME
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
@functools.wraps(func)
def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
return func(r, l)
return _swapped
def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
class _TypeAnnotator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.DATE: {
exp.CurrentDate,
exp.Date,
exp.DateAdd,
exp.DateFromParts,
exp.DateStrToDate,
exp.DateSub,
exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
exp.DateSub: lambda self, e: self._annotate_dateadd(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
# Coercion functions for binary operations.
# Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
BINARY_COERCIONS: BinaryCoercions = {
**swap_all(
{
(t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
for t in exp.DataType.TEXT_TYPES
}
),
**swap_all(
{
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
}
),
}
def __init__(
self,
schema: Schema,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
binary_coercions: t.Optional[BinaryCoercions] = None,
) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
self._visited: t.Set[int] = set()
def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
expression.type = target_type
def _set_type(
self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
) -> None:
expression.type = target_type # type: ignore
self._visited.add(id(expression))
def annotate(self, expression: E) -> E:
@ -342,8 +427,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left_type = expression.left.type.this
right_type = expression.right.type.this
left, right = expression.left, expression.right
left_type, right_type = left.type.this, right.type.this
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@ -357,6 +442,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif isinstance(expression, exp.Predicate):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif (left_type, right_type) in self.binary_coercions:
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))
@ -421,3 +508,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
return expression
def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
self._annotate_args(expression)
if expression.this.type.this in exp.DataType.TEXT_TYPES:
datatype = _coerce_literal_and_interval(expression.this, expression.interval())
elif (
expression.this.type.is_type(exp.DataType.Type.DATE)
and expression.text("unit").lower() not in DATE_UNITS
):
datatype = exp.DataType.Type.DATETIME
else:
datatype = expression.this.type
self._set_type(expression, datatype)
return expression

View file

@ -45,9 +45,11 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
*exp.DataType.TEMPORAL_TYPES
):
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
return node
@ -67,7 +69,7 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
elif isinstance(expression, (exp.Where, exp.Having)):
elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
_replace_int_predicate(expression.this)
return expression
@ -89,13 +91,16 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
and b.type
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
_replace_cast(b, "date")
_replace_cast(b, exp.DataType.Type.DATE)
def _replace_cast(node: exp.Expression, to: str) -> None:
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
node.replace(exp.cast(node.copy(), to=to))
def _replace_int_predicate(expression: exp.Expression) -> None:
if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
if isinstance(expression, exp.Coalesce):
for _, child in expression.iter_expressions():
_replace_int_predicate(child)
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))

View file

@ -181,7 +181,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
isinstance(from_or_join, exp.Join)

View file

@ -22,6 +22,13 @@ def normalize_identifiers(expression, dialect=None):
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
It's possible to make this a no-op by adding a special comment next to the
identifier of interest:
SELECT a /* sqlglot.meta case_sensitive */ FROM table
In this example, the identifier `a` will not be normalized.
Note:
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
when they're quoted, so in these cases all identifiers are normalized.
@ -43,4 +50,13 @@ def normalize_identifiers(expression, dialect=None):
"""
if isinstance(expression, str):
expression = exp.to_identifier(expression)
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
dialect = Dialect.get_or_raise(dialect)
def _normalize(node: E) -> E:
if not node.meta.get("case_sensitive"):
exp.replace_children(node, _normalize)
node = dialect.normalize_identifier(node)
return node
return _normalize(expression)

View file

@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool:
return expression.is_number
def _is_date(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Cast) and extract_date(expression) is not None
def _is_interval(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
@ -422,18 +418,15 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if r.is_number:
a_predicate = _is_number
b_predicate = _is_number
elif _is_date(r):
a_predicate = _is_date
elif _is_date_literal(r):
a_predicate = _is_date_literal
b_predicate = _is_interval
else:
return expression
if l.__class__ in INVERSE_DATE_OPS:
a = l.this
b = exp.Interval(
this=l.expression.copy(),
unit=l.unit.copy(),
)
b = l.interval()
else:
a, b = l.left, l.right
@ -509,14 +502,14 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
return date_literal(a - b)
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
@ -702,11 +695,7 @@ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
return (
isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
and isinstance(right, exp.Cast)
and right.is_type(*exp.DataType.TEMPORAL_TYPES)
)
return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
@ -731,15 +720,26 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
unit = l.unit.name.lower()
date = extract_date(r)
if not date:
return expression
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
if all(_is_datetrunc_predicate(l, r) for r in rs):
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
unit = l.unit.name.lower()
ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
ranges = []
for r in rs:
date = extract_date(r)
if not date:
return expression
drange = _datetrunc_range(date, unit)
if drange:
ranges.append(drange)
if not ranges:
return expression
@ -811,18 +811,59 @@ def eval_boolean(expression, a, b):
return None
def extract_date(cast):
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
# so in that case we can't extract the date.
def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
if isinstance(value, datetime.datetime):
return value.date()
if isinstance(value, datetime.date):
return value
try:
if cast.args["to"].this == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(cast.name)
if cast.args["to"].this == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(cast.name)
return datetime.datetime.fromisoformat(value).date()
except ValueError:
return None
def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
return datetime.datetime(year=value.year, month=value.month, day=value.day)
try:
return datetime.datetime.fromisoformat(value)
except ValueError:
return None
def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if not value:
return None
if to.is_type(exp.DataType.Type.DATE):
return cast_as_date(value)
if to.is_type(*exp.DataType.TEMPORAL_TYPES):
return cast_as_datetime(value)
return None
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if isinstance(cast, exp.Cast):
to = cast.to
elif isinstance(cast, exp.TsOrDsToDate):
to = exp.DataType.build(exp.DataType.Type.DATE)
else:
return None
if isinstance(cast.this, exp.Literal):
value: t.Any = cast.this.name
elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
value = extract_date(cast.this)
else:
return None
return cast_value(value, to)
def _is_date_literal(expression: exp.Expression) -> bool:
return extract_date(expression) is not None
def extract_interval(expression):
n = int(expression.name)
unit = expression.text("unit").lower()
@ -836,7 +877,9 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE,
)