Merging upstream version 16.2.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
c12f551e31
commit
718a80b164
106 changed files with 41940 additions and 40162 deletions
|
@ -1,13 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
|
||||
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
) -> E:
|
||||
"""
|
||||
Recursively infer & annotate types in an expression syntax tree against a schema.
|
||||
Assumes that we've already executed the optimizer's qualify_columns step.
|
||||
Infers the types of an expression, annotating its AST accordingly.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
|
@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
|||
<Type.DOUBLE: 'DOUBLE'>
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): Expression to annotate.
|
||||
schema (dict|sqlglot.optimizer.Schema): Database schema.
|
||||
annotators (dict): Maps expression type to corresponding annotation function.
|
||||
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
|
||||
expression: Expression to annotate.
|
||||
schema: Database schema.
|
||||
annotators: Maps expression type to corresponding annotation function.
|
||||
coerces_to: Maps expression type to set of types that it can be coerced into.
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: expression annotated with types
|
||||
The expression annotated with types.
|
||||
"""
|
||||
|
||||
schema = ensure_schema(schema)
|
||||
|
@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
|||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
class TypeAnnotator:
|
||||
ANNOTATORS = {
|
||||
**{
|
||||
expr_type: lambda self, expr: self._annotate_unary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Unary)
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, expr: self._annotate_binary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
|
||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Literal: lambda self, expr: self._annotate_literal(expr),
|
||||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BIGINT
|
||||
),
|
||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.Sum: lambda self, expr: self._annotate_by_args(
|
||||
expr, "this", "expressions", promote=True
|
||||
),
|
||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
|
||||
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
|
||||
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
|
||||
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
|
||||
exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BOOLEAN
|
||||
),
|
||||
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
}
|
||||
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
|
||||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
|
||||
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
|
||||
COERCES_TO = {
|
||||
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
|
||||
exp.DataType.Type.TEXT: set(),
|
||||
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.NCHAR: {
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
|
||||
class _TypeAnnotator(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
||||
# Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
|
||||
# https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
|
||||
text_precedence = (
|
||||
exp.DataType.Type.TEXT,
|
||||
},
|
||||
exp.DataType.Type.CHAR: {
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.TEXT,
|
||||
},
|
||||
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
|
||||
exp.DataType.Type.DOUBLE: set(),
|
||||
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.CHAR,
|
||||
)
|
||||
numeric_precedence = (
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.SMALLINT: {
|
||||
exp.DataType.Type.INT,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.DataType.Type.SMALLINT,
|
||||
exp.DataType.Type.INT,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
|
||||
exp.DataType.Type.TIMESTAMPLTZ: set(),
|
||||
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TINYINT,
|
||||
)
|
||||
timelike_precedence = (
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
exp.DataType.Type.DATETIME,
|
||||
exp.DataType.Type.DATE,
|
||||
)
|
||||
|
||||
for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
|
||||
coerces_to = set()
|
||||
for data_type in type_precedence:
|
||||
klass.COERCES_TO[data_type] = coerces_to.copy()
|
||||
coerces_to |= {data_type}
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
class TypeAnnotator(metaclass=_TypeAnnotator):
|
||||
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.ApproxDistinct,
|
||||
exp.ArraySize,
|
||||
exp.Count,
|
||||
exp.Length,
|
||||
},
|
||||
exp.DataType.Type.BOOLEAN: {
|
||||
exp.Between,
|
||||
exp.Boolean,
|
||||
exp.In,
|
||||
exp.RegexpLike,
|
||||
},
|
||||
exp.DataType.Type.DATE: {
|
||||
exp.DataType.Type.DATETIME,
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
exp.CurrentDate,
|
||||
exp.Date,
|
||||
exp.DateAdd,
|
||||
exp.DateStrToDate,
|
||||
exp.DateSub,
|
||||
exp.DateTrunc,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
exp.TimeStrToDate,
|
||||
exp.TsOrDsToDate,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.CurrentDatetime,
|
||||
exp.DatetimeAdd,
|
||||
exp.DatetimeSub,
|
||||
},
|
||||
exp.DataType.Type.DOUBLE: {
|
||||
exp.ApproxQuantile,
|
||||
exp.Avg,
|
||||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
exp.Log2,
|
||||
exp.Log10,
|
||||
exp.Pow,
|
||||
exp.Quantile,
|
||||
exp.Round,
|
||||
exp.SafeDivide,
|
||||
exp.Sqrt,
|
||||
exp.Stddev,
|
||||
exp.StddevPop,
|
||||
exp.StddevSamp,
|
||||
exp.Variance,
|
||||
exp.VariancePop,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.Ceil,
|
||||
exp.DateDiff,
|
||||
exp.DatetimeDiff,
|
||||
exp.Extract,
|
||||
exp.TimestampDiff,
|
||||
exp.TimeDiff,
|
||||
exp.DateToDi,
|
||||
exp.Floor,
|
||||
exp.Levenshtein,
|
||||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
exp.StrToTime,
|
||||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.Day,
|
||||
exp.Month,
|
||||
exp.Week,
|
||||
exp.Year,
|
||||
},
|
||||
exp.DataType.Type.VARCHAR: {
|
||||
exp.ArrayConcat,
|
||||
exp.Concat,
|
||||
exp.ConcatWs,
|
||||
exp.DateToDateStr,
|
||||
exp.GroupConcat,
|
||||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.SafeConcat,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
exp.Trim,
|
||||
exp.TsOrDsToDateStr,
|
||||
exp.UnixToStr,
|
||||
exp.UnixToTimeStr,
|
||||
exp.Upper,
|
||||
},
|
||||
}
|
||||
|
||||
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
|
||||
ANNOTATORS = {
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_unary(e)
|
||||
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_binary(e)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
**{
|
||||
expr_type: _annotate_with_type_lambda(data_type)
|
||||
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
|
||||
for expr_type in expressions
|
||||
},
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Literal: lambda self, e: self._annotate_literal(e),
|
||||
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
}
|
||||
|
||||
def __init__(self, schema=None, annotators=None, coerces_to=None):
|
||||
# Specifies what types a given type can be coerced into (autofilled)
|
||||
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: Schema,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
|
||||
def annotate(self, expression):
|
||||
if isinstance(expression, self.TRAVERSABLES):
|
||||
for scope in traverse_scope(expression):
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
def annotate(self, expression: E) -> E:
|
||||
for scope in traverse_scope(expression):
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
elif source and col.table in selects and col.name in selects[col.table]:
|
||||
col.type = selects[col.table][col.name].type
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
elif source and col.table in selects and col.name in selects[col.table]:
|
||||
col.type = selects[col.table][col.name].type
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression):
|
||||
def _maybe_annotate(self, expression: E) -> E:
|
||||
if expression.type:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
||||
|
@ -312,13 +290,15 @@ class TypeAnnotator:
|
|||
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
|
||||
)
|
||||
|
||||
def _annotate_args(self, expression):
|
||||
def _annotate_args(self, expression: E) -> E:
|
||||
for _, value in expression.iter_expressions():
|
||||
self._maybe_annotate(value)
|
||||
|
||||
return expression
|
||||
|
||||
def _maybe_coerce(self, type1, type2):
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType.Type:
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if isinstance(type1, exp.DataType):
|
||||
type1 = type1.this
|
||||
|
@ -330,9 +310,14 @@ class TypeAnnotator:
|
|||
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2 if type2 in self.coerces_to.get(type1, {}) else type1
|
||||
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
|
||||
|
||||
def _annotate_binary(self, expression):
|
||||
# Note: the following "no_type_check" decorators were added because mypy was yelling due
|
||||
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
|
||||
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left_type = expression.left.type.this
|
||||
|
@ -354,7 +339,8 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_unary(self, expression):
|
||||
@t.no_type_check
|
||||
def _annotate_unary(self, expression: E) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
|
||||
|
@ -364,7 +350,8 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_literal(self, expression):
|
||||
@t.no_type_check
|
||||
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
|
||||
if expression.is_string:
|
||||
expression.type = exp.DataType.Type.VARCHAR
|
||||
elif expression.is_int:
|
||||
|
@ -374,13 +361,16 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_with_type(self, expression, target_type):
|
||||
@t.no_type_check
|
||||
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
|
||||
expression.type = target_type
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _annotate_by_args(self, expression, *args, promote=False):
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
|
||||
self._annotate_args(expression)
|
||||
expressions = []
|
||||
|
||||
expressions: t.List[exp.Expression] = []
|
||||
for arg in args:
|
||||
arg_expr = expression.args.get(arg)
|
||||
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
|
||||
|
|
|
@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
|
||||
node = exp.Concat(this=node.this, expression=node.expression)
|
||||
node = exp.Concat(expressions=[node.left, node.right])
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def eliminate_joins(expression):
|
|||
|
||||
# Reverse the joins so we can remove chains of unused joins
|
||||
for join in reversed(joins):
|
||||
alias = join.this.alias_or_name
|
||||
alias = join.alias_or_name
|
||||
if _should_eliminate_join(scope, join, alias):
|
||||
join.pop()
|
||||
scope.remove_source(alias)
|
||||
|
@ -126,7 +126,7 @@ def join_condition(join):
|
|||
tuple[list[str], list[str], exp.Expression]:
|
||||
Tuple of (source key, join key, remaining predicate)
|
||||
"""
|
||||
name = join.this.alias_or_name
|
||||
name = join.alias_or_name
|
||||
on = (join.args.get("on") or exp.true()).copy()
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
|
|
@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
|
|||
source.replace(
|
||||
exp.select("*")
|
||||
.from_(
|
||||
alias(source, source.name or source.alias, table=True),
|
||||
alias(source, source.alias_or_name, table=True),
|
||||
copy=False,
|
||||
)
|
||||
.subquery(source.alias, copy=False)
|
||||
|
|
|
@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
if not isinstance(from_or_join, exp.Join):
|
||||
return False
|
||||
|
||||
alias = from_or_join.this.alias_or_name
|
||||
alias = from_or_join.alias_or_name
|
||||
|
||||
on = from_or_join.args.get("on")
|
||||
if not on:
|
||||
|
@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
|
|||
"""
|
||||
|
||||
new_joins = []
|
||||
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
|
||||
for subquery in comma_joins:
|
||||
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
|
||||
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
|
||||
|
||||
joins = inner_scope.expression.args.get("joins") or []
|
||||
for join in joins:
|
||||
|
@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
if source == from_or_join.alias_or_name:
|
||||
break
|
||||
|
||||
if set(exp.column_table_names(where.this)) <= sources:
|
||||
if exp.column_table_names(where.this) <= sources:
|
||||
from_or_join.on(where.this, copy=False)
|
||||
from_or_join.set("on", from_or_join.args.get("on"))
|
||||
return
|
||||
|
||||
expression.where(where.this, copy=False)
|
||||
expression.set("where", expression.args.get("where"))
|
||||
|
||||
|
||||
def _merge_order(outer_scope, inner_scope):
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import tsort
|
||||
|
||||
|
@ -13,25 +17,28 @@ def optimize_joins(expression):
|
|||
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
|
||||
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
|
||||
"""
|
||||
|
||||
for select in expression.find_all(exp.Select):
|
||||
references = {}
|
||||
cross_joins = []
|
||||
|
||||
for join in select.args.get("joins", []):
|
||||
name = join.this.alias_or_name
|
||||
tables = other_table_names(join, name)
|
||||
tables = other_table_names(join)
|
||||
|
||||
if tables:
|
||||
for table in tables:
|
||||
references[table] = references.get(table, []) + [join]
|
||||
else:
|
||||
cross_joins.append((name, join))
|
||||
cross_joins.append((join.alias_or_name, join))
|
||||
|
||||
for name, join in cross_joins:
|
||||
for dep in references.get(name, []):
|
||||
on = dep.args["on"]
|
||||
|
||||
if isinstance(on, exp.Connector):
|
||||
if len(other_table_names(dep)) < 2:
|
||||
continue
|
||||
|
||||
for predicate in on.flatten():
|
||||
if name in exp.column_table_names(predicate):
|
||||
predicate.replace(exp.true())
|
||||
|
@ -47,17 +54,12 @@ def reorder_joins(expression):
|
|||
Reorder joins by topological sort order based on predicate references.
|
||||
"""
|
||||
for from_ in expression.find_all(exp.From):
|
||||
head = from_.this
|
||||
parent = from_.parent
|
||||
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
|
||||
dag = {head.alias_or_name: []}
|
||||
|
||||
for name, join in joins.items():
|
||||
dag[name] = other_table_names(join, name)
|
||||
|
||||
joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
|
||||
dag = {name: other_table_names(join) for name, join in joins.items()}
|
||||
parent.set(
|
||||
"joins",
|
||||
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
|
||||
[joins[name] for name in tsort(dag) if name != from_.alias_or_name],
|
||||
)
|
||||
return expression
|
||||
|
||||
|
@ -75,9 +77,6 @@ def normalize(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def other_table_names(join, exclude):
|
||||
return [
|
||||
name
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
|
||||
if name != exclude
|
||||
]
|
||||
def other_table_names(join: exp.Join) -> t.Set[str]:
|
||||
on = join.args.get("on")
|
||||
return exp.column_table_names(on, join.alias_or_name) if on else set()
|
||||
|
|
|
@ -78,7 +78,7 @@ def optimize(
|
|||
"schema": schema,
|
||||
"dialect": dialect,
|
||||
"isolate_tables": True, # needed for other optimizations to perform well
|
||||
"quote_identifiers": False, # this happens in canonicalize
|
||||
"quote_identifiers": False,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ def pushdown_predicates(expression):
|
|||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.this.alias_or_name
|
||||
name = join.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
|
||||
return expression
|
||||
|
@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
pushdown_tables = set()
|
||||
|
||||
for a in predicates:
|
||||
a_tables = set(exp.column_table_names(a))
|
||||
a_tables = exp.column_table_names(a)
|
||||
|
||||
for b in predicates:
|
||||
a_tables &= set(exp.column_table_names(b))
|
||||
a_tables &= exp.column_table_names(b)
|
||||
|
||||
pushdown_tables.update(a_tables)
|
||||
|
||||
|
@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
tables = exp.column_table_names(predicate)
|
||||
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
|
||||
|
||||
for table in tables:
|
||||
for table in sorted(tables):
|
||||
node, source = sources.get(table) or (None, None)
|
||||
|
||||
# if the predicate is in a where statement we can try to push it down
|
||||
|
|
|
@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema
|
|||
|
||||
def qualify_columns(
|
||||
expression: exp.Expression,
|
||||
schema: dict | Schema,
|
||||
schema: t.Dict | Schema,
|
||||
expand_alias_refs: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
) -> exp.Expression:
|
||||
|
@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables):
|
|||
|
||||
def _expand_using(scope, resolver):
|
||||
joins = list(scope.find_all(exp.Join))
|
||||
names = {join.this.alias for join in joins}
|
||||
names = {join.alias_or_name for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
|
@ -105,7 +105,7 @@ def _expand_using(scope, resolver):
|
|||
if not using:
|
||||
continue
|
||||
|
||||
join_table = join.this.alias_or_name
|
||||
join_table = join.alias_or_name
|
||||
|
||||
columns = {}
|
||||
|
||||
|
|
|
@ -91,11 +91,13 @@ def qualify_tables(
|
|||
)
|
||||
elif isinstance(source, Scope) and source.is_udtf:
|
||||
udtf = source.expression
|
||||
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
|
||||
table_alias = udtf.args.get("alias") or exp.TableAlias(
|
||||
this=exp.to_identifier(next_alias_name())
|
||||
)
|
||||
udtf.set("alias", table_alias)
|
||||
|
||||
if not table_alias.name:
|
||||
table_alias.set("this", next_alias_name())
|
||||
table_alias.set("this", exp.to_identifier(next_alias_name()))
|
||||
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
||||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
|
|
|
@ -620,7 +620,7 @@ def _traverse_tables(scope):
|
|||
table_name = expression.name
|
||||
source_name = expression.alias_or_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
if table_name in scope.sources and not expression.db:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
|
||||
# it is pivoted, because then we get back a new table and hence a new source.
|
||||
pivots = expression.args.get("pivots")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue