1
0
Fork 0

Merging upstream version 23.13.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:33:25 +01:00
parent 63a75c51ff
commit 64041d1d66
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
85 changed files with 53899 additions and 50390 deletions

View file

@ -12,6 +12,8 @@ from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
class ScopeType(Enum):
ROOT = auto()
@ -495,25 +497,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
A list of the created scope instances
"""
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
# We ignore the DDL expression and build a scope for its query instead
ddl_with = expression.args.get("with")
expression = expression.expression
# If the DDL has CTEs attached, we need to add them to the query, or
# prepend them if the query itself already has CTEs attached to it
if ddl_with:
ddl_with.pop()
query_ctes = expression.ctes
if not query_ctes:
expression.set("with", ddl_with)
else:
expression.args["with"].set("recursive", ddl_with.recursive)
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
if isinstance(expression, exp.Query):
if isinstance(expression, TRAVERSABLES):
return list(_traverse_scope(Scope(expression)))
return []
@ -531,25 +516,37 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
expression = scope.expression
if isinstance(expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
elif isinstance(expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
elif isinstance(scope.expression, exp.Subquery):
elif isinstance(expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
else:
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
elif isinstance(expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
elif isinstance(expression, exp.UDTF):
yield from _traverse_udtfs(scope)
elif isinstance(expression, exp.DDL):
if isinstance(expression.expression, exp.Query):
yield from _traverse_ctes(scope)
yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
return
elif isinstance(expression, exp.DML):
yield from _traverse_ctes(scope)
for query in find_all_in_scope(expression, exp.Query):
# This check ensures we don't yield the CTE queries twice
if not isinstance(query.parent, exp.CTE):
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
return
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
)
logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
return
yield scope
@ -749,7 +746,7 @@ def _traverse_udtfs(scope):
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
scope_type=ScopeType.SUBQUERY,
outer_columns=expression.alias_column_names,
)
):
@ -757,8 +754,7 @@ def _traverse_udtfs(scope):
top = child_scope
sources[expression.alias] = child_scope
scope.derived_table_scopes.append(top)
scope.table_scopes.append(top)
scope.subquery_scopes.append(top)
scope.sources.update(sources)