1
0
Fork 0

Merging upstream version 10.4.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:01:55 +01:00
parent de4e42d4d3
commit 0c79f8b507
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
88 changed files with 1637 additions and 436 deletions

View file

@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
if isinstance(expression, exp.Identifier):
expression.set("quoted", True)
return expression

View file

@ -129,10 +129,23 @@ def join_condition(join):
"""
name = join.this.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
def extract_condition(condition):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
right_tables = exp.column_table_names(right)
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.true())
# find the join keys
# SELECT
# FROM x
@ -141,20 +154,30 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
for condition in on.flatten():
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
right_tables = exp.column_table_names(right)
extract_condition(condition)
elif normalized(on, dnf=True):
conditions = None
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.true())
for condition in on.flatten():
parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
if conditions is None:
conditions = parts
else:
temp = []
for p in parts:
cs = [c for c in conditions if p == c]
if cs:
temp.append(p)
temp.extend(cs)
conditions = temp
for condition in conditions:
extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on

View file

@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
existing_ctes = {}
with_ = root.expression.args.get("with")
recursive = False
if with_:
recursive = with_.args.get("recursive")
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte)
if new_ctes:
expression.set("with", exp.With(expressions=new_ctes))
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
return expression

View file

@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
return x
return [
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
]
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)

View file

@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
@ -34,7 +33,6 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
quote_identities,
)

View file

@ -27,7 +27,14 @@ def pushdown_predicates(expression):
select = scope.expression
where = select.args.get("where")
if where:
pushdown(where.this, scope.selected_sources, scope_ref_count)
selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
with_ = source.parent.expression.args.get("with")
if with_ and with_.recursive:
return {}
node = source.expression
if isinstance(node, exp.Join):
if node.side:
if node.side and node.side != "RIGHT":
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:

View file

@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
# SELECTION TO USE IF SELECTION LIST IS EMPTY
# Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_")
@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
return removed_indexes
@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)

View file

@ -311,6 +311,9 @@ def _qualify_outputs(scope):
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
elif isinstance(selection, exp.Subquery):
if not selection.alias:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection)

View file

@ -1,25 +0,0 @@
from sqlglot import exp
def quote_identities(expression):
"""
Rewrite sqlglot AST to ensure all identities are quoted.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
>>> quote_identities(expression).sql()
'SELECT "x"."a" AS "a" FROM "db"."x"'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
def qualify(node):
if isinstance(node, exp.Identifier):
node.set("quoted", True)
return node
return expression.transform(qualify, copy=False)

View file

@ -511,9 +511,20 @@ def _traverse_union(scope):
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE
for derived_table in derived_tables:
top = None
recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if is_cte and scope.expression.args["with"].recursive:
union = derived_table.this
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
alias = derived_table.alias
sources[alias] = child_scope
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
if is_cte:
scope.cte_scopes.append(child_scope)
else:
scope.derived_table_scopes.append(top)
scope.derived_table_scopes.append(child_scope)
scope.sources.update(sources)

View file

@ -16,7 +16,7 @@ def unnest_subqueries(expression):
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
table_alias = _alias(sequence)
keys = []
# for all external columns in the where statement,
# split out the relevant data to convert it into a join
# for all external columns in the where statement, find the relevant predicate
# keys to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
is_subquery_projection = any(
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
)
value = select.selects[0]
key_aliases = {}
group_by = []
@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
# so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
if not value.find(exp.AggFunc) and value.this not in group_by:
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
select.select(
exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
append=False,
copy=False,
)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if is_subquery_projection:
key.replace(nested)
continue
if key in group_by:
key.replace(nested)
parent_predicate = _replace(