Merging upstream version 25.7.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
dba379232c
commit
aa0eae236a
102 changed files with 52995 additions and 52070 deletions
|
@ -10,10 +10,10 @@ from sqlglot.helper import (
|
|||
is_iso_date,
|
||||
is_iso_datetime,
|
||||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import B, E
|
||||
|
@ -24,12 +24,15 @@ if t.TYPE_CHECKING:
|
|||
BinaryCoercionFunc,
|
||||
]
|
||||
|
||||
from sqlglot.dialects.dialect import DialectType, AnnotatorsType
|
||||
|
||||
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
annotators: t.Optional[AnnotatorsType] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
dialect: t.Optional[DialectType] = None,
|
||||
) -> E:
|
||||
"""
|
||||
Infers the types of an expression, annotating its AST accordingly.
|
||||
|
@ -54,11 +57,7 @@ def annotate_types(
|
|||
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
|
||||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
return TypeAnnotator(schema, annotators, coerces_to, dialect=dialect).annotate(expression)
|
||||
|
||||
|
||||
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
||||
|
@ -133,168 +132,6 @@ class _TypeAnnotator(type):
|
|||
|
||||
|
||||
class TypeAnnotator(metaclass=_TypeAnnotator):
|
||||
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.ApproxDistinct,
|
||||
exp.ArraySize,
|
||||
exp.Count,
|
||||
exp.Length,
|
||||
},
|
||||
exp.DataType.Type.BOOLEAN: {
|
||||
exp.Between,
|
||||
exp.Boolean,
|
||||
exp.In,
|
||||
exp.RegexpLike,
|
||||
},
|
||||
exp.DataType.Type.DATE: {
|
||||
exp.CurrentDate,
|
||||
exp.Date,
|
||||
exp.DateFromParts,
|
||||
exp.DateStrToDate,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
exp.TimeStrToDate,
|
||||
exp.TsOrDsToDate,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.CurrentDatetime,
|
||||
exp.Datetime,
|
||||
exp.DatetimeAdd,
|
||||
exp.DatetimeSub,
|
||||
},
|
||||
exp.DataType.Type.DOUBLE: {
|
||||
exp.ApproxQuantile,
|
||||
exp.Avg,
|
||||
exp.Div,
|
||||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
exp.Pow,
|
||||
exp.Quantile,
|
||||
exp.Round,
|
||||
exp.SafeDivide,
|
||||
exp.Sqrt,
|
||||
exp.Stddev,
|
||||
exp.StddevPop,
|
||||
exp.StddevSamp,
|
||||
exp.Variance,
|
||||
exp.VariancePop,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.Ceil,
|
||||
exp.DatetimeDiff,
|
||||
exp.DateDiff,
|
||||
exp.TimestampDiff,
|
||||
exp.TimeDiff,
|
||||
exp.DateToDi,
|
||||
exp.Floor,
|
||||
exp.Levenshtein,
|
||||
exp.Sign,
|
||||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
exp.DataType.Type.JSON: {
|
||||
exp.ParseJSON,
|
||||
},
|
||||
exp.DataType.Type.TIME: {
|
||||
exp.Time,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
exp.StrToTime,
|
||||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.Day,
|
||||
exp.Month,
|
||||
exp.Week,
|
||||
exp.Year,
|
||||
exp.Quarter,
|
||||
},
|
||||
exp.DataType.Type.VARCHAR: {
|
||||
exp.ArrayConcat,
|
||||
exp.Concat,
|
||||
exp.ConcatWs,
|
||||
exp.DateToDateStr,
|
||||
exp.GroupConcat,
|
||||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
exp.Trim,
|
||||
exp.TsOrDsToDateStr,
|
||||
exp.UnixToStr,
|
||||
exp.UnixToTimeStr,
|
||||
exp.Upper,
|
||||
},
|
||||
}
|
||||
|
||||
ANNOTATORS: t.Dict = {
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_unary(e)
|
||||
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_binary(e)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
**{
|
||||
expr_type: _annotate_with_type_lambda(data_type)
|
||||
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
|
||||
for expr_type in expressions
|
||||
},
|
||||
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Bracket: lambda self, e: self._annotate_bracket(e),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Div: lambda self, e: self._annotate_div(e),
|
||||
exp.Dot: lambda self, e: self._annotate_dot(e),
|
||||
exp.Explode: lambda self, e: self._annotate_explode(e),
|
||||
exp.Extract: lambda self, e: self._annotate_extract(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
|
||||
e, exp.DataType.build("ARRAY<DATE>")
|
||||
),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Literal: lambda self, e: self._annotate_literal(e),
|
||||
exp.Map: lambda self, e: self._annotate_map(e),
|
||||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Struct: lambda self, e: self._annotate_struct(e),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.ToMap: lambda self, e: self._annotate_to_map(e),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Unnest: lambda self, e: self._annotate_unnest(e),
|
||||
exp.VarMap: lambda self, e: self._annotate_map(e),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
exp.DataType.Type.ARRAY,
|
||||
}
|
||||
|
@ -335,12 +172,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
def __init__(
|
||||
self,
|
||||
schema: Schema,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
annotators: t.Optional[AnnotatorsType] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
binary_coercions: t.Optional[BinaryCoercions] = None,
|
||||
dialect: t.Optional[DialectType] = None,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.annotators = annotators or Dialect.get_or_raise(dialect).ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
|
||||
|
||||
|
@ -483,7 +321,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
|
||||
def _annotate_with_type(
|
||||
self, expression: E, target_type: exp.DataType | exp.DataType.Type
|
||||
) -> E:
|
||||
self._set_type(expression, target_type)
|
||||
return self._annotate_args(expression)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue