sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.generator import cached_generator 8from sqlglot.helper import while_changing 9from sqlglot.optimizer.scope import find_all_in_scope 10from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort 11 12logger = logging.getLogger("sqlglot") 13 14 15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 16 """ 17 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("(x AND y) OR z") 22 >>> normalize(expression, dnf=False).sql() 23 '(x OR z) AND (y OR z)' 24 25 Args: 26 expression: expression to normalize 27 dnf: rewrite in disjunctive normal form instead. 28 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 29 Returns: 30 sqlglot.Expression: normalized expression 31 """ 32 generate = cached_generator() 33 34 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 35 if isinstance(node, exp.Connector): 36 if normalized(node, dnf=dnf): 37 continue 38 root = node is expression 39 original = node.copy() 40 41 node.transform(rewrite_between, copy=False) 42 distance = normalization_distance(node, dnf=dnf) 43 44 if distance > max_distance: 45 logger.info( 46 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 47 ) 48 return expression 49 50 try: 51 node = node.replace( 52 while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) 53 ) 54 except OptimizeError as e: 55 logger.info(e) 56 node.replace(original) 57 if root: 58 return original 59 return expression 60 61 if root: 62 expression = node 63 64 return expression 65 66 67def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 68 """ 69 Checks whether a given expression is in a normal form of interest. 70 71 Example: 72 >>> from sqlglot import parse_one 73 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 74 True 75 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 76 True 77 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 78 False 79 80 Args: 81 expression: The expression to check if it's normalized. 82 dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). 83 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 84 """ 85 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 86 return not any( 87 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 88 ) 89 90 91def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: 92 """ 93 The difference in the number of predicates between a given expression and its normalized form. 94 95 This is used as an estimate of the cost of the conversion which is exponential in complexity. 96 97 Example: 98 >>> import sqlglot 99 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 100 >>> normalization_distance(expression) 101 4 102 103 Args: 104 expression: The expression to compute the normalization distance for. 105 dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). 106 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 107 108 Returns: 109 The normalization distance. 110 """ 111 return sum(_predicate_lengths(expression, dnf)) - ( 112 sum(1 for _ in expression.find_all(exp.Connector)) + 1 113 ) 114 115 116def _predicate_lengths(expression, dnf): 117 """ 118 Returns a list of predicate lengths when expanded to normalized form. 119 120 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 121 """ 122 expression = expression.unnest() 123 124 if not isinstance(expression, exp.Connector): 125 return (1,) 126 127 left, right = expression.args.values() 128 129 if isinstance(expression, exp.And if dnf else exp.Or): 130 return tuple( 131 a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) 132 ) 133 return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) 134 135 136def distributive_law(expression, dnf, max_distance, generate): 137 """ 138 x OR (y AND z) -> (x OR y) AND (x OR z) 139 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 140 """ 141 if normalized(expression, dnf=dnf): 142 return expression 143 144 distance = normalization_distance(expression, dnf=dnf) 145 146 if distance > max_distance: 147 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 148 149 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) 150 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 151 152 if isinstance(expression, from_exp): 153 a, b = expression.unnest_operands() 154 155 from_func = exp.and_ if from_exp == exp.And else exp.or_ 156 to_func = exp.and_ if to_exp == exp.And else exp.or_ 157 158 if isinstance(a, to_exp) and isinstance(b, to_exp): 159 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 160 return _distribute(a, b, from_func, to_func, generate) 161 return _distribute(b, a, from_func, to_func, generate) 162 if isinstance(a, to_exp): 163 return _distribute(b, a, from_func, to_func, generate) 164 if isinstance(b, to_exp): 165 return _distribute(a, b, from_func, to_func, generate) 166 167 return expression 168 169 170def _distribute(a, b, from_func, to_func, generate): 171 if isinstance(a, exp.Connector): 172 exp.replace_children( 173 a, 174 lambda c: to_func( 175 uniq_sort(flatten(from_func(c, b.left)), generate), 176 uniq_sort(flatten(from_func(c, b.right)), generate), 177 copy=False, 178 ), 179 ) 180 else: 181 a = to_func( 182 uniq_sort(flatten(from_func(a, b.left)), generate), 183 uniq_sort(flatten(from_func(a, b.right)), generate), 184 copy=False, 185 ) 186 187 return a
logger =
<Logger sqlglot (WARNING)>
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
16def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 17 """ 18 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("(x AND y) OR z") 23 >>> normalize(expression, dnf=False).sql() 24 '(x OR z) AND (y OR z)' 25 26 Args: 27 expression: expression to normalize 28 dnf: rewrite in disjunctive normal form instead. 29 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 30 Returns: 31 sqlglot.Expression: normalized expression 32 """ 33 generate = cached_generator() 34 35 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 36 if isinstance(node, exp.Connector): 37 if normalized(node, dnf=dnf): 38 continue 39 root = node is expression 40 original = node.copy() 41 42 node.transform(rewrite_between, copy=False) 43 distance = normalization_distance(node, dnf=dnf) 44 45 if distance > max_distance: 46 logger.info( 47 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 48 ) 49 return expression 50 51 try: 52 node = node.replace( 53 while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) 54 ) 55 except OptimizeError as e: 56 logger.info(e) 57 node.replace(original) 58 if root: 59 return original 60 return expression 61 62 if root: 63 expression = node 64 65 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
68def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 69 """ 70 Checks whether a given expression is in a normal form of interest. 71 72 Example: 73 >>> from sqlglot import parse_one 74 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 75 True 76 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 77 True 78 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 79 False 80 81 Args: 82 expression: The expression to check if it's normalized. 83 dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). 84 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 85 """ 86 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 87 return not any( 88 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 89 )
Checks whether a given expression is in a normal form of interest.
Example:
>>> from sqlglot import parse_one >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) True >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default True >>> normalized(parse_one("a AND (b OR c)"), dnf=True) False
Arguments:
- expression: The expression to check if it's normalized.
- dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
92def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: 93 """ 94 The difference in the number of predicates between a given expression and its normalized form. 95 96 This is used as an estimate of the cost of the conversion which is exponential in complexity. 97 98 Example: 99 >>> import sqlglot 100 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 101 >>> normalization_distance(expression) 102 4 103 104 Args: 105 expression: The expression to compute the normalization distance for. 106 dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). 107 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 108 109 Returns: 110 The normalization distance. 111 """ 112 return sum(_predicate_lengths(expression, dnf)) - ( 113 sum(1 for _ in expression.find_all(exp.Connector)) + 1 114 )
The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression: The expression to compute the normalization distance for.
- dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
Returns:
The normalization distance.
def
distributive_law(expression, dnf, max_distance, generate):
137def distributive_law(expression, dnf, max_distance, generate): 138 """ 139 x OR (y AND z) -> (x OR y) AND (x OR z) 140 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 141 """ 142 if normalized(expression, dnf=dnf): 143 return expression 144 145 distance = normalization_distance(expression, dnf=dnf) 146 147 if distance > max_distance: 148 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 149 150 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) 151 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 152 153 if isinstance(expression, from_exp): 154 a, b = expression.unnest_operands() 155 156 from_func = exp.and_ if from_exp == exp.And else exp.or_ 157 to_func = exp.and_ if to_exp == exp.And else exp.or_ 158 159 if isinstance(a, to_exp) and isinstance(b, to_exp): 160 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 161 return _distribute(a, b, from_func, to_func, generate) 162 return _distribute(b, a, from_func, to_func, generate) 163 if isinstance(a, to_exp): 164 return _distribute(b, a, from_func, to_func, generate) 165 if isinstance(b, to_exp): 166 return _distribute(a, b, from_func, to_func, generate) 167 168 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)