Merging upstream version 25.5.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
298e7a8147
commit
029b9c2c73
136 changed files with 80990 additions and 72541 deletions
|
@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.CurrentDatetime,
|
||||
exp.Datetime,
|
||||
exp.DatetimeAdd,
|
||||
exp.DatetimeSub,
|
||||
},
|
||||
|
@ -196,6 +197,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DataType.Type.JSON: {
|
||||
exp.ParseJSON,
|
||||
},
|
||||
exp.DataType.Type.TIME: {
|
||||
exp.Time,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
|
|
|
@ -42,7 +42,7 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
and not node.args.get("zone")
|
||||
):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.expression:
|
||||
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
|
||||
if not node.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
if scope.expression.args.get("distinct"):
|
||||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
if isinstance(scope.expression, exp.SetOperation):
|
||||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
|
||||
|
|
|
@ -60,8 +60,12 @@ def qualify_columns(
|
|||
_pop_table_column_aliases(scope.derived_tables)
|
||||
using_column_tables = _expand_using(scope, resolver)
|
||||
|
||||
if schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver)
|
||||
if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
|
||||
_expand_alias_refs(
|
||||
scope,
|
||||
resolver,
|
||||
expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
|
||||
)
|
||||
|
||||
_convert_columns_to_dots(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
|
@ -148,7 +152,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
|
||||
|
||||
for join in joins:
|
||||
for i, join in enumerate(joins):
|
||||
using = join.args.get("using")
|
||||
|
||||
if not using:
|
||||
|
@ -168,6 +172,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
ordered.append(join_table)
|
||||
join_columns = resolver.get_source_columns(join_table)
|
||||
conditions = []
|
||||
using_identifier_count = len(using)
|
||||
|
||||
for identifier in using:
|
||||
identifier = identifier.name
|
||||
|
@ -178,9 +183,21 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
raise OptimizeError(f"Cannot automatically join: {identifier}")
|
||||
|
||||
table = table or source_table
|
||||
conditions.append(
|
||||
exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
|
||||
)
|
||||
|
||||
if i == 0 or using_identifier_count == 1:
|
||||
lhs: exp.Expression = exp.column(identifier, table=table)
|
||||
else:
|
||||
coalesce_columns = [
|
||||
exp.column(identifier, table=t)
|
||||
for t in ordered[:-1]
|
||||
if identifier in resolver.get_source_columns(t)
|
||||
]
|
||||
if len(coalesce_columns) > 1:
|
||||
lhs = exp.func("coalesce", *coalesce_columns)
|
||||
else:
|
||||
lhs = exp.column(identifier, table=table)
|
||||
|
||||
conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
|
||||
|
||||
# Set all values in the dict to None, because we only care about the key ordering
|
||||
tables = column_tables.setdefault(identifier, {})
|
||||
|
@ -196,8 +213,8 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
for column in scope.columns:
|
||||
if not column.table and column.name in column_tables:
|
||||
tables = column_tables[column.name]
|
||||
coalesce = [exp.column(column.name, table=table) for table in tables]
|
||||
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
|
||||
coalesce_args = [exp.column(column.name, table=table) for table in tables]
|
||||
replacement = exp.func("coalesce", *coalesce_args)
|
||||
|
||||
# Ensure selects keep their output name
|
||||
if isinstance(column.parent, exp.Select):
|
||||
|
@ -208,7 +225,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
return column_tables
|
||||
|
||||
|
||||
def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
||||
def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
|
||||
expression = scope.expression
|
||||
|
||||
if not isinstance(expression, exp.Select):
|
||||
|
@ -219,7 +236,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
|
||||
) -> None:
|
||||
if not node:
|
||||
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
|
||||
return
|
||||
|
||||
for column in walk_in_scope(node, prune=lambda node: node.is_star):
|
||||
|
@ -583,14 +600,10 @@ def _expand_stars(
|
|||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
coalesced_columns.add(name)
|
||||
tables = using_column_tables[name]
|
||||
coalesce = [exp.column(name, table=table) for table in tables]
|
||||
coalesce_args = [exp.column(name, table=table) for table in tables]
|
||||
|
||||
new_selections.append(
|
||||
alias(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
|
||||
alias=name,
|
||||
copy=False,
|
||||
)
|
||||
alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
|
||||
)
|
||||
else:
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
|
@ -719,6 +732,7 @@ class Resolver:
|
|||
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
||||
self._all_columns: t.Optional[t.Set[str]] = None
|
||||
self._infer_schema = infer_schema
|
||||
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
|
||||
|
||||
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
|
||||
"""
|
||||
|
@ -771,41 +785,49 @@ class Resolver:
|
|||
|
||||
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
||||
"""Resolve the source columns for a given source `name`."""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
cache_key = (name, only_visible)
|
||||
if cache_key not in self._get_source_columns_cache:
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
||||
source = self.scope.sources[name]
|
||||
source = self.scope.sources[name]
|
||||
|
||||
if isinstance(source, exp.Table):
|
||||
columns = self.schema.column_names(source, only_visible)
|
||||
elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
|
||||
columns = source.expression.named_selects
|
||||
if isinstance(source, exp.Table):
|
||||
columns = self.schema.column_names(source, only_visible)
|
||||
elif isinstance(source, Scope) and isinstance(
|
||||
source.expression, (exp.Values, exp.Unnest)
|
||||
):
|
||||
columns = source.expression.named_selects
|
||||
|
||||
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
||||
# directly select a struct field in a query.
|
||||
# this handles the case where the unnest is statically defined.
|
||||
if self.schema.dialect == "bigquery":
|
||||
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in source.expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
||||
# directly select a struct field in a query.
|
||||
# this handles the case where the unnest is statically defined.
|
||||
if self.schema.dialect == "bigquery":
|
||||
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in source.expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
||||
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
||||
if isinstance(node, Scope):
|
||||
column_aliases = node.expression.alias_column_names
|
||||
elif isinstance(node, exp.Expression):
|
||||
column_aliases = node.alias_column_names
|
||||
else:
|
||||
column_aliases = []
|
||||
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
||||
if isinstance(node, Scope):
|
||||
column_aliases = node.expression.alias_column_names
|
||||
elif isinstance(node, exp.Expression):
|
||||
column_aliases = node.alias_column_names
|
||||
else:
|
||||
column_aliases = []
|
||||
|
||||
if column_aliases:
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
||||
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
||||
return [
|
||||
alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
||||
]
|
||||
return columns
|
||||
if column_aliases:
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
||||
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
||||
columns = [
|
||||
alias or name
|
||||
for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
||||
]
|
||||
|
||||
self._get_source_columns_cache[cache_key] = columns
|
||||
|
||||
return self._get_source_columns_cache[cache_key]
|
||||
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
||||
if self._source_columns is None:
|
||||
|
|
|
@ -29,7 +29,7 @@ class Scope:
|
|||
Selection scope.
|
||||
|
||||
Attributes:
|
||||
expression (exp.Select|exp.Union): Root expression of this scope
|
||||
expression (exp.Select|exp.SetOperation): Root expression of this scope
|
||||
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
|
||||
a Table expression or another Scope instance. For example:
|
||||
SELECT * FROM x {"x": Table(this="x")}
|
||||
|
@ -233,7 +233,7 @@ class Scope:
|
|||
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
|
||||
|
||||
Returns:
|
||||
list[exp.Select | exp.Union]: subqueries
|
||||
list[exp.Select | exp.SetOperation]: subqueries
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._subqueries
|
||||
|
@ -339,7 +339,7 @@ class Scope:
|
|||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
if isinstance(self.expression, exp.Union):
|
||||
if isinstance(self.expression, exp.SetOperation):
|
||||
left, right = self.union_scopes
|
||||
self._external_columns = left.external_columns + right.external_columns
|
||||
else:
|
||||
|
@ -535,7 +535,7 @@ def _traverse_scope(scope):
|
|||
|
||||
if isinstance(expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(expression, exp.Union):
|
||||
elif isinstance(expression, exp.SetOperation):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_union(scope)
|
||||
return
|
||||
|
@ -588,7 +588,7 @@ def _traverse_union(scope):
|
|||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
|
||||
if isinstance(expression, exp.Union):
|
||||
if isinstance(expression, exp.SetOperation):
|
||||
yield from _traverse_ctes(new_scope)
|
||||
|
||||
union_scope_stack.append(new_scope)
|
||||
|
@ -620,7 +620,7 @@ def _traverse_ctes(scope):
|
|||
if with_ and with_.recursive:
|
||||
union = cte.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
if isinstance(union, exp.SetOperation):
|
||||
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
child_scope = None
|
||||
|
|
|
@ -6,7 +6,6 @@ import functools
|
|||
import itertools
|
||||
import typing as t
|
||||
from collections import deque, defaultdict
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
|
||||
import sqlglot
|
||||
|
@ -347,8 +346,8 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
return expression
|
||||
|
||||
if l.is_number and r.is_number:
|
||||
l = float(l.name)
|
||||
r = float(r.name)
|
||||
l = l.to_py()
|
||||
r = r.to_py()
|
||||
elif l.is_string and r.is_string:
|
||||
l = l.name
|
||||
r = r.name
|
||||
|
@ -626,13 +625,8 @@ def simplify_literals(expression, root=True):
|
|||
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_binary, root)
|
||||
|
||||
if isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
value = this.name
|
||||
if value[0] == "-":
|
||||
return exp.Literal.number(value[1:])
|
||||
return exp.Literal.number(f"-{value}")
|
||||
if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
|
||||
return expression.this.this
|
||||
|
||||
if type(expression) in INVERSE_DATE_OPS:
|
||||
return _simplify_binary(expression, expression.this, expression.interval()) or expression
|
||||
|
@ -650,7 +644,7 @@ def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
|
|||
this = expr.this
|
||||
|
||||
if isinstance(expr, exp.Cast) and this.is_int:
|
||||
num = int(this.name)
|
||||
num = this.to_py()
|
||||
|
||||
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
|
||||
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
|
||||
|
@ -690,8 +684,8 @@ def _simplify_binary(expression, a, b):
|
|||
return exp.null()
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
num_a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
num_b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
num_a = a.to_py()
|
||||
num_b = b.to_py()
|
||||
|
||||
if isinstance(expression, exp.Add):
|
||||
return exp.Literal.number(num_a + num_b)
|
||||
|
@ -1206,7 +1200,7 @@ def _is_date_literal(expression: exp.Expression) -> bool:
|
|||
|
||||
def extract_interval(expression):
|
||||
try:
|
||||
n = int(expression.name)
|
||||
n = int(expression.this.to_py())
|
||||
unit = expression.text("unit").lower()
|
||||
return interval(unit, n)
|
||||
except (UnsupportedUnit, ModuleNotFoundError, ValueError):
|
||||
|
|
|
@ -48,7 +48,7 @@ def unnest(select, parent_select, next_alias_name):
|
|||
):
|
||||
return
|
||||
|
||||
if isinstance(select, exp.Union):
|
||||
if isinstance(select, exp.SetOperation):
|
||||
select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
|
||||
|
||||
alias = next_alias_name()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue