2025-02-13 15:48:10 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
from sqlglot import exp
|
2025-02-13 15:48:10 +01:00
|
|
|
from sqlglot.errors import OptimizeError
|
2025-02-13 06:15:54 +01:00
|
|
|
from sqlglot.helper import while_changing
|
2025-02-13 21:08:10 +01:00
|
|
|
from sqlglot.optimizer.scope import find_all_in_scope
|
2025-02-13 15:57:23 +01:00
|
|
|
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
|
2025-02-13 15:48:10 +01:00
|
|
|
|
|
|
|
logger = logging.getLogger("sqlglot")
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:48:10 +01:00
|
|
|
def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
|
2025-02-13 06:15:54 +01:00
|
|
|
"""
|
2025-02-13 15:48:10 +01:00
|
|
|
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> import sqlglot
|
|
|
|
>>> expression = sqlglot.parse_one("(x AND y) OR z")
|
2025-02-13 15:48:10 +01:00
|
|
|
>>> normalize(expression, dnf=False).sql()
|
2025-02-13 06:15:54 +01:00
|
|
|
'(x OR z) AND (y OR z)'
|
|
|
|
|
|
|
|
Args:
|
2025-02-13 15:48:10 +01:00
|
|
|
expression: expression to normalize
|
|
|
|
dnf: rewrite in disjunctive normal form instead.
|
|
|
|
max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
|
2025-02-13 06:15:54 +01:00
|
|
|
Returns:
|
|
|
|
sqlglot.Expression: normalized expression
|
|
|
|
"""
|
2025-02-13 15:48:10 +01:00
|
|
|
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
|
|
|
if isinstance(node, exp.Connector):
|
|
|
|
if normalized(node, dnf=dnf):
|
|
|
|
continue
|
2025-02-13 15:57:23 +01:00
|
|
|
root = node is expression
|
|
|
|
original = node.copy()
|
2025-02-13 15:48:10 +01:00
|
|
|
|
2025-02-13 15:57:23 +01:00
|
|
|
node.transform(rewrite_between, copy=False)
|
2025-02-13 15:48:10 +01:00
|
|
|
distance = normalization_distance(node, dnf=dnf)
|
|
|
|
|
|
|
|
if distance > max_distance:
|
|
|
|
logger.info(
|
|
|
|
f"Skipping normalization because distance {distance} exceeds max {max_distance}"
|
|
|
|
)
|
|
|
|
return expression
|
|
|
|
|
|
|
|
try:
|
2025-02-13 15:52:09 +01:00
|
|
|
node = node.replace(
|
2025-02-13 21:16:09 +01:00
|
|
|
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
|
2025-02-13 15:52:09 +01:00
|
|
|
)
|
2025-02-13 15:48:10 +01:00
|
|
|
except OptimizeError as e:
|
|
|
|
logger.info(e)
|
|
|
|
node.replace(original)
|
|
|
|
if root:
|
|
|
|
return original
|
|
|
|
return expression
|
|
|
|
|
|
|
|
if root:
|
|
|
|
expression = node
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:48:10 +01:00
|
|
|
return expression
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 21:08:10 +01:00
|
|
|
def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
|
|
|
|
"""
|
|
|
|
Checks whether a given expression is in a normal form of interest.
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 21:08:10 +01:00
|
|
|
Example:
|
|
|
|
>>> from sqlglot import parse_one
|
|
|
|
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
|
|
|
|
True
|
|
|
|
>>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
|
|
|
|
True
|
|
|
|
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
|
|
|
|
False
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 21:08:10 +01:00
|
|
|
Args:
|
|
|
|
expression: The expression to check if it's normalized.
|
|
|
|
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
|
|
|
|
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
|
|
|
|
"""
|
|
|
|
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
|
|
|
return not any(
|
|
|
|
connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
|
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 21:08:10 +01:00
|
|
|
|
|
|
|
def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
|
2025-02-13 06:15:54 +01:00
|
|
|
"""
|
2025-02-13 21:08:10 +01:00
|
|
|
The difference in the number of predicates between a given expression and its normalized form.
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
This is used as an estimate of the cost of the conversion which is exponential in complexity.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> import sqlglot
|
|
|
|
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
|
|
|
|
>>> normalization_distance(expression)
|
|
|
|
4
|
|
|
|
|
|
|
|
Args:
|
2025-02-13 21:08:10 +01:00
|
|
|
expression: The expression to compute the normalization distance for.
|
|
|
|
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
|
|
|
|
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
Returns:
|
2025-02-13 21:08:10 +01:00
|
|
|
The normalization distance.
|
2025-02-13 06:15:54 +01:00
|
|
|
"""
|
2025-02-13 14:53:05 +01:00
|
|
|
return sum(_predicate_lengths(expression, dnf)) - (
|
2025-02-13 15:48:10 +01:00
|
|
|
sum(1 for _ in expression.find_all(exp.Connector)) + 1
|
2025-02-13 14:53:05 +01:00
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
|
|
|
def _predicate_lengths(expression, dnf):
|
|
|
|
"""
|
|
|
|
Returns a list of predicate lengths when expanded to normalized form.
|
|
|
|
|
|
|
|
(A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
|
|
|
|
"""
|
|
|
|
expression = expression.unnest()
|
|
|
|
|
|
|
|
if not isinstance(expression, exp.Connector):
|
2025-02-13 15:48:10 +01:00
|
|
|
return (1,)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
left, right = expression.args.values()
|
|
|
|
|
|
|
|
if isinstance(expression, exp.And if dnf else exp.Or):
|
2025-02-13 15:48:10 +01:00
|
|
|
return tuple(
|
2025-02-13 15:01:55 +01:00
|
|
|
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
|
2025-02-13 15:48:10 +01:00
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
|
|
|
|
|
|
|
|
2025-02-13 21:16:09 +01:00
|
|
|
def distributive_law(expression, dnf, max_distance):
|
2025-02-13 06:15:54 +01:00
|
|
|
"""
|
|
|
|
x OR (y AND z) -> (x OR y) AND (x OR z)
|
|
|
|
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
|
|
|
|
"""
|
2025-02-13 15:48:10 +01:00
|
|
|
if normalized(expression, dnf=dnf):
|
|
|
|
return expression
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:48:10 +01:00
|
|
|
distance = normalization_distance(expression, dnf=dnf)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:48:10 +01:00
|
|
|
if distance > max_distance:
|
|
|
|
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
|
|
|
|
|
2025-02-13 21:16:09 +01:00
|
|
|
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
|
2025-02-13 15:48:10 +01:00
|
|
|
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
if isinstance(expression, from_exp):
|
|
|
|
a, b = expression.unnest_operands()
|
|
|
|
|
|
|
|
from_func = exp.and_ if from_exp == exp.And else exp.or_
|
|
|
|
to_func = exp.and_ if to_exp == exp.And else exp.or_
|
|
|
|
|
|
|
|
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
2025-02-13 08:04:41 +01:00
|
|
|
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
2025-02-13 21:16:09 +01:00
|
|
|
return _distribute(a, b, from_func, to_func)
|
|
|
|
return _distribute(b, a, from_func, to_func)
|
2025-02-13 06:15:54 +01:00
|
|
|
if isinstance(a, to_exp):
|
2025-02-13 21:16:09 +01:00
|
|
|
return _distribute(b, a, from_func, to_func)
|
2025-02-13 06:15:54 +01:00
|
|
|
if isinstance(b, to_exp):
|
2025-02-13 21:16:09 +01:00
|
|
|
return _distribute(a, b, from_func, to_func)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
return expression
|
|
|
|
|
|
|
|
|
2025-02-13 21:16:09 +01:00
|
|
|
def _distribute(a, b, from_func, to_func):
|
2025-02-13 06:15:54 +01:00
|
|
|
if isinstance(a, exp.Connector):
|
|
|
|
exp.replace_children(
|
|
|
|
a,
|
|
|
|
lambda c: to_func(
|
2025-02-13 21:16:09 +01:00
|
|
|
uniq_sort(flatten(from_func(c, b.left))),
|
|
|
|
uniq_sort(flatten(from_func(c, b.right))),
|
2025-02-13 15:53:39 +01:00
|
|
|
copy=False,
|
2025-02-13 06:15:54 +01:00
|
|
|
),
|
|
|
|
)
|
|
|
|
else:
|
2025-02-13 15:48:10 +01:00
|
|
|
a = to_func(
|
2025-02-13 21:16:09 +01:00
|
|
|
uniq_sort(flatten(from_func(a, b.left))),
|
|
|
|
uniq_sort(flatten(from_func(a, b.right))),
|
2025-02-13 15:53:39 +01:00
|
|
|
copy=False,
|
2025-02-13 15:48:10 +01:00
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:48:10 +01:00
|
|
|
return a
|