Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ebba7c6a18
commit
d26905e4af
187 changed files with 86502 additions and 71397 deletions
|
@ -8,7 +8,7 @@ from enum import Enum, auto
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import ensure_collection, find_new_name
|
||||
from sqlglot.helper import ensure_collection, find_new_name, seq_get
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -38,11 +38,11 @@ class Scope:
|
|||
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
|
||||
The LATERAL VIEW EXPLODE gets x as a source.
|
||||
cte_sources (dict[str, Scope]): Sources from CTES
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list for the alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
SELECT * FROM (SELECT ...) AS y(col1, col2)
|
||||
The inner query would have `["col1", "col2"]` for its `outer_column_list`
|
||||
The inner query would have `["col1", "col2"]` for its `outer_columns`
|
||||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries
|
||||
|
@ -58,7 +58,7 @@ class Scope:
|
|||
self,
|
||||
expression,
|
||||
sources=None,
|
||||
outer_column_list=None,
|
||||
outer_columns=None,
|
||||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
|
@ -70,7 +70,7 @@ class Scope:
|
|||
self.cte_sources = cte_sources or {}
|
||||
self.sources.update(self.lateral_sources)
|
||||
self.sources.update(self.cte_sources)
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.outer_columns = outer_columns or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
|
@ -119,10 +119,11 @@ class Scope:
|
|||
self._raw_columns = []
|
||||
self._join_hints = []
|
||||
|
||||
for node, parent, _ in self.walk(bfs=False):
|
||||
for node in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
|
||||
self._tables.append(node)
|
||||
|
@ -132,10 +133,8 @@ class Scope:
|
|||
self._udtfs.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
elif (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
elif _is_derived_table(node) and isinstance(
|
||||
node.parent, (exp.From, exp.Join, exp.Subquery)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.UNWRAPPED_QUERIES):
|
||||
|
@ -438,11 +437,21 @@ class Scope:
|
|||
Yields:
|
||||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
stack = [self]
|
||||
result = []
|
||||
while stack:
|
||||
scope = stack.pop()
|
||||
result.append(scope)
|
||||
stack.extend(
|
||||
itertools.chain(
|
||||
scope.cte_scopes,
|
||||
scope.union_scopes,
|
||||
scope.table_scopes,
|
||||
scope.subquery_scopes,
|
||||
)
|
||||
)
|
||||
|
||||
yield from reversed(result)
|
||||
|
||||
def ref_count(self):
|
||||
"""
|
||||
|
@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
expression: Expression to traverse
|
||||
|
||||
Returns:
|
||||
list[Scope]: scope instances
|
||||
A list of the created scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.Query) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
|
||||
):
|
||||
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
|
||||
# We ignore the DDL expression and build a scope for its query instead
|
||||
ddl_with = expression.args.get("with")
|
||||
expression = expression.expression
|
||||
|
||||
# If the DDL has CTEs attached, we need to add them to the query, or
|
||||
# prepend them if the query itself already has CTEs attached to it
|
||||
if ddl_with:
|
||||
ddl_with.pop()
|
||||
query_ctes = expression.ctes
|
||||
if not query_ctes:
|
||||
expression.set("with", ddl_with)
|
||||
else:
|
||||
expression.args["with"].set("recursive", ddl_with.recursive)
|
||||
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
|
||||
|
||||
if isinstance(expression, exp.Query):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
return []
|
||||
|
@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
|||
Build a scope tree.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to build the scope tree for
|
||||
expression: Expression to build the scope tree for.
|
||||
|
||||
Returns:
|
||||
Scope: root scope
|
||||
The root scope
|
||||
"""
|
||||
scopes = traverse_scope(expression)
|
||||
if scopes:
|
||||
return scopes[-1]
|
||||
return None
|
||||
return seq_get(traverse_scope(expression), -1)
|
||||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_union(scope)
|
||||
return
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
if scope.is_root:
|
||||
yield from _traverse_select(scope)
|
||||
|
@ -523,8 +546,6 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
yield from _traverse_udtfs(scope)
|
||||
elif isinstance(scope.expression, exp.DDL):
|
||||
yield from _traverse_ddl(scope)
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
|
@ -541,30 +562,38 @@ def _traverse_select(scope):
|
|||
|
||||
|
||||
def _traverse_union(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
prev_scope = None
|
||||
union_scope_stack = [scope]
|
||||
expression_stack = [scope.expression.right, scope.expression.left]
|
||||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
for left in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.left,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
while expression_stack:
|
||||
expression = expression_stack.pop()
|
||||
union_scope = union_scope_stack[-1]
|
||||
|
||||
new_scope = union_scope.branch(
|
||||
expression,
|
||||
outer_columns=union_scope.outer_columns,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield left
|
||||
|
||||
right = None
|
||||
for right in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.right,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield right
|
||||
if isinstance(expression, exp.Union):
|
||||
yield from _traverse_ctes(new_scope)
|
||||
|
||||
scope.union_scopes = [left, right]
|
||||
union_scope_stack.append(new_scope)
|
||||
expression_stack.extend([expression.right, expression.left])
|
||||
continue
|
||||
|
||||
for scope in _traverse_scope(new_scope):
|
||||
yield scope
|
||||
|
||||
if prev_scope:
|
||||
union_scope_stack.pop()
|
||||
union_scope.union_scopes = [prev_scope, scope]
|
||||
prev_scope = union_scope
|
||||
|
||||
yield union_scope
|
||||
else:
|
||||
prev_scope = scope
|
||||
|
||||
|
||||
def _traverse_ctes(scope):
|
||||
|
@ -588,7 +617,7 @@ def _traverse_ctes(scope):
|
|||
scope.branch(
|
||||
cte.this,
|
||||
cte_sources=sources,
|
||||
outer_column_list=cte.alias_column_names,
|
||||
outer_columns=cte.alias_column_names,
|
||||
scope_type=ScopeType.CTE,
|
||||
)
|
||||
):
|
||||
|
@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
|
|||
as it doesn't introduce a new scope. If an alias is present, it shadows all names
|
||||
under the Subquery, so that's one exception to this rule.
|
||||
"""
|
||||
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
|
||||
return isinstance(expression, exp.Subquery) and bool(
|
||||
expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
|
||||
)
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
|
@ -681,7 +712,7 @@ def _traverse_tables(scope):
|
|||
scope.branch(
|
||||
expression,
|
||||
lateral_sources=lateral_sources,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
outer_columns=expression.alias_column_names,
|
||||
scope_type=scope_type,
|
||||
)
|
||||
):
|
||||
|
@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
|
|||
|
||||
sources = {}
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
|
||||
if _is_derived_table(expression):
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
scope_type=ScopeType.DERIVED_TABLE,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
outer_columns=expression.alias_column_names,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
|
|||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_ddl(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
|
||||
query_scope = scope.branch(
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
|
||||
)
|
||||
query_scope._collect()
|
||||
query_scope._ctes = scope.ctes + query_scope._ctes
|
||||
|
||||
yield from _traverse_scope(query_scope)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True, prune=None):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
|
@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
|
|||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
crossed_scope_boundary = False
|
||||
|
||||
for node, parent, key in expression.walk(
|
||||
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
|
||||
for node in expression.walk(
|
||||
bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
|
||||
):
|
||||
crossed_scope_boundary = False
|
||||
|
||||
yield node, parent, key
|
||||
yield node
|
||||
|
||||
if node is expression:
|
||||
continue
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.UNWRAPPED_QUERIES)
|
||||
):
|
||||
crossed_scope_boundary = True
|
||||
|
@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
|
|||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, *_ in walk_in_scope(expression, bfs=bfs):
|
||||
for expression in walk_in_scope(expression, bfs=bfs):
|
||||
if isinstance(expression, tuple(ensure_collection(expression_types))):
|
||||
yield expression
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue