Merging upstream version 25.26.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
9138e4b92a
commit
829a709061
117 changed files with 49296 additions and 47316 deletions
|
@ -287,15 +287,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType:
|
||||
) -> exp.DataType.Type:
|
||||
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
|
||||
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
|
||||
|
||||
# We propagate the UNKNOWN type upwards if found
|
||||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.build("unknown")
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
||||
return t.cast(
|
||||
exp.DataType.Type,
|
||||
type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value,
|
||||
)
|
||||
|
||||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.scope import Scope, build_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
ExistingCTEsMapping = t.Dict[exp.Expression, str]
|
||||
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
|
||||
|
||||
|
||||
def eliminate_subqueries(expression):
|
||||
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Rewrite derived tables as CTES, deduplicating if possible.
|
||||
|
||||
|
@ -38,7 +45,7 @@ def eliminate_subqueries(expression):
|
|||
# Map of alias->Scope|Table
|
||||
# These are all aliases that are already used in the expression.
|
||||
# We don't want to create new CTEs that conflict with these names.
|
||||
taken = {}
|
||||
taken: TakenNameMapping = {}
|
||||
|
||||
# All CTE aliases in the root scope are taken
|
||||
for scope in root.cte_scopes:
|
||||
|
@ -56,7 +63,7 @@ def eliminate_subqueries(expression):
|
|||
|
||||
# Map of Expression->alias
|
||||
# Existing CTES in the root expression. We'll use this for deduplication.
|
||||
existing_ctes = {}
|
||||
existing_ctes: ExistingCTEsMapping = {}
|
||||
|
||||
with_ = root.expression.args.get("with")
|
||||
recursive = False
|
||||
|
@ -95,15 +102,21 @@ def eliminate_subqueries(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _eliminate(scope, existing_ctes, taken):
|
||||
def _eliminate(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if scope.is_derived_table:
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_cte:
|
||||
return _eliminate_cte(scope, existing_ctes, taken)
|
||||
|
||||
return None
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
|
||||
def _eliminate_derived_table(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Optional[exp.Expression]:
|
||||
# This makes sure that we don't:
|
||||
# - drop the "pivot" arg from a pivoted subquery
|
||||
# - eliminate a lateral correlated subquery
|
||||
|
@ -121,7 +134,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
|
|||
return cte
|
||||
|
||||
|
||||
def _eliminate_cte(scope, existing_ctes, taken):
|
||||
def _eliminate_cte(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Optional[exp.Expression]:
|
||||
parent = scope.expression.parent
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
|
@ -140,7 +155,9 @@ def _eliminate_cte(scope, existing_ctes, taken):
|
|||
return cte
|
||||
|
||||
|
||||
def _new_cte(scope, existing_ctes, taken):
|
||||
def _new_cte(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Tuple[str, t.Optional[exp.Expression]]:
|
||||
"""
|
||||
Returns:
|
||||
tuple of (name, cte)
|
||||
|
|
|
@ -1,11 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
def merge_subqueries(expression, leave_tables_isolated=False):
|
||||
FromOrJoin = t.Union[exp.From, exp.Join]
|
||||
|
||||
|
||||
def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to merge derived tables into the outer query.
|
||||
|
||||
|
@ -58,7 +67,7 @@ SAFE_TO_REPLACE_UNWRAPPED = (
|
|||
)
|
||||
|
||||
|
||||
def merge_ctes(expression, leave_tables_isolated=False):
|
||||
def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
|
||||
scopes = traverse_scope(expression)
|
||||
|
||||
# All places where we select from CTEs.
|
||||
|
@ -92,7 +101,7 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
return expression
|
||||
|
||||
|
||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||
def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
|
||||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
|
@ -111,17 +120,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
def _mergeable(
|
||||
outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (Scope)
|
||||
inner_scope (Scope)
|
||||
leave_tables_isolated (bool)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
|
||||
|
@ -195,7 +198,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not outer_scope.expression.is_star
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and inner_select.args.get("from") is not None
|
||||
and not outer_scope.pivots
|
||||
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
|
@ -218,19 +221,17 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
)
|
||||
|
||||
|
||||
def _rename_inner_sources(outer_scope, inner_scope, alias):
|
||||
def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
|
||||
"""
|
||||
Renames any sources in the inner query that conflict with names in the outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
alias (str)
|
||||
"""
|
||||
taken = set(outer_scope.selected_sources)
|
||||
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
||||
inner_taken = set(inner_scope.selected_sources)
|
||||
outer_taken = set(outer_scope.selected_sources)
|
||||
conflicts = outer_taken.intersection(inner_taken)
|
||||
conflicts -= {alias}
|
||||
|
||||
taken = outer_taken.union(inner_taken)
|
||||
|
||||
for conflict in conflicts:
|
||||
new_name = find_new_name(taken, conflict)
|
||||
|
||||
|
@ -250,15 +251,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
inner_scope.rename_source(conflict, new_name)
|
||||
|
||||
|
||||
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
||||
def _merge_from(
|
||||
outer_scope: Scope,
|
||||
inner_scope: Scope,
|
||||
node_to_replace: t.Union[exp.Subquery, exp.Table],
|
||||
alias: str,
|
||||
) -> None:
|
||||
"""
|
||||
Merge FROM clause of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
node_to_replace (exp.Subquery|exp.Table)
|
||||
alias (str)
|
||||
"""
|
||||
new_subquery = inner_scope.expression.args["from"].this
|
||||
new_subquery.set("joins", node_to_replace.args.get("joins"))
|
||||
|
@ -274,14 +274,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
)
|
||||
|
||||
|
||||
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
||||
def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
|
||||
"""
|
||||
Merge JOIN clauses of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
"""
|
||||
|
||||
new_joins = []
|
||||
|
@ -304,7 +299,7 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
|
|||
outer_scope.expression.set("joins", outer_joins)
|
||||
|
||||
|
||||
def _merge_expressions(outer_scope, inner_scope, alias):
|
||||
def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
|
||||
"""
|
||||
Merge projections of inner query into outer query.
|
||||
|
||||
|
@ -338,7 +333,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
|
|||
column.replace(expression.copy())
|
||||
|
||||
|
||||
def _merge_where(outer_scope, inner_scope, from_or_join):
|
||||
def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
|
||||
"""
|
||||
Merge WHERE clause of inner query into outer query.
|
||||
|
||||
|
@ -357,7 +352,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
# Merge predicates from an outer join to the ON clause
|
||||
# if it only has columns that are already joined
|
||||
from_ = expression.args.get("from")
|
||||
sources = {from_.alias_or_name} if from_ else {}
|
||||
sources = {from_.alias_or_name} if from_ else set()
|
||||
|
||||
for join in expression.args["joins"]:
|
||||
source = join.alias_or_name
|
||||
|
@ -373,7 +368,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
expression.where(where.this, copy=False)
|
||||
|
||||
|
||||
def _merge_order(outer_scope, inner_scope):
|
||||
def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
|
||||
"""
|
||||
Merge ORDER clause of inner query into outer query.
|
||||
|
||||
|
@ -393,7 +388,7 @@ def _merge_order(outer_scope, inner_scope):
|
|||
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
||||
|
||||
|
||||
def _merge_hints(outer_scope, inner_scope):
|
||||
def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
|
||||
inner_scope_hint = inner_scope.expression.args.get("hint")
|
||||
if not inner_scope_hint:
|
||||
return
|
||||
|
@ -405,7 +400,7 @@ def _merge_hints(outer_scope, inner_scope):
|
|||
outer_scope.expression.set("hint", inner_scope_hint)
|
||||
|
||||
|
||||
def _pop_cte(inner_scope):
|
||||
def _pop_cte(inner_scope: Scope) -> None:
|
||||
"""
|
||||
Remove CTE from the AST.
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ def qualify(
|
|||
infer_schema: t.Optional[bool] = None,
|
||||
isolate_tables: bool = False,
|
||||
qualify_columns: bool = True,
|
||||
allow_partial_qualification: bool = False,
|
||||
validate_qualify_columns: bool = True,
|
||||
quote_identifiers: bool = True,
|
||||
identify: bool = True,
|
||||
|
@ -56,6 +57,7 @@ def qualify(
|
|||
infer_schema: Whether to infer the schema if missing.
|
||||
isolate_tables: Whether to isolate table selects.
|
||||
qualify_columns: Whether to qualify columns.
|
||||
allow_partial_qualification: Whether to allow partial qualification.
|
||||
validate_qualify_columns: Whether to validate columns.
|
||||
quote_identifiers: Whether to run the quote_identifiers step.
|
||||
This step is necessary to ensure correctness for case sensitive queries.
|
||||
|
@ -90,6 +92,7 @@ def qualify(
|
|||
expand_alias_refs=expand_alias_refs,
|
||||
expand_stars=expand_stars,
|
||||
infer_schema=infer_schema,
|
||||
allow_partial_qualification=allow_partial_qualification,
|
||||
)
|
||||
|
||||
if quote_identifiers:
|
||||
|
|
|
@ -22,6 +22,7 @@ def qualify_columns(
|
|||
expand_alias_refs: bool = True,
|
||||
expand_stars: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
allow_partial_qualification: bool = False,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified columns.
|
||||
|
@ -41,6 +42,7 @@ def qualify_columns(
|
|||
for most of the optimizer's rules to work; do not set to False unless you
|
||||
know what you're doing!
|
||||
infer_schema: Whether to infer the schema if missing.
|
||||
allow_partial_qualification: Whether to allow partial qualification.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
|
@ -68,7 +70,7 @@ def qualify_columns(
|
|||
)
|
||||
|
||||
_convert_columns_to_dots(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
_qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
|
||||
|
||||
if not schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver)
|
||||
|
@ -240,13 +242,21 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
|
|||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
|
||||
) -> None:
|
||||
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
|
||||
is_group_by = isinstance(node, exp.Group)
|
||||
if not node or (expand_only_groupby and not is_group_by):
|
||||
return
|
||||
|
||||
for column in walk_in_scope(node, prune=lambda node: node.is_star):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
|
||||
# BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
|
||||
# SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
|
||||
# SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col))
|
||||
# This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
|
||||
if expand_only_groupby and is_group_by and column.parent is not node:
|
||||
continue
|
||||
|
||||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
|
||||
double_agg = (
|
||||
|
@ -273,9 +283,8 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
|
|||
if simplified is not column:
|
||||
column.replace(simplified)
|
||||
|
||||
for i, projection in enumerate(scope.expression.selects):
|
||||
for i, projection in enumerate(expression.selects):
|
||||
replace_columns(projection)
|
||||
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = (projection.this, i + 1)
|
||||
|
||||
|
@ -434,7 +443,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|||
scope.clear_cache()
|
||||
|
||||
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
|
||||
"""Disambiguate columns, ensuring each column specifies a source"""
|
||||
for column in scope.columns:
|
||||
column_table = column.table
|
||||
|
@ -442,7 +451,12 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
if column_table and column_table in scope.sources:
|
||||
source_columns = resolver.get_source_columns(column_table)
|
||||
if source_columns and column_name not in source_columns and "*" not in source_columns:
|
||||
if (
|
||||
not allow_partial_qualification
|
||||
and source_columns
|
||||
and column_name not in source_columns
|
||||
and "*" not in source_columns
|
||||
):
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
if not column_table:
|
||||
|
@ -526,7 +540,7 @@ def _expand_stars(
|
|||
) -> None:
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
new_selections = []
|
||||
new_selections: t.List[exp.Expression] = []
|
||||
except_columns: t.Dict[int, t.Set[str]] = {}
|
||||
replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
|
||||
rename_columns: t.Dict[int, t.Dict[str, str]] = {}
|
||||
|
|
|
@ -562,8 +562,8 @@ def _traverse_scope(scope):
|
|||
elif isinstance(expression, exp.DML):
|
||||
yield from _traverse_ctes(scope)
|
||||
for query in find_all_in_scope(expression, exp.Query):
|
||||
# This check ensures we don't yield the CTE queries twice
|
||||
if not isinstance(query.parent, exp.CTE):
|
||||
# This check ensures we don't yield the CTE/nested queries twice
|
||||
if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
|
||||
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
|
||||
return
|
||||
else:
|
||||
|
@ -679,6 +679,8 @@ def _traverse_tables(scope):
|
|||
expressions.extend(scope.expression.args.get("laterals") or [])
|
||||
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Final):
|
||||
expression = expression.this
|
||||
if isinstance(expression, exp.Table):
|
||||
table_name = expression.name
|
||||
source_name = expression.alias_or_name
|
||||
|
|
|
@ -206,6 +206,11 @@ COMPLEMENT_COMPARISONS = {
|
|||
exp.NEQ: exp.EQ,
|
||||
}
|
||||
|
||||
COMPLEMENT_SUBQUERY_PREDICATES = {
|
||||
exp.All: exp.Any,
|
||||
exp.Any: exp.All,
|
||||
}
|
||||
|
||||
|
||||
def simplify_not(expression):
|
||||
"""
|
||||
|
@ -218,9 +223,12 @@ def simplify_not(expression):
|
|||
if is_null(this):
|
||||
return exp.null()
|
||||
if this.__class__ in COMPLEMENT_COMPARISONS:
|
||||
return COMPLEMENT_COMPARISONS[this.__class__](
|
||||
this=this.this, expression=this.expression
|
||||
)
|
||||
right = this.expression
|
||||
complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
|
||||
if complement_subquery_predicate:
|
||||
right = complement_subquery_predicate(this=right.this)
|
||||
|
||||
return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
|
||||
if isinstance(this, exp.Paren):
|
||||
condition = this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue