Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import cached_generator
  9from sqlglot.helper import first, while_changing
 10
 11
 12def simplify(expression):
 13    """
 14    Rewrite sqlglot AST to simplify expressions.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 19        >>> simplify(expression).sql()
 20        'TRUE'
 21
 22    Args:
 23        expression (sqlglot.Expression): expression to simplify
 24    Returns:
 25        sqlglot.Expression: simplified expression
 26    """
 27
 28    generate = cached_generator()
 29
 30    def _simplify(expression, root=True):
 31        if expression.meta.get("final"):
 32            return expression
 33        node = expression
 34        node = rewrite_between(node)
 35        node = uniq_sort(node, generate, root)
 36        node = absorb_and_eliminate(node, root)
 37        exp.replace_children(node, lambda e: _simplify(e, False))
 38        node = simplify_not(node)
 39        node = flatten(node)
 40        node = simplify_connectors(node, root)
 41        node = remove_compliments(node, root)
 42        node.parent = expression.parent
 43        node = simplify_literals(node, root)
 44        node = simplify_parens(node)
 45        if root:
 46            expression.replace(node)
 47        return node
 48
 49    expression = while_changing(expression, _simplify)
 50    remove_where_true(expression)
 51    return expression
 52
 53
 54def rewrite_between(expression: exp.Expression) -> exp.Expression:
 55    """Rewrite x between y and z to x >= y AND x <= z.
 56
 57    This is done because comparison simplification is only done on lt/lte/gt/gte.
 58    """
 59    if isinstance(expression, exp.Between):
 60        return exp.and_(
 61            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 62            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 63            copy=False,
 64        )
 65    return expression
 66
 67
 68def simplify_not(expression):
 69    """
 70    Demorgan's Law
 71    NOT (x OR y) -> NOT x AND NOT y
 72    NOT (x AND y) -> NOT x OR NOT y
 73    """
 74    if isinstance(expression, exp.Not):
 75        if is_null(expression.this):
 76            return exp.null()
 77        if isinstance(expression.this, exp.Paren):
 78            condition = expression.this.unnest()
 79            if isinstance(condition, exp.And):
 80                return exp.or_(
 81                    exp.not_(condition.left, copy=False),
 82                    exp.not_(condition.right, copy=False),
 83                    copy=False,
 84                )
 85            if isinstance(condition, exp.Or):
 86                return exp.and_(
 87                    exp.not_(condition.left, copy=False),
 88                    exp.not_(condition.right, copy=False),
 89                    copy=False,
 90                )
 91            if is_null(condition):
 92                return exp.null()
 93        if always_true(expression.this):
 94            return exp.false()
 95        if is_false(expression.this):
 96            return exp.true()
 97        if isinstance(expression.this, exp.Not):
 98            # double negation
 99            # NOT NOT x -> x
100            return expression.this.this
101    return expression
102
103
104def flatten(expression):
105    """
106    A AND (B AND C) -> A AND B AND C
107    A OR (B OR C) -> A OR B OR C
108    """
109    if isinstance(expression, exp.Connector):
110        for node in expression.args.values():
111            child = node.unnest()
112            if isinstance(child, expression.__class__):
113                node.replace(child)
114    return expression
115
116
117def simplify_connectors(expression, root=True):
118    def _simplify_connectors(expression, left, right):
119        if left == right:
120            return left
121        if isinstance(expression, exp.And):
122            if is_false(left) or is_false(right):
123                return exp.false()
124            if is_null(left) or is_null(right):
125                return exp.null()
126            if always_true(left) and always_true(right):
127                return exp.true()
128            if always_true(left):
129                return right
130            if always_true(right):
131                return left
132            return _simplify_comparison(expression, left, right)
133        elif isinstance(expression, exp.Or):
134            if always_true(left) or always_true(right):
135                return exp.true()
136            if is_false(left) and is_false(right):
137                return exp.false()
138            if (
139                (is_null(left) and is_null(right))
140                or (is_null(left) and is_false(right))
141                or (is_false(left) and is_null(right))
142            ):
143                return exp.null()
144            if is_false(left):
145                return right
146            if is_false(right):
147                return left
148            return _simplify_comparison(expression, left, right, or_=True)
149
150    if isinstance(expression, exp.Connector):
151        return _flat_simplify(expression, _simplify_connectors, root)
152    return expression
153
154
155LT_LTE = (exp.LT, exp.LTE)
156GT_GTE = (exp.GT, exp.GTE)
157
158COMPARISONS = (
159    *LT_LTE,
160    *GT_GTE,
161    exp.EQ,
162    exp.NEQ,
163)
164
165INVERSE_COMPARISONS = {
166    exp.LT: exp.GT,
167    exp.GT: exp.LT,
168    exp.LTE: exp.GTE,
169    exp.GTE: exp.LTE,
170}
171
172
173def _simplify_comparison(expression, left, right, or_=False):
174    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
175        ll, lr = left.args.values()
176        rl, rr = right.args.values()
177
178        largs = {ll, lr}
179        rargs = {rl, rr}
180
181        matching = largs & rargs
182        columns = {m for m in matching if isinstance(m, exp.Column)}
183
184        if matching and columns:
185            try:
186                l = first(largs - columns)
187                r = first(rargs - columns)
188            except StopIteration:
189                return expression
190
191            # make sure the comparison is always of the form x > 1 instead of 1 < x
192            if left.__class__ in INVERSE_COMPARISONS and l == ll:
193                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
194            if right.__class__ in INVERSE_COMPARISONS and r == rl:
195                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
196
197            if l.is_number and r.is_number:
198                l = float(l.name)
199                r = float(r.name)
200            elif l.is_string and r.is_string:
201                l = l.name
202                r = r.name
203            else:
204                return None
205
206            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
207                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
208                    return left if (av > bv if or_ else av <= bv) else right
209                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
210                    return left if (av < bv if or_ else av >= bv) else right
211
212                # we can't ever shortcut to true because the column could be null
213                if not or_:
214                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
215                        if av <= bv:
216                            return exp.false()
217                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
218                        if av >= bv:
219                            return exp.false()
220                    elif isinstance(a, exp.EQ):
221                        if isinstance(b, exp.LT):
222                            return exp.false() if av >= bv else a
223                        if isinstance(b, exp.LTE):
224                            return exp.false() if av > bv else a
225                        if isinstance(b, exp.GT):
226                            return exp.false() if av <= bv else a
227                        if isinstance(b, exp.GTE):
228                            return exp.false() if av < bv else a
229                        if isinstance(b, exp.NEQ):
230                            return exp.false() if av == bv else a
231    return None
232
233
234def remove_compliments(expression, root=True):
235    """
236    Removing compliments.
237
238    A AND NOT A -> FALSE
239    A OR NOT A -> TRUE
240    """
241    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
242        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
243
244        for a, b in itertools.permutations(expression.flatten(), 2):
245            if is_complement(a, b):
246                return compliment
247    return expression
248
249
250def uniq_sort(expression, generate, root=True):
251    """
252    Uniq and sort a connector.
253
254    C AND A AND B AND B -> A AND B AND C
255    """
256    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
257        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
258        flattened = tuple(expression.flatten())
259        deduped = {generate(e): e for e in flattened}
260        arr = tuple(deduped.items())
261
262        # check if the operands are already sorted, if not sort them
263        # A AND C AND B -> A AND B AND C
264        for i, (sql, e) in enumerate(arr[1:]):
265            if sql < arr[i][0]:
266                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
267                break
268        else:
269            # we didn't have to sort but maybe we need to dedup
270            if len(deduped) < len(flattened):
271                expression = result_func(*deduped.values(), copy=False)
272
273    return expression
274
275
276def absorb_and_eliminate(expression, root=True):
277    """
278    absorption:
279        A AND (A OR B) -> A
280        A OR (A AND B) -> A
281        A AND (NOT A OR B) -> A AND B
282        A OR (NOT A AND B) -> A OR B
283    elimination:
284        (A AND B) OR (A AND NOT B) -> A
285        (A OR B) AND (A OR NOT B) -> A
286    """
287    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
288        kind = exp.Or if isinstance(expression, exp.And) else exp.And
289
290        for a, b in itertools.permutations(expression.flatten(), 2):
291            if isinstance(a, kind):
292                aa, ab = a.unnest_operands()
293
294                # absorb
295                if is_complement(b, aa):
296                    aa.replace(exp.true() if kind == exp.And else exp.false())
297                elif is_complement(b, ab):
298                    ab.replace(exp.true() if kind == exp.And else exp.false())
299                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
300                    a.replace(exp.false() if kind == exp.And else exp.true())
301                elif isinstance(b, kind):
302                    # eliminate
303                    rhs = b.unnest_operands()
304                    ba, bb = rhs
305
306                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
307                        a.replace(aa)
308                        b.replace(aa)
309                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
310                        a.replace(ab)
311                        b.replace(ab)
312
313    return expression
314
315
316def simplify_literals(expression, root=True):
317    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
318        return _flat_simplify(expression, _simplify_binary, root)
319    elif isinstance(expression, exp.Neg):
320        this = expression.this
321        if this.is_number:
322            value = this.name
323            if value[0] == "-":
324                return exp.Literal.number(value[1:])
325            return exp.Literal.number(f"-{value}")
326
327    return expression
328
329
330def _simplify_binary(expression, a, b):
331    if isinstance(expression, exp.Is):
332        if isinstance(b, exp.Not):
333            c = b.this
334            not_ = True
335        else:
336            c = b
337            not_ = False
338
339        if is_null(c):
340            if isinstance(a, exp.Literal):
341                return exp.true() if not_ else exp.false()
342            if is_null(a):
343                return exp.false() if not_ else exp.true()
344    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
345        return None
346    elif is_null(a) or is_null(b):
347        return exp.null()
348
349    if a.is_number and b.is_number:
350        a = int(a.name) if a.is_int else Decimal(a.name)
351        b = int(b.name) if b.is_int else Decimal(b.name)
352
353        if isinstance(expression, exp.Add):
354            return exp.Literal.number(a + b)
355        if isinstance(expression, exp.Sub):
356            return exp.Literal.number(a - b)
357        if isinstance(expression, exp.Mul):
358            return exp.Literal.number(a * b)
359        if isinstance(expression, exp.Div):
360            # engines have differing int div behavior so intdiv is not safe
361            if isinstance(a, int) and isinstance(b, int):
362                return None
363            return exp.Literal.number(a / b)
364
365        boolean = eval_boolean(expression, a, b)
366
367        if boolean:
368            return boolean
369    elif a.is_string and b.is_string:
370        boolean = eval_boolean(expression, a.this, b.this)
371
372        if boolean:
373            return boolean
374    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
375        a, b = extract_date(a), extract_interval(b)
376        if a and b:
377            if isinstance(expression, exp.Add):
378                return date_literal(a + b)
379            if isinstance(expression, exp.Sub):
380                return date_literal(a - b)
381    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
382        a, b = extract_interval(a), extract_date(b)
383        # you cannot subtract a date from an interval
384        if a and b and isinstance(expression, exp.Add):
385            return date_literal(a + b)
386
387    return None
388
389
390def simplify_parens(expression):
391    if not isinstance(expression, exp.Paren):
392        return expression
393
394    this = expression.this
395    parent = expression.parent
396
397    if not isinstance(this, exp.Select) and (
398        not isinstance(parent, (exp.Condition, exp.Binary))
399        or isinstance(this, exp.Predicate)
400        or not isinstance(this, exp.Binary)
401        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
402        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
403    ):
404        return expression.this
405    return expression
406
407
408def remove_where_true(expression):
409    for where in expression.find_all(exp.Where):
410        if always_true(where.this):
411            where.parent.set("where", None)
412    for join in expression.find_all(exp.Join):
413        if (
414            always_true(join.args.get("on"))
415            and not join.args.get("using")
416            and not join.args.get("method")
417        ):
418            join.set("on", None)
419            join.set("side", None)
420            join.set("kind", "CROSS")
421
422
423def always_true(expression):
424    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
425        expression, exp.Literal
426    )
427
428
429def is_complement(a, b):
430    return isinstance(b, exp.Not) and b.this == a
431
432
433def is_false(a: exp.Expression) -> bool:
434    return type(a) is exp.Boolean and not a.this
435
436
437def is_null(a: exp.Expression) -> bool:
438    return type(a) is exp.Null
439
440
441def eval_boolean(expression, a, b):
442    if isinstance(expression, (exp.EQ, exp.Is)):
443        return boolean_literal(a == b)
444    if isinstance(expression, exp.NEQ):
445        return boolean_literal(a != b)
446    if isinstance(expression, exp.GT):
447        return boolean_literal(a > b)
448    if isinstance(expression, exp.GTE):
449        return boolean_literal(a >= b)
450    if isinstance(expression, exp.LT):
451        return boolean_literal(a < b)
452    if isinstance(expression, exp.LTE):
453        return boolean_literal(a <= b)
454    return None
455
456
457def extract_date(cast):
458    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
459    # so in that case we can't extract the date.
460    try:
461        if cast.args["to"].this == exp.DataType.Type.DATE:
462            return datetime.date.fromisoformat(cast.name)
463        if cast.args["to"].this == exp.DataType.Type.DATETIME:
464            return datetime.datetime.fromisoformat(cast.name)
465    except ValueError:
466        return None
467
468
469def extract_interval(interval):
470    try:
471        from dateutil.relativedelta import relativedelta  # type: ignore
472    except ModuleNotFoundError:
473        return None
474
475    n = int(interval.name)
476    unit = interval.text("unit").lower()
477
478    if unit == "year":
479        return relativedelta(years=n)
480    if unit == "month":
481        return relativedelta(months=n)
482    if unit == "week":
483        return relativedelta(weeks=n)
484    if unit == "day":
485        return relativedelta(days=n)
486    return None
487
488
489def date_literal(date):
490    return exp.cast(
491        exp.Literal.string(date),
492        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
493    )
494
495
496def boolean_literal(condition):
497    return exp.true() if condition else exp.false()
498
499
500def _flat_simplify(expression, simplifier, root=True):
501    if root or not expression.same_parent:
502        operands = []
503        queue = deque(expression.flatten(unnest=False))
504        size = len(queue)
505
506        while queue:
507            a = queue.popleft()
508
509            for b in queue:
510                result = simplifier(expression, a, b)
511
512                if result:
513                    queue.remove(b)
514                    queue.appendleft(result)
515                    break
516            else:
517                operands.append(a)
518
519        if len(operands) < size:
520            return functools.reduce(
521                lambda a, b: expression.__class__(this=a, expression=b), operands
522            )
523    return expression
def simplify(expression):
13def simplify(expression):
14    """
15    Rewrite sqlglot AST to simplify expressions.
16
17    Example:
18        >>> import sqlglot
19        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
20        >>> simplify(expression).sql()
21        'TRUE'
22
23    Args:
24        expression (sqlglot.Expression): expression to simplify
25    Returns:
26        sqlglot.Expression: simplified expression
27    """
28
29    generate = cached_generator()
30
31    def _simplify(expression, root=True):
32        if expression.meta.get("final"):
33            return expression
34        node = expression
35        node = rewrite_between(node)
36        node = uniq_sort(node, generate, root)
37        node = absorb_and_eliminate(node, root)
38        exp.replace_children(node, lambda e: _simplify(e, False))
39        node = simplify_not(node)
40        node = flatten(node)
41        node = simplify_connectors(node, root)
42        node = remove_compliments(node, root)
43        node.parent = expression.parent
44        node = simplify_literals(node, root)
45        node = simplify_parens(node)
46        if root:
47            expression.replace(node)
48        return node
49
50    expression = while_changing(expression, _simplify)
51    remove_where_true(expression)
52    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
55def rewrite_between(expression: exp.Expression) -> exp.Expression:
56    """Rewrite x between y and z to x >= y AND x <= z.
57
58    This is done because comparison simplification is only done on lt/lte/gt/gte.
59    """
60    if isinstance(expression, exp.Between):
61        return exp.and_(
62            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
63            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
64            copy=False,
65        )
66    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
 69def simplify_not(expression):
 70    """
 71    Demorgan's Law
 72    NOT (x OR y) -> NOT x AND NOT y
 73    NOT (x AND y) -> NOT x OR NOT y
 74    """
 75    if isinstance(expression, exp.Not):
 76        if is_null(expression.this):
 77            return exp.null()
 78        if isinstance(expression.this, exp.Paren):
 79            condition = expression.this.unnest()
 80            if isinstance(condition, exp.And):
 81                return exp.or_(
 82                    exp.not_(condition.left, copy=False),
 83                    exp.not_(condition.right, copy=False),
 84                    copy=False,
 85                )
 86            if isinstance(condition, exp.Or):
 87                return exp.and_(
 88                    exp.not_(condition.left, copy=False),
 89                    exp.not_(condition.right, copy=False),
 90                    copy=False,
 91                )
 92            if is_null(condition):
 93                return exp.null()
 94        if always_true(expression.this):
 95            return exp.false()
 96        if is_false(expression.this):
 97            return exp.true()
 98        if isinstance(expression.this, exp.Not):
 99            # double negation
100            # NOT NOT x -> x
101            return expression.this.this
102    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
105def flatten(expression):
106    """
107    A AND (B AND C) -> A AND B AND C
108    A OR (B OR C) -> A OR B OR C
109    """
110    if isinstance(expression, exp.Connector):
111        for node in expression.args.values():
112            child = node.unnest()
113            if isinstance(child, expression.__class__):
114                node.replace(child)
115    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
118def simplify_connectors(expression, root=True):
119    def _simplify_connectors(expression, left, right):
120        if left == right:
121            return left
122        if isinstance(expression, exp.And):
123            if is_false(left) or is_false(right):
124                return exp.false()
125            if is_null(left) or is_null(right):
126                return exp.null()
127            if always_true(left) and always_true(right):
128                return exp.true()
129            if always_true(left):
130                return right
131            if always_true(right):
132                return left
133            return _simplify_comparison(expression, left, right)
134        elif isinstance(expression, exp.Or):
135            if always_true(left) or always_true(right):
136                return exp.true()
137            if is_false(left) and is_false(right):
138                return exp.false()
139            if (
140                (is_null(left) and is_null(right))
141                or (is_null(left) and is_false(right))
142                or (is_false(left) and is_null(right))
143            ):
144                return exp.null()
145            if is_false(left):
146                return right
147            if is_false(right):
148                return left
149            return _simplify_comparison(expression, left, right, or_=True)
150
151    if isinstance(expression, exp.Connector):
152        return _flat_simplify(expression, _simplify_connectors, root)
153    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
INVERSE_COMPARISONS = {<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def remove_compliments(expression, root=True):
235def remove_compliments(expression, root=True):
236    """
237    Removing compliments.
238
239    A AND NOT A -> FALSE
240    A OR NOT A -> TRUE
241    """
242    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
243        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
244
245        for a, b in itertools.permutations(expression.flatten(), 2):
246            if is_complement(a, b):
247                return compliment
248    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
251def uniq_sort(expression, generate, root=True):
252    """
253    Uniq and sort a connector.
254
255    C AND A AND B AND B -> A AND B AND C
256    """
257    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
258        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
259        flattened = tuple(expression.flatten())
260        deduped = {generate(e): e for e in flattened}
261        arr = tuple(deduped.items())
262
263        # check if the operands are already sorted, if not sort them
264        # A AND C AND B -> A AND B AND C
265        for i, (sql, e) in enumerate(arr[1:]):
266            if sql < arr[i][0]:
267                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
268                break
269        else:
270            # we didn't have to sort but maybe we need to dedup
271            if len(deduped) < len(flattened):
272                expression = result_func(*deduped.values(), copy=False)
273
274    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
277def absorb_and_eliminate(expression, root=True):
278    """
279    absorption:
280        A AND (A OR B) -> A
281        A OR (A AND B) -> A
282        A AND (NOT A OR B) -> A AND B
283        A OR (NOT A AND B) -> A OR B
284    elimination:
285        (A AND B) OR (A AND NOT B) -> A
286        (A OR B) AND (A OR NOT B) -> A
287    """
288    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
289        kind = exp.Or if isinstance(expression, exp.And) else exp.And
290
291        for a, b in itertools.permutations(expression.flatten(), 2):
292            if isinstance(a, kind):
293                aa, ab = a.unnest_operands()
294
295                # absorb
296                if is_complement(b, aa):
297                    aa.replace(exp.true() if kind == exp.And else exp.false())
298                elif is_complement(b, ab):
299                    ab.replace(exp.true() if kind == exp.And else exp.false())
300                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
301                    a.replace(exp.false() if kind == exp.And else exp.true())
302                elif isinstance(b, kind):
303                    # eliminate
304                    rhs = b.unnest_operands()
305                    ba, bb = rhs
306
307                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
308                        a.replace(aa)
309                        b.replace(aa)
310                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
311                        a.replace(ab)
312                        b.replace(ab)
313
314    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
317def simplify_literals(expression, root=True):
318    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
319        return _flat_simplify(expression, _simplify_binary, root)
320    elif isinstance(expression, exp.Neg):
321        this = expression.this
322        if this.is_number:
323            value = this.name
324            if value[0] == "-":
325                return exp.Literal.number(value[1:])
326            return exp.Literal.number(f"-{value}")
327
328    return expression
def simplify_parens(expression):
391def simplify_parens(expression):
392    if not isinstance(expression, exp.Paren):
393        return expression
394
395    this = expression.this
396    parent = expression.parent
397
398    if not isinstance(this, exp.Select) and (
399        not isinstance(parent, (exp.Condition, exp.Binary))
400        or isinstance(this, exp.Predicate)
401        or not isinstance(this, exp.Binary)
402        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
403        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
404    ):
405        return expression.this
406    return expression
def remove_where_true(expression):
409def remove_where_true(expression):
410    for where in expression.find_all(exp.Where):
411        if always_true(where.this):
412            where.parent.set("where", None)
413    for join in expression.find_all(exp.Join):
414        if (
415            always_true(join.args.get("on"))
416            and not join.args.get("using")
417            and not join.args.get("method")
418        ):
419            join.set("on", None)
420            join.set("side", None)
421            join.set("kind", "CROSS")
def always_true(expression):
424def always_true(expression):
425    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
426        expression, exp.Literal
427    )
def is_complement(a, b):
430def is_complement(a, b):
431    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
434def is_false(a: exp.Expression) -> bool:
435    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
438def is_null(a: exp.Expression) -> bool:
439    return type(a) is exp.Null
def eval_boolean(expression, a, b):
442def eval_boolean(expression, a, b):
443    if isinstance(expression, (exp.EQ, exp.Is)):
444        return boolean_literal(a == b)
445    if isinstance(expression, exp.NEQ):
446        return boolean_literal(a != b)
447    if isinstance(expression, exp.GT):
448        return boolean_literal(a > b)
449    if isinstance(expression, exp.GTE):
450        return boolean_literal(a >= b)
451    if isinstance(expression, exp.LT):
452        return boolean_literal(a < b)
453    if isinstance(expression, exp.LTE):
454        return boolean_literal(a <= b)
455    return None
def extract_date(cast):
458def extract_date(cast):
459    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
460    # so in that case we can't extract the date.
461    try:
462        if cast.args["to"].this == exp.DataType.Type.DATE:
463            return datetime.date.fromisoformat(cast.name)
464        if cast.args["to"].this == exp.DataType.Type.DATETIME:
465            return datetime.datetime.fromisoformat(cast.name)
466    except ValueError:
467        return None
def extract_interval(interval):
470def extract_interval(interval):
471    try:
472        from dateutil.relativedelta import relativedelta  # type: ignore
473    except ModuleNotFoundError:
474        return None
475
476    n = int(interval.name)
477    unit = interval.text("unit").lower()
478
479    if unit == "year":
480        return relativedelta(years=n)
481    if unit == "month":
482        return relativedelta(months=n)
483    if unit == "week":
484        return relativedelta(weeks=n)
485    if unit == "day":
486        return relativedelta(days=n)
487    return None
def date_literal(date):
490def date_literal(date):
491    return exp.cast(
492        exp.Literal.string(date),
493        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
494    )
def boolean_literal(condition):
497def boolean_literal(condition):
498    return exp.true() if condition else exp.false()