1
0
Fork 0

Merging upstream version 6.2.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:43:32 +01:00
parent 87ba722f7f
commit a62bbc24c3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 361 additions and 98 deletions

View file

@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Returns:
sqlglot.Expression: optimized expression
"""
merge_ctes(expression, leave_tables_isolated)
merge_derived_tables(expression, leave_tables_isolated)
expression = merge_ctes(expression, leave_tables_isolated)
expression = merge_derived_tables(expression, leave_tables_isolated)
return expression
@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
alias = node_to_replace.alias
else:
alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
continue
columns_to_replace = outer_columns.get(projection_name, [])
for column in columns_to_replace:
column.replace(expression.unalias())
column.replace(expression.unalias().copy())
def _merge_where(outer_scope, inner_scope, from_or_join):

View file

@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
def qualify_columns(expression, schema):
"""
@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, SKIP_QUALIFY):
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table, SKIP_QUALIFY):
if isinstance(derived_table, exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_unnest:
if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
@ -296,7 +294,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")

View file

@ -1,5 +1,4 @@
import itertools
from copy import copy
from enum import Enum, auto
from sqlglot import exp
@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UNNEST = auto()
UDTF = auto()
class Scope:
@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
def branch(self, expression, scope_type, add_sources=None, **kwargs):
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
sources = copy(self.sources)
if add_sources:
sources.update(add_sources)
return Scope(
expression=expression.unnest(),
sources=sources,
sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
prune = False
for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
elif isinstance(node, (exp.Unnest, exp.Lateral)):
elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
prune = True
self._collected = True
@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
def walk(self, bfs=True):
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
def find_all(self, *expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
def replace(self, old, new):
"""
Replace `old` with `new`.
@ -246,6 +270,16 @@ class Scope:
self._selected_sources = result
return self._selected_sources
@property
def cte_sources(self):
"""
Sources that are CTEs.
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
@property
def selects(self):
"""
@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
def is_udtf(self):
"""Determine if this scope is a UDTF (User Defined Table Function)"""
return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
Args:
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
prune = False
yield node, parent, key
if node is expression:
continue
elif isinstance(node, exp.CTE):
prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
prune = True
elif isinstance(node, exp.Subqueryable):
prune = True