Merging upstream version 23.13.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
63a75c51ff
commit
64041d1d66
85 changed files with 53899 additions and 50390 deletions
|
@ -1,11 +1,11 @@
|
|||
# ruff: noqa: F401
|
||||
|
||||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.optimizer import RULES as RULES, optimize as optimize
|
||||
from sqlglot.optimizer.scope import (
|
||||
Scope,
|
||||
build_scope,
|
||||
find_all_in_scope,
|
||||
find_in_scope,
|
||||
traverse_scope,
|
||||
walk_in_scope,
|
||||
Scope as Scope,
|
||||
build_scope as build_scope,
|
||||
find_all_in_scope as find_all_in_scope,
|
||||
find_in_scope as find_in_scope,
|
||||
traverse_scope as traverse_scope,
|
||||
walk_in_scope as walk_in_scope,
|
||||
)
|
||||
|
|
|
@ -89,7 +89,8 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(new_cte)
|
||||
|
||||
if new_ctes:
|
||||
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
||||
query = expression.expression if isinstance(expression, exp.DDL) else expression
|
||||
query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from sqlglot.helper import ensure_collection, find_new_name, seq_get
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
|
||||
|
||||
|
||||
class ScopeType(Enum):
|
||||
ROOT = auto()
|
||||
|
@ -495,25 +497,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
Returns:
|
||||
A list of the created scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
|
||||
# We ignore the DDL expression and build a scope for its query instead
|
||||
ddl_with = expression.args.get("with")
|
||||
expression = expression.expression
|
||||
|
||||
# If the DDL has CTEs attached, we need to add them to the query, or
|
||||
# prepend them if the query itself already has CTEs attached to it
|
||||
if ddl_with:
|
||||
ddl_with.pop()
|
||||
query_ctes = expression.ctes
|
||||
if not query_ctes:
|
||||
expression.set("with", ddl_with)
|
||||
else:
|
||||
expression.args["with"].set("recursive", ddl_with.recursive)
|
||||
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
|
||||
|
||||
if isinstance(expression, exp.Query):
|
||||
if isinstance(expression, TRAVERSABLES):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
|
@ -531,25 +516,37 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
|||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
expression = scope.expression
|
||||
|
||||
if isinstance(expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
elif isinstance(expression, exp.Union):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_union(scope)
|
||||
return
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
elif isinstance(expression, exp.Subquery):
|
||||
if scope.is_root:
|
||||
yield from _traverse_select(scope)
|
||||
else:
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
elif isinstance(expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
elif isinstance(expression, exp.UDTF):
|
||||
yield from _traverse_udtfs(scope)
|
||||
elif isinstance(expression, exp.DDL):
|
||||
if isinstance(expression.expression, exp.Query):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
|
||||
return
|
||||
elif isinstance(expression, exp.DML):
|
||||
yield from _traverse_ctes(scope)
|
||||
for query in find_all_in_scope(expression, exp.Query):
|
||||
# This check ensures we don't yield the CTE queries twice
|
||||
if not isinstance(query.parent, exp.CTE):
|
||||
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
|
||||
return
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
)
|
||||
logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
|
||||
return
|
||||
|
||||
yield scope
|
||||
|
@ -749,7 +746,7 @@ def _traverse_udtfs(scope):
|
|||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
scope_type=ScopeType.DERIVED_TABLE,
|
||||
scope_type=ScopeType.SUBQUERY,
|
||||
outer_columns=expression.alias_column_names,
|
||||
)
|
||||
):
|
||||
|
@ -757,8 +754,7 @@ def _traverse_udtfs(scope):
|
|||
top = child_scope
|
||||
sources[expression.alias] = child_scope
|
||||
|
||||
scope.derived_table_scopes.append(top)
|
||||
scope.table_scopes.append(top)
|
||||
scope.subquery_scopes.append(top)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
|
|
@ -224,6 +224,8 @@ def flatten(expression):
|
|||
def simplify_connectors(expression, root=True):
|
||||
def _simplify_connectors(expression, left, right):
|
||||
if left == right:
|
||||
if isinstance(expression, exp.Xor):
|
||||
return exp.false()
|
||||
return left
|
||||
if isinstance(expression, exp.And):
|
||||
if is_false(left) or is_false(right):
|
||||
|
@ -365,10 +367,17 @@ def uniq_sort(expression, root=True):
|
|||
C AND A AND B AND B -> A AND B AND C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {gen(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
if isinstance(expression, exp.Xor):
|
||||
result_func = exp.xor
|
||||
# Do not deduplicate XOR as A XOR A != A if A == True
|
||||
deduped = None
|
||||
arr = tuple((gen(e), e) for e in flattened)
|
||||
else:
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
deduped = {gen(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
# A AND C AND B -> A AND B AND C
|
||||
|
@ -378,7 +387,7 @@ def uniq_sort(expression, root=True):
|
|||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
if len(deduped) < len(flattened):
|
||||
if deduped and len(deduped) < len(flattened):
|
||||
expression = result_func(*deduped.values(), copy=False)
|
||||
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue