Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot._typing import E
  7from sqlglot.helper import ensure_list, subclasses
  8from sqlglot.optimizer.scope import Scope, traverse_scope
  9from sqlglot.schema import Schema, ensure_schema
 10
 11if t.TYPE_CHECKING:
 12    B = t.TypeVar("B", bound=exp.Binary)
 13
 14
 15def annotate_types(
 16    expression: E,
 17    schema: t.Optional[t.Dict | Schema] = None,
 18    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 19    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 20) -> E:
 21    """
 22    Infers the types of an expression, annotating its AST accordingly.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> schema = {"y": {"cola": "SMALLINT"}}
 27        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 28        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 29        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 30        <Type.DOUBLE: 'DOUBLE'>
 31
 32    Args:
 33        expression: Expression to annotate.
 34        schema: Database schema.
 35        annotators: Maps expression type to corresponding annotation function.
 36        coerces_to: Maps expression type to set of types that it can be coerced into.
 37
 38    Returns:
 39        The expression annotated with types.
 40    """
 41
 42    schema = ensure_schema(schema)
 43
 44    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 45
 46
 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 48    return lambda self, e: self._annotate_with_type(e, data_type)
 49
 50
 51class _TypeAnnotator(type):
 52    def __new__(cls, clsname, bases, attrs):
 53        klass = super().__new__(cls, clsname, bases, attrs)
 54
 55        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
 56        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
 57        text_precedence = (
 58            exp.DataType.Type.TEXT,
 59            exp.DataType.Type.NVARCHAR,
 60            exp.DataType.Type.VARCHAR,
 61            exp.DataType.Type.NCHAR,
 62            exp.DataType.Type.CHAR,
 63        )
 64        numeric_precedence = (
 65            exp.DataType.Type.DOUBLE,
 66            exp.DataType.Type.FLOAT,
 67            exp.DataType.Type.DECIMAL,
 68            exp.DataType.Type.BIGINT,
 69            exp.DataType.Type.INT,
 70            exp.DataType.Type.SMALLINT,
 71            exp.DataType.Type.TINYINT,
 72        )
 73        timelike_precedence = (
 74            exp.DataType.Type.TIMESTAMPLTZ,
 75            exp.DataType.Type.TIMESTAMPTZ,
 76            exp.DataType.Type.TIMESTAMP,
 77            exp.DataType.Type.DATETIME,
 78            exp.DataType.Type.DATE,
 79        )
 80
 81        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
 82            coerces_to = set()
 83            for data_type in type_precedence:
 84                klass.COERCES_TO[data_type] = coerces_to.copy()
 85                coerces_to |= {data_type}
 86
 87        return klass
 88
 89
 90class TypeAnnotator(metaclass=_TypeAnnotator):
 91    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 92        exp.DataType.Type.BIGINT: {
 93            exp.ApproxDistinct,
 94            exp.ArraySize,
 95            exp.Count,
 96            exp.Length,
 97        },
 98        exp.DataType.Type.BOOLEAN: {
 99            exp.Between,
100            exp.Boolean,
101            exp.In,
102            exp.RegexpLike,
103        },
104        exp.DataType.Type.DATE: {
105            exp.CurrentDate,
106            exp.Date,
107            exp.DateAdd,
108            exp.DateFromParts,
109            exp.DateStrToDate,
110            exp.DateSub,
111            exp.DateTrunc,
112            exp.DiToDate,
113            exp.StrToDate,
114            exp.TimeStrToDate,
115            exp.TsOrDsToDate,
116        },
117        exp.DataType.Type.DATETIME: {
118            exp.CurrentDatetime,
119            exp.DatetimeAdd,
120            exp.DatetimeSub,
121        },
122        exp.DataType.Type.DOUBLE: {
123            exp.ApproxQuantile,
124            exp.Avg,
125            exp.Exp,
126            exp.Ln,
127            exp.Log,
128            exp.Log2,
129            exp.Log10,
130            exp.Pow,
131            exp.Quantile,
132            exp.Round,
133            exp.SafeDivide,
134            exp.Sqrt,
135            exp.Stddev,
136            exp.StddevPop,
137            exp.StddevSamp,
138            exp.Variance,
139            exp.VariancePop,
140        },
141        exp.DataType.Type.INT: {
142            exp.Ceil,
143            exp.DateDiff,
144            exp.DatetimeDiff,
145            exp.Extract,
146            exp.TimestampDiff,
147            exp.TimeDiff,
148            exp.DateToDi,
149            exp.Floor,
150            exp.Levenshtein,
151            exp.StrPosition,
152            exp.TsOrDiToDi,
153        },
154        exp.DataType.Type.TIMESTAMP: {
155            exp.CurrentTime,
156            exp.CurrentTimestamp,
157            exp.StrToTime,
158            exp.TimeAdd,
159            exp.TimeStrToTime,
160            exp.TimeSub,
161            exp.TimestampAdd,
162            exp.TimestampSub,
163            exp.UnixToTime,
164        },
165        exp.DataType.Type.TINYINT: {
166            exp.Day,
167            exp.Month,
168            exp.Week,
169            exp.Year,
170        },
171        exp.DataType.Type.VARCHAR: {
172            exp.ArrayConcat,
173            exp.Concat,
174            exp.ConcatWs,
175            exp.DateToDateStr,
176            exp.GroupConcat,
177            exp.Initcap,
178            exp.Lower,
179            exp.SafeConcat,
180            exp.Substring,
181            exp.TimeToStr,
182            exp.TimeToTimeStr,
183            exp.Trim,
184            exp.TsOrDsToDateStr,
185            exp.UnixToStr,
186            exp.UnixToTimeStr,
187            exp.Upper,
188        },
189    }
190
191    ANNOTATORS: t.Dict = {
192        **{
193            expr_type: lambda self, e: self._annotate_unary(e)
194            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
195        },
196        **{
197            expr_type: lambda self, e: self._annotate_binary(e)
198            for expr_type in subclasses(exp.__name__, exp.Binary)
199        },
200        **{
201            expr_type: _annotate_with_type_lambda(data_type)
202            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
203            for expr_type in expressions
204        },
205        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
206        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
207        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
208        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
209        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
210        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
211        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
212        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
213        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
214        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
215        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
216        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
217        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
218        exp.Literal: lambda self, e: self._annotate_literal(e),
219        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
220        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
221        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
222        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
223        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
224        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
225        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
226    }
227
228    NESTED_TYPES = {
229        exp.DataType.Type.ARRAY,
230    }
231
232    # Specifies what types a given type can be coerced into (autofilled)
233    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
234
235    def __init__(
236        self,
237        schema: Schema,
238        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
239        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
240    ) -> None:
241        self.schema = schema
242        self.annotators = annotators or self.ANNOTATORS
243        self.coerces_to = coerces_to or self.COERCES_TO
244
245        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
246        self._visited: t.Set[int] = set()
247
248    def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
249        expression.type = target_type
250        self._visited.add(id(expression))
251
252    def annotate(self, expression: E) -> E:
253        for scope in traverse_scope(expression):
254            selects = {}
255            for name, source in scope.sources.items():
256                if not isinstance(source, Scope):
257                    continue
258                if isinstance(source.expression, exp.UDTF):
259                    values = []
260
261                    if isinstance(source.expression, exp.Lateral):
262                        if isinstance(source.expression.this, exp.Explode):
263                            values = [source.expression.this.this]
264                    else:
265                        values = source.expression.expressions[0].expressions
266
267                    if not values:
268                        continue
269
270                    selects[name] = {
271                        alias: column
272                        for alias, column in zip(
273                            source.expression.alias_column_names,
274                            values,
275                        )
276                    }
277                else:
278                    selects[name] = {
279                        select.alias_or_name: select for select in source.expression.selects
280                    }
281
282            # First annotate the current scope's column references
283            for col in scope.columns:
284                if not col.table:
285                    continue
286
287                source = scope.sources.get(col.table)
288                if isinstance(source, exp.Table):
289                    self._set_type(col, self.schema.get_column_type(source, col))
290                elif source and col.table in selects and col.name in selects[col.table]:
291                    self._set_type(col, selects[col.table][col.name].type)
292
293            # Then (possibly) annotate the remaining expressions in the scope
294            self._maybe_annotate(scope.expression)
295
296        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
297
298    def _maybe_annotate(self, expression: E) -> E:
299        if id(expression) in self._visited:
300            return expression  # We've already inferred the expression's type
301
302        annotator = self.annotators.get(expression.__class__)
303
304        return (
305            annotator(self, expression)
306            if annotator
307            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
308        )
309
310    def _annotate_args(self, expression: E) -> E:
311        for _, value in expression.iter_expressions():
312            self._maybe_annotate(value)
313
314        return expression
315
316    def _maybe_coerce(
317        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
318    ) -> exp.DataType | exp.DataType.Type:
319        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
320        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
321
322        # We propagate the NULL / UNKNOWN types upwards if found
323        if exp.DataType.Type.NULL in (type1_value, type2_value):
324            return exp.DataType.Type.NULL
325        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
326            return exp.DataType.Type.UNKNOWN
327
328        if type1_value in self.NESTED_TYPES:
329            return type1
330        if type2_value in self.NESTED_TYPES:
331            return type2
332
333        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
334
335    # Note: the following "no_type_check" decorators were added because mypy was yelling due
336    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
337    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
338
339    @t.no_type_check
340    def _annotate_binary(self, expression: B) -> B:
341        self._annotate_args(expression)
342
343        left_type = expression.left.type.this
344        right_type = expression.right.type.this
345
346        if isinstance(expression, exp.Connector):
347            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
348                self._set_type(expression, exp.DataType.Type.NULL)
349            elif exp.DataType.Type.NULL in (left_type, right_type):
350                self._set_type(
351                    expression,
352                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
353                )
354            else:
355                self._set_type(expression, exp.DataType.Type.BOOLEAN)
356        elif isinstance(expression, exp.Predicate):
357            self._set_type(expression, exp.DataType.Type.BOOLEAN)
358        else:
359            self._set_type(expression, self._maybe_coerce(left_type, right_type))
360
361        return expression
362
363    @t.no_type_check
364    def _annotate_unary(self, expression: E) -> E:
365        self._annotate_args(expression)
366
367        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
368            self._set_type(expression, exp.DataType.Type.BOOLEAN)
369        else:
370            self._set_type(expression, expression.this.type)
371
372        return expression
373
374    @t.no_type_check
375    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
376        if expression.is_string:
377            self._set_type(expression, exp.DataType.Type.VARCHAR)
378        elif expression.is_int:
379            self._set_type(expression, exp.DataType.Type.INT)
380        else:
381            self._set_type(expression, exp.DataType.Type.DOUBLE)
382
383        return expression
384
385    @t.no_type_check
386    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
387        self._set_type(expression, target_type)
388        return self._annotate_args(expression)
389
390    @t.no_type_check
391    def _annotate_by_args(
392        self, expression: E, *args: str, promote: bool = False, array: bool = False
393    ) -> E:
394        self._annotate_args(expression)
395
396        expressions: t.List[exp.Expression] = []
397        for arg in args:
398            arg_expr = expression.args.get(arg)
399            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
400
401        last_datatype = None
402        for expr in expressions:
403            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
404
405        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
406
407        if promote:
408            if expression.type.this in exp.DataType.INTEGER_TYPES:
409                self._set_type(expression, exp.DataType.Type.BIGINT)
410            elif expression.type.this in exp.DataType.FLOAT_TYPES:
411                self._set_type(expression, exp.DataType.Type.DOUBLE)
412
413        if array:
414            self._set_type(
415                expression,
416                exp.DataType(
417                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
418                ),
419            )
420
421        return expression
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types(
17    expression: E,
18    schema: t.Optional[t.Dict | Schema] = None,
19    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
20    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
21) -> E:
22    """
23    Infers the types of an expression, annotating its AST accordingly.
24
25    Example:
26        >>> import sqlglot
27        >>> schema = {"y": {"cola": "SMALLINT"}}
28        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
29        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
30        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
31        <Type.DOUBLE: 'DOUBLE'>
32
33    Args:
34        expression: Expression to annotate.
35        schema: Database schema.
36        annotators: Maps expression type to corresponding annotation function.
37        coerces_to: Maps expression type to set of types that it can be coerced into.
38
39    Returns:
40        The expression annotated with types.
41    """
42
43    schema = ensure_schema(schema)
44
45    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

class TypeAnnotator:
 91class TypeAnnotator(metaclass=_TypeAnnotator):
 92    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 93        exp.DataType.Type.BIGINT: {
 94            exp.ApproxDistinct,
 95            exp.ArraySize,
 96            exp.Count,
 97            exp.Length,
 98        },
 99        exp.DataType.Type.BOOLEAN: {
100            exp.Between,
101            exp.Boolean,
102            exp.In,
103            exp.RegexpLike,
104        },
105        exp.DataType.Type.DATE: {
106            exp.CurrentDate,
107            exp.Date,
108            exp.DateAdd,
109            exp.DateFromParts,
110            exp.DateStrToDate,
111            exp.DateSub,
112            exp.DateTrunc,
113            exp.DiToDate,
114            exp.StrToDate,
115            exp.TimeStrToDate,
116            exp.TsOrDsToDate,
117        },
118        exp.DataType.Type.DATETIME: {
119            exp.CurrentDatetime,
120            exp.DatetimeAdd,
121            exp.DatetimeSub,
122        },
123        exp.DataType.Type.DOUBLE: {
124            exp.ApproxQuantile,
125            exp.Avg,
126            exp.Exp,
127            exp.Ln,
128            exp.Log,
129            exp.Log2,
130            exp.Log10,
131            exp.Pow,
132            exp.Quantile,
133            exp.Round,
134            exp.SafeDivide,
135            exp.Sqrt,
136            exp.Stddev,
137            exp.StddevPop,
138            exp.StddevSamp,
139            exp.Variance,
140            exp.VariancePop,
141        },
142        exp.DataType.Type.INT: {
143            exp.Ceil,
144            exp.DateDiff,
145            exp.DatetimeDiff,
146            exp.Extract,
147            exp.TimestampDiff,
148            exp.TimeDiff,
149            exp.DateToDi,
150            exp.Floor,
151            exp.Levenshtein,
152            exp.StrPosition,
153            exp.TsOrDiToDi,
154        },
155        exp.DataType.Type.TIMESTAMP: {
156            exp.CurrentTime,
157            exp.CurrentTimestamp,
158            exp.StrToTime,
159            exp.TimeAdd,
160            exp.TimeStrToTime,
161            exp.TimeSub,
162            exp.TimestampAdd,
163            exp.TimestampSub,
164            exp.UnixToTime,
165        },
166        exp.DataType.Type.TINYINT: {
167            exp.Day,
168            exp.Month,
169            exp.Week,
170            exp.Year,
171        },
172        exp.DataType.Type.VARCHAR: {
173            exp.ArrayConcat,
174            exp.Concat,
175            exp.ConcatWs,
176            exp.DateToDateStr,
177            exp.GroupConcat,
178            exp.Initcap,
179            exp.Lower,
180            exp.SafeConcat,
181            exp.Substring,
182            exp.TimeToStr,
183            exp.TimeToTimeStr,
184            exp.Trim,
185            exp.TsOrDsToDateStr,
186            exp.UnixToStr,
187            exp.UnixToTimeStr,
188            exp.Upper,
189        },
190    }
191
192    ANNOTATORS: t.Dict = {
193        **{
194            expr_type: lambda self, e: self._annotate_unary(e)
195            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
196        },
197        **{
198            expr_type: lambda self, e: self._annotate_binary(e)
199            for expr_type in subclasses(exp.__name__, exp.Binary)
200        },
201        **{
202            expr_type: _annotate_with_type_lambda(data_type)
203            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
204            for expr_type in expressions
205        },
206        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
207        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
208        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
209        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
210        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
211        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
212        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
213        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
214        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
215        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
216        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
217        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
218        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
219        exp.Literal: lambda self, e: self._annotate_literal(e),
220        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
221        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
222        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
223        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
224        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
225        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
226        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
227    }
228
229    NESTED_TYPES = {
230        exp.DataType.Type.ARRAY,
231    }
232
233    # Specifies what types a given type can be coerced into (autofilled)
234    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
235
236    def __init__(
237        self,
238        schema: Schema,
239        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
240        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
241    ) -> None:
242        self.schema = schema
243        self.annotators = annotators or self.ANNOTATORS
244        self.coerces_to = coerces_to or self.COERCES_TO
245
246        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
247        self._visited: t.Set[int] = set()
248
249    def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
250        expression.type = target_type
251        self._visited.add(id(expression))
252
253    def annotate(self, expression: E) -> E:
254        for scope in traverse_scope(expression):
255            selects = {}
256            for name, source in scope.sources.items():
257                if not isinstance(source, Scope):
258                    continue
259                if isinstance(source.expression, exp.UDTF):
260                    values = []
261
262                    if isinstance(source.expression, exp.Lateral):
263                        if isinstance(source.expression.this, exp.Explode):
264                            values = [source.expression.this.this]
265                    else:
266                        values = source.expression.expressions[0].expressions
267
268                    if not values:
269                        continue
270
271                    selects[name] = {
272                        alias: column
273                        for alias, column in zip(
274                            source.expression.alias_column_names,
275                            values,
276                        )
277                    }
278                else:
279                    selects[name] = {
280                        select.alias_or_name: select for select in source.expression.selects
281                    }
282
283            # First annotate the current scope's column references
284            for col in scope.columns:
285                if not col.table:
286                    continue
287
288                source = scope.sources.get(col.table)
289                if isinstance(source, exp.Table):
290                    self._set_type(col, self.schema.get_column_type(source, col))
291                elif source and col.table in selects and col.name in selects[col.table]:
292                    self._set_type(col, selects[col.table][col.name].type)
293
294            # Then (possibly) annotate the remaining expressions in the scope
295            self._maybe_annotate(scope.expression)
296
297        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
298
299    def _maybe_annotate(self, expression: E) -> E:
300        if id(expression) in self._visited:
301            return expression  # We've already inferred the expression's type
302
303        annotator = self.annotators.get(expression.__class__)
304
305        return (
306            annotator(self, expression)
307            if annotator
308            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
309        )
310
311    def _annotate_args(self, expression: E) -> E:
312        for _, value in expression.iter_expressions():
313            self._maybe_annotate(value)
314
315        return expression
316
317    def _maybe_coerce(
318        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
319    ) -> exp.DataType | exp.DataType.Type:
320        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
321        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
322
323        # We propagate the NULL / UNKNOWN types upwards if found
324        if exp.DataType.Type.NULL in (type1_value, type2_value):
325            return exp.DataType.Type.NULL
326        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
327            return exp.DataType.Type.UNKNOWN
328
329        if type1_value in self.NESTED_TYPES:
330            return type1
331        if type2_value in self.NESTED_TYPES:
332            return type2
333
334        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
335
336    # Note: the following "no_type_check" decorators were added because mypy was yelling due
337    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
338    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
339
340    @t.no_type_check
341    def _annotate_binary(self, expression: B) -> B:
342        self._annotate_args(expression)
343
344        left_type = expression.left.type.this
345        right_type = expression.right.type.this
346
347        if isinstance(expression, exp.Connector):
348            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
349                self._set_type(expression, exp.DataType.Type.NULL)
350            elif exp.DataType.Type.NULL in (left_type, right_type):
351                self._set_type(
352                    expression,
353                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
354                )
355            else:
356                self._set_type(expression, exp.DataType.Type.BOOLEAN)
357        elif isinstance(expression, exp.Predicate):
358            self._set_type(expression, exp.DataType.Type.BOOLEAN)
359        else:
360            self._set_type(expression, self._maybe_coerce(left_type, right_type))
361
362        return expression
363
364    @t.no_type_check
365    def _annotate_unary(self, expression: E) -> E:
366        self._annotate_args(expression)
367
368        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
369            self._set_type(expression, exp.DataType.Type.BOOLEAN)
370        else:
371            self._set_type(expression, expression.this.type)
372
373        return expression
374
375    @t.no_type_check
376    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
377        if expression.is_string:
378            self._set_type(expression, exp.DataType.Type.VARCHAR)
379        elif expression.is_int:
380            self._set_type(expression, exp.DataType.Type.INT)
381        else:
382            self._set_type(expression, exp.DataType.Type.DOUBLE)
383
384        return expression
385
386    @t.no_type_check
387    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
388        self._set_type(expression, target_type)
389        return self._annotate_args(expression)
390
391    @t.no_type_check
392    def _annotate_by_args(
393        self, expression: E, *args: str, promote: bool = False, array: bool = False
394    ) -> E:
395        self._annotate_args(expression)
396
397        expressions: t.List[exp.Expression] = []
398        for arg in args:
399            arg_expr = expression.args.get(arg)
400            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
401
402        last_datatype = None
403        for expr in expressions:
404            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
405
406        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
407
408        if promote:
409            if expression.type.this in exp.DataType.INTEGER_TYPES:
410                self._set_type(expression, exp.DataType.Type.BIGINT)
411            elif expression.type.this in exp.DataType.FLOAT_TYPES:
412                self._set_type(expression, exp.DataType.Type.DOUBLE)
413
414        if array:
415            self._set_type(
416                expression,
417                exp.DataType(
418                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
419                ),
420            )
421
422        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
236    def __init__(
237        self,
238        schema: Schema,
239        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
240        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
241    ) -> None:
242        self.schema = schema
243        self.annotators = annotators or self.ANNOTATORS
244        self.coerces_to = coerces_to or self.COERCES_TO
245
246        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
247        self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.ArraySize'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Boolean'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.DateFromParts'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.DatetimeAdd'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Quantile'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.Levenshtein'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimestampAdd'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Day'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.Upper'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
NESTED_TYPES = {<Type.ARRAY: 'ARRAY'>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.CHAR: 'CHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.NCHAR: 'NCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}}
schema
annotators
coerces_to
def annotate(self, expression: ~E) -> ~E:
253    def annotate(self, expression: E) -> E:
254        for scope in traverse_scope(expression):
255            selects = {}
256            for name, source in scope.sources.items():
257                if not isinstance(source, Scope):
258                    continue
259                if isinstance(source.expression, exp.UDTF):
260                    values = []
261
262                    if isinstance(source.expression, exp.Lateral):
263                        if isinstance(source.expression.this, exp.Explode):
264                            values = [source.expression.this.this]
265                    else:
266                        values = source.expression.expressions[0].expressions
267
268                    if not values:
269                        continue
270
271                    selects[name] = {
272                        alias: column
273                        for alias, column in zip(
274                            source.expression.alias_column_names,
275                            values,
276                        )
277                    }
278                else:
279                    selects[name] = {
280                        select.alias_or_name: select for select in source.expression.selects
281                    }
282
283            # First annotate the current scope's column references
284            for col in scope.columns:
285                if not col.table:
286                    continue
287
288                source = scope.sources.get(col.table)
289                if isinstance(source, exp.Table):
290                    self._set_type(col, self.schema.get_column_type(source, col))
291                elif source and col.table in selects and col.name in selects[col.table]:
292                    self._set_type(col, selects[col.table][col.name].type)
293
294            # Then (possibly) annotate the remaining expressions in the scope
295            self._maybe_annotate(scope.expression)
296
297        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions