1
0
Fork 0

Merging upstream version 25.5.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:41:14 +01:00
parent 298e7a8147
commit 029b9c2c73
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
136 changed files with 80990 additions and 72541 deletions

View file

@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
},
exp.DataType.Type.DATETIME: {
exp.CurrentDatetime,
exp.Datetime,
exp.DatetimeAdd,
exp.DatetimeSub,
},
@ -196,6 +197,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.JSON: {
exp.ParseJSON,
},
exp.DataType.Type.TIME: {
exp.Time,
},
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,

View file

@ -42,7 +42,7 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
and not node.args.get("zone")
):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
if not node.type:
from sqlglot.optimizer.annotate_types import annotate_types

View file

@ -47,7 +47,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if scope.expression.args.get("distinct"):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
if isinstance(scope.expression, exp.SetOperation):
left, right = scope.union_scopes
referenced_columns[left] = parent_selections

View file

@ -60,8 +60,12 @@ def qualify_columns(
_pop_table_column_aliases(scope.derived_tables)
using_column_tables = _expand_using(scope, resolver)
if schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
_expand_alias_refs(
scope,
resolver,
expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
)
_convert_columns_to_dots(scope, resolver)
_qualify_columns(scope, resolver)
@ -148,7 +152,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
for join in joins:
for i, join in enumerate(joins):
using = join.args.get("using")
if not using:
@ -168,6 +172,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
ordered.append(join_table)
join_columns = resolver.get_source_columns(join_table)
conditions = []
using_identifier_count = len(using)
for identifier in using:
identifier = identifier.name
@ -178,9 +183,21 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
raise OptimizeError(f"Cannot automatically join: {identifier}")
table = table or source_table
conditions.append(
exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
)
if i == 0 or using_identifier_count == 1:
lhs: exp.Expression = exp.column(identifier, table=table)
else:
coalesce_columns = [
exp.column(identifier, table=t)
for t in ordered[:-1]
if identifier in resolver.get_source_columns(t)
]
if len(coalesce_columns) > 1:
lhs = exp.func("coalesce", *coalesce_columns)
else:
lhs = exp.column(identifier, table=table)
conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
# Set all values in the dict to None, because we only care about the key ordering
tables = column_tables.setdefault(identifier, {})
@ -196,8 +213,8 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
for column in scope.columns:
if not column.table and column.name in column_tables:
tables = column_tables[column.name]
coalesce = [exp.column(column.name, table=table) for table in tables]
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
coalesce_args = [exp.column(column.name, table=table) for table in tables]
replacement = exp.func("coalesce", *coalesce_args)
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
@ -208,7 +225,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
return column_tables
def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
expression = scope.expression
if not isinstance(expression, exp.Select):
@ -219,7 +236,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
def replace_columns(
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
) -> None:
if not node:
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
return
for column in walk_in_scope(node, prune=lambda node: node.is_star):
@ -583,14 +600,10 @@ def _expand_stars(
if name in using_column_tables and table in using_column_tables[name]:
coalesced_columns.add(name)
tables = using_column_tables[name]
coalesce = [exp.column(name, table=table) for table in tables]
coalesce_args = [exp.column(name, table=table) for table in tables]
new_selections.append(
alias(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
alias=name,
copy=False,
)
alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
)
else:
alias_ = replace_columns.get(table_id, {}).get(name, name)
@ -719,6 +732,7 @@ class Resolver:
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
"""
@ -771,41 +785,49 @@ class Resolver:
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
"""Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
cache_key = (name, only_visible)
if cache_key not in self._get_source_columns_cache:
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
source = self.scope.sources[name]
source = self.scope.sources[name]
if isinstance(source, exp.Table):
columns = self.schema.column_names(source, only_visible)
elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
columns = source.expression.named_selects
if isinstance(source, exp.Table):
columns = self.schema.column_names(source, only_visible)
elif isinstance(source, Scope) and isinstance(
source.expression, (exp.Values, exp.Unnest)
):
columns = source.expression.named_selects
# in bigquery, unnest structs are automatically scoped as tables, so you can
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.schema.dialect == "bigquery":
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
columns.append(k.name)
else:
columns = source.expression.named_selects
# in bigquery, unnest structs are automatically scoped as tables, so you can
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.schema.dialect == "bigquery":
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
columns.append(k.name)
else:
columns = source.expression.named_selects
node, _ = self.scope.selected_sources.get(name) or (None, None)
if isinstance(node, Scope):
column_aliases = node.expression.alias_column_names
elif isinstance(node, exp.Expression):
column_aliases = node.alias_column_names
else:
column_aliases = []
node, _ = self.scope.selected_sources.get(name) or (None, None)
if isinstance(node, Scope):
column_aliases = node.expression.alias_column_names
elif isinstance(node, exp.Expression):
column_aliases = node.alias_column_names
else:
column_aliases = []
if column_aliases:
# If the source's columns are aliased, their aliases shadow the corresponding column names.
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
return [
alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
return columns
if column_aliases:
# If the source's columns are aliased, their aliases shadow the corresponding column names.
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
columns = [
alias or name
for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
self._get_source_columns_cache[cache_key] = columns
return self._get_source_columns_cache[cache_key]
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:

View file

@ -29,7 +29,7 @@ class Scope:
Selection scope.
Attributes:
expression (exp.Select|exp.Union): Root expression of this scope
expression (exp.Select|exp.SetOperation): Root expression of this scope
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
a Table expression or another Scope instance. For example:
SELECT * FROM x {"x": Table(this="x")}
@ -233,7 +233,7 @@ class Scope:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
list[exp.Select | exp.SetOperation]: subqueries
"""
self._ensure_collected()
return self._subqueries
@ -339,7 +339,7 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
if isinstance(self.expression, exp.Union):
if isinstance(self.expression, exp.SetOperation):
left, right = self.union_scopes
self._external_columns = left.external_columns + right.external_columns
else:
@ -535,7 +535,7 @@ def _traverse_scope(scope):
if isinstance(expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(expression, exp.Union):
elif isinstance(expression, exp.SetOperation):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
@ -588,7 +588,7 @@ def _traverse_union(scope):
scope_type=ScopeType.UNION,
)
if isinstance(expression, exp.Union):
if isinstance(expression, exp.SetOperation):
yield from _traverse_ctes(new_scope)
union_scope_stack.append(new_scope)
@ -620,7 +620,7 @@ def _traverse_ctes(scope):
if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
if isinstance(union, exp.SetOperation):
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
child_scope = None

View file

@ -6,7 +6,6 @@ import functools
import itertools
import typing as t
from collections import deque, defaultdict
from decimal import Decimal
from functools import reduce
import sqlglot
@ -347,8 +346,8 @@ def _simplify_comparison(expression, left, right, or_=False):
return expression
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
l = l.to_py()
r = r.to_py()
elif l.is_string and r.is_string:
l = l.name
r = r.name
@ -626,13 +625,8 @@ def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
if isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
if value[0] == "-":
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
return expression.this.this
if type(expression) in INVERSE_DATE_OPS:
return _simplify_binary(expression, expression.this, expression.interval()) or expression
@ -650,7 +644,7 @@ def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
this = expr.this
if isinstance(expr, exp.Cast) and this.is_int:
num = int(this.name)
num = this.to_py()
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
@ -690,8 +684,8 @@ def _simplify_binary(expression, a, b):
return exp.null()
if a.is_number and b.is_number:
num_a = int(a.name) if a.is_int else Decimal(a.name)
num_b = int(b.name) if b.is_int else Decimal(b.name)
num_a = a.to_py()
num_b = b.to_py()
if isinstance(expression, exp.Add):
return exp.Literal.number(num_a + num_b)
@ -1206,7 +1200,7 @@ def _is_date_literal(expression: exp.Expression) -> bool:
def extract_interval(expression):
try:
n = int(expression.name)
n = int(expression.this.to_py())
unit = expression.text("unit").lower()
return interval(unit, n)
except (UnsupportedUnit, ModuleNotFoundError, ValueError):

View file

@ -48,7 +48,7 @@ def unnest(select, parent_select, next_alias_name):
):
return
if isinstance(select, exp.Union):
if isinstance(select, exp.SetOperation):
select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
alias = next_alias_name()