Merging upstream version 10.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
582b160275
commit
a5128ea109
57 changed files with 1542 additions and 529 deletions
|
@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
|
|||
for cte_scope in root.cte_scopes:
|
||||
# Append all the new CTEs from this existing CTE
|
||||
for scope in cte_scope.traverse():
|
||||
if scope is cte_scope:
|
||||
# Don't try to eliminate this CTE itself
|
||||
continue
|
||||
new_cte = _eliminate(scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
new_ctes.append(new_cte)
|
||||
|
@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_cte:
|
||||
return _eliminate_cte(scope, existing_ctes, taken)
|
||||
|
||||
|
||||
def _eliminate_union(scope, existing_ctes, taken):
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
|
@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
parent = scope.expression.parent
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
|
||||
parent.replace(table)
|
||||
|
||||
return cte
|
||||
|
||||
|
||||
def _eliminate_cte(scope, existing_ctes, taken):
|
||||
parent = scope.expression.parent
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
with_ = parent.parent
|
||||
parent.pop()
|
||||
if not with_.expressions:
|
||||
with_.pop()
|
||||
|
||||
# Rename references to this CTE
|
||||
for child_scope in scope.parent.traverse():
|
||||
for table, source in child_scope.selected_sources.values():
|
||||
if source is scope:
|
||||
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
|
||||
table.replace(new_table)
|
||||
|
||||
return cte
|
||||
|
||||
|
||||
def _new_cte(scope, existing_ctes, taken):
|
||||
"""
|
||||
Returns:
|
||||
tuple of (name, cte)
|
||||
where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
|
||||
If this CTE duplicates an existing CTE, `cte` will be None.
|
||||
"""
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
parent = scope.expression.parent
|
||||
name = alias = parent.alias
|
||||
name = parent.alias
|
||||
|
||||
if not alias:
|
||||
name = alias = find_new_name(taken=taken, base="cte")
|
||||
if not name:
|
||||
name = find_new_name(taken=taken, base="cte")
|
||||
|
||||
if duplicate_cte_alias:
|
||||
name = duplicate_cte_alias
|
||||
elif taken.get(alias):
|
||||
name = find_new_name(taken=taken, base=alias)
|
||||
elif taken.get(name):
|
||||
name = find_new_name(taken=taken, base=name)
|
||||
|
||||
taken[name] = scope
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=alias)
|
||||
parent.replace(table)
|
||||
|
||||
if not duplicate_cte_alias:
|
||||
existing_ctes[scope.expression] = name
|
||||
return exp.CTE(
|
||||
cte = exp.CTE(
|
||||
this=scope.expression,
|
||||
alias=exp.TableAlias(this=exp.to_identifier(name)),
|
||||
)
|
||||
else:
|
||||
cte = None
|
||||
return name, cte
|
||||
|
|
92
sqlglot/optimizer/lower_identities.py
Normal file
92
sqlglot/optimizer/lower_identities.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_collection
|
||||
|
||||
|
||||
def lower_identities(expression):
|
||||
"""
|
||||
Convert all unquoted identifiers to lower case.
|
||||
|
||||
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
|
||||
>>> lower_identities(expression).sql()
|
||||
'SELECT bar.a AS A FROM "Foo".bar'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to quote
|
||||
Returns:
|
||||
sqlglot.Expression: quoted expression
|
||||
"""
|
||||
# We need to leave the output aliases unchanged, so the selects need special handling
|
||||
_lower_selects(expression)
|
||||
|
||||
# These clauses can reference output aliases and also need special handling
|
||||
_lower_order(expression)
|
||||
_lower_having(expression)
|
||||
|
||||
# We've already handled these args, so don't traverse into them
|
||||
traversed = {"expressions", "order", "having"}
|
||||
|
||||
if isinstance(expression, exp.Subquery):
|
||||
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
|
||||
lower_identities(expression.this)
|
||||
traversed |= {"this"}
|
||||
|
||||
if isinstance(expression, exp.Union):
|
||||
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
|
||||
lower_identities(expression.left)
|
||||
lower_identities(expression.right)
|
||||
traversed |= {"this", "expression"}
|
||||
|
||||
for k, v in expression.args.items():
|
||||
if k in traversed:
|
||||
continue
|
||||
|
||||
for child in ensure_collection(v):
|
||||
if isinstance(child, exp.Expression):
|
||||
child.transform(_lower, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _lower_selects(expression):
|
||||
for e in expression.expressions:
|
||||
# Leave output aliases as-is
|
||||
e.unalias().transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower_order(expression):
|
||||
order = expression.args.get("order")
|
||||
|
||||
if not order:
|
||||
return
|
||||
|
||||
output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
|
||||
|
||||
for ordered in order.expressions:
|
||||
# Don't lower references to output aliases
|
||||
if not (
|
||||
isinstance(ordered.this, exp.Column)
|
||||
and not ordered.this.table
|
||||
and ordered.this.name in output_aliases
|
||||
):
|
||||
ordered.transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower_having(expression):
|
||||
having = expression.args.get("having")
|
||||
|
||||
if not having:
|
||||
return
|
||||
|
||||
# Don't lower references to output aliases
|
||||
for agg in having.find_all(exp.AggFunc):
|
||||
agg.transform(_lower, copy=False)
|
||||
|
||||
|
||||
def _lower(node):
|
||||
if isinstance(node, exp.Identifier) and not node.quoted:
|
||||
node.set("this", node.this.lower())
|
||||
return node
|
|
@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
|||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.lower_identities import lower_identities
|
||||
from sqlglot.optimizer.merge_subqueries import merge_subqueries
|
||||
from sqlglot.optimizer.normalize import normalize
|
||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||
|
@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
|
|||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
|
||||
RULES = (
|
||||
lower_identities,
|
||||
qualify_tables,
|
||||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.scope import ScopeType, traverse_scope
|
||||
|
||||
|
||||
def unnest_subqueries(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
|
||||
|
||||
Convert the subquery into a group by so it is not a many to many left join.
|
||||
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
|
||||
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
|
||||
Convert scalar subqueries into cross joins.
|
||||
Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
|
@ -29,21 +28,43 @@ def unnest_subqueries(expression):
|
|||
for scope in traverse_scope(expression):
|
||||
select = scope.expression
|
||||
parent = select.parent_select
|
||||
if not parent:
|
||||
continue
|
||||
if scope.external_columns:
|
||||
decorrelate(select, parent, scope.external_columns, sequence)
|
||||
else:
|
||||
elif scope.scope_type == ScopeType.SUBQUERY:
|
||||
unnest(select, parent, sequence)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def unnest(select, parent_select, sequence):
|
||||
predicate = select.find_ancestor(exp.In, exp.Any)
|
||||
if len(select.selects) > 1:
|
||||
return
|
||||
|
||||
predicate = select.find_ancestor(exp.Condition)
|
||||
alias = _alias(sequence)
|
||||
|
||||
if not predicate or parent_select is not predicate.parent_select:
|
||||
return
|
||||
|
||||
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
|
||||
# this subquery returns a scalar and can just be converted to a cross join
|
||||
if not isinstance(predicate, (exp.In, exp.Any)):
|
||||
having = predicate.find_ancestor(exp.Having)
|
||||
column = exp.column(select.selects[0].alias_or_name, alias)
|
||||
if having and having.parent_select is parent_select:
|
||||
column = exp.Max(this=column)
|
||||
_replace(select.parent, column)
|
||||
|
||||
parent_select.join(
|
||||
select,
|
||||
join_type="CROSS",
|
||||
join_alias=alias,
|
||||
copy=False,
|
||||
)
|
||||
return
|
||||
|
||||
if select.find(exp.Limit, exp.Offset):
|
||||
return
|
||||
|
||||
if isinstance(predicate, exp.Any):
|
||||
|
@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
|
|||
|
||||
column = _other_operand(predicate)
|
||||
value = select.selects[0]
|
||||
alias = _alias(sequence)
|
||||
|
||||
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
|
||||
_replace(predicate, f"NOT {on.right} IS NULL")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue