Adding upstream version 6.0.4.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d01130b3f1
commit
527597d2af
122 changed files with 23162 additions and 0 deletions
438
sqlglot/optimizer/scope.py
Normal file
438
sqlglot/optimizer/scope.py
Normal file
|
@ -0,0 +1,438 @@
|
|||
from copy import copy
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
|
||||
|
||||
class ScopeType(Enum):
|
||||
ROOT = auto()
|
||||
SUBQUERY = auto()
|
||||
DERIVED_TABLE = auto()
|
||||
CTE = auto()
|
||||
UNION = auto()
|
||||
UNNEST = auto()
|
||||
|
||||
|
||||
class Scope:
|
||||
"""
|
||||
Selection scope.
|
||||
|
||||
Attributes:
|
||||
expression (exp.Select|exp.Union): Root expression of this scope
|
||||
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
|
||||
a Table expression or another Scope instance. For example:
|
||||
SELECT * FROM x {"x": Table(this="x")}
|
||||
SELECT * FROM x AS y {"y": Table(this="x")}
|
||||
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
SELECT * FROM (SELECT ...) AS y(col1, col2)
|
||||
The inner query would have `["col1", "col2"]` for its `outer_column_list`
|
||||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
|
||||
This does not include derived tables or CTEs.
|
||||
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
|
||||
a tuple of the left and right child scopes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
sources=None,
|
||||
outer_column_list=None,
|
||||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
self.union = None
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
self._collected = False
|
||||
self._raw_columns = None
|
||||
self._derived_tables = None
|
||||
self._tables = None
|
||||
self._ctes = None
|
||||
self._subqueries = None
|
||||
self._selected_sources = None
|
||||
self._columns = None
|
||||
self._external_columns = None
|
||||
|
||||
def branch(self, expression, scope_type, add_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
sources = copy(self.sources)
|
||||
if add_sources:
|
||||
sources.update(add_sources)
|
||||
return Scope(
|
||||
expression=expression.unnest(),
|
||||
sources=sources,
|
||||
parent=self,
|
||||
scope_type=scope_type,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _collect(self):
|
||||
self._tables = []
|
||||
self._ctes = []
|
||||
self._subqueries = []
|
||||
self._derived_tables = []
|
||||
self._raw_columns = []
|
||||
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
|
||||
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
|
||||
prune = False
|
||||
|
||||
if node is self.expression:
|
||||
continue
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table):
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, (exp.Unnest, exp.Lateral)):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(
|
||||
parent, (exp.From, exp.Join)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
self._subqueries.append(node)
|
||||
prune = True
|
||||
|
||||
self._collected = True
|
||||
|
||||
def _ensure_collected(self):
|
||||
if not self._collected:
|
||||
self._collect()
|
||||
|
||||
def replace(self, old, new):
|
||||
"""
|
||||
Replace `old` with `new`.
|
||||
|
||||
This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
|
||||
|
||||
Args:
|
||||
old (exp.Expression): old node
|
||||
new (exp.Expression): new node
|
||||
"""
|
||||
old.replace(new)
|
||||
self.clear_cache()
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
"""
|
||||
List of tables in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.Table]: tables
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._tables
|
||||
|
||||
@property
|
||||
def ctes(self):
|
||||
"""
|
||||
List of CTEs in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.CTE]: ctes
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._ctes
|
||||
|
||||
@property
|
||||
def derived_tables(self):
|
||||
"""
|
||||
List of derived tables in this scope.
|
||||
|
||||
For example:
|
||||
SELECT * FROM (SELECT ...) <- that's a derived table
|
||||
|
||||
Returns:
|
||||
list[exp.Subquery]: derived tables
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._derived_tables
|
||||
|
||||
@property
|
||||
def subqueries(self):
|
||||
"""
|
||||
List of subqueries in this scope.
|
||||
|
||||
For example:
|
||||
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
|
||||
|
||||
Returns:
|
||||
list[exp.Subqueryable]: subqueries
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._subqueries
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
"""
|
||||
List of columns in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.Column]: Column instances in this scope, plus any
|
||||
Columns that reference this scope from correlated subqueries.
|
||||
"""
|
||||
if self._columns is None:
|
||||
self._ensure_collected()
|
||||
columns = self._raw_columns
|
||||
|
||||
external_columns = [
|
||||
column
|
||||
for scope in self.subquery_scopes
|
||||
for column in scope.external_columns
|
||||
]
|
||||
|
||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||
|
||||
self._columns = [
|
||||
c
|
||||
for c in columns + external_columns
|
||||
if not (
|
||||
c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
|
||||
)
|
||||
]
|
||||
return self._columns
|
||||
|
||||
@property
|
||||
def selected_sources(self):
|
||||
"""
|
||||
Mapping of nodes and sources that are actually selected from in this scope.
|
||||
|
||||
That is, all tables in a schema are selectable at any point. But a
|
||||
table only becomes a selected source if it's included in a FROM or JOIN clause.
|
||||
|
||||
Returns:
|
||||
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
|
||||
"""
|
||||
if self._selected_sources is None:
|
||||
referenced_names = []
|
||||
|
||||
for table in self.tables:
|
||||
referenced_names.append(
|
||||
(
|
||||
table.parent.alias
|
||||
if isinstance(table.parent, exp.Alias)
|
||||
else table.name,
|
||||
table,
|
||||
)
|
||||
)
|
||||
for derived_table in self.derived_tables:
|
||||
referenced_names.append((derived_table.alias, derived_table.unnest()))
|
||||
|
||||
result = {}
|
||||
|
||||
for name, node in referenced_names:
|
||||
if name in self.sources:
|
||||
result[name] = (node, self.sources[name])
|
||||
|
||||
self._selected_sources = result
|
||||
return self._selected_sources
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
"""
|
||||
Select expressions of this scope.
|
||||
|
||||
For example, for the following expression:
|
||||
SELECT 1 as a, 2 as b FROM x
|
||||
|
||||
The outputs are the "1 as a" and "2 as b" expressions.
|
||||
|
||||
Returns:
|
||||
list[exp.Expression]: expressions
|
||||
"""
|
||||
if isinstance(self.expression, exp.Union):
|
||||
return []
|
||||
return self.expression.selects
|
||||
|
||||
@property
|
||||
def external_columns(self):
|
||||
"""
|
||||
Columns that appear to reference sources in outer scopes.
|
||||
|
||||
Returns:
|
||||
list[exp.Column]: Column instances that don't reference
|
||||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
return self._external_columns
|
||||
|
||||
def source_columns(self, source_name):
|
||||
"""
|
||||
Get all columns in the current scope for a particular source.
|
||||
|
||||
Args:
|
||||
source_name (str): Name of the source
|
||||
Returns:
|
||||
list[exp.Column]: Column instances that reference `source_name`
|
||||
"""
|
||||
return [column for column in self.columns if column.table == source_name]
|
||||
|
||||
@property
|
||||
def is_subquery(self):
|
||||
"""Determine if this scope is a subquery"""
|
||||
return self.scope_type == ScopeType.SUBQUERY
|
||||
|
||||
@property
|
||||
def is_unnest(self):
|
||||
"""Determine if this scope is an unnest"""
|
||||
return self.scope_type == ScopeType.UNNEST
|
||||
|
||||
@property
|
||||
def is_correlated_subquery(self):
|
||||
"""Determine if this scope is a correlated subquery"""
|
||||
return bool(self.is_subquery and self.external_columns)
|
||||
|
||||
def rename_source(self, old_name, new_name):
|
||||
"""Rename a source in this scope"""
|
||||
columns = self.sources.pop(old_name or "", [])
|
||||
self.sources[new_name] = columns
|
||||
|
||||
|
||||
def traverse_scope(expression):
|
||||
"""
|
||||
Traverse an expression by it's "scopes".
|
||||
|
||||
"Scope" represents the current context of a Select statement.
|
||||
|
||||
This is helpful for optimizing queries, where we need more information than
|
||||
the expression tree itself. For example, we might care about the source
|
||||
names within a subquery. Returns a list because a generator could result in
|
||||
incomplete properties which is confusing.
|
||||
|
||||
Examples:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
|
||||
>>> scopes = traverse_scope(expression)
|
||||
>>> scopes[0].expression.sql(), list(scopes[0].sources)
|
||||
('SELECT a FROM x', ['x'])
|
||||
>>> scopes[1].expression.sql(), list(scopes[1].sources)
|
||||
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
Returns:
|
||||
List[Scope]: scope instances
|
||||
"""
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
|
||||
pass
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
else:
|
||||
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
|
||||
yield scope
|
||||
|
||||
|
||||
def _traverse_select(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||
yield from _traverse_subqueries(scope)
|
||||
yield from _traverse_derived_tables(
|
||||
scope.derived_tables, scope, ScopeType.DERIVED_TABLE
|
||||
)
|
||||
_add_table_sources(scope)
|
||||
|
||||
|
||||
def _traverse_union(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
|
||||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
for left in _traverse_scope(
|
||||
scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
|
||||
):
|
||||
yield left
|
||||
|
||||
right = None
|
||||
for right in _traverse_scope(
|
||||
scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
|
||||
):
|
||||
yield right
|
||||
|
||||
scope.union = (left, right)
|
||||
|
||||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
sources = {}
|
||||
|
||||
for derived_table in derived_tables:
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table
|
||||
if isinstance(derived_table, (exp.Unnest, exp.Lateral))
|
||||
else derived_table.this,
|
||||
add_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
outer_column_list=derived_table.alias_column_names,
|
||||
scope_type=ScopeType.UNNEST
|
||||
if isinstance(derived_table, exp.Unnest)
|
||||
else scope_type,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
# Tables without aliases will be set as ""
|
||||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
sources[derived_table.alias] = child_scope
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _add_table_sources(scope):
|
||||
sources = {}
|
||||
for table in scope.tables:
|
||||
table_name = table.name
|
||||
|
||||
if isinstance(table.parent, exp.Alias):
|
||||
source_name = table.parent.alias
|
||||
else:
|
||||
source_name = table_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
scope.sources[source_name] = scope.sources[table_name]
|
||||
elif source_name in scope.sources:
|
||||
raise OptimizeError(f"Duplicate table name: {source_name}")
|
||||
else:
|
||||
sources[source_name] = table
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_subqueries(scope):
|
||||
for subquery in scope.subqueries:
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
|
||||
):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
scope.subquery_scopes.append(top)
|
Loading…
Add table
Add a link
Reference in a new issue