1
0
Fork 0

Merging upstream version 22.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:29:39 +01:00
parent b13ba670fd
commit 2c28c49d7e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
148 changed files with 68457 additions and 63176 deletions

View file

@ -191,6 +191,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateToDi,
exp.Floor,
exp.Levenshtein,
exp.Sign,
exp.StrPosition,
exp.TsOrDiToDi,
},
@ -262,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
@ -273,15 +275,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
}
NESTED_TYPES = {
@ -380,8 +384,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
self._set_type(col, self.schema.get_column_type(source, col))
elif source and col.table in selects and col.name in selects[col.table]:
self._set_type(col, selects[col.table][col.name].type)
elif source:
if col.table in selects and col.name in selects[col.table]:
self._set_type(col, selects[col.table][col.name].type)
elif isinstance(source.expression, exp.Unnest):
self._set_type(col, source.expression.type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@ -514,7 +521,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
last_datatype = None
for expr in expressions:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expr_type = expr.type
# Stop at the first nested data type found - we don't want to _maybe_coerce nested types
if expr_type.args.get("nested"):
last_datatype = expr_type
break
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
@ -594,7 +608,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
self._annotate_args(expression)
self._set_type(expression, None)
this_type = expression.this.type
if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
for e in this_type.expressions:
if e.name == expression.expression.name:
self._set_type(expression, e.kind)
break
return expression
def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
self._annotate_args(expression)
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
return expression
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
self._annotate_args(expression)
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
return expression

View file

@ -10,13 +10,11 @@ if t.TYPE_CHECKING:
@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
...
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
...
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression, dialect=None):

View file

@ -120,6 +120,8 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
continue
table_alias = derived_table.args.get("alias")
if table_alias:
table_alias.args.pop("columns", None)
@ -214,7 +216,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
(
alias_expr.find(exp.AggFunc)
and (
column.find_ancestor(exp.AggFunc)
and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
)
)
if alias_expr
else False
)
@ -404,7 +412,7 @@ def _expand_stars(
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
elif expression.is_star:
elif expression.is_star and not isinstance(expression, exp.Dot):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
@ -437,7 +445,7 @@ def _expand_stars(
if pivot_columns:
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
alias(exp.column(name, table=pivot.alias), name, copy=False)
for name in pivot_columns
if name not in columns_to_exclude
)
@ -466,7 +474,7 @@ def _expand_stars(
)
# Ensures we don't overwrite the initial selections with an empty list
if new_selections:
if new_selections and isinstance(scope.expression, exp.Select):
scope.expression.set("expressions", new_selections)
@ -528,7 +536,8 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
new_selections.append(selection)
scope.expression.set("expressions", new_selections)
if isinstance(scope.expression, exp.Select):
scope.expression.set("expressions", new_selections)
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
@ -615,7 +624,7 @@ class Resolver:
node, _ = self.scope.selected_sources.get(table_name)
if isinstance(node, exp.Subqueryable):
if isinstance(node, exp.Query):
while node and node.alias != table_name:
node = node.parent

View file

@ -55,8 +55,8 @@ def qualify_tables(
if not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog)
if not isinstance(expression, exp.Subqueryable):
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
if not isinstance(expression, exp.Query):
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
if isinstance(node, exp.Table):
_qualify(node)

View file

@ -138,7 +138,7 @@ class Scope:
and _is_derived_table(node)
):
self._derived_tables.append(node)
elif isinstance(node, exp.Subqueryable):
elif isinstance(node, exp.UNWRAPPED_QUERIES):
self._subqueries.append(node)
self._collected = True
@ -225,7 +225,7 @@ class Scope:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
list[exp.Select | exp.Union]: subqueries
"""
self._ensure_collected()
return self._subqueries
@ -486,8 +486,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
if isinstance(expression, exp.Unionable) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
if isinstance(expression, exp.Query) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
):
return list(_traverse_scope(Scope(expression)))
@ -615,7 +615,7 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
def _traverse_tables(scope):
@ -786,7 +786,7 @@ def walk_in_scope(expression, bfs=True, prune=None):
and _is_derived_table(node)
)
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True

View file

@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str:
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
exp.Anonymous: lambda e: _anonymous(e),
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
@ -1219,6 +1219,20 @@ GEN_MAP = {
}
def _anonymous(e: exp.Anonymous) -> str:
this = e.this
if isinstance(this, str):
name = this.upper()
elif isinstance(this, exp.Identifier):
name = f'"{this.name}"' if this.quoted else this.name.upper()
else:
raise ValueError(
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
)
return f"{name} {','.join(gen(e) for e in e.expressions)}"
def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"

View file

@ -94,8 +94,20 @@ def unnest(select, parent_select, next_alias_name):
else:
_replace(predicate, join_key_not_null)
group = select.args.get("group")
if group:
if {value.this} != set(group.expressions):
select = (
exp.select(exp.column(value.alias, "_q"))
.from_(select.subquery("_q", copy=False), copy=False)
.group_by(exp.column(value.alias, "_q"), copy=False)
)
else:
select = select.group_by(value.this, copy=False)
parent_select.join(
select.group_by(value.this, copy=False),
select,
on=column.eq(join_key),
join_type="LEFT",
join_alias=alias,