1
0
Fork 0

Merging upstream version 16.2.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 16:00:51 +01:00
parent c12f551e31
commit 718a80b164
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
106 changed files with 41940 additions and 40162 deletions

View file

@ -1,13 +1,25 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
def annotate_types(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
) -> E:
"""
Recursively infer & annotate types in an expression syntax tree against a schema.
Assumes that we've already executed the optimizer's qualify_columns step.
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot
@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
<Type.DOUBLE: 'DOUBLE'>
Args:
expression (sqlglot.Expression): Expression to annotate.
schema (dict|sqlglot.optimizer.Schema): Database schema.
annotators (dict): Maps expression type to corresponding annotation function.
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
expression: Expression to annotate.
schema: Database schema.
annotators: Maps expression type to corresponding annotation function.
coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
The expression annotated with types.
"""
schema = ensure_schema(schema)
@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
class TypeAnnotator:
ANNOTATORS = {
**{
expr_type: lambda self, expr: self._annotate_unary(expr)
for expr_type in subclasses(exp.__name__, exp.Unary)
},
**{
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.Sum: lambda self, expr: self._annotate_by_args(
expr, "this", "expressions", promote=True
),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BOOLEAN
),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.StrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
return lambda self, e: self._annotate_with_type(e, data_type)
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
COERCES_TO = {
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
class _TypeAnnotator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
# Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
# https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
text_precedence = (
exp.DataType.Type.TEXT,
},
exp.DataType.Type.CHAR: {
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.CHAR,
)
numeric_precedence = (
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.INT: {
exp.DataType.Type.FLOAT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.SMALLINT: {
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.TINYINT: {
exp.DataType.Type.SMALLINT,
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TINYINT,
)
timelike_precedence = (
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
exp.DataType.Type.DATETIME,
exp.DataType.Type.DATE,
)
for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
coerces_to = set()
for data_type in type_precedence:
klass.COERCES_TO[data_type] = coerces_to.copy()
coerces_to |= {data_type}
return klass
class TypeAnnotator(metaclass=_TypeAnnotator):
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
exp.DataType.Type.BIGINT: {
exp.ApproxDistinct,
exp.ArraySize,
exp.Count,
exp.Length,
},
exp.DataType.Type.BOOLEAN: {
exp.Between,
exp.Boolean,
exp.In,
exp.RegexpLike,
},
exp.DataType.Type.DATE: {
exp.DataType.Type.DATETIME,
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
exp.CurrentDate,
exp.Date,
exp.DateAdd,
exp.DateStrToDate,
exp.DateSub,
exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
exp.TimeStrToDate,
exp.TsOrDsToDate,
},
exp.DataType.Type.DATETIME: {
exp.CurrentDatetime,
exp.DatetimeAdd,
exp.DatetimeSub,
},
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
exp.Exp,
exp.Ln,
exp.Log,
exp.Log2,
exp.Log10,
exp.Pow,
exp.Quantile,
exp.Round,
exp.SafeDivide,
exp.Sqrt,
exp.Stddev,
exp.StddevPop,
exp.StddevSamp,
exp.Variance,
exp.VariancePop,
},
exp.DataType.Type.INT: {
exp.Ceil,
exp.DateDiff,
exp.DatetimeDiff,
exp.Extract,
exp.TimestampDiff,
exp.TimeDiff,
exp.DateToDi,
exp.Floor,
exp.Levenshtein,
exp.StrPosition,
exp.TsOrDiToDi,
},
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,
exp.StrToTime,
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
},
exp.DataType.Type.TINYINT: {
exp.Day,
exp.Month,
exp.Week,
exp.Year,
},
exp.DataType.Type.VARCHAR: {
exp.ArrayConcat,
exp.Concat,
exp.ConcatWs,
exp.DateToDateStr,
exp.GroupConcat,
exp.Initcap,
exp.Lower,
exp.SafeConcat,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
exp.Trim,
exp.TsOrDsToDateStr,
exp.UnixToStr,
exp.UnixToTimeStr,
exp.Upper,
},
}
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
ANNOTATORS = {
**{
expr_type: lambda self, e: self._annotate_unary(e)
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
},
**{
expr_type: lambda self, e: self._annotate_binary(e)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
**{
expr_type: _annotate_with_type_lambda(data_type)
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
for expr_type in expressions
},
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
def __init__(self, schema=None, annotators=None, coerces_to=None):
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
def __init__(
self,
schema: Schema,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
def annotate(self, expression: E) -> E:
for scope in traverse_scope(expression):
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
values = source.expression.expressions[0].expressions
if not values:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
elif source and col.table in selects and col.name in selects[col.table]:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
elif source and col.table in selects and col.name in selects[col.table]:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
def _maybe_annotate(self, expression: E) -> E:
if expression.type:
return expression # We've already inferred the expression's type
@ -312,13 +290,15 @@ class TypeAnnotator:
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
)
def _annotate_args(self, expression):
def _annotate_args(self, expression: E) -> E:
for _, value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
def _maybe_coerce(self, type1, type2):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
@ -330,9 +310,14 @@ class TypeAnnotator:
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to.get(type1, {}) else type1
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
def _annotate_binary(self, expression):
# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
@t.no_type_check
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left_type = expression.left.type.this
@ -354,7 +339,8 @@ class TypeAnnotator:
return expression
def _annotate_unary(self, expression):
@t.no_type_check
def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
@ -364,7 +350,8 @@ class TypeAnnotator:
return expression
def _annotate_literal(self, expression):
@t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
@ -374,13 +361,16 @@ class TypeAnnotator:
return expression
def _annotate_with_type(self, expression, target_type):
@t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
expression.type = target_type
return self._annotate_args(expression)
def _annotate_by_args(self, expression, *args, promote=False):
@t.no_type_check
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
self._annotate_args(expression)
expressions = []
expressions: t.List[exp.Expression] = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)

View file

@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
node = exp.Concat(expressions=[node.left, node.right])
return node

View file

@ -32,7 +32,7 @@ def eliminate_joins(expression):
# Reverse the joins so we can remove chains of unused joins
for join in reversed(joins):
alias = join.this.alias_or_name
alias = join.alias_or_name
if _should_eliminate_join(scope, join, alias):
join.pop()
scope.remove_source(alias)
@ -126,7 +126,7 @@ def join_condition(join):
tuple[list[str], list[str], exp.Expression]:
Tuple of (source key, join key, remaining predicate)
"""
name = join.this.alias_or_name
name = join.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
source_key = []
join_key = []

View file

@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
alias(source, source.name or source.alias, table=True),
alias(source, source.alias_or_name, table=True),
copy=False,
)
.subquery(source.alias, copy=False)

View file

@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
if not isinstance(from_or_join, exp.Join):
return False
alias = from_or_join.this.alias_or_name
alias = from_or_join.alias_or_name
on = from_or_join.args.get("on")
if not on:
@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
"""
new_joins = []
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
for subquery in comma_joins:
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
joins = inner_scope.expression.args.get("joins") or []
for join in joins:
@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if source == from_or_join.alias_or_name:
break
if set(exp.column_table_names(where.this)) <= sources:
if exp.column_table_names(where.this) <= sources:
from_or_join.on(where.this, copy=False)
from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):

View file

@ -1,3 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.helper import tsort
@ -13,25 +17,28 @@ def optimize_joins(expression):
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
"""
for select in expression.find_all(exp.Select):
references = {}
cross_joins = []
for join in select.args.get("joins", []):
name = join.this.alias_or_name
tables = other_table_names(join, name)
tables = other_table_names(join)
if tables:
for table in tables:
references[table] = references.get(table, []) + [join]
else:
cross_joins.append((name, join))
cross_joins.append((join.alias_or_name, join))
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
if isinstance(on, exp.Connector):
if len(other_table_names(dep)) < 2:
continue
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
@ -47,17 +54,12 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
head = from_.this
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
for name, join in joins.items():
dag[name] = other_table_names(join, name)
joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {name: other_table_names(join) for name, join in joins.items()}
parent.set(
"joins",
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
[joins[name] for name in tsort(dag) if name != from_.alias_or_name],
)
return expression
@ -75,9 +77,6 @@ def normalize(expression):
return expression
def other_table_names(join, exclude):
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude
]
def other_table_names(join: exp.Join) -> t.Set[str]:
on = join.args.get("on")
return exp.column_table_names(on, join.alias_or_name) if on else set()

View file

@ -78,7 +78,7 @@ def optimize(
"schema": schema,
"dialect": dialect,
"isolate_tables": True, # needed for other optimizations to perform well
"quote_identifiers": False, # this happens in canonicalize
"quote_identifiers": False,
**kwargs,
}

View file

@ -41,7 +41,7 @@ def pushdown_predicates(expression):
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
name = join.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
pushdown_tables = set()
for a in predicates:
a_tables = set(exp.column_table_names(a))
a_tables = exp.column_table_names(a)
for b in predicates:
a_tables &= set(exp.column_table_names(b))
a_tables &= exp.column_table_names(b)
pushdown_tables.update(a_tables)
@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
for table in tables:
for table in sorted(tables):
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down

View file

@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema
def qualify_columns(
expression: exp.Expression,
schema: dict | Schema,
schema: t.Dict | Schema,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to an ordered set of source names (dict).
@ -105,7 +105,7 @@ def _expand_using(scope, resolver):
if not using:
continue
join_table = join.this.alias_or_name
join_table = join.alias_or_name
columns = {}

View file

@ -91,11 +91,13 @@ def qualify_tables(
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
table_alias = udtf.args.get("alias") or exp.TableAlias(
this=exp.to_identifier(next_alias_name())
)
udtf.set("alias", table_alias)
if not table_alias.name:
table_alias.set("this", next_alias_name())
table_alias.set("this", exp.to_identifier(next_alias_name()))
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))

View file

@ -620,7 +620,7 @@ def _traverse_tables(scope):
table_name = expression.name
source_name = expression.alias_or_name
if table_name in scope.sources:
if table_name in scope.sources and not expression.db:
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")