Merging upstream version 26.16.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
f03ef3fd88
commit
1e2a8571aa
110 changed files with 62370 additions and 61414 deletions
|
@ -5,7 +5,7 @@ import typing as t
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.helper import find_new_name, seq_get
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -217,6 +217,7 @@ def _mergeable(
|
|||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
and not _is_recursive()
|
||||
and not (inner_select.args.get("order") and outer_scope.is_union)
|
||||
and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from sqlglot.optimizer.qualify_columns import Resolver
|
|||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
@ -92,7 +93,13 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
# Push the selected columns down to the next scope
|
||||
for name, (node, source) in scope.selected_sources.items():
|
||||
if isinstance(source, Scope):
|
||||
columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
|
||||
select = seq_get(source.expression.selects, 0)
|
||||
|
||||
if scope.pivots or isinstance(select, exp.QueryTransform):
|
||||
columns = {SELECT_ALL}
|
||||
else:
|
||||
columns = selects.get(name) or set()
|
||||
|
||||
referenced_columns[source].update(columns)
|
||||
|
||||
column_aliases = node.alias_column_names
|
||||
|
|
|
@ -770,7 +770,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
|
||||
):
|
||||
if selection is None:
|
||||
if selection is None or isinstance(selection, exp.QueryTransform):
|
||||
break
|
||||
|
||||
if isinstance(selection, exp.Subquery):
|
||||
|
@ -787,7 +787,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
|
||||
new_selections.append(selection)
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if new_selections and isinstance(scope.expression, exp.Select):
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
|
@ -945,7 +945,14 @@ class Resolver:
|
|||
else:
|
||||
columns = set_op.named_selects
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
select = seq_get(source.expression.selects, 0)
|
||||
|
||||
if isinstance(select, exp.QueryTransform):
|
||||
# https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
|
||||
schema = select.args.get("schema")
|
||||
columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
||||
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
||||
if isinstance(node, Scope):
|
||||
|
|
|
@ -54,10 +54,10 @@ def qualify_tables(
|
|||
|
||||
def _qualify(table: exp.Table) -> None:
|
||||
if isinstance(table.this, exp.Identifier):
|
||||
if not table.args.get("db"):
|
||||
table.set("db", db)
|
||||
if not table.args.get("catalog") and table.args.get("db"):
|
||||
table.set("catalog", catalog)
|
||||
if db and not table.args.get("db"):
|
||||
table.set("db", db.copy())
|
||||
if catalog and not table.args.get("catalog") and table.args.get("db"):
|
||||
table.set("catalog", catalog.copy())
|
||||
|
||||
if (db or catalog) and not isinstance(expression, exp.Query):
|
||||
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
|
||||
|
@ -148,6 +148,7 @@ def qualify_tables(
|
|||
if table_alias:
|
||||
for p in exp.COLUMN_PARTS[1:]:
|
||||
column.set(p, None)
|
||||
column.set("table", table_alias)
|
||||
|
||||
column.set("table", table_alias.copy())
|
||||
|
||||
return expression
|
||||
|
|
|
@ -40,7 +40,6 @@ def simplify(
|
|||
expression: exp.Expression,
|
||||
constant_propagation: bool = False,
|
||||
dialect: DialectType = None,
|
||||
max_depth: t.Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
@ -54,114 +53,99 @@ def simplify(
|
|||
Args:
|
||||
expression: expression to simplify
|
||||
constant_propagation: whether the constant propagation rule should be used
|
||||
max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
|
||||
Returns:
|
||||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
if (
|
||||
max_depth
|
||||
and isinstance(expression, exp.Connector)
|
||||
and not isinstance(expression.parent, exp.Connector)
|
||||
):
|
||||
depth = connector_depth(expression)
|
||||
if depth > max_depth:
|
||||
logger.info(
|
||||
f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
|
||||
)
|
||||
return expression
|
||||
def _simplify(expression):
|
||||
pre_transformation_stack = [expression]
|
||||
post_transformation_stack = []
|
||||
|
||||
if expression.meta.get(FINAL):
|
||||
return expression
|
||||
while pre_transformation_stack:
|
||||
node = pre_transformation_stack.pop()
|
||||
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
group = expression.args.get("group")
|
||||
if node.meta.get(FINAL):
|
||||
continue
|
||||
|
||||
if group and hasattr(expression, "selects"):
|
||||
groups = set(group.expressions)
|
||||
group.meta[FINAL] = True
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
group = node.args.get("group")
|
||||
|
||||
for e in expression.selects:
|
||||
for node in e.walk():
|
||||
if node in groups:
|
||||
e.meta[FINAL] = True
|
||||
break
|
||||
if group and hasattr(node, "selects"):
|
||||
groups = set(group.expressions)
|
||||
group.meta[FINAL] = True
|
||||
|
||||
having = expression.args.get("having")
|
||||
if having:
|
||||
for node in having.walk():
|
||||
if node in groups:
|
||||
having.meta[FINAL] = True
|
||||
break
|
||||
for s in node.selects:
|
||||
for n in s.walk():
|
||||
if n in groups:
|
||||
s.meta[FINAL] = True
|
||||
break
|
||||
|
||||
# Pre-order transformations
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
node = simplify_concat(node)
|
||||
node = simplify_conditionals(node)
|
||||
having = node.args.get("having")
|
||||
if having:
|
||||
for n in having.walk():
|
||||
if n in groups:
|
||||
having.meta[FINAL] = True
|
||||
break
|
||||
|
||||
if constant_propagation:
|
||||
node = propagate_constants(node, root)
|
||||
parent = node.parent
|
||||
root = node is expression
|
||||
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
new_node = rewrite_between(node)
|
||||
new_node = uniq_sort(new_node, root)
|
||||
new_node = absorb_and_eliminate(new_node, root)
|
||||
new_node = simplify_concat(new_node)
|
||||
new_node = simplify_conditionals(new_node)
|
||||
|
||||
# Post-order transformations
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node, root)
|
||||
node = remove_complements(node, root)
|
||||
node = simplify_coalesce(node, dialect)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node, root)
|
||||
node = simplify_equality(node)
|
||||
node = simplify_parens(node)
|
||||
node = simplify_datetrunc(node, dialect)
|
||||
node = sort_comparison(node)
|
||||
node = simplify_startswith(node)
|
||||
if constant_propagation:
|
||||
new_node = propagate_constants(new_node, root)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
return node
|
||||
if new_node is not node:
|
||||
node.replace(new_node)
|
||||
|
||||
pre_transformation_stack.extend(
|
||||
n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
|
||||
)
|
||||
post_transformation_stack.append((new_node, parent))
|
||||
|
||||
while post_transformation_stack:
|
||||
node, parent = post_transformation_stack.pop()
|
||||
root = node is expression
|
||||
|
||||
# Resets parent, arg_key, index pointers– this is needed because some of the
|
||||
# previous transformations mutate the AST, leading to an inconsistent state
|
||||
for k, v in tuple(node.args.items()):
|
||||
node.set(k, v)
|
||||
|
||||
# Post-order transformations
|
||||
new_node = simplify_not(node)
|
||||
new_node = flatten(new_node)
|
||||
new_node = simplify_connectors(new_node, root)
|
||||
new_node = remove_complements(new_node, root)
|
||||
new_node = simplify_coalesce(new_node, dialect)
|
||||
|
||||
new_node.parent = parent
|
||||
|
||||
new_node = simplify_literals(new_node, root)
|
||||
new_node = simplify_equality(new_node)
|
||||
new_node = simplify_parens(new_node)
|
||||
new_node = simplify_datetrunc(new_node, dialect)
|
||||
new_node = sort_comparison(new_node)
|
||||
new_node = simplify_startswith(new_node)
|
||||
|
||||
if new_node is not node:
|
||||
node.replace(new_node)
|
||||
|
||||
return new_node
|
||||
|
||||
expression = while_changing(expression, _simplify)
|
||||
remove_where_true(expression)
|
||||
return expression
|
||||
|
||||
|
||||
def connector_depth(expression: exp.Expression) -> int:
|
||||
"""
|
||||
Determine the maximum depth of a tree of Connectors.
|
||||
|
||||
For example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> connector_depth(parse_one("a AND b AND c AND d"))
|
||||
3
|
||||
"""
|
||||
stack = deque([(expression, 0)])
|
||||
max_depth = 0
|
||||
|
||||
while stack:
|
||||
expression, depth = stack.pop()
|
||||
|
||||
if not isinstance(expression, exp.Connector):
|
||||
continue
|
||||
|
||||
depth += 1
|
||||
max_depth = max(depth, max_depth)
|
||||
|
||||
stack.append((expression.left, depth))
|
||||
stack.append((expression.right, depth))
|
||||
|
||||
return max_depth
|
||||
|
||||
|
||||
def catch(*exceptions):
|
||||
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue