sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot._typing import E 7from sqlglot.helper import ensure_list, subclasses 8from sqlglot.optimizer.scope import Scope, traverse_scope 9from sqlglot.schema import Schema, ensure_schema 10 11if t.TYPE_CHECKING: 12 B = t.TypeVar("B", bound=exp.Binary) 13 14 15def annotate_types( 16 expression: E, 17 schema: t.Optional[t.Dict | Schema] = None, 18 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 19 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 20) -> E: 21 """ 22 Infers the types of an expression, annotating its AST accordingly. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"y": {"cola": "SMALLINT"}} 27 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 28 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 29 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 30 <Type.DOUBLE: 'DOUBLE'> 31 32 Args: 33 expression: Expression to annotate. 34 schema: Database schema. 35 annotators: Maps expression type to corresponding annotation function. 36 coerces_to: Maps expression type to set of types that it can be coerced into. 37 38 Returns: 39 The expression annotated with types. 40 """ 41 42 schema = ensure_schema(schema) 43 44 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 45 46 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 48 return lambda self, e: self._annotate_with_type(e, data_type) 49 50 51class _TypeAnnotator(type): 52 def __new__(cls, clsname, bases, attrs): 53 klass = super().__new__(cls, clsname, bases, attrs) 54 55 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 56 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 57 text_precedence = ( 58 exp.DataType.Type.TEXT, 59 exp.DataType.Type.NVARCHAR, 60 exp.DataType.Type.VARCHAR, 61 exp.DataType.Type.NCHAR, 62 exp.DataType.Type.CHAR, 63 ) 64 numeric_precedence = ( 65 exp.DataType.Type.DOUBLE, 66 exp.DataType.Type.FLOAT, 67 exp.DataType.Type.DECIMAL, 68 exp.DataType.Type.BIGINT, 69 exp.DataType.Type.INT, 70 exp.DataType.Type.SMALLINT, 71 exp.DataType.Type.TINYINT, 72 ) 73 timelike_precedence = ( 74 exp.DataType.Type.TIMESTAMPLTZ, 75 exp.DataType.Type.TIMESTAMPTZ, 76 exp.DataType.Type.TIMESTAMP, 77 exp.DataType.Type.DATETIME, 78 exp.DataType.Type.DATE, 79 ) 80 81 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 82 coerces_to = set() 83 for data_type in type_precedence: 84 klass.COERCES_TO[data_type] = coerces_to.copy() 85 coerces_to |= {data_type} 86 87 return klass 88 89 90class TypeAnnotator(metaclass=_TypeAnnotator): 91 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 92 exp.DataType.Type.BIGINT: { 93 exp.ApproxDistinct, 94 exp.ArraySize, 95 exp.Count, 96 exp.Length, 97 }, 98 exp.DataType.Type.BOOLEAN: { 99 exp.Between, 100 exp.Boolean, 101 exp.In, 102 exp.RegexpLike, 103 }, 104 exp.DataType.Type.DATE: { 105 exp.CurrentDate, 106 exp.Date, 107 exp.DateAdd, 108 exp.DateFromParts, 109 exp.DateStrToDate, 110 exp.DateSub, 111 exp.DateTrunc, 112 exp.DiToDate, 113 exp.StrToDate, 114 exp.TimeStrToDate, 115 exp.TsOrDsToDate, 116 }, 117 exp.DataType.Type.DATETIME: { 118 exp.CurrentDatetime, 119 exp.DatetimeAdd, 120 exp.DatetimeSub, 121 }, 122 exp.DataType.Type.DOUBLE: { 123 exp.ApproxQuantile, 124 exp.Avg, 125 exp.Exp, 126 exp.Ln, 127 exp.Log, 128 exp.Log2, 129 exp.Log10, 130 exp.Pow, 131 exp.Quantile, 132 exp.Round, 133 exp.SafeDivide, 134 exp.Sqrt, 135 exp.Stddev, 136 exp.StddevPop, 137 exp.StddevSamp, 138 exp.Variance, 139 exp.VariancePop, 140 }, 141 exp.DataType.Type.INT: { 142 exp.Ceil, 143 exp.DateDiff, 144 exp.DatetimeDiff, 145 exp.Extract, 146 exp.TimestampDiff, 147 exp.TimeDiff, 148 exp.DateToDi, 149 exp.Floor, 150 exp.Levenshtein, 151 exp.StrPosition, 152 exp.TsOrDiToDi, 153 }, 154 exp.DataType.Type.TIMESTAMP: { 155 exp.CurrentTime, 156 exp.CurrentTimestamp, 157 exp.StrToTime, 158 exp.TimeAdd, 159 exp.TimeStrToTime, 160 exp.TimeSub, 161 exp.TimestampAdd, 162 exp.TimestampSub, 163 exp.UnixToTime, 164 }, 165 exp.DataType.Type.TINYINT: { 166 exp.Day, 167 exp.Month, 168 exp.Week, 169 exp.Year, 170 }, 171 exp.DataType.Type.VARCHAR: { 172 exp.ArrayConcat, 173 exp.Concat, 174 exp.ConcatWs, 175 exp.DateToDateStr, 176 exp.GroupConcat, 177 exp.Initcap, 178 exp.Lower, 179 exp.SafeConcat, 180 exp.Substring, 181 exp.TimeToStr, 182 exp.TimeToTimeStr, 183 exp.Trim, 184 exp.TsOrDsToDateStr, 185 exp.UnixToStr, 186 exp.UnixToTimeStr, 187 exp.Upper, 188 }, 189 } 190 191 ANNOTATORS: t.Dict = { 192 **{ 193 expr_type: lambda self, e: self._annotate_unary(e) 194 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 195 }, 196 **{ 197 expr_type: lambda self, e: self._annotate_binary(e) 198 for expr_type in subclasses(exp.__name__, exp.Binary) 199 }, 200 **{ 201 expr_type: _annotate_with_type_lambda(data_type) 202 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 203 for expr_type in expressions 204 }, 205 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 206 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 207 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 208 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 209 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 210 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 211 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 212 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 213 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 214 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 215 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 216 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 217 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 218 exp.Literal: lambda self, e: self._annotate_literal(e), 219 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 220 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 221 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 222 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 223 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 224 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 225 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 226 } 227 228 NESTED_TYPES = { 229 exp.DataType.Type.ARRAY, 230 } 231 232 # Specifies what types a given type can be coerced into (autofilled) 233 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 234 235 def __init__( 236 self, 237 schema: Schema, 238 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 239 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 240 ) -> None: 241 self.schema = schema 242 self.annotators = annotators or self.ANNOTATORS 243 self.coerces_to = coerces_to or self.COERCES_TO 244 245 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 246 self._visited: t.Set[int] = set() 247 248 def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: 249 expression.type = target_type 250 self._visited.add(id(expression)) 251 252 def annotate(self, expression: E) -> E: 253 for scope in traverse_scope(expression): 254 selects = {} 255 for name, source in scope.sources.items(): 256 if not isinstance(source, Scope): 257 continue 258 if isinstance(source.expression, exp.UDTF): 259 values = [] 260 261 if isinstance(source.expression, exp.Lateral): 262 if isinstance(source.expression.this, exp.Explode): 263 values = [source.expression.this.this] 264 else: 265 values = source.expression.expressions[0].expressions 266 267 if not values: 268 continue 269 270 selects[name] = { 271 alias: column 272 for alias, column in zip( 273 source.expression.alias_column_names, 274 values, 275 ) 276 } 277 else: 278 selects[name] = { 279 select.alias_or_name: select for select in source.expression.selects 280 } 281 282 # First annotate the current scope's column references 283 for col in scope.columns: 284 if not col.table: 285 continue 286 287 source = scope.sources.get(col.table) 288 if isinstance(source, exp.Table): 289 self._set_type(col, self.schema.get_column_type(source, col)) 290 elif source and col.table in selects and col.name in selects[col.table]: 291 self._set_type(col, selects[col.table][col.name].type) 292 293 # Then (possibly) annotate the remaining expressions in the scope 294 self._maybe_annotate(scope.expression) 295 296 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 297 298 def _maybe_annotate(self, expression: E) -> E: 299 if id(expression) in self._visited: 300 return expression # We've already inferred the expression's type 301 302 annotator = self.annotators.get(expression.__class__) 303 304 return ( 305 annotator(self, expression) 306 if annotator 307 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 308 ) 309 310 def _annotate_args(self, expression: E) -> E: 311 for _, value in expression.iter_expressions(): 312 self._maybe_annotate(value) 313 314 return expression 315 316 def _maybe_coerce( 317 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 318 ) -> exp.DataType | exp.DataType.Type: 319 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 320 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 321 322 # We propagate the NULL / UNKNOWN types upwards if found 323 if exp.DataType.Type.NULL in (type1_value, type2_value): 324 return exp.DataType.Type.NULL 325 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 326 return exp.DataType.Type.UNKNOWN 327 328 if type1_value in self.NESTED_TYPES: 329 return type1 330 if type2_value in self.NESTED_TYPES: 331 return type2 332 333 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 334 335 # Note: the following "no_type_check" decorators were added because mypy was yelling due 336 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 337 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 338 339 @t.no_type_check 340 def _annotate_binary(self, expression: B) -> B: 341 self._annotate_args(expression) 342 343 left_type = expression.left.type.this 344 right_type = expression.right.type.this 345 346 if isinstance(expression, exp.Connector): 347 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 348 self._set_type(expression, exp.DataType.Type.NULL) 349 elif exp.DataType.Type.NULL in (left_type, right_type): 350 self._set_type( 351 expression, 352 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 353 ) 354 else: 355 self._set_type(expression, exp.DataType.Type.BOOLEAN) 356 elif isinstance(expression, exp.Predicate): 357 self._set_type(expression, exp.DataType.Type.BOOLEAN) 358 else: 359 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 360 361 return expression 362 363 @t.no_type_check 364 def _annotate_unary(self, expression: E) -> E: 365 self._annotate_args(expression) 366 367 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 368 self._set_type(expression, exp.DataType.Type.BOOLEAN) 369 else: 370 self._set_type(expression, expression.this.type) 371 372 return expression 373 374 @t.no_type_check 375 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 376 if expression.is_string: 377 self._set_type(expression, exp.DataType.Type.VARCHAR) 378 elif expression.is_int: 379 self._set_type(expression, exp.DataType.Type.INT) 380 else: 381 self._set_type(expression, exp.DataType.Type.DOUBLE) 382 383 return expression 384 385 @t.no_type_check 386 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 387 self._set_type(expression, target_type) 388 return self._annotate_args(expression) 389 390 @t.no_type_check 391 def _annotate_by_args( 392 self, expression: E, *args: str, promote: bool = False, array: bool = False 393 ) -> E: 394 self._annotate_args(expression) 395 396 expressions: t.List[exp.Expression] = [] 397 for arg in args: 398 arg_expr = expression.args.get(arg) 399 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 400 401 last_datatype = None 402 for expr in expressions: 403 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 404 405 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 406 407 if promote: 408 if expression.type.this in exp.DataType.INTEGER_TYPES: 409 self._set_type(expression, exp.DataType.Type.BIGINT) 410 elif expression.type.this in exp.DataType.FLOAT_TYPES: 411 self._set_type(expression, exp.DataType.Type.DOUBLE) 412 413 if array: 414 self._set_type( 415 expression, 416 exp.DataType( 417 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 418 ), 419 ) 420 421 return expression
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types( 17 expression: E, 18 schema: t.Optional[t.Dict | Schema] = None, 19 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 20 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 21) -> E: 22 """ 23 Infers the types of an expression, annotating its AST accordingly. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"y": {"cola": "SMALLINT"}} 28 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 29 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 30 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 31 <Type.DOUBLE: 'DOUBLE'> 32 33 Args: 34 expression: Expression to annotate. 35 schema: Database schema. 36 annotators: Maps expression type to corresponding annotation function. 37 coerces_to: Maps expression type to set of types that it can be coerced into. 38 39 Returns: 40 The expression annotated with types. 41 """ 42 43 schema = ensure_schema(schema) 44 45 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
class
TypeAnnotator:
91class TypeAnnotator(metaclass=_TypeAnnotator): 92 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 93 exp.DataType.Type.BIGINT: { 94 exp.ApproxDistinct, 95 exp.ArraySize, 96 exp.Count, 97 exp.Length, 98 }, 99 exp.DataType.Type.BOOLEAN: { 100 exp.Between, 101 exp.Boolean, 102 exp.In, 103 exp.RegexpLike, 104 }, 105 exp.DataType.Type.DATE: { 106 exp.CurrentDate, 107 exp.Date, 108 exp.DateAdd, 109 exp.DateFromParts, 110 exp.DateStrToDate, 111 exp.DateSub, 112 exp.DateTrunc, 113 exp.DiToDate, 114 exp.StrToDate, 115 exp.TimeStrToDate, 116 exp.TsOrDsToDate, 117 }, 118 exp.DataType.Type.DATETIME: { 119 exp.CurrentDatetime, 120 exp.DatetimeAdd, 121 exp.DatetimeSub, 122 }, 123 exp.DataType.Type.DOUBLE: { 124 exp.ApproxQuantile, 125 exp.Avg, 126 exp.Exp, 127 exp.Ln, 128 exp.Log, 129 exp.Log2, 130 exp.Log10, 131 exp.Pow, 132 exp.Quantile, 133 exp.Round, 134 exp.SafeDivide, 135 exp.Sqrt, 136 exp.Stddev, 137 exp.StddevPop, 138 exp.StddevSamp, 139 exp.Variance, 140 exp.VariancePop, 141 }, 142 exp.DataType.Type.INT: { 143 exp.Ceil, 144 exp.DateDiff, 145 exp.DatetimeDiff, 146 exp.Extract, 147 exp.TimestampDiff, 148 exp.TimeDiff, 149 exp.DateToDi, 150 exp.Floor, 151 exp.Levenshtein, 152 exp.StrPosition, 153 exp.TsOrDiToDi, 154 }, 155 exp.DataType.Type.TIMESTAMP: { 156 exp.CurrentTime, 157 exp.CurrentTimestamp, 158 exp.StrToTime, 159 exp.TimeAdd, 160 exp.TimeStrToTime, 161 exp.TimeSub, 162 exp.TimestampAdd, 163 exp.TimestampSub, 164 exp.UnixToTime, 165 }, 166 exp.DataType.Type.TINYINT: { 167 exp.Day, 168 exp.Month, 169 exp.Week, 170 exp.Year, 171 }, 172 exp.DataType.Type.VARCHAR: { 173 exp.ArrayConcat, 174 exp.Concat, 175 exp.ConcatWs, 176 exp.DateToDateStr, 177 exp.GroupConcat, 178 exp.Initcap, 179 exp.Lower, 180 exp.SafeConcat, 181 exp.Substring, 182 exp.TimeToStr, 183 exp.TimeToTimeStr, 184 exp.Trim, 185 exp.TsOrDsToDateStr, 186 exp.UnixToStr, 187 exp.UnixToTimeStr, 188 exp.Upper, 189 }, 190 } 191 192 ANNOTATORS: t.Dict = { 193 **{ 194 expr_type: lambda self, e: self._annotate_unary(e) 195 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 196 }, 197 **{ 198 expr_type: lambda self, e: self._annotate_binary(e) 199 for expr_type in subclasses(exp.__name__, exp.Binary) 200 }, 201 **{ 202 expr_type: _annotate_with_type_lambda(data_type) 203 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 204 for expr_type in expressions 205 }, 206 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 207 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 208 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 209 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 210 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 211 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 212 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 213 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 214 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 215 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 216 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 217 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 218 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 219 exp.Literal: lambda self, e: self._annotate_literal(e), 220 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 221 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 222 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 223 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 224 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 225 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 226 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 227 } 228 229 NESTED_TYPES = { 230 exp.DataType.Type.ARRAY, 231 } 232 233 # Specifies what types a given type can be coerced into (autofilled) 234 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 235 236 def __init__( 237 self, 238 schema: Schema, 239 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 240 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 241 ) -> None: 242 self.schema = schema 243 self.annotators = annotators or self.ANNOTATORS 244 self.coerces_to = coerces_to or self.COERCES_TO 245 246 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 247 self._visited: t.Set[int] = set() 248 249 def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: 250 expression.type = target_type 251 self._visited.add(id(expression)) 252 253 def annotate(self, expression: E) -> E: 254 for scope in traverse_scope(expression): 255 selects = {} 256 for name, source in scope.sources.items(): 257 if not isinstance(source, Scope): 258 continue 259 if isinstance(source.expression, exp.UDTF): 260 values = [] 261 262 if isinstance(source.expression, exp.Lateral): 263 if isinstance(source.expression.this, exp.Explode): 264 values = [source.expression.this.this] 265 else: 266 values = source.expression.expressions[0].expressions 267 268 if not values: 269 continue 270 271 selects[name] = { 272 alias: column 273 for alias, column in zip( 274 source.expression.alias_column_names, 275 values, 276 ) 277 } 278 else: 279 selects[name] = { 280 select.alias_or_name: select for select in source.expression.selects 281 } 282 283 # First annotate the current scope's column references 284 for col in scope.columns: 285 if not col.table: 286 continue 287 288 source = scope.sources.get(col.table) 289 if isinstance(source, exp.Table): 290 self._set_type(col, self.schema.get_column_type(source, col)) 291 elif source and col.table in selects and col.name in selects[col.table]: 292 self._set_type(col, selects[col.table][col.name].type) 293 294 # Then (possibly) annotate the remaining expressions in the scope 295 self._maybe_annotate(scope.expression) 296 297 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 298 299 def _maybe_annotate(self, expression: E) -> E: 300 if id(expression) in self._visited: 301 return expression # We've already inferred the expression's type 302 303 annotator = self.annotators.get(expression.__class__) 304 305 return ( 306 annotator(self, expression) 307 if annotator 308 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 309 ) 310 311 def _annotate_args(self, expression: E) -> E: 312 for _, value in expression.iter_expressions(): 313 self._maybe_annotate(value) 314 315 return expression 316 317 def _maybe_coerce( 318 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 319 ) -> exp.DataType | exp.DataType.Type: 320 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 321 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 322 323 # We propagate the NULL / UNKNOWN types upwards if found 324 if exp.DataType.Type.NULL in (type1_value, type2_value): 325 return exp.DataType.Type.NULL 326 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 327 return exp.DataType.Type.UNKNOWN 328 329 if type1_value in self.NESTED_TYPES: 330 return type1 331 if type2_value in self.NESTED_TYPES: 332 return type2 333 334 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 335 336 # Note: the following "no_type_check" decorators were added because mypy was yelling due 337 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 338 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 339 340 @t.no_type_check 341 def _annotate_binary(self, expression: B) -> B: 342 self._annotate_args(expression) 343 344 left_type = expression.left.type.this 345 right_type = expression.right.type.this 346 347 if isinstance(expression, exp.Connector): 348 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 349 self._set_type(expression, exp.DataType.Type.NULL) 350 elif exp.DataType.Type.NULL in (left_type, right_type): 351 self._set_type( 352 expression, 353 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 354 ) 355 else: 356 self._set_type(expression, exp.DataType.Type.BOOLEAN) 357 elif isinstance(expression, exp.Predicate): 358 self._set_type(expression, exp.DataType.Type.BOOLEAN) 359 else: 360 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 361 362 return expression 363 364 @t.no_type_check 365 def _annotate_unary(self, expression: E) -> E: 366 self._annotate_args(expression) 367 368 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 369 self._set_type(expression, exp.DataType.Type.BOOLEAN) 370 else: 371 self._set_type(expression, expression.this.type) 372 373 return expression 374 375 @t.no_type_check 376 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 377 if expression.is_string: 378 self._set_type(expression, exp.DataType.Type.VARCHAR) 379 elif expression.is_int: 380 self._set_type(expression, exp.DataType.Type.INT) 381 else: 382 self._set_type(expression, exp.DataType.Type.DOUBLE) 383 384 return expression 385 386 @t.no_type_check 387 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 388 self._set_type(expression, target_type) 389 return self._annotate_args(expression) 390 391 @t.no_type_check 392 def _annotate_by_args( 393 self, expression: E, *args: str, promote: bool = False, array: bool = False 394 ) -> E: 395 self._annotate_args(expression) 396 397 expressions: t.List[exp.Expression] = [] 398 for arg in args: 399 arg_expr = expression.args.get(arg) 400 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 401 402 last_datatype = None 403 for expr in expressions: 404 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 405 406 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 407 408 if promote: 409 if expression.type.this in exp.DataType.INTEGER_TYPES: 410 self._set_type(expression, exp.DataType.Type.BIGINT) 411 elif expression.type.this in exp.DataType.FLOAT_TYPES: 412 self._set_type(expression, exp.DataType.Type.DOUBLE) 413 414 if array: 415 self._set_type( 416 expression, 417 exp.DataType( 418 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 419 ), 420 ) 421 422 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
236 def __init__( 237 self, 238 schema: Schema, 239 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 240 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 241 ) -> None: 242 self.schema = schema 243 self.annotators = annotators or self.ANNOTATORS 244 self.coerces_to = coerces_to or self.COERCES_TO 245 246 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 247 self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.ArraySize'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Boolean'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.DateFromParts'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.DatetimeAdd'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Quantile'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.Levenshtein'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimestampAdd'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Day'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.Upper'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.CHAR: 'CHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.NCHAR: 'NCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}}
def
annotate(self, expression: ~E) -> ~E:
253 def annotate(self, expression: E) -> E: 254 for scope in traverse_scope(expression): 255 selects = {} 256 for name, source in scope.sources.items(): 257 if not isinstance(source, Scope): 258 continue 259 if isinstance(source.expression, exp.UDTF): 260 values = [] 261 262 if isinstance(source.expression, exp.Lateral): 263 if isinstance(source.expression.this, exp.Explode): 264 values = [source.expression.this.this] 265 else: 266 values = source.expression.expressions[0].expressions 267 268 if not values: 269 continue 270 271 selects[name] = { 272 alias: column 273 for alias, column in zip( 274 source.expression.alias_column_names, 275 values, 276 ) 277 } 278 else: 279 selects[name] = { 280 select.alias_or_name: select for select in source.expression.selects 281 } 282 283 # First annotate the current scope's column references 284 for col in scope.columns: 285 if not col.table: 286 continue 287 288 source = scope.sources.get(col.table) 289 if isinstance(source, exp.Table): 290 self._set_type(col, self.schema.get_column_type(source, col)) 291 elif source and col.table in selects and col.name in selects[col.table]: 292 self._set_type(col, selects[col.table][col.name].type) 293 294 # Then (possibly) annotate the remaining expressions in the scope 295 self._maybe_annotate(scope.expression) 296 297 return self._maybe_annotate(expression) # This takes care of non-traversable expressions