1
0
Fork 0

Merging upstream version 23.13.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:33:25 +01:00
parent 63a75c51ff
commit 64041d1d66
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
85 changed files with 53899 additions and 50390 deletions

View file

@ -1,11 +1,11 @@
# ruff: noqa: F401
from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.optimizer import RULES as RULES, optimize as optimize
from sqlglot.optimizer.scope import (
Scope,
build_scope,
find_all_in_scope,
find_in_scope,
traverse_scope,
walk_in_scope,
Scope as Scope,
build_scope as build_scope,
find_all_in_scope as find_all_in_scope,
find_in_scope as find_in_scope,
traverse_scope as traverse_scope,
walk_in_scope as walk_in_scope,
)

View file

@ -89,7 +89,8 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte)
if new_ctes:
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
query = expression.expression if isinstance(expression, exp.DDL) else expression
query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
return expression

View file

@ -12,6 +12,8 @@ from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
class ScopeType(Enum):
ROOT = auto()
@ -495,25 +497,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
A list of the created scope instances
"""
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
# We ignore the DDL expression and build a scope for its query instead
ddl_with = expression.args.get("with")
expression = expression.expression
# If the DDL has CTEs attached, we need to add them to the query, or
# prepend them if the query itself already has CTEs attached to it
if ddl_with:
ddl_with.pop()
query_ctes = expression.ctes
if not query_ctes:
expression.set("with", ddl_with)
else:
expression.args["with"].set("recursive", ddl_with.recursive)
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
if isinstance(expression, exp.Query):
if isinstance(expression, TRAVERSABLES):
return list(_traverse_scope(Scope(expression)))
return []
@ -531,25 +516,37 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
expression = scope.expression
if isinstance(expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
elif isinstance(expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
elif isinstance(scope.expression, exp.Subquery):
elif isinstance(expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
else:
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
elif isinstance(expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
elif isinstance(expression, exp.UDTF):
yield from _traverse_udtfs(scope)
elif isinstance(expression, exp.DDL):
if isinstance(expression.expression, exp.Query):
yield from _traverse_ctes(scope)
yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
return
elif isinstance(expression, exp.DML):
yield from _traverse_ctes(scope)
for query in find_all_in_scope(expression, exp.Query):
# This check ensures we don't yield the CTE queries twice
if not isinstance(query.parent, exp.CTE):
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
return
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
)
logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
return
yield scope
@ -749,7 +746,7 @@ def _traverse_udtfs(scope):
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
scope_type=ScopeType.SUBQUERY,
outer_columns=expression.alias_column_names,
)
):
@ -757,8 +754,7 @@ def _traverse_udtfs(scope):
top = child_scope
sources[expression.alias] = child_scope
scope.derived_table_scopes.append(top)
scope.table_scopes.append(top)
scope.subquery_scopes.append(top)
scope.sources.update(sources)

View file

@ -224,6 +224,8 @@ def flatten(expression):
def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
if left == right:
if isinstance(expression, exp.Xor):
return exp.false()
return left
if isinstance(expression, exp.And):
if is_false(left) or is_false(right):
@ -365,10 +367,17 @@ def uniq_sort(expression, root=True):
C AND A AND B AND B -> A AND B AND C
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
if isinstance(expression, exp.Xor):
result_func = exp.xor
# Do not deduplicate XOR as A XOR A != A if A == True
deduped = None
arr = tuple((gen(e), e) for e in flattened)
else:
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
@ -378,7 +387,7 @@ def uniq_sort(expression, root=True):
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
if deduped and len(deduped) < len(flattened):
expression = result_func(*deduped.values(), copy=False)
return expression