Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.generator import cached_generator
  8from sqlglot.helper import while_changing
  9from sqlglot.optimizer.scope import find_all_in_scope
 10from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
 11
 12logger = logging.getLogger("sqlglot")
 13
 14
 15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
 16    """
 17    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 22        >>> normalize(expression, dnf=False).sql()
 23        '(x OR z) AND (y OR z)'
 24
 25    Args:
 26        expression: expression to normalize
 27        dnf: rewrite in disjunctive normal form instead.
 28        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 29    Returns:
 30        sqlglot.Expression: normalized expression
 31    """
 32    generate = cached_generator()
 33
 34    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
 35        if isinstance(node, exp.Connector):
 36            if normalized(node, dnf=dnf):
 37                continue
 38            root = node is expression
 39            original = node.copy()
 40
 41            node.transform(rewrite_between, copy=False)
 42            distance = normalization_distance(node, dnf=dnf)
 43
 44            if distance > max_distance:
 45                logger.info(
 46                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 47                )
 48                return expression
 49
 50            try:
 51                node = node.replace(
 52                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
 53                )
 54            except OptimizeError as e:
 55                logger.info(e)
 56                node.replace(original)
 57                if root:
 58                    return original
 59                return expression
 60
 61            if root:
 62                expression = node
 63
 64    return expression
 65
 66
 67def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
 68    """
 69    Checks whether a given expression is in a normal form of interest.
 70
 71    Example:
 72        >>> from sqlglot import parse_one
 73        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
 74        True
 75        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
 76        True
 77        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
 78        False
 79
 80    Args:
 81        expression: The expression to check if it's normalized.
 82        dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
 83            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
 84    """
 85    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 86    return not any(
 87        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
 88    )
 89
 90
 91def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
 92    """
 93    The difference in the number of predicates between a given expression and its normalized form.
 94
 95    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 96
 97    Example:
 98        >>> import sqlglot
 99        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
100        >>> normalization_distance(expression)
101        4
102
103    Args:
104        expression: The expression to compute the normalization distance for.
105        dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
106            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
107
108    Returns:
109        The normalization distance.
110    """
111    return sum(_predicate_lengths(expression, dnf)) - (
112        sum(1 for _ in expression.find_all(exp.Connector)) + 1
113    )
114
115
116def _predicate_lengths(expression, dnf):
117    """
118    Returns a list of predicate lengths when expanded to normalized form.
119
120    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
121    """
122    expression = expression.unnest()
123
124    if not isinstance(expression, exp.Connector):
125        return (1,)
126
127    left, right = expression.args.values()
128
129    if isinstance(expression, exp.And if dnf else exp.Or):
130        return tuple(
131            a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
132        )
133    return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
134
135
136def distributive_law(expression, dnf, max_distance, generate):
137    """
138    x OR (y AND z) -> (x OR y) AND (x OR z)
139    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
140    """
141    if normalized(expression, dnf=dnf):
142        return expression
143
144    distance = normalization_distance(expression, dnf=dnf)
145
146    if distance > max_distance:
147        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
148
149    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
150    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
151
152    if isinstance(expression, from_exp):
153        a, b = expression.unnest_operands()
154
155        from_func = exp.and_ if from_exp == exp.And else exp.or_
156        to_func = exp.and_ if to_exp == exp.And else exp.or_
157
158        if isinstance(a, to_exp) and isinstance(b, to_exp):
159            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
160                return _distribute(a, b, from_func, to_func, generate)
161            return _distribute(b, a, from_func, to_func, generate)
162        if isinstance(a, to_exp):
163            return _distribute(b, a, from_func, to_func, generate)
164        if isinstance(b, to_exp):
165            return _distribute(a, b, from_func, to_func, generate)
166
167    return expression
168
169
170def _distribute(a, b, from_func, to_func, generate):
171    if isinstance(a, exp.Connector):
172        exp.replace_children(
173            a,
174            lambda c: to_func(
175                uniq_sort(flatten(from_func(c, b.left)), generate),
176                uniq_sort(flatten(from_func(c, b.right)), generate),
177                copy=False,
178            ),
179        )
180    else:
181        a = to_func(
182            uniq_sort(flatten(from_func(a, b.left)), generate),
183            uniq_sort(flatten(from_func(a, b.right)), generate),
184            copy=False,
185        )
186
187    return a
logger = <Logger sqlglot (WARNING)>
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
16def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
17    """
18    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
19
20    Example:
21        >>> import sqlglot
22        >>> expression = sqlglot.parse_one("(x AND y) OR z")
23        >>> normalize(expression, dnf=False).sql()
24        '(x OR z) AND (y OR z)'
25
26    Args:
27        expression: expression to normalize
28        dnf: rewrite in disjunctive normal form instead.
29        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
30    Returns:
31        sqlglot.Expression: normalized expression
32    """
33    generate = cached_generator()
34
35    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
36        if isinstance(node, exp.Connector):
37            if normalized(node, dnf=dnf):
38                continue
39            root = node is expression
40            original = node.copy()
41
42            node.transform(rewrite_between, copy=False)
43            distance = normalization_distance(node, dnf=dnf)
44
45            if distance > max_distance:
46                logger.info(
47                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
48                )
49                return expression
50
51            try:
52                node = node.replace(
53                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
54                )
55            except OptimizeError as e:
56                logger.info(e)
57                node.replace(original)
58                if root:
59                    return original
60                return expression
61
62            if root:
63                expression = node
64
65    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression: sqlglot.expressions.Expression, dnf: bool = False) -> bool:
68def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
69    """
70    Checks whether a given expression is in a normal form of interest.
71
72    Example:
73        >>> from sqlglot import parse_one
74        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
75        True
76        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
77        True
78        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
79        False
80
81    Args:
82        expression: The expression to check if it's normalized.
83        dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
84            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
85    """
86    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
87    return not any(
88        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
89    )

Checks whether a given expression is in a normal form of interest.

Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False
Arguments:
  • expression: The expression to check if it's normalized.
  • dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def normalization_distance(expression: sqlglot.expressions.Expression, dnf: bool = False) -> int:
 92def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
 93    """
 94    The difference in the number of predicates between a given expression and its normalized form.
 95
 96    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 97
 98    Example:
 99        >>> import sqlglot
100        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
101        >>> normalization_distance(expression)
102        4
103
104    Args:
105        expression: The expression to compute the normalization distance for.
106        dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
107            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
108
109    Returns:
110        The normalization distance.
111    """
112    return sum(_predicate_lengths(expression, dnf)) - (
113        sum(1 for _ in expression.find_all(exp.Connector)) + 1
114    )

The difference in the number of predicates between a given expression and its normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression: The expression to compute the normalization distance for.
  • dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
Returns:

The normalization distance.

def distributive_law(expression, dnf, max_distance, generate):
137def distributive_law(expression, dnf, max_distance, generate):
138    """
139    x OR (y AND z) -> (x OR y) AND (x OR z)
140    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
141    """
142    if normalized(expression, dnf=dnf):
143        return expression
144
145    distance = normalization_distance(expression, dnf=dnf)
146
147    if distance > max_distance:
148        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
149
150    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
151    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
152
153    if isinstance(expression, from_exp):
154        a, b = expression.unnest_operands()
155
156        from_func = exp.and_ if from_exp == exp.And else exp.or_
157        to_func = exp.and_ if to_exp == exp.And else exp.or_
158
159        if isinstance(a, to_exp) and isinstance(b, to_exp):
160            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
161                return _distribute(a, b, from_func, to_func, generate)
162            return _distribute(b, a, from_func, to_func, generate)
163        if isinstance(a, to_exp):
164            return _distribute(b, a, from_func, to_func, generate)
165        if isinstance(b, to_exp):
166            return _distribute(a, b, from_func, to_func, generate)
167
168    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)