Merging upstream version 11.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8c1c1864c5
commit
fb546b57e5
95 changed files with 32569 additions and 30081 deletions
|
@ -280,6 +280,9 @@ class TypeAnnotator:
|
|||
}
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
|
|
|
@ -81,9 +81,7 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(cte_scope.expression.parent)
|
||||
|
||||
# Now append the rest
|
||||
for scope in itertools.chain(
|
||||
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
|
||||
):
|
||||
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
|
||||
for child_scope in scope.traverse():
|
||||
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
|
@ -99,7 +97,7 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
if scope.is_union:
|
||||
return _eliminate_union(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
||||
if scope.is_derived_table:
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_cte:
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import Schema, exp
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.canonicalize import canonicalize
|
||||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
|
@ -24,8 +30,8 @@ RULES = (
|
|||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
expand_laterals,
|
||||
validate_qualify_columns,
|
||||
pushdown_projections,
|
||||
validate_qualify_columns,
|
||||
normalize,
|
||||
unnest_subqueries,
|
||||
expand_multi_table_selects,
|
||||
|
@ -40,22 +46,31 @@ RULES = (
|
|||
)
|
||||
|
||||
|
||||
def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
|
||||
def optimize(
|
||||
expression: str | exp.Expression,
|
||||
schema: t.Optional[dict | Schema] = None,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
dialect: DialectType = None,
|
||||
rules: t.Sequence[t.Callable] = RULES,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Rewrite a sqlglot AST into an optimized form.
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
schema (dict|sqlglot.optimizer.Schema): database schema.
|
||||
expression: expression to optimize
|
||||
schema: database schema.
|
||||
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
|
||||
the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
rules (sequence): sequence of optimizer rules to use.
|
||||
db: specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
dialect: The dialect to parse the sql string.
|
||||
rules: sequence of optimizer rules to use.
|
||||
Many of the rules require tables and columns to be qualified.
|
||||
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
|
||||
what you're doing!
|
||||
|
@ -65,7 +80,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
"""
|
||||
schema = ensure_schema(schema or sqlglot.schema)
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
expression = expression.copy()
|
||||
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.helper import flatten
|
||||
from sqlglot.optimizer.qualify_columns import Resolver
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
@ -10,7 +13,7 @@ SELECT_ALL = object()
|
|||
DEFAULT_SELECTION = lambda: alias("1", "_")
|
||||
|
||||
|
||||
def pushdown_projections(expression):
|
||||
def pushdown_projections(expression, schema=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to remove unused columns projections.
|
||||
|
||||
|
@ -27,9 +30,9 @@ def pushdown_projections(expression):
|
|||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
# Map of Scope to all columns being selected by outer queries.
|
||||
schema = ensure_schema(schema)
|
||||
referenced_columns = defaultdict(set)
|
||||
left_union = None
|
||||
right_union = None
|
||||
|
||||
# We build the scope tree (which is traversed in DFS postorder), then iterate
|
||||
# over the result in reverse order. This should ensure that the set of selected
|
||||
# columns for a particular scope are completely build by the time we get to it.
|
||||
|
@ -41,16 +44,20 @@ def pushdown_projections(expression):
|
|||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
left_union, right_union = scope.union_scopes
|
||||
referenced_columns[left_union] = parent_selections
|
||||
referenced_columns[right_union] = parent_selections
|
||||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
|
||||
if isinstance(scope.expression, exp.Select) and scope != right_union:
|
||||
removed_indexes = _remove_unused_selections(scope, parent_selections)
|
||||
# The left union is used for column names to select and if we remove columns from the left
|
||||
# we need to also remove those same columns in the right that were at the same position
|
||||
if scope is left_union:
|
||||
_remove_indexed_selections(right_union, removed_indexes)
|
||||
if any(select.is_star for select in right.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.selects):
|
||||
referenced_columns[right] = [
|
||||
right.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.selects)
|
||||
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
|
||||
]
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
_remove_unused_selections(scope, parent_selections, schema)
|
||||
|
||||
# Group columns by source name
|
||||
selects = defaultdict(set)
|
||||
|
@ -68,8 +75,7 @@ def pushdown_projections(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _remove_unused_selections(scope, parent_selections):
|
||||
removed_indexes = []
|
||||
def _remove_unused_selections(scope, parent_selections, schema):
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
|
@ -78,33 +84,33 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
else:
|
||||
order_refs = set()
|
||||
|
||||
new_selections = []
|
||||
new_selections = defaultdict(list)
|
||||
removed = False
|
||||
for i, selection in enumerate(scope.selects):
|
||||
if (
|
||||
SELECT_ALL in parent_selections
|
||||
or selection.alias_or_name in parent_selections
|
||||
or selection.alias_or_name in order_refs
|
||||
):
|
||||
new_selections.append(selection)
|
||||
star = False
|
||||
for selection in scope.selects:
|
||||
name = selection.alias_or_name
|
||||
|
||||
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
|
||||
new_selections[name].append(selection)
|
||||
else:
|
||||
removed_indexes.append(i)
|
||||
if selection.is_star:
|
||||
star = True
|
||||
removed = True
|
||||
|
||||
if star:
|
||||
resolver = Resolver(scope, schema)
|
||||
|
||||
for name in sorted(parent_selections):
|
||||
if name not in new_selections:
|
||||
new_selections[name].append(
|
||||
alias(exp.column(name, table=resolver.get_table(name)), name)
|
||||
)
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
new_selections[""].append(DEFAULT_SELECTION())
|
||||
|
||||
scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
if removed:
|
||||
scope.clear_cache()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _remove_indexed_selections(scope, indexes_to_remove):
|
||||
new_selections = [
|
||||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||
]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -27,17 +27,16 @@ def qualify_columns(expression, schema):
|
|||
schema = ensure_schema(schema)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = _Resolver(scope, schema)
|
||||
resolver = Resolver(scope, schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
_pop_table_column_aliases(scope.derived_tables)
|
||||
_expand_using(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -48,7 +47,8 @@ def validate_qualify_columns(expression):
|
|||
if isinstance(scope.expression, exp.Select):
|
||||
unqualified_columns.extend(scope.unqualified_columns)
|
||||
if scope.external_columns and not scope.is_correlated_subquery:
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
|
||||
column = scope.external_columns[0]
|
||||
raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
|
||||
|
||||
if unqualified_columns:
|
||||
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
|
||||
|
@ -62,8 +62,6 @@ def _pop_table_column_aliases(derived_tables):
|
|||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table.unnest(), exp.UDTF):
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
table_alias.args.pop("columns", None)
|
||||
|
@ -206,7 +204,7 @@ def _qualify_columns(scope, resolver):
|
|||
|
||||
if column_table and column_table in scope.sources:
|
||||
source_columns = resolver.get_source_columns(column_table)
|
||||
if source_columns and column_name not in source_columns:
|
||||
if source_columns and column_name not in source_columns and "*" not in source_columns:
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
if not column_table:
|
||||
|
@ -256,7 +254,7 @@ def _expand_stars(scope, resolver):
|
|||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
_add_replace_columns(expression, tables, replace_columns)
|
||||
elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
|
||||
elif expression.is_star:
|
||||
tables = [expression.table]
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
|
@ -268,17 +266,16 @@ def _expand_stars(scope, resolver):
|
|||
if table not in scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {table}")
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
if not columns:
|
||||
raise OptimizeError(
|
||||
f"Table has no schema/columns. Cannot expand star for table: {table}."
|
||||
)
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
|
||||
if columns and "*" not in columns:
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
else:
|
||||
return
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
|
@ -316,7 +313,7 @@ def _qualify_outputs(scope):
|
|||
if isinstance(selection, exp.Subquery):
|
||||
if not selection.output_name:
|
||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||
elif not isinstance(selection, exp.Alias):
|
||||
elif not isinstance(selection, exp.Alias) and not selection.is_star:
|
||||
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
|
@ -329,7 +326,7 @@ def _qualify_outputs(scope):
|
|||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
class _Resolver:
|
||||
class Resolver:
|
||||
"""
|
||||
Helper for resolving columns.
|
||||
|
||||
|
@ -361,7 +358,9 @@ class _Resolver:
|
|||
|
||||
if not table:
|
||||
sources_without_schema = tuple(
|
||||
source for source, columns in self._get_all_source_columns().items() if not columns
|
||||
source
|
||||
for source, columns in self._get_all_source_columns().items()
|
||||
if not columns or "*" in columns
|
||||
)
|
||||
if len(sources_without_schema) == 1:
|
||||
return sources_without_schema[0]
|
||||
|
@ -397,7 +396,8 @@ class _Resolver:
|
|||
def _get_all_source_columns(self):
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
||||
k: self.get_source_columns(k)
|
||||
for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
|
||||
}
|
||||
return self._source_columns
|
||||
|
||||
|
@ -436,7 +436,7 @@ class _Resolver:
|
|||
Find the unique columns in a list of columns.
|
||||
|
||||
Example:
|
||||
>>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
||||
>>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
|
||||
['a', 'c']
|
||||
|
||||
This is necessary because duplicate column names are ambiguous.
|
||||
|
|
|
@ -28,7 +28,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
next_name = lambda: f"_q_{next(sequence)}"
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in scope.ctes + scope.derived_tables:
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = f"_q_{next(sequence)}"
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
|
|
|
@ -26,6 +26,10 @@ class Scope:
|
|||
SELECT * FROM x {"x": Table(this="x")}
|
||||
SELECT * FROM x AS y {"y": Table(this="x")}
|
||||
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
|
||||
lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
|
||||
For example:
|
||||
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
|
||||
The LATERAL VIEW EXPLODE gets x as a source.
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
|
@ -34,8 +38,10 @@ class Scope:
|
|||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries
|
||||
cte_scopes = (list[Scope]) List of all child scopes for CTEs
|
||||
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
|
||||
cte_scopes (list[Scope]): List of all child scopes for CTEs
|
||||
derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
|
||||
udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
|
||||
table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
|
||||
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
|
||||
a list of the left and right child scopes.
|
||||
"""
|
||||
|
@ -47,22 +53,28 @@ class Scope:
|
|||
outer_column_list=None,
|
||||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
|
||||
self.sources.update(self.lateral_sources)
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
self.derived_table_scopes = []
|
||||
self.table_scopes = []
|
||||
self.cte_scopes = []
|
||||
self.union_scopes = []
|
||||
self.udtf_scopes = []
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
self._collected = False
|
||||
self._raw_columns = None
|
||||
self._derived_tables = None
|
||||
self._udtfs = None
|
||||
self._tables = None
|
||||
self._ctes = None
|
||||
self._subqueries = None
|
||||
|
@ -86,6 +98,7 @@ class Scope:
|
|||
self._ctes = []
|
||||
self._subqueries = []
|
||||
self._derived_tables = []
|
||||
self._udtfs = []
|
||||
self._raw_columns = []
|
||||
self._join_hints = []
|
||||
|
||||
|
@ -99,7 +112,7 @@ class Scope:
|
|||
elif isinstance(node, exp.JoinHint):
|
||||
self._join_hints.append(node)
|
||||
elif isinstance(node, exp.UDTF):
|
||||
self._derived_tables.append(node)
|
||||
self._udtfs.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
|
@ -199,6 +212,17 @@ class Scope:
|
|||
self._ensure_collected()
|
||||
return self._derived_tables
|
||||
|
||||
@property
|
||||
def udtfs(self):
|
||||
"""
|
||||
List of "User Defined Tabular Functions" in this scope.
|
||||
|
||||
Returns:
|
||||
list[exp.UDTF]: UDTFs
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._udtfs
|
||||
|
||||
@property
|
||||
def subqueries(self):
|
||||
"""
|
||||
|
@ -227,7 +251,9 @@ class Scope:
|
|||
columns = self._raw_columns
|
||||
|
||||
external_columns = [
|
||||
column for scope in self.subquery_scopes for column in scope.external_columns
|
||||
column
|
||||
for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
|
||||
for column in scope.external_columns
|
||||
]
|
||||
|
||||
named_selects = set(self.expression.named_selects)
|
||||
|
@ -262,9 +288,8 @@ class Scope:
|
|||
|
||||
for table in self.tables:
|
||||
referenced_names.append((table.alias_or_name, table))
|
||||
for derived_table in self.derived_tables:
|
||||
referenced_names.append((derived_table.alias, derived_table.unnest()))
|
||||
|
||||
for expression in itertools.chain(self.derived_tables, self.udtfs):
|
||||
referenced_names.append((expression.alias, expression.unnest()))
|
||||
result = {}
|
||||
|
||||
for name, node in referenced_names:
|
||||
|
@ -414,7 +439,7 @@ class Scope:
|
|||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
|
||||
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
|
@ -480,24 +505,23 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
_set_udtf_scope(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
else:
|
||||
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
|
||||
yield scope
|
||||
|
||||
|
||||
def _traverse_select(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
|
||||
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_tables(scope)
|
||||
yield from _traverse_subqueries(scope)
|
||||
_add_table_sources(scope)
|
||||
|
||||
|
||||
def _traverse_union(scope):
|
||||
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
|
||||
yield from _traverse_ctes(scope)
|
||||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
|
@ -511,44 +535,84 @@ def _traverse_union(scope):
|
|||
scope.union_scopes = [left, right]
|
||||
|
||||
|
||||
def _set_udtf_scope(scope):
|
||||
parent = scope.expression.parent
|
||||
from_ = parent.args.get("from")
|
||||
|
||||
if not from_:
|
||||
return
|
||||
|
||||
for table in from_.expressions:
|
||||
if isinstance(table, exp.Table):
|
||||
scope.tables.append(table)
|
||||
elif isinstance(table, exp.Subquery):
|
||||
scope.subqueries.append(table)
|
||||
_add_table_sources(scope)
|
||||
_traverse_subqueries(scope)
|
||||
|
||||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
def _traverse_ctes(scope):
|
||||
sources = {}
|
||||
is_cte = scope_type == ScopeType.CTE
|
||||
|
||||
for derived_table in derived_tables:
|
||||
for cte in scope.ctes:
|
||||
recursive_scope = None
|
||||
|
||||
# if the scope is a recursive cte, it must be in the form of
|
||||
# base_case UNION recursive. thus the recursive scope is the first
|
||||
# section of the union.
|
||||
if is_cte and scope.expression.args["with"].recursive:
|
||||
union = derived_table.this
|
||||
if scope.expression.args["with"].recursive:
|
||||
union = cte.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
|
||||
chain_sources=sources if scope_type == ScopeType.CTE else None,
|
||||
outer_column_list=derived_table.alias_column_names,
|
||||
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
|
||||
cte.this,
|
||||
chain_sources=sources,
|
||||
outer_column_list=cte.alias_column_names,
|
||||
scope_type=ScopeType.CTE,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
||||
alias = cte.alias
|
||||
sources[alias] = child_scope
|
||||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
|
||||
# append the final child_scope yielded
|
||||
scope.cte_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
sources = {}
|
||||
|
||||
# Traverse FROMs, JOINs, and LATERALs in the order they are defined
|
||||
expressions = []
|
||||
from_ = scope.expression.args.get("from")
|
||||
if from_:
|
||||
expressions.extend(from_.expressions)
|
||||
|
||||
for join in scope.expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
|
||||
expressions.extend(scope.expression.args.get("laterals") or [])
|
||||
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Table):
|
||||
table_name = expression.name
|
||||
source_name = expression.alias_or_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
sources[source_name] = scope.sources[table_name]
|
||||
else:
|
||||
sources[source_name] = expression
|
||||
continue
|
||||
|
||||
if isinstance(expression, exp.UDTF):
|
||||
lateral_sources = sources
|
||||
scope_type = ScopeType.UDTF
|
||||
scopes = scope.udtf_scopes
|
||||
else:
|
||||
lateral_sources = None
|
||||
scope_type = ScopeType.DERIVED_TABLE
|
||||
scopes = scope.derived_table_scopes
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
lateral_sources=lateral_sources,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
scope_type=scope_type,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -557,36 +621,12 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
alias = derived_table.alias
|
||||
alias = expression.alias
|
||||
sources[alias] = child_scope
|
||||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
|
||||
# append the final child_scope yielded
|
||||
if is_cte:
|
||||
scope.cte_scopes.append(child_scope)
|
||||
else:
|
||||
scope.derived_table_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _add_table_sources(scope):
|
||||
sources = {}
|
||||
for table in scope.tables:
|
||||
table_name = table.name
|
||||
|
||||
if table.alias:
|
||||
source_name = table.alias
|
||||
else:
|
||||
source_name = table_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
scope.sources[source_name] = scope.sources[table_name]
|
||||
else:
|
||||
sources[source_name] = table
|
||||
scopes.append(child_scope)
|
||||
scope.table_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True):
|
|||
|
||||
if node is expression:
|
||||
continue
|
||||
elif isinstance(node, exp.CTE):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
prune = True
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
):
|
||||
prune = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue