Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import cached_generator
  9from sqlglot.helper import first, while_changing
 10
 11
 12def simplify(expression):
 13    """
 14    Rewrite sqlglot AST to simplify expressions.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 19        >>> simplify(expression).sql()
 20        'TRUE'
 21
 22    Args:
 23        expression (sqlglot.Expression): expression to simplify
 24    Returns:
 25        sqlglot.Expression: simplified expression
 26    """
 27
 28    generate = cached_generator()
 29
 30    def _simplify(expression, root=True):
 31        if expression.meta.get("final"):
 32            return expression
 33        node = expression
 34        node = rewrite_between(node)
 35        node = uniq_sort(node, generate, root)
 36        node = absorb_and_eliminate(node, root)
 37        exp.replace_children(node, lambda e: _simplify(e, False))
 38        node = simplify_not(node)
 39        node = flatten(node)
 40        node = simplify_connectors(node, root)
 41        node = remove_compliments(node, root)
 42        node.parent = expression.parent
 43        node = simplify_literals(node, root)
 44        node = simplify_parens(node)
 45        if root:
 46            expression.replace(node)
 47        return node
 48
 49    expression = while_changing(expression, _simplify)
 50    remove_where_true(expression)
 51    return expression
 52
 53
 54def rewrite_between(expression: exp.Expression) -> exp.Expression:
 55    """Rewrite x between y and z to x >= y AND x <= z.
 56
 57    This is done because comparison simplification is only done on lt/lte/gt/gte.
 58    """
 59    if isinstance(expression, exp.Between):
 60        return exp.and_(
 61            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 62            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 63            copy=False,
 64        )
 65    return expression
 66
 67
 68def simplify_not(expression):
 69    """
 70    Demorgan's Law
 71    NOT (x OR y) -> NOT x AND NOT y
 72    NOT (x AND y) -> NOT x OR NOT y
 73    """
 74    if isinstance(expression, exp.Not):
 75        if is_null(expression.this):
 76            return exp.null()
 77        if isinstance(expression.this, exp.Paren):
 78            condition = expression.this.unnest()
 79            if isinstance(condition, exp.And):
 80                return exp.or_(
 81                    exp.not_(condition.left, copy=False),
 82                    exp.not_(condition.right, copy=False),
 83                    copy=False,
 84                )
 85            if isinstance(condition, exp.Or):
 86                return exp.and_(
 87                    exp.not_(condition.left, copy=False),
 88                    exp.not_(condition.right, copy=False),
 89                    copy=False,
 90                )
 91            if is_null(condition):
 92                return exp.null()
 93        if always_true(expression.this):
 94            return exp.false()
 95        if is_false(expression.this):
 96            return exp.true()
 97        if isinstance(expression.this, exp.Not):
 98            # double negation
 99            # NOT NOT x -> x
100            return expression.this.this
101    return expression
102
103
104def flatten(expression):
105    """
106    A AND (B AND C) -> A AND B AND C
107    A OR (B OR C) -> A OR B OR C
108    """
109    if isinstance(expression, exp.Connector):
110        for node in expression.args.values():
111            child = node.unnest()
112            if isinstance(child, expression.__class__):
113                node.replace(child)
114    return expression
115
116
117def simplify_connectors(expression, root=True):
118    def _simplify_connectors(expression, left, right):
119        if left == right:
120            return left
121        if isinstance(expression, exp.And):
122            if is_false(left) or is_false(right):
123                return exp.false()
124            if is_null(left) or is_null(right):
125                return exp.null()
126            if always_true(left) and always_true(right):
127                return exp.true()
128            if always_true(left):
129                return right
130            if always_true(right):
131                return left
132            return _simplify_comparison(expression, left, right)
133        elif isinstance(expression, exp.Or):
134            if always_true(left) or always_true(right):
135                return exp.true()
136            if is_false(left) and is_false(right):
137                return exp.false()
138            if (
139                (is_null(left) and is_null(right))
140                or (is_null(left) and is_false(right))
141                or (is_false(left) and is_null(right))
142            ):
143                return exp.null()
144            if is_false(left):
145                return right
146            if is_false(right):
147                return left
148            return _simplify_comparison(expression, left, right, or_=True)
149
150    if isinstance(expression, exp.Connector):
151        return _flat_simplify(expression, _simplify_connectors, root)
152    return expression
153
154
155LT_LTE = (exp.LT, exp.LTE)
156GT_GTE = (exp.GT, exp.GTE)
157
158COMPARISONS = (
159    *LT_LTE,
160    *GT_GTE,
161    exp.EQ,
162    exp.NEQ,
163)
164
165INVERSE_COMPARISONS = {
166    exp.LT: exp.GT,
167    exp.GT: exp.LT,
168    exp.LTE: exp.GTE,
169    exp.GTE: exp.LTE,
170}
171
172
173def _simplify_comparison(expression, left, right, or_=False):
174    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
175        ll, lr = left.args.values()
176        rl, rr = right.args.values()
177
178        largs = {ll, lr}
179        rargs = {rl, rr}
180
181        matching = largs & rargs
182        columns = {m for m in matching if isinstance(m, exp.Column)}
183
184        if matching and columns:
185            try:
186                l = first(largs - columns)
187                r = first(rargs - columns)
188            except StopIteration:
189                return expression
190
191            # make sure the comparison is always of the form x > 1 instead of 1 < x
192            if left.__class__ in INVERSE_COMPARISONS and l == ll:
193                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
194            if right.__class__ in INVERSE_COMPARISONS and r == rl:
195                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
196
197            if l.is_number and r.is_number:
198                l = float(l.name)
199                r = float(r.name)
200            elif l.is_string and r.is_string:
201                l = l.name
202                r = r.name
203            else:
204                return None
205
206            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
207                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
208                    return left if (av > bv if or_ else av <= bv) else right
209                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
210                    return left if (av < bv if or_ else av >= bv) else right
211
212                # we can't ever shortcut to true because the column could be null
213                if not or_:
214                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
215                        if av <= bv:
216                            return exp.false()
217                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
218                        if av >= bv:
219                            return exp.false()
220                    elif isinstance(a, exp.EQ):
221                        if isinstance(b, exp.LT):
222                            return exp.false() if av >= bv else a
223                        if isinstance(b, exp.LTE):
224                            return exp.false() if av > bv else a
225                        if isinstance(b, exp.GT):
226                            return exp.false() if av <= bv else a
227                        if isinstance(b, exp.GTE):
228                            return exp.false() if av < bv else a
229                        if isinstance(b, exp.NEQ):
230                            return exp.false() if av == bv else a
231    return None
232
233
234def remove_compliments(expression, root=True):
235    """
236    Removing compliments.
237
238    A AND NOT A -> FALSE
239    A OR NOT A -> TRUE
240    """
241    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
242        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
243
244        for a, b in itertools.permutations(expression.flatten(), 2):
245            if is_complement(a, b):
246                return compliment
247    return expression
248
249
250def uniq_sort(expression, generate, root=True):
251    """
252    Uniq and sort a connector.
253
254    C AND A AND B AND B -> A AND B AND C
255    """
256    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
257        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
258        flattened = tuple(expression.flatten())
259        deduped = {generate(e): e for e in flattened}
260        arr = tuple(deduped.items())
261
262        # check if the operands are already sorted, if not sort them
263        # A AND C AND B -> A AND B AND C
264        for i, (sql, e) in enumerate(arr[1:]):
265            if sql < arr[i][0]:
266                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
267                break
268        else:
269            # we didn't have to sort but maybe we need to dedup
270            if len(deduped) < len(flattened):
271                expression = result_func(*deduped.values(), copy=False)
272
273    return expression
274
275
276def absorb_and_eliminate(expression, root=True):
277    """
278    absorption:
279        A AND (A OR B) -> A
280        A OR (A AND B) -> A
281        A AND (NOT A OR B) -> A AND B
282        A OR (NOT A AND B) -> A OR B
283    elimination:
284        (A AND B) OR (A AND NOT B) -> A
285        (A OR B) AND (A OR NOT B) -> A
286    """
287    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
288        kind = exp.Or if isinstance(expression, exp.And) else exp.And
289
290        for a, b in itertools.permutations(expression.flatten(), 2):
291            if isinstance(a, kind):
292                aa, ab = a.unnest_operands()
293
294                # absorb
295                if is_complement(b, aa):
296                    aa.replace(exp.true() if kind == exp.And else exp.false())
297                elif is_complement(b, ab):
298                    ab.replace(exp.true() if kind == exp.And else exp.false())
299                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
300                    a.replace(exp.false() if kind == exp.And else exp.true())
301                elif isinstance(b, kind):
302                    # eliminate
303                    rhs = b.unnest_operands()
304                    ba, bb = rhs
305
306                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
307                        a.replace(aa)
308                        b.replace(aa)
309                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
310                        a.replace(ab)
311                        b.replace(ab)
312
313    return expression
314
315
316def simplify_literals(expression, root=True):
317    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
318        return _flat_simplify(expression, _simplify_binary, root)
319    elif isinstance(expression, exp.Neg):
320        this = expression.this
321        if this.is_number:
322            value = this.name
323            if value[0] == "-":
324                return exp.Literal.number(value[1:])
325            return exp.Literal.number(f"-{value}")
326
327    return expression
328
329
330def _simplify_binary(expression, a, b):
331    if isinstance(expression, exp.Is):
332        if isinstance(b, exp.Not):
333            c = b.this
334            not_ = True
335        else:
336            c = b
337            not_ = False
338
339        if is_null(c):
340            if isinstance(a, exp.Literal):
341                return exp.true() if not_ else exp.false()
342            if is_null(a):
343                return exp.false() if not_ else exp.true()
344    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
345        return None
346    elif is_null(a) or is_null(b):
347        return exp.null()
348
349    if a.is_number and b.is_number:
350        a = int(a.name) if a.is_int else Decimal(a.name)
351        b = int(b.name) if b.is_int else Decimal(b.name)
352
353        if isinstance(expression, exp.Add):
354            return exp.Literal.number(a + b)
355        if isinstance(expression, exp.Sub):
356            return exp.Literal.number(a - b)
357        if isinstance(expression, exp.Mul):
358            return exp.Literal.number(a * b)
359        if isinstance(expression, exp.Div):
360            # engines have differing int div behavior so intdiv is not safe
361            if isinstance(a, int) and isinstance(b, int):
362                return None
363            return exp.Literal.number(a / b)
364
365        boolean = eval_boolean(expression, a, b)
366
367        if boolean:
368            return boolean
369    elif a.is_string and b.is_string:
370        boolean = eval_boolean(expression, a.this, b.this)
371
372        if boolean:
373            return boolean
374    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
375        a, b = extract_date(a), extract_interval(b)
376        if a and b:
377            if isinstance(expression, exp.Add):
378                return date_literal(a + b)
379            if isinstance(expression, exp.Sub):
380                return date_literal(a - b)
381    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
382        a, b = extract_interval(a), extract_date(b)
383        # you cannot subtract a date from an interval
384        if a and b and isinstance(expression, exp.Add):
385            return date_literal(a + b)
386
387    return None
388
389
390def simplify_parens(expression):
391    if not isinstance(expression, exp.Paren):
392        return expression
393
394    this = expression.this
395    parent = expression.parent
396
397    if not isinstance(this, exp.Select) and (
398        not isinstance(parent, (exp.Condition, exp.Binary))
399        or isinstance(this, exp.Predicate)
400        or not isinstance(this, exp.Binary)
401        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
402        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
403        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
404    ):
405        return expression.this
406    return expression
407
408
409def remove_where_true(expression):
410    for where in expression.find_all(exp.Where):
411        if always_true(where.this):
412            where.parent.set("where", None)
413    for join in expression.find_all(exp.Join):
414        if (
415            always_true(join.args.get("on"))
416            and not join.args.get("using")
417            and not join.args.get("method")
418        ):
419            join.set("on", None)
420            join.set("side", None)
421            join.set("kind", "CROSS")
422
423
424def always_true(expression):
425    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
426        expression, exp.Literal
427    )
428
429
430def is_complement(a, b):
431    return isinstance(b, exp.Not) and b.this == a
432
433
434def is_false(a: exp.Expression) -> bool:
435    return type(a) is exp.Boolean and not a.this
436
437
438def is_null(a: exp.Expression) -> bool:
439    return type(a) is exp.Null
440
441
442def eval_boolean(expression, a, b):
443    if isinstance(expression, (exp.EQ, exp.Is)):
444        return boolean_literal(a == b)
445    if isinstance(expression, exp.NEQ):
446        return boolean_literal(a != b)
447    if isinstance(expression, exp.GT):
448        return boolean_literal(a > b)
449    if isinstance(expression, exp.GTE):
450        return boolean_literal(a >= b)
451    if isinstance(expression, exp.LT):
452        return boolean_literal(a < b)
453    if isinstance(expression, exp.LTE):
454        return boolean_literal(a <= b)
455    return None
456
457
458def extract_date(cast):
459    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
460    # so in that case we can't extract the date.
461    try:
462        if cast.args["to"].this == exp.DataType.Type.DATE:
463            return datetime.date.fromisoformat(cast.name)
464        if cast.args["to"].this == exp.DataType.Type.DATETIME:
465            return datetime.datetime.fromisoformat(cast.name)
466    except ValueError:
467        return None
468
469
470def extract_interval(interval):
471    try:
472        from dateutil.relativedelta import relativedelta  # type: ignore
473    except ModuleNotFoundError:
474        return None
475
476    n = int(interval.name)
477    unit = interval.text("unit").lower()
478
479    if unit == "year":
480        return relativedelta(years=n)
481    if unit == "month":
482        return relativedelta(months=n)
483    if unit == "week":
484        return relativedelta(weeks=n)
485    if unit == "day":
486        return relativedelta(days=n)
487    return None
488
489
490def date_literal(date):
491    return exp.cast(
492        exp.Literal.string(date),
493        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
494    )
495
496
497def boolean_literal(condition):
498    return exp.true() if condition else exp.false()
499
500
501def _flat_simplify(expression, simplifier, root=True):
502    if root or not expression.same_parent:
503        operands = []
504        queue = deque(expression.flatten(unnest=False))
505        size = len(queue)
506
507        while queue:
508            a = queue.popleft()
509
510            for b in queue:
511                result = simplifier(expression, a, b)
512
513                if result:
514                    queue.remove(b)
515                    queue.appendleft(result)
516                    break
517            else:
518                operands.append(a)
519
520        if len(operands) < size:
521            return functools.reduce(
522                lambda a, b: expression.__class__(this=a, expression=b), operands
523            )
524    return expression
def simplify(expression):
13def simplify(expression):
14    """
15    Rewrite sqlglot AST to simplify expressions.
16
17    Example:
18        >>> import sqlglot
19        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
20        >>> simplify(expression).sql()
21        'TRUE'
22
23    Args:
24        expression (sqlglot.Expression): expression to simplify
25    Returns:
26        sqlglot.Expression: simplified expression
27    """
28
29    generate = cached_generator()
30
31    def _simplify(expression, root=True):
32        if expression.meta.get("final"):
33            return expression
34        node = expression
35        node = rewrite_between(node)
36        node = uniq_sort(node, generate, root)
37        node = absorb_and_eliminate(node, root)
38        exp.replace_children(node, lambda e: _simplify(e, False))
39        node = simplify_not(node)
40        node = flatten(node)
41        node = simplify_connectors(node, root)
42        node = remove_compliments(node, root)
43        node.parent = expression.parent
44        node = simplify_literals(node, root)
45        node = simplify_parens(node)
46        if root:
47            expression.replace(node)
48        return node
49
50    expression = while_changing(expression, _simplify)
51    remove_where_true(expression)
52    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
55def rewrite_between(expression: exp.Expression) -> exp.Expression:
56    """Rewrite x between y and z to x >= y AND x <= z.
57
58    This is done because comparison simplification is only done on lt/lte/gt/gte.
59    """
60    if isinstance(expression, exp.Between):
61        return exp.and_(
62            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
63            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
64            copy=False,
65        )
66    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
 69def simplify_not(expression):
 70    """
 71    Demorgan's Law
 72    NOT (x OR y) -> NOT x AND NOT y
 73    NOT (x AND y) -> NOT x OR NOT y
 74    """
 75    if isinstance(expression, exp.Not):
 76        if is_null(expression.this):
 77            return exp.null()
 78        if isinstance(expression.this, exp.Paren):
 79            condition = expression.this.unnest()
 80            if isinstance(condition, exp.And):
 81                return exp.or_(
 82                    exp.not_(condition.left, copy=False),
 83                    exp.not_(condition.right, copy=False),
 84                    copy=False,
 85                )
 86            if isinstance(condition, exp.Or):
 87                return exp.and_(
 88                    exp.not_(condition.left, copy=False),
 89                    exp.not_(condition.right, copy=False),
 90                    copy=False,
 91                )
 92            if is_null(condition):
 93                return exp.null()
 94        if always_true(expression.this):
 95            return exp.false()
 96        if is_false(expression.this):
 97            return exp.true()
 98        if isinstance(expression.this, exp.Not):
 99            # double negation
100            # NOT NOT x -> x
101            return expression.this.this
102    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
105def flatten(expression):
106    """
107    A AND (B AND C) -> A AND B AND C
108    A OR (B OR C) -> A OR B OR C
109    """
110    if isinstance(expression, exp.Connector):
111        for node in expression.args.values():
112            child = node.unnest()
113            if isinstance(child, expression.__class__):
114                node.replace(child)
115    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
118def simplify_connectors(expression, root=True):
119    def _simplify_connectors(expression, left, right):
120        if left == right:
121            return left
122        if isinstance(expression, exp.And):
123            if is_false(left) or is_false(right):
124                return exp.false()
125            if is_null(left) or is_null(right):
126                return exp.null()
127            if always_true(left) and always_true(right):
128                return exp.true()
129            if always_true(left):
130                return right
131            if always_true(right):
132                return left
133            return _simplify_comparison(expression, left, right)
134        elif isinstance(expression, exp.Or):
135            if always_true(left) or always_true(right):
136                return exp.true()
137            if is_false(left) and is_false(right):
138                return exp.false()
139            if (
140                (is_null(left) and is_null(right))
141                or (is_null(left) and is_false(right))
142                or (is_false(left) and is_null(right))
143            ):
144                return exp.null()
145            if is_false(left):
146                return right
147            if is_false(right):
148                return left
149            return _simplify_comparison(expression, left, right, or_=True)
150
151    if isinstance(expression, exp.Connector):
152        return _flat_simplify(expression, _simplify_connectors, root)
153    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
INVERSE_COMPARISONS = {<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def remove_compliments(expression, root=True):
235def remove_compliments(expression, root=True):
236    """
237    Removing compliments.
238
239    A AND NOT A -> FALSE
240    A OR NOT A -> TRUE
241    """
242    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
243        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
244
245        for a, b in itertools.permutations(expression.flatten(), 2):
246            if is_complement(a, b):
247                return compliment
248    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
251def uniq_sort(expression, generate, root=True):
252    """
253    Uniq and sort a connector.
254
255    C AND A AND B AND B -> A AND B AND C
256    """
257    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
258        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
259        flattened = tuple(expression.flatten())
260        deduped = {generate(e): e for e in flattened}
261        arr = tuple(deduped.items())
262
263        # check if the operands are already sorted, if not sort them
264        # A AND C AND B -> A AND B AND C
265        for i, (sql, e) in enumerate(arr[1:]):
266            if sql < arr[i][0]:
267                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
268                break
269        else:
270            # we didn't have to sort but maybe we need to dedup
271            if len(deduped) < len(flattened):
272                expression = result_func(*deduped.values(), copy=False)
273
274    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
277def absorb_and_eliminate(expression, root=True):
278    """
279    absorption:
280        A AND (A OR B) -> A
281        A OR (A AND B) -> A
282        A AND (NOT A OR B) -> A AND B
283        A OR (NOT A AND B) -> A OR B
284    elimination:
285        (A AND B) OR (A AND NOT B) -> A
286        (A OR B) AND (A OR NOT B) -> A
287    """
288    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
289        kind = exp.Or if isinstance(expression, exp.And) else exp.And
290
291        for a, b in itertools.permutations(expression.flatten(), 2):
292            if isinstance(a, kind):
293                aa, ab = a.unnest_operands()
294
295                # absorb
296                if is_complement(b, aa):
297                    aa.replace(exp.true() if kind == exp.And else exp.false())
298                elif is_complement(b, ab):
299                    ab.replace(exp.true() if kind == exp.And else exp.false())
300                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
301                    a.replace(exp.false() if kind == exp.And else exp.true())
302                elif isinstance(b, kind):
303                    # eliminate
304                    rhs = b.unnest_operands()
305                    ba, bb = rhs
306
307                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
308                        a.replace(aa)
309                        b.replace(aa)
310                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
311                        a.replace(ab)
312                        b.replace(ab)
313
314    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
317def simplify_literals(expression, root=True):
318    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
319        return _flat_simplify(expression, _simplify_binary, root)
320    elif isinstance(expression, exp.Neg):
321        this = expression.this
322        if this.is_number:
323            value = this.name
324            if value[0] == "-":
325                return exp.Literal.number(value[1:])
326            return exp.Literal.number(f"-{value}")
327
328    return expression
def simplify_parens(expression):
391def simplify_parens(expression):
392    if not isinstance(expression, exp.Paren):
393        return expression
394
395    this = expression.this
396    parent = expression.parent
397
398    if not isinstance(this, exp.Select) and (
399        not isinstance(parent, (exp.Condition, exp.Binary))
400        or isinstance(this, exp.Predicate)
401        or not isinstance(this, exp.Binary)
402        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
403        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
404        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
405    ):
406        return expression.this
407    return expression
def remove_where_true(expression):
410def remove_where_true(expression):
411    for where in expression.find_all(exp.Where):
412        if always_true(where.this):
413            where.parent.set("where", None)
414    for join in expression.find_all(exp.Join):
415        if (
416            always_true(join.args.get("on"))
417            and not join.args.get("using")
418            and not join.args.get("method")
419        ):
420            join.set("on", None)
421            join.set("side", None)
422            join.set("kind", "CROSS")
def always_true(expression):
425def always_true(expression):
426    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
427        expression, exp.Literal
428    )
def is_complement(a, b):
431def is_complement(a, b):
432    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
435def is_false(a: exp.Expression) -> bool:
436    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
439def is_null(a: exp.Expression) -> bool:
440    return type(a) is exp.Null
def eval_boolean(expression, a, b):
443def eval_boolean(expression, a, b):
444    if isinstance(expression, (exp.EQ, exp.Is)):
445        return boolean_literal(a == b)
446    if isinstance(expression, exp.NEQ):
447        return boolean_literal(a != b)
448    if isinstance(expression, exp.GT):
449        return boolean_literal(a > b)
450    if isinstance(expression, exp.GTE):
451        return boolean_literal(a >= b)
452    if isinstance(expression, exp.LT):
453        return boolean_literal(a < b)
454    if isinstance(expression, exp.LTE):
455        return boolean_literal(a <= b)
456    return None
def extract_date(cast):
459def extract_date(cast):
460    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
461    # so in that case we can't extract the date.
462    try:
463        if cast.args["to"].this == exp.DataType.Type.DATE:
464            return datetime.date.fromisoformat(cast.name)
465        if cast.args["to"].this == exp.DataType.Type.DATETIME:
466            return datetime.datetime.fromisoformat(cast.name)
467    except ValueError:
468        return None
def extract_interval(interval):
471def extract_interval(interval):
472    try:
473        from dateutil.relativedelta import relativedelta  # type: ignore
474    except ModuleNotFoundError:
475        return None
476
477    n = int(interval.name)
478    unit = interval.text("unit").lower()
479
480    if unit == "year":
481        return relativedelta(years=n)
482    if unit == "month":
483        return relativedelta(months=n)
484    if unit == "week":
485        return relativedelta(weeks=n)
486    if unit == "day":
487        return relativedelta(days=n)
488    return None
def date_literal(date):
491def date_literal(date):
492    return exp.cast(
493        exp.Literal.string(date),
494        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
495    )
def boolean_literal(condition):
498def boolean_literal(condition):
499    return exp.true() if condition else exp.false()