1
0
Fork 0

Merging upstream version 26.16.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-04-25 07:27:01 +02:00
parent f03ef3fd88
commit 1e2a8571aa
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
110 changed files with 62370 additions and 61414 deletions

View file

@ -5,7 +5,7 @@ import typing as t
from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.helper import find_new_name, seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope
if t.TYPE_CHECKING:
@ -217,6 +217,7 @@ def _mergeable(
and not _is_a_window_expression_in_unmergable_operation()
and not _is_recursive()
and not (inner_select.args.get("order") and outer_scope.is_union)
and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform)
)

View file

@ -5,6 +5,7 @@ from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
@ -92,7 +93,13 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
# Push the selected columns down to the next scope
for name, (node, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
select = seq_get(source.expression.selects, 0)
if scope.pivots or isinstance(select, exp.QueryTransform):
columns = {SELECT_ALL}
else:
columns = selects.get(name) or set()
referenced_columns[source].update(columns)
column_aliases = node.alias_column_names

View file

@ -770,7 +770,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
):
if selection is None:
if selection is None or isinstance(selection, exp.QueryTransform):
break
if isinstance(selection, exp.Subquery):
@ -787,7 +787,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
new_selections.append(selection)
if isinstance(scope.expression, exp.Select):
if new_selections and isinstance(scope.expression, exp.Select):
scope.expression.set("expressions", new_selections)
@ -945,7 +945,14 @@ class Resolver:
else:
columns = set_op.named_selects
else:
columns = source.expression.named_selects
select = seq_get(source.expression.selects, 0)
if isinstance(select, exp.QueryTransform):
# https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
schema = select.args.get("schema")
columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
else:
columns = source.expression.named_selects
node, _ = self.scope.selected_sources.get(name) or (None, None)
if isinstance(node, Scope):

View file

@ -54,10 +54,10 @@ def qualify_tables(
def _qualify(table: exp.Table) -> None:
if isinstance(table.this, exp.Identifier):
if not table.args.get("db"):
table.set("db", db)
if not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog)
if db and not table.args.get("db"):
table.set("db", db.copy())
if catalog and not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog.copy())
if (db or catalog) and not isinstance(expression, exp.Query):
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
@ -148,6 +148,7 @@ def qualify_tables(
if table_alias:
for p in exp.COLUMN_PARTS[1:]:
column.set(p, None)
column.set("table", table_alias)
column.set("table", table_alias.copy())
return expression

View file

@ -40,7 +40,6 @@ def simplify(
expression: exp.Expression,
constant_propagation: bool = False,
dialect: DialectType = None,
max_depth: t.Optional[int] = None,
):
"""
Rewrite sqlglot AST to simplify expressions.
@ -54,114 +53,99 @@ def simplify(
Args:
expression: expression to simplify
constant_propagation: whether the constant propagation rule should be used
max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
Returns:
sqlglot.Expression: simplified expression
"""
dialect = Dialect.get_or_raise(dialect)
def _simplify(expression, root=True):
if (
max_depth
and isinstance(expression, exp.Connector)
and not isinstance(expression.parent, exp.Connector)
):
depth = connector_depth(expression)
if depth > max_depth:
logger.info(
f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
)
return expression
def _simplify(expression):
pre_transformation_stack = [expression]
post_transformation_stack = []
if expression.meta.get(FINAL):
return expression
while pre_transformation_stack:
node = pre_transformation_stack.pop()
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
group = expression.args.get("group")
if node.meta.get(FINAL):
continue
if group and hasattr(expression, "selects"):
groups = set(group.expressions)
group.meta[FINAL] = True
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
group = node.args.get("group")
for e in expression.selects:
for node in e.walk():
if node in groups:
e.meta[FINAL] = True
break
if group and hasattr(node, "selects"):
groups = set(group.expressions)
group.meta[FINAL] = True
having = expression.args.get("having")
if having:
for node in having.walk():
if node in groups:
having.meta[FINAL] = True
break
for s in node.selects:
for n in s.walk():
if n in groups:
s.meta[FINAL] = True
break
# Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
having = node.args.get("having")
if having:
for n in having.walk():
if n in groups:
having.meta[FINAL] = True
break
if constant_propagation:
node = propagate_constants(node, root)
parent = node.parent
root = node is expression
exp.replace_children(node, lambda e: _simplify(e, False))
new_node = rewrite_between(node)
new_node = uniq_sort(new_node, root)
new_node = absorb_and_eliminate(new_node, root)
new_node = simplify_concat(new_node)
new_node = simplify_conditionals(new_node)
# Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
node = remove_complements(node, root)
node = simplify_coalesce(node, dialect)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
node = simplify_datetrunc(node, dialect)
node = sort_comparison(node)
node = simplify_startswith(node)
if constant_propagation:
new_node = propagate_constants(new_node, root)
if root:
expression.replace(node)
return node
if new_node is not node:
node.replace(new_node)
pre_transformation_stack.extend(
n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
)
post_transformation_stack.append((new_node, parent))
while post_transformation_stack:
node, parent = post_transformation_stack.pop()
root = node is expression
# Resets parent, arg_key, index pointers this is needed because some of the
# previous transformations mutate the AST, leading to an inconsistent state
for k, v in tuple(node.args.items()):
node.set(k, v)
# Post-order transformations
new_node = simplify_not(node)
new_node = flatten(new_node)
new_node = simplify_connectors(new_node, root)
new_node = remove_complements(new_node, root)
new_node = simplify_coalesce(new_node, dialect)
new_node.parent = parent
new_node = simplify_literals(new_node, root)
new_node = simplify_equality(new_node)
new_node = simplify_parens(new_node)
new_node = simplify_datetrunc(new_node, dialect)
new_node = sort_comparison(new_node)
new_node = simplify_startswith(new_node)
if new_node is not node:
node.replace(new_node)
return new_node
expression = while_changing(expression, _simplify)
remove_where_true(expression)
return expression
def connector_depth(expression: exp.Expression) -> int:
"""
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
"""
stack = deque([(expression, 0)])
max_depth = 0
while stack:
expression, depth = stack.pop()
if not isinstance(expression, exp.Connector):
continue
depth += 1
max_depth = max(depth, max_depth)
stack.append((expression.left, depth))
stack.append((expression.right, depth))
return max_depth
def catch(*exceptions):
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""