Merging upstream version 16.7.7.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
4d512c00f3
commit
70bf18533e
59 changed files with 16125 additions and 15681 deletions
|
@ -47,6 +47,17 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
|
|||
}
|
||||
|
||||
|
||||
# Projections in the outer query that are instances of these types can be replaced
|
||||
# without getting wrapped in parentheses, because the precedence won't be altered.
|
||||
SAFE_TO_REPLACE_UNWRAPPED = (
|
||||
exp.Column,
|
||||
exp.EQ,
|
||||
exp.Func,
|
||||
exp.NEQ,
|
||||
exp.Paren,
|
||||
)
|
||||
|
||||
|
||||
def merge_ctes(expression, leave_tables_isolated=False):
|
||||
scopes = traverse_scope(expression)
|
||||
|
||||
|
@ -293,8 +304,17 @@ def _merge_expressions(outer_scope, inner_scope, alias):
|
|||
if not projection_name:
|
||||
continue
|
||||
columns_to_replace = outer_columns.get(projection_name, [])
|
||||
|
||||
expression = expression.unalias()
|
||||
must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
|
||||
|
||||
for column in columns_to_replace:
|
||||
column.replace(expression.unalias().copy())
|
||||
# Ensures we don't alter the intended operator precedence if there's additional
|
||||
# context surrounding the outer expression (i.e. it's not a simple projection).
|
||||
if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
|
||||
expression = exp.paren(expression, copy=False)
|
||||
|
||||
column.replace(expression.copy())
|
||||
|
||||
|
||||
def _merge_where(outer_scope, inner_scope, from_or_join):
|
||||
|
|
|
@ -170,9 +170,11 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if not isinstance(expression, exp.Select):
|
||||
return
|
||||
|
||||
alias_to_expression: t.Dict[str, exp.Expression] = {}
|
||||
alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
|
||||
|
||||
def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
|
||||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
|
||||
) -> None:
|
||||
if not node:
|
||||
return
|
||||
|
||||
|
@ -180,7 +182,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr = alias_to_expression.get(column.name)
|
||||
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
|
||||
double_agg = (
|
||||
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
|
||||
if alias_expr
|
||||
|
@ -190,16 +192,20 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if table and (not alias_expr or double_agg):
|
||||
column.set("table", table)
|
||||
elif not column.table and alias_expr and not double_agg:
|
||||
column.replace(alias_expr.copy())
|
||||
if isinstance(alias_expr, exp.Literal):
|
||||
if literal_index:
|
||||
column.replace(exp.Literal.number(i))
|
||||
else:
|
||||
column.replace(alias_expr.copy())
|
||||
|
||||
for projection in scope.selects:
|
||||
for i, projection in enumerate(scope.selects):
|
||||
replace_columns(projection)
|
||||
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
alias_to_expression[projection.alias] = (projection.this, i + 1)
|
||||
|
||||
replace_columns(expression.args.get("where"))
|
||||
replace_columns(expression.args.get("group"))
|
||||
replace_columns(expression.args.get("group"), literal_index=True)
|
||||
replace_columns(expression.args.get("having"), resolve_table=True)
|
||||
replace_columns(expression.args.get("qualify"), resolve_table=True)
|
||||
scope.clear_cache()
|
||||
|
@ -255,27 +261,39 @@ def _expand_order_by(scope: Scope, resolver: Resolver):
|
|||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
|
||||
|
||||
for ordered in ordereds:
|
||||
ordered.set("this", selects.get(ordered.this, ordered.this))
|
||||
ordered = ordered.this
|
||||
|
||||
ordered.replace(
|
||||
exp.to_identifier(_select_by_pos(scope, ordered).alias)
|
||||
if ordered.is_int
|
||||
else selects.get(ordered, ordered)
|
||||
)
|
||||
|
||||
|
||||
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
|
||||
new_nodes = []
|
||||
for node in expressions:
|
||||
if node.is_int:
|
||||
try:
|
||||
select = scope.selects[int(node.name) - 1]
|
||||
except IndexError:
|
||||
raise OptimizeError(f"Unknown output column: {node.name}")
|
||||
if isinstance(select, exp.Alias):
|
||||
select = select.this
|
||||
new_nodes.append(select.copy())
|
||||
scope.clear_cache()
|
||||
select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
|
||||
|
||||
if isinstance(select, exp.Literal):
|
||||
new_nodes.append(node)
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
scope.clear_cache()
|
||||
else:
|
||||
new_nodes.append(node)
|
||||
|
||||
return new_nodes
|
||||
|
||||
|
||||
def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
|
||||
try:
|
||||
return scope.selects[int(node.this) - 1].assert_is(exp.Alias)
|
||||
except IndexError:
|
||||
raise OptimizeError(f"Unknown output column: {node.name}")
|
||||
|
||||
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
||||
"""Disambiguate columns, ensuring each column specifies a source"""
|
||||
for column in scope.columns:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import itertools
|
||||
import logging
|
||||
import typing as t
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
@ -7,6 +8,8 @@ from sqlglot import exp
|
|||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import find_new_name
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
class ScopeType(Enum):
|
||||
ROOT = auto()
|
||||
|
@ -85,6 +88,7 @@ class Scope:
|
|||
self._external_columns = None
|
||||
self._join_hints = None
|
||||
self._pivots = None
|
||||
self._references = None
|
||||
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
|
@ -264,14 +268,19 @@ class Scope:
|
|||
self._columns = []
|
||||
for column in columns + external_columns:
|
||||
ancestor = column.find_ancestor(
|
||||
exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
|
||||
exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
|
||||
)
|
||||
if (
|
||||
not ancestor
|
||||
or column.table
|
||||
or isinstance(ancestor, exp.Select)
|
||||
or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
|
||||
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
|
||||
or (
|
||||
isinstance(ancestor, exp.Order)
|
||||
and (
|
||||
isinstance(ancestor.parent, exp.Window)
|
||||
or column.name not in named_selects
|
||||
)
|
||||
)
|
||||
):
|
||||
self._columns.append(column)
|
||||
|
||||
|
@ -289,15 +298,9 @@ class Scope:
|
|||
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
|
||||
"""
|
||||
if self._selected_sources is None:
|
||||
referenced_names = []
|
||||
|
||||
for table in self.tables:
|
||||
referenced_names.append((table.alias_or_name, table))
|
||||
for expression in itertools.chain(self.derived_tables, self.udtfs):
|
||||
referenced_names.append((expression.alias, expression.unnest()))
|
||||
result = {}
|
||||
|
||||
for name, node in referenced_names:
|
||||
for name, node in self.references:
|
||||
if name in result:
|
||||
raise OptimizeError(f"Alias already used: {name}")
|
||||
if name in self.sources:
|
||||
|
@ -306,6 +309,23 @@ class Scope:
|
|||
self._selected_sources = result
|
||||
return self._selected_sources
|
||||
|
||||
@property
|
||||
def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
|
||||
if self._references is None:
|
||||
self._references = []
|
||||
|
||||
for table in self.tables:
|
||||
self._references.append((table.alias_or_name, table))
|
||||
for expression in itertools.chain(self.derived_tables, self.udtfs):
|
||||
self._references.append(
|
||||
(
|
||||
expression.alias,
|
||||
expression if expression.args.get("pivots") else expression.unnest(),
|
||||
)
|
||||
)
|
||||
|
||||
return self._references
|
||||
|
||||
@property
|
||||
def cte_sources(self):
|
||||
"""
|
||||
|
@ -378,9 +398,7 @@ class Scope:
|
|||
def pivots(self):
|
||||
if not self._pivots:
|
||||
self._pivots = [
|
||||
pivot
|
||||
for node in self.tables + self.derived_tables
|
||||
for pivot in node.args.get("pivots") or []
|
||||
pivot for _, node in self.references for pivot in node.args.get("pivots") or []
|
||||
]
|
||||
|
||||
return self._pivots
|
||||
|
@ -536,7 +554,11 @@ def _traverse_scope(scope):
|
|||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
else:
|
||||
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
)
|
||||
return
|
||||
|
||||
yield scope
|
||||
|
||||
|
||||
|
@ -576,6 +598,8 @@ def _traverse_ctes(scope):
|
|||
if isinstance(union, exp.Union):
|
||||
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
child_scope = None
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
cte.this,
|
||||
|
@ -593,7 +617,8 @@ def _traverse_ctes(scope):
|
|||
child_scope.add_source(alias, recursive_scope)
|
||||
|
||||
# append the final child_scope yielded
|
||||
scope.cte_scopes.append(child_scope)
|
||||
if child_scope:
|
||||
scope.cte_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
@ -634,6 +659,9 @@ def _traverse_tables(scope):
|
|||
sources[source_name] = expression
|
||||
continue
|
||||
|
||||
if not isinstance(expression, exp.DerivedTable):
|
||||
continue
|
||||
|
||||
if isinstance(expression, exp.UDTF):
|
||||
lateral_sources = sources
|
||||
scope_type = ScopeType.UDTF
|
||||
|
|
|
@ -400,6 +400,7 @@ def simplify_parens(expression):
|
|||
or not isinstance(this, exp.Binary)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue