1
0
Fork 0
sqlglot/sqlglot/optimizer/scope.py
Daniel Baumann 49aa147013
Merging upstream version 23.10.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 21:31:23 +01:00

846 lines
27 KiB
Python

from __future__ import annotations
import itertools
import logging
import typing as t
from collections import defaultdict
from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
class ScopeType(Enum):
ROOT = auto()
SUBQUERY = auto()
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UDTF = auto()
class Scope:
"""
Selection scope.
Attributes:
expression (exp.Select|exp.Union): Root expression of this scope
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
a Table expression or another Scope instance. For example:
SELECT * FROM x {"x": Table(this="x")}
SELECT * FROM x AS y {"y": Table(this="x")}
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
For example:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_columns`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
cte_scopes (list[Scope]): List of all child scopes for CTEs
derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
def __init__(
self,
expression,
sources=None,
outer_columns=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
cte_sources=None,
):
self.expression = expression
self.sources = sources or {}
self.lateral_sources = lateral_sources or {}
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
self.outer_columns = outer_columns or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.derived_table_scopes = []
self.table_scopes = []
self.cte_scopes = []
self.union_scopes = []
self.udtf_scopes = []
self.clear_cache()
def clear_cache(self):
self._collected = False
self._raw_columns = None
self._derived_tables = None
self._udtfs = None
self._tables = None
self._ctes = None
self._subqueries = None
self._selected_sources = None
self._columns = None
self._external_columns = None
self._join_hints = None
self._pivots = None
self._references = None
def branch(
self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
):
"""Branch from the current scope to a new, inner scope"""
return Scope(
expression=expression.unnest(),
sources=sources.copy() if sources else None,
parent=self,
scope_type=scope_type,
cte_sources={**self.cte_sources, **(cte_sources or {})},
lateral_sources=lateral_sources.copy() if lateral_sources else None,
**kwargs,
)
def _collect(self):
self._tables = []
self._ctes = []
self._subqueries = []
self._derived_tables = []
self._udtfs = []
self._raw_columns = []
self._join_hints = []
for node in self.walk(bfs=False):
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
elif isinstance(node, exp.UDTF):
self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
elif _is_derived_table(node) and isinstance(
node.parent, (exp.From, exp.Join, exp.Subquery)
):
self._derived_tables.append(node)
elif isinstance(node, exp.UNWRAPPED_QUERIES):
self._subqueries.append(node)
self._collected = True
def _ensure_collected(self):
if not self._collected:
self._collect()
def walk(self, bfs=True, prune=None):
return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
"""
Replace `old` with `new`.
This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
Args:
old (exp.Expression): old node
new (exp.Expression): new node
"""
old.replace(new)
self.clear_cache()
@property
def tables(self):
"""
List of tables in this scope.
Returns:
list[exp.Table]: tables
"""
self._ensure_collected()
return self._tables
@property
def ctes(self):
"""
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
"""
self._ensure_collected()
return self._ctes
@property
def derived_tables(self):
"""
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
"""
self._ensure_collected()
return self._derived_tables
@property
def udtfs(self):
"""
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
"""
self._ensure_collected()
return self._udtfs
@property
def subqueries(self):
"""
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
"""
self._ensure_collected()
return self._subqueries
@property
def columns(self):
"""
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any
Columns that reference this scope from correlated subqueries.
"""
if self._columns is None:
self._ensure_collected()
columns = self._raw_columns
external_columns = [
column
for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
for column in scope.external_columns
]
named_selects = set(self.expression.named_selects)
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(
exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star
)
if (
not ancestor
or column.table
or isinstance(ancestor, exp.Select)
or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
or (
isinstance(ancestor, exp.Order)
and (
isinstance(ancestor.parent, exp.Window)
or column.name not in named_selects
)
)
):
self._columns.append(column)
return self._columns
@property
def selected_sources(self):
"""
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a
table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
"""
if self._selected_sources is None:
result = {}
for name, node in self.references:
if name in result:
raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
result[name] = (node, self.sources[name])
self._selected_sources = result
return self._selected_sources
@property
def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
if self._references is None:
self._references = []
for table in self.tables:
self._references.append((table.alias_or_name, table))
for expression in itertools.chain(self.derived_tables, self.udtfs):
self._references.append(
(
expression.alias,
expression if expression.args.get("pivots") else expression.unnest(),
)
)
return self._references
@property
def external_columns(self):
"""
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference
sources in the current scope.
"""
if self._external_columns is None:
if isinstance(self.expression, exp.Union):
left, right = self.union_scopes
self._external_columns = left.external_columns + right.external_columns
else:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns
@property
def unqualified_columns(self):
"""
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
"""
return [c for c in self.columns if not c.table]
@property
def join_hints(self):
"""
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
"""
if self._join_hints is None:
return []
return self._join_hints
@property
def pivots(self):
if not self._pivots:
self._pivots = [
pivot for _, node in self.references for pivot in node.args.get("pivots") or []
]
return self._pivots
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
Args:
source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference `source_name`
"""
return [column for column in self.columns if column.table == source_name]
@property
def is_subquery(self):
"""Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY
@property
def is_derived_table(self):
"""Determine if this scope is a derived table"""
return self.scope_type == ScopeType.DERIVED_TABLE
@property
def is_union(self):
"""Determine if this scope is a union"""
return self.scope_type == ScopeType.UNION
@property
def is_cte(self):
"""Determine if this scope is a common table expression"""
return self.scope_type == ScopeType.CTE
@property
def is_root(self):
"""Determine if this is the root scope"""
return self.scope_type == ScopeType.ROOT
@property
def is_udtf(self):
"""Determine if this scope is a UDTF (User Defined Table Function)"""
return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
return bool(
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
and self.external_columns
)
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns
def add_source(self, name, source):
"""Add a source to this scope"""
self.sources[name] = source
self.clear_cache()
def remove_source(self, name):
"""Remove a source from this scope"""
self.sources.pop(name, None)
self.clear_cache()
def __repr__(self):
return f"Scope<{self.expression.sql()}>"
def traverse(self):
"""
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
"""
stack = [self]
result = []
while stack:
scope = stack.pop()
result.append(scope)
stack.extend(
itertools.chain(
scope.cte_scopes,
scope.union_scopes,
scope.table_scopes,
scope.subquery_scopes,
)
)
yield from reversed(result)
def ref_count(self):
"""
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
"""
scope_ref_count = defaultdict(lambda: 0)
for scope in self.traverse():
for _, source in scope.selected_sources.values():
scope_ref_count[id(source)] += 1
return scope_ref_count
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than
the expression tree itself. For example, we might care about the source
names within a subquery. Returns a list because a generator could result in
incomplete properties which is confusing.
Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
expression: Expression to traverse
Returns:
A list of the created scope instances
"""
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
# We ignore the DDL expression and build a scope for its query instead
ddl_with = expression.args.get("with")
expression = expression.expression
# If the DDL has CTEs attached, we need to add them to the query, or
# prepend them if the query itself already has CTEs attached to it
if ddl_with:
ddl_with.pop()
query_ctes = expression.ctes
if not query_ctes:
expression.set("with", ddl_with)
else:
expression.args["with"].set("recursive", ddl_with.recursive)
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
if isinstance(expression, exp.Query):
return list(_traverse_scope(Scope(expression)))
return []
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
"""
Build a scope tree.
Args:
expression: Expression to build the scope tree for.
Returns:
The root scope
"""
return seq_get(traverse_scope(expression), -1)
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
elif isinstance(scope.expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
else:
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
yield from _traverse_udtfs(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
)
return
yield scope
def _traverse_select(scope):
yield from _traverse_ctes(scope)
yield from _traverse_tables(scope)
yield from _traverse_subqueries(scope)
def _traverse_union(scope):
prev_scope = None
union_scope_stack = [scope]
expression_stack = [scope.expression.right, scope.expression.left]
while expression_stack:
expression = expression_stack.pop()
union_scope = union_scope_stack[-1]
new_scope = union_scope.branch(
expression,
outer_columns=union_scope.outer_columns,
scope_type=ScopeType.UNION,
)
if isinstance(expression, exp.Union):
yield from _traverse_ctes(new_scope)
union_scope_stack.append(new_scope)
expression_stack.extend([expression.right, expression.left])
continue
for scope in _traverse_scope(new_scope):
yield scope
if prev_scope:
union_scope_stack.pop()
union_scope.union_scopes = [prev_scope, scope]
prev_scope = union_scope
yield union_scope
else:
prev_scope = scope
def _traverse_ctes(scope):
sources = {}
for cte in scope.ctes:
cte_name = cte.alias
# if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
# thus the recursive scope is the first section of the union.
with_ = scope.expression.args.get("with")
if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
child_scope = None
for child_scope in _traverse_scope(
scope.branch(
cte.this,
cte_sources=sources,
outer_columns=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
yield child_scope
# append the final child_scope yielded
if child_scope:
sources[cte_name] = child_scope
scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
scope.cte_sources.update(sources)
def _is_derived_table(expression: exp.Subquery) -> bool:
"""
We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
return isinstance(expression, exp.Subquery) and bool(
expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
)
def _traverse_tables(scope):
sources = {}
# Traverse FROMs, JOINs, and LATERALs in the order they are defined
expressions = []
from_ = scope.expression.args.get("from")
if from_:
expressions.append(from_.this)
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
if isinstance(scope.expression, exp.Table):
expressions.append(scope.expression)
expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions:
if isinstance(expression, exp.Table):
table_name = expression.name
source_name = expression.alias_or_name
if table_name in scope.sources and not expression.db:
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")
if pivots:
sources[pivots[0].alias] = expression
else:
sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:
sources[source_name] = expression
# Make sure to not include the joins twice
if expression is not scope.expression:
expressions.extend(join.this for join in expression.args.get("joins") or [])
continue
if not isinstance(expression, exp.DerivedTable):
continue
if isinstance(expression, exp.UDTF):
lateral_sources = sources
scope_type = ScopeType.UDTF
scopes = scope.udtf_scopes
elif _is_derived_table(expression):
lateral_sources = None
scope_type = ScopeType.DERIVED_TABLE
scopes = scope.derived_table_scopes
expressions.extend(join.this for join in expression.args.get("joins") or [])
else:
# Makes sure we check for possible sources in nested table constructs
expressions.append(expression.this)
expressions.extend(join.this for join in expression.args.get("joins") or [])
continue
for child_scope in _traverse_scope(
scope.branch(
expression,
lateral_sources=lateral_sources,
outer_columns=expression.alias_column_names,
scope_type=scope_type,
)
):
yield child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[expression.alias] = child_scope
# append the final child_scope yielded
scopes.append(child_scope)
scope.table_scopes.append(child_scope)
scope.sources.update(sources)
def _traverse_subqueries(scope):
for subquery in scope.subqueries:
top = None
for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
def _traverse_udtfs(scope):
if isinstance(scope.expression, exp.Unnest):
expressions = scope.expression.expressions
elif isinstance(scope.expression, exp.Lateral):
expressions = [scope.expression.this]
else:
expressions = []
sources = {}
for expression in expressions:
if _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
outer_columns=expression.alias_column_names,
)
):
yield child_scope
top = child_scope
sources[expression.alias] = child_scope
scope.derived_table_scopes.append(top)
scope.table_scopes.append(top)
scope.sources.update(sources)
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
Args:
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
prune ((node, parent, arg_key) -> bool): callable that returns True if
the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
crossed_scope_boundary = False
for node in expression.walk(
bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
):
crossed_scope_boundary = False
yield node
if node is expression:
continue
if (
isinstance(node, exp.CTE)
or (
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
if isinstance(node, (exp.Subquery, exp.UDTF)):
# The following args are not actually in the inner scope, so we should visit them
for key in ("joins", "laterals", "pivots"):
for arg in node.args.get(key) or []:
yield from walk_in_scope(arg, bfs=bfs)
def find_all_in_scope(expression, expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression
def find_in_scope(expression, expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)