Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4import typing as t
  5from collections import deque
  6from decimal import Decimal
  7
  8from sqlglot import exp
  9from sqlglot.generator import cached_generator
 10from sqlglot.helper import first, merge_ranges, while_changing
 11
 12# Final means that an expression should not be simplified
 13FINAL = "final"
 14
 15
 16class UnsupportedUnit(Exception):
 17    pass
 18
 19
 20def simplify(expression):
 21    """
 22    Rewrite sqlglot AST to simplify expressions.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 27        >>> simplify(expression).sql()
 28        'TRUE'
 29
 30    Args:
 31        expression (sqlglot.Expression): expression to simplify
 32    Returns:
 33        sqlglot.Expression: simplified expression
 34    """
 35
 36    generate = cached_generator()
 37
 38    # group by expressions cannot be simplified, for example
 39    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 40    # the projection must exactly match the group by key
 41    for group in expression.find_all(exp.Group):
 42        select = group.parent
 43        groups = set(group.expressions)
 44        group.meta[FINAL] = True
 45
 46        for e in select.selects:
 47            for node, *_ in e.walk():
 48                if node in groups:
 49                    e.meta[FINAL] = True
 50                    break
 51
 52        having = select.args.get("having")
 53        if having:
 54            for node, *_ in having.walk():
 55                if node in groups:
 56                    having.meta[FINAL] = True
 57                    break
 58
 59    def _simplify(expression, root=True):
 60        if expression.meta.get(FINAL):
 61            return expression
 62
 63        # Pre-order transformations
 64        node = expression
 65        node = rewrite_between(node)
 66        node = uniq_sort(node, generate, root)
 67        node = absorb_and_eliminate(node, root)
 68        node = simplify_concat(node)
 69
 70        exp.replace_children(node, lambda e: _simplify(e, False))
 71
 72        # Post-order transformations
 73        node = simplify_not(node)
 74        node = flatten(node)
 75        node = simplify_connectors(node, root)
 76        node = remove_compliments(node, root)
 77        node = simplify_coalesce(node)
 78        node.parent = expression.parent
 79        node = simplify_literals(node, root)
 80        node = simplify_equality(node)
 81        node = simplify_parens(node)
 82        node = simplify_datetrunc_predicate(node)
 83
 84        if root:
 85            expression.replace(node)
 86
 87        return node
 88
 89    expression = while_changing(expression, _simplify)
 90    remove_where_true(expression)
 91    return expression
 92
 93
 94def catch(*exceptions):
 95    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 96
 97    def decorator(func):
 98        def wrapped(expression, *args, **kwargs):
 99            try:
100                return func(expression, *args, **kwargs)
101            except exceptions:
102                return expression
103
104        return wrapped
105
106    return decorator
107
108
109def rewrite_between(expression: exp.Expression) -> exp.Expression:
110    """Rewrite x between y and z to x >= y AND x <= z.
111
112    This is done because comparison simplification is only done on lt/lte/gt/gte.
113    """
114    if isinstance(expression, exp.Between):
115        return exp.and_(
116            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
117            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
118            copy=False,
119        )
120    return expression
121
122
123def simplify_not(expression):
124    """
125    Demorgan's Law
126    NOT (x OR y) -> NOT x AND NOT y
127    NOT (x AND y) -> NOT x OR NOT y
128    """
129    if isinstance(expression, exp.Not):
130        if is_null(expression.this):
131            return exp.null()
132        if isinstance(expression.this, exp.Paren):
133            condition = expression.this.unnest()
134            if isinstance(condition, exp.And):
135                return exp.or_(
136                    exp.not_(condition.left, copy=False),
137                    exp.not_(condition.right, copy=False),
138                    copy=False,
139                )
140            if isinstance(condition, exp.Or):
141                return exp.and_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if is_null(condition):
147                return exp.null()
148        if always_true(expression.this):
149            return exp.false()
150        if is_false(expression.this):
151            return exp.true()
152        if isinstance(expression.this, exp.Not):
153            # double negation
154            # NOT NOT x -> x
155            return expression.this.this
156    return expression
157
158
159def flatten(expression):
160    """
161    A AND (B AND C) -> A AND B AND C
162    A OR (B OR C) -> A OR B OR C
163    """
164    if isinstance(expression, exp.Connector):
165        for node in expression.args.values():
166            child = node.unnest()
167            if isinstance(child, expression.__class__):
168                node.replace(child)
169    return expression
170
171
172def simplify_connectors(expression, root=True):
173    def _simplify_connectors(expression, left, right):
174        if left == right:
175            return left
176        if isinstance(expression, exp.And):
177            if is_false(left) or is_false(right):
178                return exp.false()
179            if is_null(left) or is_null(right):
180                return exp.null()
181            if always_true(left) and always_true(right):
182                return exp.true()
183            if always_true(left):
184                return right
185            if always_true(right):
186                return left
187            return _simplify_comparison(expression, left, right)
188        elif isinstance(expression, exp.Or):
189            if always_true(left) or always_true(right):
190                return exp.true()
191            if is_false(left) and is_false(right):
192                return exp.false()
193            if (
194                (is_null(left) and is_null(right))
195                or (is_null(left) and is_false(right))
196                or (is_false(left) and is_null(right))
197            ):
198                return exp.null()
199            if is_false(left):
200                return right
201            if is_false(right):
202                return left
203            return _simplify_comparison(expression, left, right, or_=True)
204
205    if isinstance(expression, exp.Connector):
206        return _flat_simplify(expression, _simplify_connectors, root)
207    return expression
208
209
210LT_LTE = (exp.LT, exp.LTE)
211GT_GTE = (exp.GT, exp.GTE)
212
213COMPARISONS = (
214    *LT_LTE,
215    *GT_GTE,
216    exp.EQ,
217    exp.NEQ,
218    exp.Is,
219)
220
221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
222    exp.LT: exp.GT,
223    exp.GT: exp.LT,
224    exp.LTE: exp.GTE,
225    exp.GTE: exp.LTE,
226}
227
228
229def _simplify_comparison(expression, left, right, or_=False):
230    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
231        ll, lr = left.args.values()
232        rl, rr = right.args.values()
233
234        largs = {ll, lr}
235        rargs = {rl, rr}
236
237        matching = largs & rargs
238        columns = {m for m in matching if isinstance(m, exp.Column)}
239
240        if matching and columns:
241            try:
242                l = first(largs - columns)
243                r = first(rargs - columns)
244            except StopIteration:
245                return expression
246
247            # make sure the comparison is always of the form x > 1 instead of 1 < x
248            if left.__class__ in INVERSE_COMPARISONS and l == ll:
249                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
250            if right.__class__ in INVERSE_COMPARISONS and r == rl:
251                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
252
253            if l.is_number and r.is_number:
254                l = float(l.name)
255                r = float(r.name)
256            elif l.is_string and r.is_string:
257                l = l.name
258                r = r.name
259            else:
260                return None
261
262            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
263                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
264                    return left if (av > bv if or_ else av <= bv) else right
265                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
266                    return left if (av < bv if or_ else av >= bv) else right
267
268                # we can't ever shortcut to true because the column could be null
269                if not or_:
270                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
271                        if av <= bv:
272                            return exp.false()
273                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
274                        if av >= bv:
275                            return exp.false()
276                    elif isinstance(a, exp.EQ):
277                        if isinstance(b, exp.LT):
278                            return exp.false() if av >= bv else a
279                        if isinstance(b, exp.LTE):
280                            return exp.false() if av > bv else a
281                        if isinstance(b, exp.GT):
282                            return exp.false() if av <= bv else a
283                        if isinstance(b, exp.GTE):
284                            return exp.false() if av < bv else a
285                        if isinstance(b, exp.NEQ):
286                            return exp.false() if av == bv else a
287    return None
288
289
290def remove_compliments(expression, root=True):
291    """
292    Removing compliments.
293
294    A AND NOT A -> FALSE
295    A OR NOT A -> TRUE
296    """
297    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
298        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
299
300        for a, b in itertools.permutations(expression.flatten(), 2):
301            if is_complement(a, b):
302                return compliment
303    return expression
304
305
306def uniq_sort(expression, generate, root=True):
307    """
308    Uniq and sort a connector.
309
310    C AND A AND B AND B -> A AND B AND C
311    """
312    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
313        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
314        flattened = tuple(expression.flatten())
315        deduped = {generate(e): e for e in flattened}
316        arr = tuple(deduped.items())
317
318        # check if the operands are already sorted, if not sort them
319        # A AND C AND B -> A AND B AND C
320        for i, (sql, e) in enumerate(arr[1:]):
321            if sql < arr[i][0]:
322                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
323                break
324        else:
325            # we didn't have to sort but maybe we need to dedup
326            if len(deduped) < len(flattened):
327                expression = result_func(*deduped.values(), copy=False)
328
329    return expression
330
331
332def absorb_and_eliminate(expression, root=True):
333    """
334    absorption:
335        A AND (A OR B) -> A
336        A OR (A AND B) -> A
337        A AND (NOT A OR B) -> A AND B
338        A OR (NOT A AND B) -> A OR B
339    elimination:
340        (A AND B) OR (A AND NOT B) -> A
341        (A OR B) AND (A OR NOT B) -> A
342    """
343    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
344        kind = exp.Or if isinstance(expression, exp.And) else exp.And
345
346        for a, b in itertools.permutations(expression.flatten(), 2):
347            if isinstance(a, kind):
348                aa, ab = a.unnest_operands()
349
350                # absorb
351                if is_complement(b, aa):
352                    aa.replace(exp.true() if kind == exp.And else exp.false())
353                elif is_complement(b, ab):
354                    ab.replace(exp.true() if kind == exp.And else exp.false())
355                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
356                    a.replace(exp.false() if kind == exp.And else exp.true())
357                elif isinstance(b, kind):
358                    # eliminate
359                    rhs = b.unnest_operands()
360                    ba, bb = rhs
361
362                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
363                        a.replace(aa)
364                        b.replace(aa)
365                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
366                        a.replace(ab)
367                        b.replace(ab)
368
369    return expression
370
371
372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
373    exp.DateAdd: exp.Sub,
374    exp.DateSub: exp.Add,
375    exp.DatetimeAdd: exp.Sub,
376    exp.DatetimeSub: exp.Add,
377}
378
379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
380    **INVERSE_DATE_OPS,
381    exp.Add: exp.Sub,
382    exp.Sub: exp.Add,
383}
384
385
386def _is_number(expression: exp.Expression) -> bool:
387    return expression.is_number
388
389
390def _is_date(expression: exp.Expression) -> bool:
391    return isinstance(expression, exp.Cast) and extract_date(expression) is not None
392
393
394def _is_interval(expression: exp.Expression) -> bool:
395    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
396
397
398@catch(ModuleNotFoundError, UnsupportedUnit)
399def simplify_equality(expression: exp.Expression) -> exp.Expression:
400    """
401    Use the subtraction and addition properties of equality to simplify expressions:
402
403        x + 1 = 3 becomes x = 2
404
405    There are two binary operations in the above expression: + and =
406    Here's how we reference all the operands in the code below:
407
408          l     r
409        x + 1 = 3
410        a   b
411    """
412    if isinstance(expression, COMPARISONS):
413        l, r = expression.left, expression.right
414
415        if l.__class__ in INVERSE_OPS:
416            pass
417        elif r.__class__ in INVERSE_OPS:
418            l, r = r, l
419        else:
420            return expression
421
422        if r.is_number:
423            a_predicate = _is_number
424            b_predicate = _is_number
425        elif _is_date(r):
426            a_predicate = _is_date
427            b_predicate = _is_interval
428        else:
429            return expression
430
431        if l.__class__ in INVERSE_DATE_OPS:
432            a = l.this
433            b = exp.Interval(
434                this=l.expression.copy(),
435                unit=l.unit.copy(),
436            )
437        else:
438            a, b = l.left, l.right
439
440        if not a_predicate(a) and b_predicate(b):
441            pass
442        elif not a_predicate(b) and b_predicate(a):
443            a, b = b, a
444        else:
445            return expression
446
447        return expression.__class__(
448            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
449        )
450    return expression
451
452
453def simplify_literals(expression, root=True):
454    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
455        return _flat_simplify(expression, _simplify_binary, root)
456
457    if isinstance(expression, exp.Neg):
458        this = expression.this
459        if this.is_number:
460            value = this.name
461            if value[0] == "-":
462                return exp.Literal.number(value[1:])
463            return exp.Literal.number(f"-{value}")
464
465    return expression
466
467
468def _simplify_binary(expression, a, b):
469    if isinstance(expression, exp.Is):
470        if isinstance(b, exp.Not):
471            c = b.this
472            not_ = True
473        else:
474            c = b
475            not_ = False
476
477        if is_null(c):
478            if isinstance(a, exp.Literal):
479                return exp.true() if not_ else exp.false()
480            if is_null(a):
481                return exp.false() if not_ else exp.true()
482    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
483        return None
484    elif is_null(a) or is_null(b):
485        return exp.null()
486
487    if a.is_number and b.is_number:
488        a = int(a.name) if a.is_int else Decimal(a.name)
489        b = int(b.name) if b.is_int else Decimal(b.name)
490
491        if isinstance(expression, exp.Add):
492            return exp.Literal.number(a + b)
493        if isinstance(expression, exp.Sub):
494            return exp.Literal.number(a - b)
495        if isinstance(expression, exp.Mul):
496            return exp.Literal.number(a * b)
497        if isinstance(expression, exp.Div):
498            # engines have differing int div behavior so intdiv is not safe
499            if isinstance(a, int) and isinstance(b, int):
500                return None
501            return exp.Literal.number(a / b)
502
503        boolean = eval_boolean(expression, a, b)
504
505        if boolean:
506            return boolean
507    elif a.is_string and b.is_string:
508        boolean = eval_boolean(expression, a.this, b.this)
509
510        if boolean:
511            return boolean
512    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
513        a, b = extract_date(a), extract_interval(b)
514        if a and b:
515            if isinstance(expression, exp.Add):
516                return date_literal(a + b)
517            if isinstance(expression, exp.Sub):
518                return date_literal(a - b)
519    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
520        a, b = extract_interval(a), extract_date(b)
521        # you cannot subtract a date from an interval
522        if a and b and isinstance(expression, exp.Add):
523            return date_literal(a + b)
524
525    return None
526
527
528def simplify_parens(expression):
529    if not isinstance(expression, exp.Paren):
530        return expression
531
532    this = expression.this
533    parent = expression.parent
534
535    if not isinstance(this, exp.Select) and (
536        not isinstance(parent, (exp.Condition, exp.Binary))
537        or isinstance(parent, exp.Paren)
538        or not isinstance(this, exp.Binary)
539        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
540        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
541        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
542        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
543    ):
544        return this
545    return expression
546
547
548CONSTANTS = (
549    exp.Literal,
550    exp.Boolean,
551    exp.Null,
552)
553
554
555def simplify_coalesce(expression):
556    # COALESCE(x) -> x
557    if (
558        isinstance(expression, exp.Coalesce)
559        and not expression.expressions
560        # COALESCE is also used as a Spark partitioning hint
561        and not isinstance(expression.parent, exp.Hint)
562    ):
563        return expression.this
564
565    if not isinstance(expression, COMPARISONS):
566        return expression
567
568    if isinstance(expression.left, exp.Coalesce):
569        coalesce = expression.left
570        other = expression.right
571    elif isinstance(expression.right, exp.Coalesce):
572        coalesce = expression.right
573        other = expression.left
574    else:
575        return expression
576
577    # This transformation is valid for non-constants,
578    # but it really only does anything if they are both constants.
579    if not isinstance(other, CONSTANTS):
580        return expression
581
582    # Find the first constant arg
583    for arg_index, arg in enumerate(coalesce.expressions):
584        if isinstance(arg, CONSTANTS):
585            break
586    else:
587        return expression
588
589    coalesce.set("expressions", coalesce.expressions[:arg_index])
590
591    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
592    # since we already remove COALESCE at the top of this function.
593    coalesce = coalesce if coalesce.expressions else coalesce.this
594
595    # This expression is more complex than when we started, but it will get simplified further
596    return exp.paren(
597        exp.or_(
598            exp.and_(
599                coalesce.is_(exp.null()).not_(copy=False),
600                expression.copy(),
601                copy=False,
602            ),
603            exp.and_(
604                coalesce.is_(exp.null()),
605                type(expression)(this=arg.copy(), expression=other.copy()),
606                copy=False,
607            ),
608            copy=False,
609        )
610    )
611
612
613CONCATS = (exp.Concat, exp.DPipe)
614SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
615
616
617def simplify_concat(expression):
618    """Reduces all groups that contain string literals by concatenating them."""
619    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
620        return expression
621
622    new_args = []
623    for is_string_group, group in itertools.groupby(
624        expression.expressions or expression.flatten(), lambda e: e.is_string
625    ):
626        if is_string_group:
627            new_args.append(exp.Literal.string("".join(string.name for string in group)))
628        else:
629            new_args.extend(group)
630
631    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
632    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
633    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
634
635
636DateRange = t.Tuple[datetime.date, datetime.date]
637
638
639def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
640    """
641    Get the date range for a DATE_TRUNC equality comparison:
642
643    Example:
644        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
645    Returns:
646        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
647    """
648    floor = date_floor(date, unit)
649
650    if date != floor:
651        # This will always be False, except for NULL values.
652        return None
653
654    return floor, floor + interval(unit)
655
656
657def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
658    """Get the logical expression for a date range"""
659    return exp.and_(
660        left >= date_literal(drange[0]),
661        left < date_literal(drange[1]),
662        copy=False,
663    )
664
665
666def _datetrunc_eq(
667    left: exp.Expression, date: datetime.date, unit: str
668) -> t.Optional[exp.Expression]:
669    drange = _datetrunc_range(date, unit)
670    if not drange:
671        return None
672
673    return _datetrunc_eq_expression(left, drange)
674
675
676def _datetrunc_neq(
677    left: exp.Expression, date: datetime.date, unit: str
678) -> t.Optional[exp.Expression]:
679    drange = _datetrunc_range(date, unit)
680    if not drange:
681        return None
682
683    return exp.and_(
684        left < date_literal(drange[0]),
685        left >= date_literal(drange[1]),
686        copy=False,
687    )
688
689
690DateTruncBinaryTransform = t.Callable[
691    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
692]
693DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
694    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
695    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
696    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
697    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
698    exp.EQ: _datetrunc_eq,
699    exp.NEQ: _datetrunc_neq,
700}
701DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
702
703
704def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
705    return (
706        isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
707        and isinstance(right, exp.Cast)
708        and right.is_type(*exp.DataType.TEMPORAL_TYPES)
709    )
710
711
712@catch(ModuleNotFoundError, UnsupportedUnit)
713def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
714    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
715    comparison = expression.__class__
716
717    if comparison not in DATETRUNC_COMPARISONS:
718        return expression
719
720    if isinstance(expression, exp.Binary):
721        l, r = expression.left, expression.right
722
723        if _is_datetrunc_predicate(l, r):
724            pass
725        elif _is_datetrunc_predicate(r, l):
726            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
727            l, r = r, l
728        else:
729            return expression
730
731        unit = l.unit.name.lower()
732        date = extract_date(r)
733
734        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
735    elif isinstance(expression, exp.In):
736        l = expression.this
737        rs = expression.expressions
738
739        if all(_is_datetrunc_predicate(l, r) for r in rs):
740            unit = l.unit.name.lower()
741
742            ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
743            if not ranges:
744                return expression
745
746            ranges = merge_ranges(ranges)
747
748            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
749
750    return expression
751
752
753# CROSS joins result in an empty table if the right table is empty.
754# So we can only simplify certain types of joins to CROSS.
755# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
756JOINS = {
757    ("", ""),
758    ("", "INNER"),
759    ("RIGHT", ""),
760    ("RIGHT", "OUTER"),
761}
762
763
764def remove_where_true(expression):
765    for where in expression.find_all(exp.Where):
766        if always_true(where.this):
767            where.parent.set("where", None)
768    for join in expression.find_all(exp.Join):
769        if (
770            always_true(join.args.get("on"))
771            and not join.args.get("using")
772            and not join.args.get("method")
773            and (join.side, join.kind) in JOINS
774        ):
775            join.set("on", None)
776            join.set("side", None)
777            join.set("kind", "CROSS")
778
779
780def always_true(expression):
781    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
782        expression, exp.Literal
783    )
784
785
786def is_complement(a, b):
787    return isinstance(b, exp.Not) and b.this == a
788
789
790def is_false(a: exp.Expression) -> bool:
791    return type(a) is exp.Boolean and not a.this
792
793
794def is_null(a: exp.Expression) -> bool:
795    return type(a) is exp.Null
796
797
798def eval_boolean(expression, a, b):
799    if isinstance(expression, (exp.EQ, exp.Is)):
800        return boolean_literal(a == b)
801    if isinstance(expression, exp.NEQ):
802        return boolean_literal(a != b)
803    if isinstance(expression, exp.GT):
804        return boolean_literal(a > b)
805    if isinstance(expression, exp.GTE):
806        return boolean_literal(a >= b)
807    if isinstance(expression, exp.LT):
808        return boolean_literal(a < b)
809    if isinstance(expression, exp.LTE):
810        return boolean_literal(a <= b)
811    return None
812
813
814def extract_date(cast):
815    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
816    # so in that case we can't extract the date.
817    try:
818        if cast.args["to"].this == exp.DataType.Type.DATE:
819            return datetime.date.fromisoformat(cast.name)
820        if cast.args["to"].this == exp.DataType.Type.DATETIME:
821            return datetime.datetime.fromisoformat(cast.name)
822    except ValueError:
823        return None
824
825
826def extract_interval(expression):
827    n = int(expression.name)
828    unit = expression.text("unit").lower()
829
830    try:
831        return interval(unit, n)
832    except (UnsupportedUnit, ModuleNotFoundError):
833        return None
834
835
836def date_literal(date):
837    return exp.cast(
838        exp.Literal.string(date),
839        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
840    )
841
842
843def interval(unit: str, n: int = 1):
844    from dateutil.relativedelta import relativedelta
845
846    if unit == "year":
847        return relativedelta(years=1 * n)
848    if unit == "quarter":
849        return relativedelta(months=3 * n)
850    if unit == "month":
851        return relativedelta(months=1 * n)
852    if unit == "week":
853        return relativedelta(weeks=1 * n)
854    if unit == "day":
855        return relativedelta(days=1 * n)
856    if unit == "hour":
857        return relativedelta(hours=1 * n)
858    if unit == "minute":
859        return relativedelta(minutes=1 * n)
860    if unit == "second":
861        return relativedelta(seconds=1 * n)
862
863    raise UnsupportedUnit(f"Unsupported unit: {unit}")
864
865
866def date_floor(d: datetime.date, unit: str) -> datetime.date:
867    if unit == "year":
868        return d.replace(month=1, day=1)
869    if unit == "quarter":
870        if d.month <= 3:
871            return d.replace(month=1, day=1)
872        elif d.month <= 6:
873            return d.replace(month=4, day=1)
874        elif d.month <= 9:
875            return d.replace(month=7, day=1)
876        else:
877            return d.replace(month=10, day=1)
878    if unit == "month":
879        return d.replace(month=d.month, day=1)
880    if unit == "week":
881        # Assuming week starts on Monday (0) and ends on Sunday (6)
882        return d - datetime.timedelta(days=d.weekday())
883    if unit == "day":
884        return d
885
886    raise UnsupportedUnit(f"Unsupported unit: {unit}")
887
888
889def date_ceil(d: datetime.date, unit: str) -> datetime.date:
890    floor = date_floor(d, unit)
891
892    if floor == d:
893        return d
894
895    return floor + interval(unit)
896
897
898def boolean_literal(condition):
899    return exp.true() if condition else exp.false()
900
901
902def _flat_simplify(expression, simplifier, root=True):
903    if root or not expression.same_parent:
904        operands = []
905        queue = deque(expression.flatten(unnest=False))
906        size = len(queue)
907
908        while queue:
909            a = queue.popleft()
910
911            for b in queue:
912                result = simplifier(expression, a, b)
913
914                if result and result is not expression:
915                    queue.remove(b)
916                    queue.appendleft(result)
917                    break
918            else:
919                operands.append(a)
920
921        if len(operands) < size:
922            return functools.reduce(
923                lambda a, b: expression.__class__(this=a, expression=b), operands
924            )
925    return expression
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
17class UnsupportedUnit(Exception):
18    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression):
21def simplify(expression):
22    """
23    Rewrite sqlglot AST to simplify expressions.
24
25    Example:
26        >>> import sqlglot
27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
28        >>> simplify(expression).sql()
29        'TRUE'
30
31    Args:
32        expression (sqlglot.Expression): expression to simplify
33    Returns:
34        sqlglot.Expression: simplified expression
35    """
36
37    generate = cached_generator()
38
39    # group by expressions cannot be simplified, for example
40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
41    # the projection must exactly match the group by key
42    for group in expression.find_all(exp.Group):
43        select = group.parent
44        groups = set(group.expressions)
45        group.meta[FINAL] = True
46
47        for e in select.selects:
48            for node, *_ in e.walk():
49                if node in groups:
50                    e.meta[FINAL] = True
51                    break
52
53        having = select.args.get("having")
54        if having:
55            for node, *_ in having.walk():
56                if node in groups:
57                    having.meta[FINAL] = True
58                    break
59
60    def _simplify(expression, root=True):
61        if expression.meta.get(FINAL):
62            return expression
63
64        # Pre-order transformations
65        node = expression
66        node = rewrite_between(node)
67        node = uniq_sort(node, generate, root)
68        node = absorb_and_eliminate(node, root)
69        node = simplify_concat(node)
70
71        exp.replace_children(node, lambda e: _simplify(e, False))
72
73        # Post-order transformations
74        node = simplify_not(node)
75        node = flatten(node)
76        node = simplify_connectors(node, root)
77        node = remove_compliments(node, root)
78        node = simplify_coalesce(node)
79        node.parent = expression.parent
80        node = simplify_literals(node, root)
81        node = simplify_equality(node)
82        node = simplify_parens(node)
83        node = simplify_datetrunc_predicate(node)
84
85        if root:
86            expression.replace(node)
87
88        return node
89
90    expression = while_changing(expression, _simplify)
91    remove_where_true(expression)
92    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
 95def catch(*exceptions):
 96    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 97
 98    def decorator(func):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
104
105        return wrapped
106
107    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
110def rewrite_between(expression: exp.Expression) -> exp.Expression:
111    """Rewrite x between y and z to x >= y AND x <= z.
112
113    This is done because comparison simplification is only done on lt/lte/gt/gte.
114    """
115    if isinstance(expression, exp.Between):
116        return exp.and_(
117            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
118            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
119            copy=False,
120        )
121    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
124def simplify_not(expression):
125    """
126    Demorgan's Law
127    NOT (x OR y) -> NOT x AND NOT y
128    NOT (x AND y) -> NOT x OR NOT y
129    """
130    if isinstance(expression, exp.Not):
131        if is_null(expression.this):
132            return exp.null()
133        if isinstance(expression.this, exp.Paren):
134            condition = expression.this.unnest()
135            if isinstance(condition, exp.And):
136                return exp.or_(
137                    exp.not_(condition.left, copy=False),
138                    exp.not_(condition.right, copy=False),
139                    copy=False,
140                )
141            if isinstance(condition, exp.Or):
142                return exp.and_(
143                    exp.not_(condition.left, copy=False),
144                    exp.not_(condition.right, copy=False),
145                    copy=False,
146                )
147            if is_null(condition):
148                return exp.null()
149        if always_true(expression.this):
150            return exp.false()
151        if is_false(expression.this):
152            return exp.true()
153        if isinstance(expression.this, exp.Not):
154            # double negation
155            # NOT NOT x -> x
156            return expression.this.this
157    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
160def flatten(expression):
161    """
162    A AND (B AND C) -> A AND B AND C
163    A OR (B OR C) -> A OR B OR C
164    """
165    if isinstance(expression, exp.Connector):
166        for node in expression.args.values():
167            child = node.unnest()
168            if isinstance(child, expression.__class__):
169                node.replace(child)
170    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
173def simplify_connectors(expression, root=True):
174    def _simplify_connectors(expression, left, right):
175        if left == right:
176            return left
177        if isinstance(expression, exp.And):
178            if is_false(left) or is_false(right):
179                return exp.false()
180            if is_null(left) or is_null(right):
181                return exp.null()
182            if always_true(left) and always_true(right):
183                return exp.true()
184            if always_true(left):
185                return right
186            if always_true(right):
187                return left
188            return _simplify_comparison(expression, left, right)
189        elif isinstance(expression, exp.Or):
190            if always_true(left) or always_true(right):
191                return exp.true()
192            if is_false(left) and is_false(right):
193                return exp.false()
194            if (
195                (is_null(left) and is_null(right))
196                or (is_null(left) and is_false(right))
197                or (is_false(left) and is_null(right))
198            ):
199                return exp.null()
200            if is_false(left):
201                return right
202            if is_false(right):
203                return left
204            return _simplify_comparison(expression, left, right, or_=True)
205
206    if isinstance(expression, exp.Connector):
207        return _flat_simplify(expression, _simplify_connectors, root)
208    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_compliments(expression, root=True):
291def remove_compliments(expression, root=True):
292    """
293    Removing compliments.
294
295    A AND NOT A -> FALSE
296    A OR NOT A -> TRUE
297    """
298    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
299        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
300
301        for a, b in itertools.permutations(expression.flatten(), 2):
302            if is_complement(a, b):
303                return compliment
304    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
307def uniq_sort(expression, generate, root=True):
308    """
309    Uniq and sort a connector.
310
311    C AND A AND B AND B -> A AND B AND C
312    """
313    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
314        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
315        flattened = tuple(expression.flatten())
316        deduped = {generate(e): e for e in flattened}
317        arr = tuple(deduped.items())
318
319        # check if the operands are already sorted, if not sort them
320        # A AND C AND B -> A AND B AND C
321        for i, (sql, e) in enumerate(arr[1:]):
322            if sql < arr[i][0]:
323                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
324                break
325        else:
326            # we didn't have to sort but maybe we need to dedup
327            if len(deduped) < len(flattened):
328                expression = result_func(*deduped.values(), copy=False)
329
330    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
333def absorb_and_eliminate(expression, root=True):
334    """
335    absorption:
336        A AND (A OR B) -> A
337        A OR (A AND B) -> A
338        A AND (NOT A OR B) -> A AND B
339        A OR (NOT A AND B) -> A OR B
340    elimination:
341        (A AND B) OR (A AND NOT B) -> A
342        (A OR B) AND (A OR NOT B) -> A
343    """
344    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
345        kind = exp.Or if isinstance(expression, exp.And) else exp.And
346
347        for a, b in itertools.permutations(expression.flatten(), 2):
348            if isinstance(a, kind):
349                aa, ab = a.unnest_operands()
350
351                # absorb
352                if is_complement(b, aa):
353                    aa.replace(exp.true() if kind == exp.And else exp.false())
354                elif is_complement(b, ab):
355                    ab.replace(exp.true() if kind == exp.And else exp.false())
356                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
357                    a.replace(exp.false() if kind == exp.And else exp.true())
358                elif isinstance(b, kind):
359                    # eliminate
360                    rhs = b.unnest_operands()
361                    ba, bb = rhs
362
363                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
364                        a.replace(aa)
365                        b.replace(aa)
366                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
367                        a.replace(ab)
368                        b.replace(ab)
369
370    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_equality(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
454def simplify_literals(expression, root=True):
455    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
456        return _flat_simplify(expression, _simplify_binary, root)
457
458    if isinstance(expression, exp.Neg):
459        this = expression.this
460        if this.is_number:
461            value = this.name
462            if value[0] == "-":
463                return exp.Literal.number(value[1:])
464            return exp.Literal.number(f"-{value}")
465
466    return expression
def simplify_parens(expression):
529def simplify_parens(expression):
530    if not isinstance(expression, exp.Paren):
531        return expression
532
533    this = expression.this
534    parent = expression.parent
535
536    if not isinstance(this, exp.Select) and (
537        not isinstance(parent, (exp.Condition, exp.Binary))
538        or isinstance(parent, exp.Paren)
539        or not isinstance(this, exp.Binary)
540        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
541        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
542        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
543        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
544    ):
545        return this
546    return expression
def simplify_coalesce(expression):
556def simplify_coalesce(expression):
557    # COALESCE(x) -> x
558    if (
559        isinstance(expression, exp.Coalesce)
560        and not expression.expressions
561        # COALESCE is also used as a Spark partitioning hint
562        and not isinstance(expression.parent, exp.Hint)
563    ):
564        return expression.this
565
566    if not isinstance(expression, COMPARISONS):
567        return expression
568
569    if isinstance(expression.left, exp.Coalesce):
570        coalesce = expression.left
571        other = expression.right
572    elif isinstance(expression.right, exp.Coalesce):
573        coalesce = expression.right
574        other = expression.left
575    else:
576        return expression
577
578    # This transformation is valid for non-constants,
579    # but it really only does anything if they are both constants.
580    if not isinstance(other, CONSTANTS):
581        return expression
582
583    # Find the first constant arg
584    for arg_index, arg in enumerate(coalesce.expressions):
585        if isinstance(arg, CONSTANTS):
586            break
587    else:
588        return expression
589
590    coalesce.set("expressions", coalesce.expressions[:arg_index])
591
592    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
593    # since we already remove COALESCE at the top of this function.
594    coalesce = coalesce if coalesce.expressions else coalesce.this
595
596    # This expression is more complex than when we started, but it will get simplified further
597    return exp.paren(
598        exp.or_(
599            exp.and_(
600                coalesce.is_(exp.null()).not_(copy=False),
601                expression.copy(),
602                copy=False,
603            ),
604            exp.and_(
605                coalesce.is_(exp.null()),
606                type(expression)(this=arg.copy(), expression=other.copy()),
607                copy=False,
608            ),
609            copy=False,
610        )
611    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
618def simplify_concat(expression):
619    """Reduces all groups that contain string literals by concatenating them."""
620    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
621        return expression
622
623    new_args = []
624    for is_string_group, group in itertools.groupby(
625        expression.expressions or expression.flatten(), lambda e: e.is_string
626    ):
627        if is_string_group:
628            new_args.append(exp.Literal.string("".join(string.name for string in group)))
629        else:
630            new_args.extend(group)
631
632    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
633    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
634    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('RIGHT', ''), ('', 'INNER'), ('RIGHT', 'OUTER'), ('', '')}
def remove_where_true(expression):
765def remove_where_true(expression):
766    for where in expression.find_all(exp.Where):
767        if always_true(where.this):
768            where.parent.set("where", None)
769    for join in expression.find_all(exp.Join):
770        if (
771            always_true(join.args.get("on"))
772            and not join.args.get("using")
773            and not join.args.get("method")
774            and (join.side, join.kind) in JOINS
775        ):
776            join.set("on", None)
777            join.set("side", None)
778            join.set("kind", "CROSS")
def always_true(expression):
781def always_true(expression):
782    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
783        expression, exp.Literal
784    )
def is_complement(a, b):
787def is_complement(a, b):
788    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
791def is_false(a: exp.Expression) -> bool:
792    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
795def is_null(a: exp.Expression) -> bool:
796    return type(a) is exp.Null
def eval_boolean(expression, a, b):
799def eval_boolean(expression, a, b):
800    if isinstance(expression, (exp.EQ, exp.Is)):
801        return boolean_literal(a == b)
802    if isinstance(expression, exp.NEQ):
803        return boolean_literal(a != b)
804    if isinstance(expression, exp.GT):
805        return boolean_literal(a > b)
806    if isinstance(expression, exp.GTE):
807        return boolean_literal(a >= b)
808    if isinstance(expression, exp.LT):
809        return boolean_literal(a < b)
810    if isinstance(expression, exp.LTE):
811        return boolean_literal(a <= b)
812    return None
def extract_date(cast):
815def extract_date(cast):
816    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
817    # so in that case we can't extract the date.
818    try:
819        if cast.args["to"].this == exp.DataType.Type.DATE:
820            return datetime.date.fromisoformat(cast.name)
821        if cast.args["to"].this == exp.DataType.Type.DATETIME:
822            return datetime.datetime.fromisoformat(cast.name)
823    except ValueError:
824        return None
def extract_interval(expression):
827def extract_interval(expression):
828    n = int(expression.name)
829    unit = expression.text("unit").lower()
830
831    try:
832        return interval(unit, n)
833    except (UnsupportedUnit, ModuleNotFoundError):
834        return None
def date_literal(date):
837def date_literal(date):
838    return exp.cast(
839        exp.Literal.string(date),
840        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
841    )
def interval(unit: str, n: int = 1):
844def interval(unit: str, n: int = 1):
845    from dateutil.relativedelta import relativedelta
846
847    if unit == "year":
848        return relativedelta(years=1 * n)
849    if unit == "quarter":
850        return relativedelta(months=3 * n)
851    if unit == "month":
852        return relativedelta(months=1 * n)
853    if unit == "week":
854        return relativedelta(weeks=1 * n)
855    if unit == "day":
856        return relativedelta(days=1 * n)
857    if unit == "hour":
858        return relativedelta(hours=1 * n)
859    if unit == "minute":
860        return relativedelta(minutes=1 * n)
861    if unit == "second":
862        return relativedelta(seconds=1 * n)
863
864    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
867def date_floor(d: datetime.date, unit: str) -> datetime.date:
868    if unit == "year":
869        return d.replace(month=1, day=1)
870    if unit == "quarter":
871        if d.month <= 3:
872            return d.replace(month=1, day=1)
873        elif d.month <= 6:
874            return d.replace(month=4, day=1)
875        elif d.month <= 9:
876            return d.replace(month=7, day=1)
877        else:
878            return d.replace(month=10, day=1)
879    if unit == "month":
880        return d.replace(month=d.month, day=1)
881    if unit == "week":
882        # Assuming week starts on Monday (0) and ends on Sunday (6)
883        return d - datetime.timedelta(days=d.weekday())
884    if unit == "day":
885        return d
886
887    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
890def date_ceil(d: datetime.date, unit: str) -> datetime.date:
891    floor = date_floor(d, unit)
892
893    if floor == d:
894        return d
895
896    return floor + interval(unit)
def boolean_literal(condition):
899def boolean_literal(condition):
900    return exp.true() if condition else exp.false()