Merging upstream version 26.8.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d551ab0954
commit
010433ad9a
61 changed files with 43883 additions and 41898 deletions
|
@ -195,7 +195,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
# Maps an exp.SetOperation's id (e.g. UNION) to its projection types. This is computed if the
|
||||
# exp.SetOperation is the expression of a scope source, as selecting from it multiple times
|
||||
# would reprocess the entire subtree to coerce the types of its operands' projections
|
||||
self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType.Type]] = {}
|
||||
self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {}
|
||||
|
||||
def _set_type(
|
||||
self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
|
||||
|
@ -312,18 +312,32 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType.Type:
|
||||
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
|
||||
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
|
||||
) -> exp.DataType | exp.DataType.Type:
|
||||
"""
|
||||
Returns type2 if type1 can be coerced into it, otherwise type1.
|
||||
|
||||
If either type is parameterized (e.g. DECIMAL(18, 2) contains two parameters),
|
||||
we assume type1 does not coerce into type2, so we also return it in this case.
|
||||
"""
|
||||
if isinstance(type1, exp.DataType):
|
||||
if type1.expressions:
|
||||
return type1
|
||||
type1_value = type1.this
|
||||
else:
|
||||
type1_value = type1
|
||||
|
||||
if isinstance(type2, exp.DataType):
|
||||
if type2.expressions:
|
||||
return type1
|
||||
type2_value = type2.this
|
||||
else:
|
||||
type2_value = type2
|
||||
|
||||
# We propagate the UNKNOWN type upwards if found
|
||||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return t.cast(
|
||||
exp.DataType.Type,
|
||||
type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value,
|
||||
)
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
||||
|
||||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
|
|
@ -102,7 +102,10 @@ def qualify_columns(
|
|||
qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope, dialect)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
||||
# DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
|
||||
# https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
|
||||
_expand_order_by_and_distinct_on(scope, resolver)
|
||||
|
||||
if dialect == "bigquery":
|
||||
annotator.annotate_scope(scope)
|
||||
|
@ -359,36 +362,41 @@ def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
|
|||
expression.set("group", group)
|
||||
|
||||
|
||||
def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
||||
order = scope.expression.args.get("order")
|
||||
if not order:
|
||||
return
|
||||
def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
|
||||
for modifier_key in ("order", "distinct"):
|
||||
modifier = scope.expression.args.get(modifier_key)
|
||||
if isinstance(modifier, exp.Distinct):
|
||||
modifier = modifier.args.get("on")
|
||||
|
||||
ordereds = order.expressions
|
||||
for ordered, new_expression in zip(
|
||||
ordereds,
|
||||
_expand_positional_references(
|
||||
scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
|
||||
),
|
||||
):
|
||||
for agg in ordered.find_all(exp.AggFunc):
|
||||
for col in agg.find_all(exp.Column):
|
||||
if not col.table:
|
||||
col.set("table", resolver.get_table(col.name))
|
||||
if not isinstance(modifier, exp.Expression):
|
||||
continue
|
||||
|
||||
ordered.set("this", new_expression)
|
||||
modifier_expressions = modifier.expressions
|
||||
if modifier_key == "order":
|
||||
modifier_expressions = [ordered.this for ordered in modifier_expressions]
|
||||
|
||||
if scope.expression.args.get("group"):
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
|
||||
for original, expanded in zip(
|
||||
modifier_expressions,
|
||||
_expand_positional_references(
|
||||
scope, modifier_expressions, resolver.schema.dialect, alias=True
|
||||
),
|
||||
):
|
||||
for agg in original.find_all(exp.AggFunc):
|
||||
for col in agg.find_all(exp.Column):
|
||||
if not col.table:
|
||||
col.set("table", resolver.get_table(col.name))
|
||||
|
||||
for ordered in ordereds:
|
||||
ordered = ordered.this
|
||||
original.replace(expanded)
|
||||
|
||||
ordered.replace(
|
||||
exp.to_identifier(_select_by_pos(scope, ordered).alias)
|
||||
if ordered.is_int
|
||||
else selects.get(ordered, ordered)
|
||||
)
|
||||
if scope.expression.args.get("group"):
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
|
||||
|
||||
for expression in modifier_expressions:
|
||||
expression.replace(
|
||||
exp.to_identifier(_select_by_pos(scope, expression).alias)
|
||||
if expression.is_int
|
||||
else selects.get(expression, expression)
|
||||
)
|
||||
|
||||
|
||||
def _expand_positional_references(
|
||||
|
|
|
@ -97,12 +97,16 @@ def qualify_tables(
|
|||
source.alias
|
||||
)
|
||||
|
||||
_qualify(source)
|
||||
if pivots:
|
||||
if not pivots[0].alias:
|
||||
pivot_alias = next_alias_name()
|
||||
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
|
||||
|
||||
if pivots and not pivots[0].alias:
|
||||
pivots[0].set(
|
||||
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
|
||||
)
|
||||
# This case corresponds to a pivoted CTE, we don't want to qualify that
|
||||
if isinstance(scope.sources.get(source.alias_or_name), Scope):
|
||||
continue
|
||||
|
||||
_qualify(source)
|
||||
|
||||
if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
|
||||
with csv_reader(source.this) as reader:
|
||||
|
|
|
@ -282,7 +282,14 @@ class Scope:
|
|||
self._columns = []
|
||||
for column in columns + external_columns:
|
||||
ancestor = column.find_ancestor(
|
||||
exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star
|
||||
exp.Select,
|
||||
exp.Qualify,
|
||||
exp.Order,
|
||||
exp.Having,
|
||||
exp.Hint,
|
||||
exp.Table,
|
||||
exp.Star,
|
||||
exp.Distinct,
|
||||
)
|
||||
if (
|
||||
not ancestor
|
||||
|
@ -290,9 +297,9 @@ class Scope:
|
|||
or isinstance(ancestor, exp.Select)
|
||||
or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
|
||||
or (
|
||||
isinstance(ancestor, exp.Order)
|
||||
isinstance(ancestor, (exp.Order, exp.Distinct))
|
||||
and (
|
||||
isinstance(ancestor.parent, exp.Window)
|
||||
isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
|
||||
or column.name not in named_selects
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue