Merging upstream version 10.4.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
de4e42d4d3
commit
0c79f8b507
88 changed files with 1637 additions and 436 deletions
|
@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
|
||||
if isinstance(expression, exp.Identifier):
|
||||
expression.set("quoted", True)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
|
|
@ -129,10 +129,23 @@ def join_condition(join):
|
|||
"""
|
||||
name = join.this.alias_or_name
|
||||
on = (join.args.get("on") or exp.true()).copy()
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
||||
def extract_condition(condition):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
right_tables = exp.column_table_names(right)
|
||||
|
||||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.true())
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.true())
|
||||
|
||||
# find the join keys
|
||||
# SELECT
|
||||
# FROM x
|
||||
|
@ -141,20 +154,30 @@ def join_condition(join):
|
|||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
if normalized(on):
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||
|
||||
for condition in on.flatten():
|
||||
if isinstance(condition, exp.EQ):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
right_tables = exp.column_table_names(right)
|
||||
extract_condition(condition)
|
||||
elif normalized(on, dnf=True):
|
||||
conditions = None
|
||||
|
||||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.true())
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.true())
|
||||
for condition in on.flatten():
|
||||
parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
|
||||
if conditions is None:
|
||||
conditions = parts
|
||||
else:
|
||||
temp = []
|
||||
for p in parts:
|
||||
cs = [c for c in conditions if p == c]
|
||||
|
||||
if cs:
|
||||
temp.append(p)
|
||||
temp.extend(cs)
|
||||
conditions = temp
|
||||
|
||||
for condition in conditions:
|
||||
extract_condition(condition)
|
||||
|
||||
on = simplify(on)
|
||||
remaining_condition = None if on == exp.true() else on
|
||||
|
|
|
@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
|
|||
existing_ctes = {}
|
||||
|
||||
with_ = root.expression.args.get("with")
|
||||
recursive = False
|
||||
if with_:
|
||||
recursive = with_.args.get("recursive")
|
||||
for cte in with_.expressions:
|
||||
existing_ctes[cte.this] = cte.alias
|
||||
new_ctes = []
|
||||
|
@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(new_cte)
|
||||
|
||||
if new_ctes:
|
||||
expression.set("with", exp.With(expressions=new_ctes))
|
||||
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
|
|||
left, right = expression.args.values()
|
||||
|
||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
|
||||
return x
|
||||
return [
|
||||
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
|
||||
]
|
||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
|||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.quote_identities import quote_identities
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
|
||||
RULES = (
|
||||
|
@ -34,7 +33,6 @@ RULES = (
|
|||
eliminate_ctes,
|
||||
annotate_types,
|
||||
canonicalize,
|
||||
quote_identities,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -27,7 +27,14 @@ def pushdown_predicates(expression):
|
|||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
pushdown(where.this, scope.selected_sources, scope_ref_count)
|
||||
selected_sources = scope.selected_sources
|
||||
# a right join can only push down to itself and not the source FROM table
|
||||
for k, (node, source) in selected_sources.items():
|
||||
parent = node.find_ancestor(exp.Join, exp.From)
|
||||
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
|
||||
selected_sources = {k: (node, source)}
|
||||
break
|
||||
pushdown(where.this, selected_sources, scope_ref_count)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
|
@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
|
||||
# a node can reference a CTE which should be pushed down
|
||||
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
|
||||
with_ = source.parent.expression.args.get("with")
|
||||
if with_ and with_.recursive:
|
||||
return {}
|
||||
node = source.expression
|
||||
|
||||
if isinstance(node, exp.Join):
|
||||
if node.side:
|
||||
if node.side and node.side != "RIGHT":
|
||||
return {}
|
||||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
|
|
|
@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
||||
# SELECTION TO USE IF SELECTION LIST IS EMPTY
|
||||
# Selection to use if selection list is empty
|
||||
DEFAULT_SELECTION = alias("1", "_")
|
||||
|
||||
|
||||
|
@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
return removed_indexes
|
||||
|
@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
|
|||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||
]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -311,6 +311,9 @@ def _qualify_outputs(scope):
|
|||
alias_ = alias(exp.column(""), alias=selection.name)
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
elif isinstance(selection, exp.Subquery):
|
||||
if not selection.alias:
|
||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||
elif not isinstance(selection, exp.Alias):
|
||||
alias_ = alias(exp.column(""), f"_col_{i}")
|
||||
alias_.set("this", selection)
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
from sqlglot import exp
|
||||
|
||||
|
||||
def quote_identities(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to ensure all identities are quoted.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
|
||||
>>> quote_identities(expression).sql()
|
||||
'SELECT "x"."a" AS "a" FROM "db"."x"'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to quote
|
||||
Returns:
|
||||
sqlglot.Expression: quoted expression
|
||||
"""
|
||||
|
||||
def qualify(node):
|
||||
if isinstance(node, exp.Identifier):
|
||||
node.set("quoted", True)
|
||||
return node
|
||||
|
||||
return expression.transform(qualify, copy=False)
|
|
@ -511,9 +511,20 @@ def _traverse_union(scope):
|
|||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
sources = {}
|
||||
is_cte = scope_type == ScopeType.CTE
|
||||
|
||||
for derived_table in derived_tables:
|
||||
top = None
|
||||
recursive_scope = None
|
||||
|
||||
# if the scope is a recursive cte, it must be in the form of
|
||||
# base_case UNION recursive. thus the recursive scope is the first
|
||||
# section of the union.
|
||||
if is_cte and scope.expression.args["with"].recursive:
|
||||
union = derived_table.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
|
||||
|
@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
)
|
||||
):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
|
||||
# Tables without aliases will be set as ""
|
||||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
sources[derived_table.alias] = child_scope
|
||||
if scope_type == ScopeType.CTE:
|
||||
scope.cte_scopes.append(top)
|
||||
alias = derived_table.alias
|
||||
sources[alias] = child_scope
|
||||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
|
||||
# append the final child_scope yielded
|
||||
if is_cte:
|
||||
scope.cte_scopes.append(child_scope)
|
||||
else:
|
||||
scope.derived_table_scopes.append(top)
|
||||
scope.derived_table_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ def unnest_subqueries(expression):
|
|||
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
|
||||
>>> unnest_subqueries(expression).sql()
|
||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
|
||||
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
|
||||
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to unnest
|
||||
|
@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
table_alias = _alias(sequence)
|
||||
keys = []
|
||||
|
||||
# for all external columns in the where statement,
|
||||
# split out the relevant data to convert it into a join
|
||||
# for all external columns in the where statement, find the relevant predicate
|
||||
# keys to convert it into a join
|
||||
for column in external_columns:
|
||||
if column.find_ancestor(exp.Where) is not where:
|
||||
return
|
||||
|
@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
|
||||
return
|
||||
|
||||
is_subquery_projection = any(
|
||||
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
|
||||
)
|
||||
|
||||
value = select.selects[0]
|
||||
key_aliases = {}
|
||||
group_by = []
|
||||
|
@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
parent_predicate = select.find_ancestor(exp.Predicate)
|
||||
|
||||
# if the value of the subquery is not an agg or a key, we need to collect it into an array
|
||||
# so that it can be grouped
|
||||
# so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
|
||||
agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
|
||||
if not value.find(exp.AggFunc) and value.this not in group_by:
|
||||
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
|
||||
select.select(
|
||||
exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
|
||||
append=False,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
# exists queries should not have any selects as it only checks if there are any rows
|
||||
# all selects will be added by the optimizer and only used for join keys
|
||||
|
@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
if isinstance(parent_predicate, exp.Exists) or key != value.this:
|
||||
select.select(f"{key} AS {alias}", copy=False)
|
||||
else:
|
||||
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
|
||||
select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
|
||||
|
||||
alias = exp.column(value.alias, table_alias)
|
||||
other = _other_operand(parent_predicate)
|
||||
|
@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
|
||||
)
|
||||
else:
|
||||
if is_subquery_projection:
|
||||
alias = exp.alias_(alias, select.parent.alias)
|
||||
select.parent.replace(alias)
|
||||
|
||||
for key, column, predicate in keys:
|
||||
predicate.replace(exp.true())
|
||||
nested = exp.column(key_aliases[key], table_alias)
|
||||
|
||||
if is_subquery_projection:
|
||||
key.replace(nested)
|
||||
continue
|
||||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue