Merging upstream version 23.12.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
1271e5fe1c
commit
740634a4e8
93 changed files with 55455 additions and 52777 deletions
|
@ -351,55 +351,57 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def annotate(self, expression: E) -> E:
|
||||
for scope in traverse_scope(expression):
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
self.annotate_scope(scope)
|
||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def annotate_scope(self, scope: Scope) -> None:
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
def _maybe_annotate(self, expression: E) -> E:
|
||||
if id(expression) in self._visited:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
@ -601,7 +603,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
|
||||
self._annotate_args(expression)
|
||||
child = seq_get(expression.expressions, 0)
|
||||
self._set_type(expression, child and seq_get(child.type.expressions, 0))
|
||||
|
||||
if child and child.is_type(exp.DataType.Type.ARRAY):
|
||||
expr_type = seq_get(child.type.expressions, 0)
|
||||
else:
|
||||
expr_type = None
|
||||
|
||||
self._set_type(expression, expr_type)
|
||||
return expression
|
||||
|
||||
def _annotate_struct_value(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
|
@ -85,7 +86,7 @@ def optimize(
|
|||
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_params = inspect.getfullargspec(rule).args
|
||||
rule_kwargs = {
|
||||
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ from sqlglot import alias, exp
|
|||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get, SingleValuedMapping
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
@ -49,8 +49,10 @@ def qualify_columns(
|
|||
- Currently only handles a single PIVOT or UNPIVOT operator
|
||||
"""
|
||||
schema = ensure_schema(schema)
|
||||
annotator = TypeAnnotator(schema)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
|
||||
dialect = Dialect.get_or_raise(schema.dialect)
|
||||
pseudocolumns = dialect.PSEUDOCOLUMNS
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
||||
|
@ -74,6 +76,9 @@ def qualify_columns(
|
|||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
||||
if dialect == "bigquery":
|
||||
annotator.annotate_scope(scope)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -660,11 +665,8 @@ class Resolver:
|
|||
# directly select a struct field in a query.
|
||||
# this handles the case where the unnest is statically defined.
|
||||
if self.schema.dialect == "bigquery":
|
||||
expression = source.expression
|
||||
annotate_types(expression)
|
||||
|
||||
if expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in expression.type.expressions: # type: ignore
|
||||
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in source.expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
|
|
@ -6,6 +6,7 @@ import itertools
|
|||
import typing as t
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import Dialect, exp
|
||||
|
@ -658,17 +659,21 @@ def simplify_parens(expression):
|
|||
parent = expression.parent
|
||||
parent_is_predicate = isinstance(parent, exp.Predicate)
|
||||
|
||||
if not isinstance(this, exp.Select) and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(parent, exp.Paren)
|
||||
or (
|
||||
not isinstance(this, exp.Binary)
|
||||
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
||||
if (
|
||||
not isinstance(this, exp.Select)
|
||||
and not isinstance(parent, exp.SubqueryPredicate)
|
||||
and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(parent, exp.Paren)
|
||||
or (
|
||||
not isinstance(this, exp.Binary)
|
||||
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
||||
)
|
||||
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
)
|
||||
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
):
|
||||
return this
|
||||
return expression
|
||||
|
@ -779,6 +784,8 @@ def simplify_concat(expression):
|
|||
|
||||
if concat_type is exp.ConcatWs:
|
||||
new_args = [sep_expr] + new_args
|
||||
elif isinstance(expression, exp.DPipe):
|
||||
return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
|
||||
|
||||
return concat_type(expressions=new_args, **args)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue