1
0
Fork 0
sqlglot/sqlglot/optimizer/scope.py
Daniel Baumann 66e2d714bf
Merging upstream version 6.2.6.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 14:40:43 +01:00

485 lines
15 KiB
Python

import itertools
from copy import copy
from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
class ScopeType(Enum):
ROOT = auto()
SUBQUERY = auto()
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UNNEST = auto()
class Scope:
"""
Selection scope.
Attributes:
expression (exp.Select|exp.Union): Root expression of this scope
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
a Table expression or another Scope instance. For example:
SELECT * FROM x {"x": Table(this="x")}
SELECT * FROM x AS y {"y": Table(this="x")}
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
cte_scopes = (list[Scope]) List of all child scopes for CTEs
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
def __init__(
self,
expression,
sources=None,
outer_column_list=None,
parent=None,
scope_type=ScopeType.ROOT,
):
self.expression = expression
self.sources = sources or {}
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.derived_table_scopes = []
self.cte_scopes = []
self.union_scopes = []
self.clear_cache()
def clear_cache(self):
self._collected = False
self._raw_columns = None
self._derived_tables = None
self._tables = None
self._ctes = None
self._subqueries = None
self._selected_sources = None
self._columns = None
self._external_columns = None
def branch(self, expression, scope_type, add_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
sources = copy(self.sources)
if add_sources:
sources.update(add_sources)
return Scope(
expression=expression.unnest(),
sources=sources,
parent=self,
scope_type=scope_type,
**kwargs,
)
def _collect(self):
self._tables = []
self._ctes = []
self._subqueries = []
self._derived_tables = []
self._raw_columns = []
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
prune = False
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
elif isinstance(node, (exp.Unnest, exp.Lateral)):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
prune = True
self._collected = True
def _ensure_collected(self):
if not self._collected:
self._collect()
def replace(self, old, new):
"""
Replace `old` with `new`.
This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
Args:
old (exp.Expression): old node
new (exp.Expression): new node
"""
old.replace(new)
self.clear_cache()
@property
def tables(self):
"""
List of tables in this scope.
Returns:
list[exp.Table]: tables
"""
self._ensure_collected()
return self._tables
@property
def ctes(self):
"""
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
"""
self._ensure_collected()
return self._ctes
@property
def derived_tables(self):
"""
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
"""
self._ensure_collected()
return self._derived_tables
@property
def subqueries(self):
"""
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
"""
self._ensure_collected()
return self._subqueries
@property
def columns(self):
"""
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any
Columns that reference this scope from correlated subqueries.
"""
if self._columns is None:
self._ensure_collected()
columns = self._raw_columns
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
if (
not ancestor
or column.table
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
return self._columns
@property
def selected_sources(self):
"""
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a
table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
"""
if self._selected_sources is None:
referenced_names = []
for table in self.tables:
referenced_names.append(
(
table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
table,
)
)
for derived_table in self.derived_tables:
referenced_names.append((derived_table.alias, derived_table.unnest()))
result = {}
for name, node in referenced_names:
if name in self.sources:
result[name] = (node, self.sources[name])
self._selected_sources = result
return self._selected_sources
@property
def selects(self):
"""
Select expressions of this scope.
For example, for the following expression:
SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
"""
if isinstance(self.expression, exp.Union):
return []
return self.expression.selects
@property
def external_columns(self):
"""
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference
sources in the current scope.
"""
if self._external_columns is None:
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
return self._external_columns
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
Args:
source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference `source_name`
"""
return [column for column in self.columns if column.table == source_name]
@property
def is_subquery(self):
"""Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY
@property
def is_derived_table(self):
"""Determine if this scope is a derived table"""
return self.scope_type == ScopeType.DERIVED_TABLE
@property
def is_union(self):
"""Determine if this scope is a union"""
return self.scope_type == ScopeType.UNION
@property
def is_cte(self):
"""Determine if this scope is a common table expression"""
return self.scope_type == ScopeType.CTE
@property
def is_root(self):
"""Determine if this is the root scope"""
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
return bool(self.is_subquery and self.external_columns)
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns
def add_source(self, name, source):
"""Add a source to this scope"""
self.sources[name] = source
self.clear_cache()
def remove_source(self, name):
"""Remove a source from this scope"""
self.sources.pop(name, None)
self.clear_cache()
def __repr__(self):
return f"Scope<{self.expression.sql()}>"
def traverse(self):
"""
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
):
yield from child_scope.traverse()
yield self
def traverse_scope(expression):
"""
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than
the expression tree itself. For example, we might care about the source
names within a subquery. Returns a list because a generator could result in
incomplete properties which is confusing.
Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
"""
return list(_traverse_scope(Scope(expression)))
def build_scope(expression):
"""
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
"""
return traverse_scope(expression)[-1]
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
else:
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
yield scope
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
_add_table_sources(scope)
def _traverse_union(scope):
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
yield left
right = None
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
scope.union_scopes = [left, right]
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
for derived_table in derived_tables:
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
)
):
yield child_scope
top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
else:
scope.derived_table_scopes.append(top)
scope.sources.update(sources)
def _add_table_sources(scope):
sources = {}
for table in scope.tables:
table_name = table.name
if isinstance(table.parent, exp.Alias):
source_name = table.parent.alias
else:
source_name = table_name
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
else:
sources[source_name] = table
scope.sources.update(sources)
def _traverse_subqueries(scope):
for subquery in scope.subqueries:
top = None
for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)