1
0
Fork 0

Merging upstream version 6.2.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:43:32 +01:00
parent 87ba722f7f
commit a62bbc24c3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 361 additions and 98 deletions

View file

@ -1,5 +1,4 @@
import itertools
from copy import copy
from enum import Enum, auto
from sqlglot import exp
@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UNNEST = auto()
UDTF = auto()
class Scope:
@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
def branch(self, expression, scope_type, add_sources=None, **kwargs):
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
sources = copy(self.sources)
if add_sources:
sources.update(add_sources)
return Scope(
expression=expression.unnest(),
sources=sources,
sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
prune = False
for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
elif isinstance(node, (exp.Unnest, exp.Lateral)):
elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
prune = True
self._collected = True
@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
def walk(self, bfs=True):
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
def find_all(self, *expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
def replace(self, old, new):
"""
Replace `old` with `new`.
@ -246,6 +270,16 @@ class Scope:
self._selected_sources = result
return self._selected_sources
@property
def cte_sources(self):
"""
Sources that are CTEs.
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
@property
def selects(self):
"""
@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
def is_udtf(self):
"""Determine if this scope is a UDTF (User Defined Table Function)"""
return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
Args:
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
prune = False
yield node, parent, key
if node is expression:
continue
elif isinstance(node, exp.CTE):
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
prune = True
elif isinstance(node, exp.Subqueryable):
prune = True