1
0
Fork 0

Merging upstream version 25.26.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:56:02 +01:00
parent 9138e4b92a
commit 829a709061
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
117 changed files with 49296 additions and 47316 deletions

View file

@ -287,15 +287,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType:
) -> exp.DataType.Type:
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
# We propagate the UNKNOWN type upwards if found
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.build("unknown")
return exp.DataType.Type.UNKNOWN
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
return t.cast(
exp.DataType.Type,
type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value,
)
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)

View file

@ -1,11 +1,18 @@
from __future__ import annotations
import itertools
import typing as t
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.scope import Scope, build_scope
if t.TYPE_CHECKING:
ExistingCTEsMapping = t.Dict[exp.Expression, str]
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
def eliminate_subqueries(expression):
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
"""
Rewrite derived tables as CTES, deduplicating if possible.
@ -38,7 +45,7 @@ def eliminate_subqueries(expression):
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
taken = {}
taken: TakenNameMapping = {}
# All CTE aliases in the root scope are taken
for scope in root.cte_scopes:
@ -56,7 +63,7 @@ def eliminate_subqueries(expression):
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {}
existing_ctes: ExistingCTEsMapping = {}
with_ = root.expression.args.get("with")
recursive = False
@ -95,15 +102,21 @@ def eliminate_subqueries(expression):
return expression
def _eliminate(scope, existing_ctes, taken):
def _eliminate(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)
return None
def _eliminate_derived_table(scope, existing_ctes, taken):
def _eliminate_derived_table(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
# - eliminate a lateral correlated subquery
@ -121,7 +134,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
return cte
def _eliminate_cte(scope, existing_ctes, taken):
def _eliminate_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
@ -140,7 +155,9 @@ def _eliminate_cte(scope, existing_ctes, taken):
return cte
def _new_cte(scope, existing_ctes, taken):
def _new_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Tuple[str, t.Optional[exp.Expression]]:
"""
Returns:
tuple of (name, cte)

View file

@ -1,11 +1,20 @@
from __future__ import annotations
import typing as t
from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
if t.TYPE_CHECKING:
from sqlglot._typing import E
def merge_subqueries(expression, leave_tables_isolated=False):
FromOrJoin = t.Union[exp.From, exp.Join]
def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
"""
Rewrite sqlglot AST to merge derived tables into the outer query.
@ -58,7 +67,7 @@ SAFE_TO_REPLACE_UNWRAPPED = (
)
def merge_ctes(expression, leave_tables_isolated=False):
def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
scopes = traverse_scope(expression)
# All places where we select from CTEs.
@ -92,7 +101,7 @@ def merge_ctes(expression, leave_tables_isolated=False):
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
@ -111,17 +120,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
return expression
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
def _mergeable(
outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
) -> bool:
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
inner_select = inner_scope.expression.unnest()
@ -195,7 +198,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not outer_scope.expression.is_star
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and inner_select.args.get("from") is not None
and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
@ -218,19 +221,17 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
)
def _rename_inner_sources(outer_scope, inner_scope, alias):
def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Renames any sources in the inner query that conflict with names in the outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources))
inner_taken = set(inner_scope.selected_sources)
outer_taken = set(outer_scope.selected_sources)
conflicts = outer_taken.intersection(inner_taken)
conflicts -= {alias}
taken = outer_taken.union(inner_taken)
for conflict in conflicts:
new_name = find_new_name(taken, conflict)
@ -250,15 +251,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name)
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
def _merge_from(
outer_scope: Scope,
inner_scope: Scope,
node_to_replace: t.Union[exp.Subquery, exp.Table],
alias: str,
) -> None:
"""
Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args["from"].this
new_subquery.set("joins", node_to_replace.args.get("joins"))
@ -274,14 +274,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
)
def _merge_joins(outer_scope, inner_scope, from_or_join):
def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge JOIN clauses of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
from_or_join (exp.From|exp.Join)
"""
new_joins = []
@ -304,7 +299,7 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
outer_scope.expression.set("joins", outer_joins)
def _merge_expressions(outer_scope, inner_scope, alias):
def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Merge projections of inner query into outer query.
@ -338,7 +333,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
column.replace(expression.copy())
def _merge_where(outer_scope, inner_scope, from_or_join):
def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge WHERE clause of inner query into outer query.
@ -357,7 +352,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
sources = {from_.alias_or_name} if from_ else {}
sources = {from_.alias_or_name} if from_ else set()
for join in expression.args["joins"]:
source = join.alias_or_name
@ -373,7 +368,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
expression.where(where.this, copy=False)
def _merge_order(outer_scope, inner_scope):
def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
"""
Merge ORDER clause of inner query into outer query.
@ -393,7 +388,7 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _merge_hints(outer_scope, inner_scope):
def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
inner_scope_hint = inner_scope.expression.args.get("hint")
if not inner_scope_hint:
return
@ -405,7 +400,7 @@ def _merge_hints(outer_scope, inner_scope):
outer_scope.expression.set("hint", inner_scope_hint)
def _pop_cte(inner_scope):
def _pop_cte(inner_scope: Scope) -> None:
"""
Remove CTE from the AST.

View file

@ -27,6 +27,7 @@ def qualify(
infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
allow_partial_qualification: bool = False,
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
@ -56,6 +57,7 @@ def qualify(
infer_schema: Whether to infer the schema if missing.
isolate_tables: Whether to isolate table selects.
qualify_columns: Whether to qualify columns.
allow_partial_qualification: Whether to allow partial qualification.
validate_qualify_columns: Whether to validate columns.
quote_identifiers: Whether to run the quote_identifiers step.
This step is necessary to ensure correctness for case sensitive queries.
@ -90,6 +92,7 @@ def qualify(
expand_alias_refs=expand_alias_refs,
expand_stars=expand_stars,
infer_schema=infer_schema,
allow_partial_qualification=allow_partial_qualification,
)
if quote_identifiers:

View file

@ -22,6 +22,7 @@ def qualify_columns(
expand_alias_refs: bool = True,
expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
allow_partial_qualification: bool = False,
) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
@ -41,6 +42,7 @@ def qualify_columns(
for most of the optimizer's rules to work; do not set to False unless you
know what you're doing!
infer_schema: Whether to infer the schema if missing.
allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
@ -68,7 +70,7 @@ def qualify_columns(
)
_convert_columns_to_dots(scope, resolver)
_qualify_columns(scope, resolver)
_qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
@ -240,13 +242,21 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
def replace_columns(
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
) -> None:
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
is_group_by = isinstance(node, exp.Group)
if not node or (expand_only_groupby and not is_group_by):
return
for column in walk_in_scope(node, prune=lambda node: node.is_star):
if not isinstance(column, exp.Column):
continue
# BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
# SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
# SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col))
# This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
if expand_only_groupby and is_group_by and column.parent is not node:
continue
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
@ -273,9 +283,8 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
if simplified is not column:
column.replace(simplified)
for i, projection in enumerate(scope.expression.selects):
for i, projection in enumerate(expression.selects):
replace_columns(projection)
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = (projection.this, i + 1)
@ -434,7 +443,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
scope.clear_cache()
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
column_table = column.table
@ -442,7 +451,12 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
if column_table and column_table in scope.sources:
source_columns = resolver.get_source_columns(column_table)
if source_columns and column_name not in source_columns and "*" not in source_columns:
if (
not allow_partial_qualification
and source_columns
and column_name not in source_columns
and "*" not in source_columns
):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
@ -526,7 +540,7 @@ def _expand_stars(
) -> None:
"""Expand stars to lists of column selections"""
new_selections = []
new_selections: t.List[exp.Expression] = []
except_columns: t.Dict[int, t.Set[str]] = {}
replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
rename_columns: t.Dict[int, t.Dict[str, str]] = {}

View file

@ -562,8 +562,8 @@ def _traverse_scope(scope):
elif isinstance(expression, exp.DML):
yield from _traverse_ctes(scope)
for query in find_all_in_scope(expression, exp.Query):
# This check ensures we don't yield the CTE queries twice
if not isinstance(query.parent, exp.CTE):
# This check ensures we don't yield the CTE/nested queries twice
if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
return
else:
@ -679,6 +679,8 @@ def _traverse_tables(scope):
expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions:
if isinstance(expression, exp.Final):
expression = expression.this
if isinstance(expression, exp.Table):
table_name = expression.name
source_name = expression.alias_or_name

View file

@ -206,6 +206,11 @@ COMPLEMENT_COMPARISONS = {
exp.NEQ: exp.EQ,
}
COMPLEMENT_SUBQUERY_PREDICATES = {
exp.All: exp.Any,
exp.Any: exp.All,
}
def simplify_not(expression):
"""
@ -218,9 +223,12 @@ def simplify_not(expression):
if is_null(this):
return exp.null()
if this.__class__ in COMPLEMENT_COMPARISONS:
return COMPLEMENT_COMPARISONS[this.__class__](
this=this.this, expression=this.expression
)
right = this.expression
complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
if complement_subquery_predicate:
right = complement_subquery_predicate(this=right.this)
return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
if isinstance(this, exp.Paren):
condition = this.unnest()
if isinstance(condition, exp.And):