Merging upstream version 6.2.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
87ba722f7f
commit
a62bbc24c3
22 changed files with 361 additions and 98 deletions
|
@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
merge_ctes(expression, leave_tables_isolated)
|
||||
merge_derived_tables(expression, leave_tables_isolated)
|
||||
expression = merge_ctes(expression, leave_tables_isolated)
|
||||
expression = merge_derived_tables(expression, leave_tables_isolated)
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
alias = node_to_replace.alias
|
||||
else:
|
||||
alias = table.name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
_pop_cte(inner_scope)
|
||||
return expression
|
||||
|
||||
|
||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||
|
@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, subquery, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
_merge_order(outer_scope, inner_scope)
|
||||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
|
||||
|
@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
|
|||
continue
|
||||
columns_to_replace = outer_columns.get(projection_name, [])
|
||||
for column in columns_to_replace:
|
||||
column.replace(expression.unalias())
|
||||
column.replace(expression.unalias().copy())
|
||||
|
||||
|
||||
def _merge_where(outer_scope, inner_scope, from_or_join):
|
||||
|
|
|
@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
|
|||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
|
||||
|
||||
|
||||
def qualify_columns(expression, schema):
|
||||
"""
|
||||
|
@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
|
|||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
_qualify_columns(scope, resolver)
|
||||
if not isinstance(scope.expression, SKIP_QUALIFY):
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
_check_unknown_tables(scope)
|
||||
|
@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
|
|||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table, SKIP_QUALIFY):
|
||||
if isinstance(derived_table, exp.UDTF):
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
|
@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
|
|||
if not column_table:
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
if not scope.is_subquery and not scope.is_unnest:
|
||||
if not scope.is_subquery and not scope.is_udtf:
|
||||
if column_name not in resolver.all_columns:
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
|
@ -296,7 +294,7 @@ def _qualify_outputs(scope):
|
|||
|
||||
|
||||
def _check_unknown_tables(scope):
|
||||
if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
|
||||
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import itertools
|
||||
from copy import copy
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlglot import exp
|
||||
|
@ -12,7 +11,7 @@ class ScopeType(Enum):
|
|||
DERIVED_TABLE = auto()
|
||||
CTE = auto()
|
||||
UNION = auto()
|
||||
UNNEST = auto()
|
||||
UDTF = auto()
|
||||
|
||||
|
||||
class Scope:
|
||||
|
@ -70,14 +69,11 @@ class Scope:
|
|||
self._columns = None
|
||||
self._external_columns = None
|
||||
|
||||
def branch(self, expression, scope_type, add_sources=None, **kwargs):
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
sources = copy(self.sources)
|
||||
if add_sources:
|
||||
sources.update(add_sources)
|
||||
return Scope(
|
||||
expression=expression.unnest(),
|
||||
sources=sources,
|
||||
sources={**self.cte_sources, **(chain_sources or {})},
|
||||
parent=self,
|
||||
scope_type=scope_type,
|
||||
**kwargs,
|
||||
|
@ -90,30 +86,21 @@ class Scope:
|
|||
self._derived_tables = []
|
||||
self._raw_columns = []
|
||||
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
|
||||
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
|
||||
prune = False
|
||||
|
||||
for node, parent, _ in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table):
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, (exp.Unnest, exp.Lateral)):
|
||||
elif isinstance(node, exp.UDTF):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
self._derived_tables.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
self._subqueries.append(node)
|
||||
prune = True
|
||||
|
||||
self._collected = True
|
||||
|
||||
|
@ -121,6 +108,43 @@ class Scope:
|
|||
if not self._collected:
|
||||
self._collect()
|
||||
|
||||
def walk(self, bfs=True):
|
||||
return walk_in_scope(self.expression, bfs=bfs)
|
||||
|
||||
def find(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns the first node in this scope which matches at least one of the specified types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Returns:
|
||||
exp.Expression: the node which matches the criteria or None if no node matching
|
||||
the criteria was found.
|
||||
"""
|
||||
return next(self.find_all(*expression_types, bfs=bfs), None)
|
||||
|
||||
def find_all(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in this scope and only yields those that
|
||||
match at least one of the specified expression types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, _, _ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
|
||||
def replace(self, old, new):
|
||||
"""
|
||||
Replace `old` with `new`.
|
||||
|
@ -246,6 +270,16 @@ class Scope:
|
|||
self._selected_sources = result
|
||||
return self._selected_sources
|
||||
|
||||
@property
|
||||
def cte_sources(self):
|
||||
"""
|
||||
Sources that are CTEs.
|
||||
|
||||
Returns:
|
||||
dict[str, Scope]: Mapping of source alias to Scope
|
||||
"""
|
||||
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
"""
|
||||
|
@ -313,9 +347,9 @@ class Scope:
|
|||
return self.scope_type == ScopeType.ROOT
|
||||
|
||||
@property
|
||||
def is_unnest(self):
|
||||
"""Determine if this scope is an unnest"""
|
||||
return self.scope_type == ScopeType.UNNEST
|
||||
def is_udtf(self):
|
||||
"""Determine if this scope is a UDTF (User Defined Table Function)"""
|
||||
return self.scope_type == ScopeType.UDTF
|
||||
|
||||
@property
|
||||
def is_correlated_subquery(self):
|
||||
|
@ -348,7 +382,7 @@ class Scope:
|
|||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
|
||||
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
|
@ -399,7 +433,7 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
|
@ -410,8 +444,8 @@ def _traverse_scope(scope):
|
|||
|
||||
def _traverse_select(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||
yield from _traverse_subqueries(scope)
|
||||
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
|
||||
yield from _traverse_subqueries(scope)
|
||||
_add_table_sources(scope)
|
||||
|
||||
|
||||
|
@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
|
||||
add_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
|
||||
chain_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
outer_column_list=derived_table.alias_column_names,
|
||||
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
|
||||
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
|
|||
yield child_scope
|
||||
top = child_scope
|
||||
scope.subquery_scopes.append(top)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
nodes that start child scopes.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression):
|
||||
bfs (bool): if set to True the BFS traversal order will be applied,
|
||||
otherwise the DFS traversal will be used instead.
|
||||
|
||||
Yields:
|
||||
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
|
||||
"""
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
|
||||
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
|
||||
prune = False
|
||||
|
||||
yield node, parent, key
|
||||
|
||||
if node is expression:
|
||||
continue
|
||||
elif isinstance(node, exp.CTE):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
prune = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue