Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
b13ba670fd
commit
2c28c49d7e
148 changed files with 68457 additions and 63176 deletions
|
@ -191,6 +191,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DateToDi,
|
||||
exp.Floor,
|
||||
exp.Levenshtein,
|
||||
exp.Sign,
|
||||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
|
@ -262,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Div: lambda self, e: self._annotate_div(e),
|
||||
exp.Dot: lambda self, e: self._annotate_dot(e),
|
||||
exp.Explode: lambda self, e: self._annotate_explode(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
|
@ -273,15 +275,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Unnest: lambda self, e: self._annotate_unnest(e),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
|
@ -380,8 +384,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source and col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
@ -514,7 +521,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
last_datatype = None
|
||||
for expr in expressions:
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
||||
expr_type = expr.type
|
||||
|
||||
# Stop at the first nested data type found - we don't want to _maybe_coerce nested types
|
||||
if expr_type.args.get("nested"):
|
||||
last_datatype = expr_type
|
||||
break
|
||||
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
|
||||
|
||||
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
|
||||
|
||||
|
@ -594,7 +608,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
|
||||
self._annotate_args(expression)
|
||||
self._set_type(expression, None)
|
||||
this_type = expression.this.type
|
||||
|
||||
if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
|
||||
for e in this_type.expressions:
|
||||
if e.name == expression.expression.name:
|
||||
self._set_type(expression, e.kind)
|
||||
break
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
|
||||
self._annotate_args(expression)
|
||||
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
|
||||
return expression
|
||||
|
||||
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
|
||||
self._annotate_args(expression)
|
||||
child = seq_get(expression.expressions, 0)
|
||||
self._set_type(expression, child and seq_get(child.type.expressions, 0))
|
||||
return expression
|
||||
|
|
|
@ -10,13 +10,11 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
...
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
|
||||
...
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
|
||||
|
||||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
|
|
|
@ -120,6 +120,8 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
|
|||
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
table_alias.args.pop("columns", None)
|
||||
|
@ -214,7 +216,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
|
||||
double_agg = (
|
||||
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
|
||||
(
|
||||
alias_expr.find(exp.AggFunc)
|
||||
and (
|
||||
column.find_ancestor(exp.AggFunc)
|
||||
and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
|
||||
)
|
||||
)
|
||||
if alias_expr
|
||||
else False
|
||||
)
|
||||
|
@ -404,7 +412,7 @@ def _expand_stars(
|
|||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
_add_replace_columns(expression, tables, replace_columns)
|
||||
elif expression.is_star:
|
||||
elif expression.is_star and not isinstance(expression, exp.Dot):
|
||||
tables = [expression.table]
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
|
@ -437,7 +445,7 @@ def _expand_stars(
|
|||
|
||||
if pivot_columns:
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
alias(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
for name in pivot_columns
|
||||
if name not in columns_to_exclude
|
||||
)
|
||||
|
@ -466,7 +474,7 @@ def _expand_stars(
|
|||
)
|
||||
|
||||
# Ensures we don't overwrite the initial selections with an empty list
|
||||
if new_selections:
|
||||
if new_selections and isinstance(scope.expression, exp.Select):
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
|
@ -528,7 +536,8 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
|
||||
new_selections.append(selection)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
|
||||
|
@ -615,7 +624,7 @@ class Resolver:
|
|||
|
||||
node, _ = self.scope.selected_sources.get(table_name)
|
||||
|
||||
if isinstance(node, exp.Subqueryable):
|
||||
if isinstance(node, exp.Query):
|
||||
while node and node.alias != table_name:
|
||||
node = node.parent
|
||||
|
||||
|
|
|
@ -55,8 +55,8 @@ def qualify_tables(
|
|||
if not table.args.get("catalog") and table.args.get("db"):
|
||||
table.set("catalog", catalog)
|
||||
|
||||
if not isinstance(expression, exp.Subqueryable):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
|
||||
if not isinstance(expression, exp.Query):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
|
||||
if isinstance(node, exp.Table):
|
||||
_qualify(node)
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ class Scope:
|
|||
and _is_derived_table(node)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
elif isinstance(node, exp.UNWRAPPED_QUERIES):
|
||||
self._subqueries.append(node)
|
||||
|
||||
self._collected = True
|
||||
|
@ -225,7 +225,7 @@ class Scope:
|
|||
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
|
||||
|
||||
Returns:
|
||||
list[exp.Subqueryable]: subqueries
|
||||
list[exp.Select | exp.Union]: subqueries
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._subqueries
|
||||
|
@ -486,8 +486,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
Returns:
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.Unionable) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
|
||||
if isinstance(expression, exp.Query) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
|
||||
):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
@ -615,7 +615,7 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
|
|||
as it doesn't introduce a new scope. If an alias is present, it shadows all names
|
||||
under the Subquery, so that's one exception to this rule.
|
||||
"""
|
||||
return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
|
||||
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
|
@ -786,7 +786,7 @@ def walk_in_scope(expression, bfs=True, prune=None):
|
|||
and _is_derived_table(node)
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
or isinstance(node, exp.UNWRAPPED_QUERIES)
|
||||
):
|
||||
crossed_scope_boundary = True
|
||||
|
||||
|
|
|
@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str:
|
|||
GEN_MAP = {
|
||||
exp.Add: lambda e: _binary(e, "+"),
|
||||
exp.And: lambda e: _binary(e, "AND"),
|
||||
exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
|
||||
exp.Anonymous: lambda e: _anonymous(e),
|
||||
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
|
||||
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
|
||||
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
|
||||
|
@ -1219,6 +1219,20 @@ GEN_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def _anonymous(e: exp.Anonymous) -> str:
|
||||
this = e.this
|
||||
if isinstance(this, str):
|
||||
name = this.upper()
|
||||
elif isinstance(this, exp.Identifier):
|
||||
name = f'"{this.name}"' if this.quoted else this.name.upper()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
|
||||
)
|
||||
|
||||
return f"{name} {','.join(gen(e) for e in e.expressions)}"
|
||||
|
||||
|
||||
def _binary(e: exp.Binary, op: str) -> str:
|
||||
return f"{gen(e.left)} {op} {gen(e.right)}"
|
||||
|
||||
|
|
|
@ -94,8 +94,20 @@ def unnest(select, parent_select, next_alias_name):
|
|||
else:
|
||||
_replace(predicate, join_key_not_null)
|
||||
|
||||
group = select.args.get("group")
|
||||
|
||||
if group:
|
||||
if {value.this} != set(group.expressions):
|
||||
select = (
|
||||
exp.select(exp.column(value.alias, "_q"))
|
||||
.from_(select.subquery("_q", copy=False), copy=False)
|
||||
.group_by(exp.column(value.alias, "_q"), copy=False)
|
||||
)
|
||||
else:
|
||||
select = select.group_by(value.this, copy=False)
|
||||
|
||||
parent_select.join(
|
||||
select.group_by(value.this, copy=False),
|
||||
select,
|
||||
on=column.eq(join_key),
|
||||
join_type="LEFT",
|
||||
join_alias=alias,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue