Merging upstream version 18.11.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
15b8b39545
commit
c37998973e
88 changed files with 52059 additions and 46960 deletions
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
|
@ -11,6 +13,16 @@ from sqlglot.schema import Schema, ensure_schema
|
|||
if t.TYPE_CHECKING:
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
|
||||
BinaryCoercions = t.Dict[
|
||||
t.Tuple[exp.DataType.Type, exp.DataType.Type],
|
||||
BinaryCoercionFunc,
|
||||
]
|
||||
|
||||
|
||||
# Interval units that operate on date components
|
||||
DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
|
||||
|
||||
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
|
@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
|
|||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
|
||||
|
||||
def _is_iso_date(text: str) -> bool:
|
||||
try:
|
||||
datetime.date.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_iso_datetime(text: str) -> bool:
|
||||
try:
|
||||
datetime.datetime.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
date_text = l.name
|
||||
unit = r.text("unit").lower()
|
||||
|
||||
is_iso_date = _is_iso_date(date_text)
|
||||
|
||||
if is_iso_date and unit in DATE_UNITS:
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
|
||||
return exp.DataType.Type.DATE
|
||||
|
||||
# An ISO date is also an ISO datetime, but not vice versa
|
||||
if is_iso_date or _is_iso_datetime(date_text):
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
|
||||
return exp.DataType.Type.DATETIME
|
||||
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
||||
def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
unit = r.text("unit").lower()
|
||||
if unit not in DATE_UNITS:
|
||||
return exp.DataType.Type.DATETIME
|
||||
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
||||
def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
|
||||
@functools.wraps(func)
|
||||
def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
return func(r, l)
|
||||
|
||||
return _swapped
|
||||
|
||||
|
||||
def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
|
||||
return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
|
||||
|
||||
|
||||
class _TypeAnnotator(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DataType.Type.DATE: {
|
||||
exp.CurrentDate,
|
||||
exp.Date,
|
||||
exp.DateAdd,
|
||||
exp.DateFromParts,
|
||||
exp.DateStrToDate,
|
||||
exp.DateSub,
|
||||
exp.DateTrunc,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
|
@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
|
@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
# Specifies what types a given type can be coerced into (autofilled)
|
||||
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
|
||||
|
||||
# Coercion functions for binary operations.
|
||||
# Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
|
||||
BINARY_COERCIONS: BinaryCoercions = {
|
||||
**swap_all(
|
||||
{
|
||||
(t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
|
||||
for t in exp.DataType.TEXT_TYPES
|
||||
}
|
||||
),
|
||||
**swap_all(
|
||||
{
|
||||
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: Schema,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
binary_coercions: t.Optional[BinaryCoercions] = None,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
|
||||
|
||||
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
|
||||
self._visited: t.Set[int] = set()
|
||||
|
||||
def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
|
||||
expression.type = target_type
|
||||
def _set_type(
|
||||
self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
|
||||
) -> None:
|
||||
expression.type = target_type # type: ignore
|
||||
self._visited.add(id(expression))
|
||||
|
||||
def annotate(self, expression: E) -> E:
|
||||
|
@ -342,8 +427,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left_type = expression.left.type.this
|
||||
right_type = expression.right.type.this
|
||||
left, right = expression.left, expression.right
|
||||
left_type, right_type = left.type.this, right.type.this
|
||||
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
|
@ -357,6 +442,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif isinstance(expression, exp.Predicate):
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif (left_type, right_type) in self.binary_coercions:
|
||||
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
|
||||
else:
|
||||
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
||||
|
||||
|
@ -421,3 +508,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
|
||||
self._annotate_args(expression)
|
||||
|
||||
if expression.this.type.this in exp.DataType.TEXT_TYPES:
|
||||
datatype = _coerce_literal_and_interval(expression.this, expression.interval())
|
||||
elif (
|
||||
expression.this.type.is_type(exp.DataType.Type.DATE)
|
||||
and expression.text("unit").lower() not in DATE_UNITS
|
||||
):
|
||||
datatype = exp.DataType.Type.DATETIME
|
||||
else:
|
||||
datatype = expression.this.type
|
||||
|
||||
self._set_type(expression, datatype)
|
||||
return expression
|
||||
|
|
|
@ -45,9 +45,11 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
|
|||
_coerce_date(node.left, node.right)
|
||||
elif isinstance(node, exp.Between):
|
||||
_coerce_date(node.this, node.args["low"])
|
||||
elif isinstance(node, exp.Extract):
|
||||
if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
|
||||
_replace_cast(node.expression, "datetime")
|
||||
elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
|
||||
*exp.DataType.TEMPORAL_TYPES
|
||||
):
|
||||
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
@ -67,7 +69,7 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
|||
_replace_int_predicate(expression.left)
|
||||
_replace_int_predicate(expression.right)
|
||||
|
||||
elif isinstance(expression, (exp.Where, exp.Having)):
|
||||
elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
|
||||
_replace_int_predicate(expression.this)
|
||||
|
||||
return expression
|
||||
|
@ -89,13 +91,16 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
|||
and b.type
|
||||
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
|
||||
):
|
||||
_replace_cast(b, "date")
|
||||
_replace_cast(b, exp.DataType.Type.DATE)
|
||||
|
||||
|
||||
def _replace_cast(node: exp.Expression, to: str) -> None:
|
||||
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
|
||||
node.replace(exp.cast(node.copy(), to=to))
|
||||
|
||||
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
if isinstance(expression, exp.Coalesce):
|
||||
for _, child in expression.iter_expressions():
|
||||
_replace_int_predicate(child)
|
||||
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
|
||||
|
|
|
@ -181,7 +181,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and not outer_scope.pivots
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.Join)
|
||||
|
|
|
@ -22,6 +22,13 @@ def normalize_identifiers(expression, dialect=None):
|
|||
Normalize all unquoted identifiers to either lower or upper case, depending
|
||||
on the dialect. This essentially makes those identifiers case-insensitive.
|
||||
|
||||
It's possible to make this a no-op by adding a special comment next to the
|
||||
identifier of interest:
|
||||
|
||||
SELECT a /* sqlglot.meta case_sensitive */ FROM table
|
||||
|
||||
In this example, the identifier `a` will not be normalized.
|
||||
|
||||
Note:
|
||||
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
|
||||
when they're quoted, so in these cases all identifiers are normalized.
|
||||
|
@ -43,4 +50,13 @@ def normalize_identifiers(expression, dialect=None):
|
|||
"""
|
||||
if isinstance(expression, str):
|
||||
expression = exp.to_identifier(expression)
|
||||
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _normalize(node: E) -> E:
|
||||
if not node.meta.get("case_sensitive"):
|
||||
exp.replace_children(node, _normalize)
|
||||
node = dialect.normalize_identifier(node)
|
||||
return node
|
||||
|
||||
return _normalize(expression)
|
||||
|
|
|
@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool:
|
|||
return expression.is_number
|
||||
|
||||
|
||||
def _is_date(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.Cast) and extract_date(expression) is not None
|
||||
|
||||
|
||||
def _is_interval(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
|
||||
|
||||
|
@ -422,18 +418,15 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
|||
if r.is_number:
|
||||
a_predicate = _is_number
|
||||
b_predicate = _is_number
|
||||
elif _is_date(r):
|
||||
a_predicate = _is_date
|
||||
elif _is_date_literal(r):
|
||||
a_predicate = _is_date_literal
|
||||
b_predicate = _is_interval
|
||||
else:
|
||||
return expression
|
||||
|
||||
if l.__class__ in INVERSE_DATE_OPS:
|
||||
a = l.this
|
||||
b = exp.Interval(
|
||||
this=l.expression.copy(),
|
||||
unit=l.unit.copy(),
|
||||
)
|
||||
b = l.interval()
|
||||
else:
|
||||
a, b = l.left, l.right
|
||||
|
||||
|
@ -509,14 +502,14 @@ def _simplify_binary(expression, a, b):
|
|||
|
||||
if boolean:
|
||||
return boolean
|
||||
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
||||
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if a and b:
|
||||
if isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return date_literal(a - b)
|
||||
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
||||
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
|
@ -702,11 +695,7 @@ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
|
|||
|
||||
|
||||
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
|
||||
return (
|
||||
isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
|
||||
and isinstance(right, exp.Cast)
|
||||
and right.is_type(*exp.DataType.TEMPORAL_TYPES)
|
||||
)
|
||||
return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
|
||||
|
||||
|
||||
@catch(ModuleNotFoundError, UnsupportedUnit)
|
||||
|
@ -731,15 +720,26 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
unit = l.unit.name.lower()
|
||||
date = extract_date(r)
|
||||
|
||||
if not date:
|
||||
return expression
|
||||
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
|
||||
elif isinstance(expression, exp.In):
|
||||
l = expression.this
|
||||
rs = expression.expressions
|
||||
|
||||
if all(_is_datetrunc_predicate(l, r) for r in rs):
|
||||
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
|
||||
unit = l.unit.name.lower()
|
||||
|
||||
ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
|
||||
ranges = []
|
||||
for r in rs:
|
||||
date = extract_date(r)
|
||||
if not date:
|
||||
return expression
|
||||
drange = _datetrunc_range(date, unit)
|
||||
if drange:
|
||||
ranges.append(drange)
|
||||
|
||||
if not ranges:
|
||||
return expression
|
||||
|
||||
|
@ -811,18 +811,59 @@ def eval_boolean(expression, a, b):
|
|||
return None
|
||||
|
||||
|
||||
def extract_date(cast):
|
||||
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
|
||||
# so in that case we can't extract the date.
|
||||
def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.date()
|
||||
if isinstance(value, datetime.date):
|
||||
return value
|
||||
try:
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
if cast.args["to"].this == exp.DataType.Type.DATETIME:
|
||||
return datetime.datetime.fromisoformat(cast.name)
|
||||
return datetime.datetime.fromisoformat(value).date()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value
|
||||
if isinstance(value, datetime.date):
|
||||
return datetime.datetime(year=value.year, month=value.month, day=value.day)
|
||||
try:
|
||||
return datetime.datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
|
||||
if not value:
|
||||
return None
|
||||
if to.is_type(exp.DataType.Type.DATE):
|
||||
return cast_as_date(value)
|
||||
if to.is_type(*exp.DataType.TEMPORAL_TYPES):
|
||||
return cast_as_datetime(value)
|
||||
return None
|
||||
|
||||
|
||||
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
|
||||
if isinstance(cast, exp.Cast):
|
||||
to = cast.to
|
||||
elif isinstance(cast, exp.TsOrDsToDate):
|
||||
to = exp.DataType.build(exp.DataType.Type.DATE)
|
||||
else:
|
||||
return None
|
||||
|
||||
if isinstance(cast.this, exp.Literal):
|
||||
value: t.Any = cast.this.name
|
||||
elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
|
||||
value = extract_date(cast.this)
|
||||
else:
|
||||
return None
|
||||
return cast_value(value, to)
|
||||
|
||||
|
||||
def _is_date_literal(expression: exp.Expression) -> bool:
|
||||
return extract_date(expression) is not None
|
||||
|
||||
|
||||
def extract_interval(expression):
|
||||
n = int(expression.name)
|
||||
unit = expression.text("unit").lower()
|
||||
|
@ -836,7 +877,9 @@ def extract_interval(expression):
|
|||
def date_literal(date):
|
||||
return exp.cast(
|
||||
exp.Literal.string(date),
|
||||
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue