1
0
Fork 0

Merging upstream version 10.1.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:56:25 +01:00
parent 582b160275
commit a5128ea109
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
57 changed files with 1542 additions and 529 deletions

View file

@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
if scope is cte_scope:
# Don't try to eliminate this CTE itself
continue
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
parent.replace(table)
return cte
def _eliminate_cte(scope, existing_ctes, taken):
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
with_ = parent.parent
parent.pop()
if not with_.expressions:
with_.pop()
# Rename references to this CTE
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
table.replace(new_table)
return cte
def _new_cte(scope, existing_ctes, taken):
"""
Returns:
tuple of (name, cte)
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
If this CTE duplicates an existing CTE, `cte` will be None.
"""
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
name = alias = parent.alias
name = parent.alias
if not alias:
name = alias = find_new_name(taken=taken, base="cte")
if not name:
name = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
elif taken.get(alias):
name = find_new_name(taken=taken, base=alias)
elif taken.get(name):
name = find_new_name(taken=taken, base=name)
taken[name] = scope
table = exp.alias_(exp.table_(name), alias=alias)
parent.replace(table)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
return exp.CTE(
cte = exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)
else:
cte = None
return name, cte

View file

@ -0,0 +1,92 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection
def lower_identities(expression):
"""
Convert all unquoted identifiers to lower case.
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> lower_identities(expression).sql()
'SELECT bar.a AS A FROM "Foo".bar'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
# We need to leave the output aliases unchanged, so the selects need special handling
_lower_selects(expression)
# These clauses can reference output aliases and also need special handling
_lower_order(expression)
_lower_having(expression)
# We've already handled these args, so don't traverse into them
traversed = {"expressions", "order", "having"}
if isinstance(expression, exp.Subquery):
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
lower_identities(expression.this)
traversed |= {"this"}
if isinstance(expression, exp.Union):
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
lower_identities(expression.left)
lower_identities(expression.right)
traversed |= {"this", "expression"}
for k, v in expression.args.items():
if k in traversed:
continue
for child in ensure_collection(v):
if isinstance(child, exp.Expression):
child.transform(_lower, copy=False)
return expression
def _lower_selects(expression):
for e in expression.expressions:
# Leave output aliases as-is
e.unalias().transform(_lower, copy=False)
def _lower_order(expression):
order = expression.args.get("order")
if not order:
return
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
for ordered in order.expressions:
# Don't lower references to output aliases
if not (
isinstance(ordered.this, exp.Column)
and not ordered.this.table
and ordered.this.name in output_aliases
):
ordered.transform(_lower, copy=False)
def _lower_having(expression):
having = expression.args.get("having")
if not having:
return
# Don't lower references to output aliases
for agg in having.find_all(exp.AggFunc):
agg.transform(_lower, copy=False)
def _lower(node):
if isinstance(node, exp.Identifier) and not node.quoted:
node.set("this", node.this.lower())
return node

View file

@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,

View file

@ -1,16 +1,15 @@
import itertools
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import ScopeType, traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Convert scalar subqueries into cross joins.
Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot
@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if not parent:
continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
else:
elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any)
if len(select.selects) > 1:
return
predicate = select.find_ancestor(exp.Condition)
alias = _alias(sequence)
if not predicate or parent_select is not predicate.parent_select:
return
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
# this subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
if having and having.parent_select is parent_select:
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(
select,
join_type="CROSS",
join_alias=alias,
copy=False,
)
return
if select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
column = _other_operand(predicate)
value = select.selects[0]
alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")