1
0
Fork 0

Merging upstream version 23.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:30:28 +01:00
parent ebba7c6a18
commit d26905e4af
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
187 changed files with 86502 additions and 71397 deletions

View file

@ -8,7 +8,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import ensure_collection, find_new_name
from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
@ -38,11 +38,11 @@ class Scope:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_column_list`
The inner query would have `["col1", "col2"]` for its `outer_columns`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
@ -58,7 +58,7 @@ class Scope:
self,
expression,
sources=None,
outer_column_list=None,
outer_columns=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
@ -70,7 +70,7 @@ class Scope:
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.outer_columns = outer_columns or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
@ -119,10 +119,11 @@ class Scope:
self._raw_columns = []
self._join_hints = []
for node, parent, _ in self.walk(bfs=False):
for node in self.walk(bfs=False):
if node is self.expression:
continue
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
@ -132,10 +133,8 @@ class Scope:
self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
elif (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
elif _is_derived_table(node) and isinstance(
node.parent, (exp.From, exp.Join, exp.Subquery)
):
self._derived_tables.append(node)
elif isinstance(node, exp.UNWRAPPED_QUERIES):
@ -438,11 +437,21 @@ class Scope:
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
stack = [self]
result = []
while stack:
scope = stack.pop()
result.append(scope)
stack.extend(
itertools.chain(
scope.cte_scopes,
scope.union_scopes,
scope.table_scopes,
scope.subquery_scopes,
)
)
yield from reversed(result)
def ref_count(self):
"""
@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
expression (exp.Expression): expression to traverse
expression: Expression to traverse
Returns:
list[Scope]: scope instances
A list of the created scope instances
"""
if isinstance(expression, exp.Query) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
):
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
# We ignore the DDL expression and build a scope for its query instead
ddl_with = expression.args.get("with")
expression = expression.expression
# If the DDL has CTEs attached, we need to add them to the query, or
# prepend them if the query itself already has CTEs attached to it
if ddl_with:
ddl_with.pop()
query_ctes = expression.ctes
if not query_ctes:
expression.set("with", ddl_with)
else:
expression.args["with"].set("recursive", ddl_with.recursive)
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
if isinstance(expression, exp.Query):
return list(_traverse_scope(Scope(expression)))
return []
@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
expression: Expression to build the scope tree for.
Returns:
Scope: root scope
The root scope
"""
scopes = traverse_scope(expression)
if scopes:
return scopes[-1]
return None
return seq_get(traverse_scope(expression), -1)
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
elif isinstance(scope.expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
@ -523,8 +546,6 @@ def _traverse_scope(scope):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
yield from _traverse_udtfs(scope)
elif isinstance(scope.expression, exp.DDL):
yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@ -541,30 +562,38 @@ def _traverse_select(scope):
def _traverse_union(scope):
yield from _traverse_ctes(scope)
prev_scope = None
union_scope_stack = [scope]
expression_stack = [scope.expression.right, scope.expression.left]
# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(
scope.branch(
scope.expression.left,
outer_column_list=scope.outer_column_list,
while expression_stack:
expression = expression_stack.pop()
union_scope = union_scope_stack[-1]
new_scope = union_scope.branch(
expression,
outer_columns=union_scope.outer_columns,
scope_type=ScopeType.UNION,
)
):
yield left
right = None
for right in _traverse_scope(
scope.branch(
scope.expression.right,
outer_column_list=scope.outer_column_list,
scope_type=ScopeType.UNION,
)
):
yield right
if isinstance(expression, exp.Union):
yield from _traverse_ctes(new_scope)
scope.union_scopes = [left, right]
union_scope_stack.append(new_scope)
expression_stack.extend([expression.right, expression.left])
continue
for scope in _traverse_scope(new_scope):
yield scope
if prev_scope:
union_scope_stack.pop()
union_scope.union_scopes = [prev_scope, scope]
prev_scope = union_scope
yield union_scope
else:
prev_scope = scope
def _traverse_ctes(scope):
@ -588,7 +617,7 @@ def _traverse_ctes(scope):
scope.branch(
cte.this,
cte_sources=sources,
outer_column_list=cte.alias_column_names,
outer_columns=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
return isinstance(expression, exp.Subquery) and bool(
expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
)
def _traverse_tables(scope):
@ -681,7 +712,7 @@ def _traverse_tables(scope):
scope.branch(
expression,
lateral_sources=lateral_sources,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
scope_type=scope_type,
)
):
@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
sources = {}
for expression in expressions:
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
if _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
)
):
yield child_scope
@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
scope.sources.update(sources)
def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes
yield from _traverse_scope(query_scope)
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
# Whenever we set it to True, we exclude a subtree from traversal.
crossed_scope_boundary = False
for node, parent, key in expression.walk(
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
for node in expression.walk(
bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
):
crossed_scope_boundary = False
yield node, parent, key
yield node
if node is expression:
continue
if (
isinstance(node, exp.CTE)
or (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
)
or isinstance(node, exp.UDTF)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
Yields:
exp.Expression: nodes
"""
for expression, *_ in walk_in_scope(expression, bfs=bfs):
for expression in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression