1
0
Fork 0

Adding upstream version 6.2.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:37:25 +01:00
parent 8425a9678d
commit d62bab68ae
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
49 changed files with 1741 additions and 566 deletions

View file

@ -1,3 +1,4 @@
import itertools
from copy import copy
from enum import Enum, auto
@ -32,10 +33,11 @@ class Scope:
The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
This does not include derived tables or CTEs.
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
a tuple of the left and right child scopes.
subquery_scopes (list[Scope]): List of all child scopes for subqueries
cte_scopes = (list[Scope]) List of all child scopes for CTEs
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
def __init__(
@ -52,7 +54,9 @@ class Scope:
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.union = None
self.derived_table_scopes = []
self.cte_scopes = []
self.union_scopes = []
self.clear_cache()
def clear_cache(self):
@ -197,11 +201,16 @@ class Scope:
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
]
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
if (
not ancestor
or column.table
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
return self._columns
@property
@ -283,6 +292,26 @@ class Scope:
"""Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY
@property
def is_derived_table(self):
"""Determine if this scope is a derived table"""
return self.scope_type == ScopeType.DERIVED_TABLE
@property
def is_union(self):
"""Determine if this scope is a union"""
return self.scope_type == ScopeType.UNION
@property
def is_cte(self):
"""Determine if this scope is a common table expression"""
return self.scope_type == ScopeType.CTE
@property
def is_root(self):
"""Determine if this is the root scope"""
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
@ -308,6 +337,22 @@ class Scope:
self.sources.pop(name, None)
self.clear_cache()
def __repr__(self):
return f"Scope<{self.expression.sql()}>"
def traverse(self):
"""
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
):
yield from child_scope.traverse()
yield self
def traverse_scope(expression):
"""
@ -337,6 +382,18 @@ def traverse_scope(expression):
return list(_traverse_scope(Scope(expression)))
def build_scope(expression):
"""
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
"""
return traverse_scope(expression)[-1]
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
@ -370,13 +427,14 @@ def _traverse_union(scope):
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
scope.union = (left, right)
scope.union_scopes = [left, right]
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
for derived_table in derived_tables:
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
else:
scope.derived_table_scopes.append(top)
scope.sources.update(sources)
@ -407,8 +470,6 @@ def _add_table_sources(scope):
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
elif source_name in scope.sources:
raise OptimizeError(f"Duplicate table name: {source_name}")
else:
sources[source_name] = table