1
0
Fork 0

Merging upstream version 23.12.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:32:41 +01:00
parent 1271e5fe1c
commit 740634a4e8
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
93 changed files with 55455 additions and 52777 deletions

View file

@ -351,55 +351,57 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def annotate(self, expression: E) -> E:
for scope in traverse_scope(expression):
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
elif isinstance(source.expression, exp.Unnest):
values = [source.expression]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
self._set_type(col, self.schema.get_column_type(source, col))
elif source:
if col.table in selects and col.name in selects[col.table]:
self._set_type(col, selects[col.table][col.name].type)
elif isinstance(source.expression, exp.Unnest):
self._set_type(col, source.expression.type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
self.annotate_scope(scope)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def annotate_scope(self, scope: Scope) -> None:
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
elif isinstance(source.expression, exp.Unnest):
values = [source.expression]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
self._set_type(col, self.schema.get_column_type(source, col))
elif source:
if col.table in selects and col.name in selects[col.table]:
self._set_type(col, selects[col.table][col.name].type)
elif isinstance(source.expression, exp.Unnest):
self._set_type(col, source.expression.type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
def _maybe_annotate(self, expression: E) -> E:
if id(expression) in self._visited:
return expression # We've already inferred the expression's type
@ -601,7 +603,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
self._annotate_args(expression)
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
if child and child.is_type(exp.DataType.Type.ARRAY):
expr_type = seq_get(child.type.expressions, 0)
else:
expr_type = None
self._set_type(expression, expr_type)
return expression
def _annotate_struct_value(

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import inspect
import typing as t
import sqlglot
@ -85,7 +86,7 @@ def optimize(
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_params = inspect.getfullargspec(rule).args
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}

View file

@ -7,7 +7,7 @@ from sqlglot import alias, exp
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get, SingleValuedMapping
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.annotate_types import TypeAnnotator
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@ -49,8 +49,10 @@ def qualify_columns(
- Currently only handles a single PIVOT or UNPIVOT operator
"""
schema = ensure_schema(schema)
annotator = TypeAnnotator(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
dialect = Dialect.get_or_raise(schema.dialect)
pseudocolumns = dialect.PSEUDOCOLUMNS
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema, infer_schema=infer_schema)
@ -74,6 +76,9 @@ def qualify_columns(
_expand_group_by(scope)
_expand_order_by(scope, resolver)
if dialect == "bigquery":
annotator.annotate_scope(scope)
return expression
@ -660,11 +665,8 @@ class Resolver:
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.schema.dialect == "bigquery":
expression = source.expression
annotate_types(expression)
if expression.is_type(exp.DataType.Type.STRUCT):
for k in expression.type.expressions: # type: ignore
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
columns.append(k.name)
else:
columns = source.expression.named_selects

View file

@ -6,6 +6,7 @@ import itertools
import typing as t
from collections import deque
from decimal import Decimal
from functools import reduce
import sqlglot
from sqlglot import Dialect, exp
@ -658,17 +659,21 @@ def simplify_parens(expression):
parent = expression.parent
parent_is_predicate = isinstance(parent, exp.Predicate)
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
or (
not isinstance(this, exp.Binary)
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
if (
not isinstance(this, exp.Select)
and not isinstance(parent, exp.SubqueryPredicate)
and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
or (
not isinstance(this, exp.Binary)
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
)
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
)
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
):
return this
return expression
@ -779,6 +784,8 @@ def simplify_concat(expression):
if concat_type is exp.ConcatWs:
new_args = [sep_expr] + new_args
elif isinstance(expression, exp.DPipe):
return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
return concat_type(expressions=new_args, **args)