Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.generator import cached_generator
  11from sqlglot.helper import first, merge_ranges, while_changing
  12from sqlglot.optimizer.scope import find_all_in_scope
  13
  14# Final means that an expression should not be simplified
  15FINAL = "final"
  16
  17
  18class UnsupportedUnit(Exception):
  19    pass
  20
  21
  22def simplify(expression, constant_propagation=False):
  23    """
  24    Rewrite sqlglot AST to simplify expressions.
  25
  26    Example:
  27        >>> import sqlglot
  28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  29        >>> simplify(expression).sql()
  30        'TRUE'
  31
  32    Args:
  33        expression (sqlglot.Expression): expression to simplify
  34        constant_propagation: whether or not the constant propagation rule should be used
  35
  36    Returns:
  37        sqlglot.Expression: simplified expression
  38    """
  39
  40    generate = cached_generator()
  41
  42    # group by expressions cannot be simplified, for example
  43    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  44    # the projection must exactly match the group by key
  45    for group in expression.find_all(exp.Group):
  46        select = group.parent
  47        groups = set(group.expressions)
  48        group.meta[FINAL] = True
  49
  50        for e in select.selects:
  51            for node, *_ in e.walk():
  52                if node in groups:
  53                    e.meta[FINAL] = True
  54                    break
  55
  56        having = select.args.get("having")
  57        if having:
  58            for node, *_ in having.walk():
  59                if node in groups:
  60                    having.meta[FINAL] = True
  61                    break
  62
  63    def _simplify(expression, root=True):
  64        if expression.meta.get(FINAL):
  65            return expression
  66
  67        # Pre-order transformations
  68        node = expression
  69        node = rewrite_between(node)
  70        node = uniq_sort(node, generate, root)
  71        node = absorb_and_eliminate(node, root)
  72        node = simplify_concat(node)
  73
  74        if constant_propagation:
  75            node = propagate_constants(node, root)
  76
  77        exp.replace_children(node, lambda e: _simplify(e, False))
  78
  79        # Post-order transformations
  80        node = simplify_not(node)
  81        node = flatten(node)
  82        node = simplify_connectors(node, root)
  83        node = remove_complements(node, root)
  84        node = simplify_coalesce(node)
  85        node.parent = expression.parent
  86        node = simplify_literals(node, root)
  87        node = simplify_equality(node)
  88        node = simplify_parens(node)
  89        node = simplify_datetrunc_predicate(node)
  90
  91        if root:
  92            expression.replace(node)
  93
  94        return node
  95
  96    expression = while_changing(expression, _simplify)
  97    remove_where_true(expression)
  98    return expression
  99
 100
 101def catch(*exceptions):
 102    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 103
 104    def decorator(func):
 105        def wrapped(expression, *args, **kwargs):
 106            try:
 107                return func(expression, *args, **kwargs)
 108            except exceptions:
 109                return expression
 110
 111        return wrapped
 112
 113    return decorator
 114
 115
 116def rewrite_between(expression: exp.Expression) -> exp.Expression:
 117    """Rewrite x between y and z to x >= y AND x <= z.
 118
 119    This is done because comparison simplification is only done on lt/lte/gt/gte.
 120    """
 121    if isinstance(expression, exp.Between):
 122        return exp.and_(
 123            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 124            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 125            copy=False,
 126        )
 127    return expression
 128
 129
 130def simplify_not(expression):
 131    """
 132    Demorgan's Law
 133    NOT (x OR y) -> NOT x AND NOT y
 134    NOT (x AND y) -> NOT x OR NOT y
 135    """
 136    if isinstance(expression, exp.Not):
 137        if is_null(expression.this):
 138            return exp.null()
 139        if isinstance(expression.this, exp.Paren):
 140            condition = expression.this.unnest()
 141            if isinstance(condition, exp.And):
 142                return exp.or_(
 143                    exp.not_(condition.left, copy=False),
 144                    exp.not_(condition.right, copy=False),
 145                    copy=False,
 146                )
 147            if isinstance(condition, exp.Or):
 148                return exp.and_(
 149                    exp.not_(condition.left, copy=False),
 150                    exp.not_(condition.right, copy=False),
 151                    copy=False,
 152                )
 153            if is_null(condition):
 154                return exp.null()
 155        if always_true(expression.this):
 156            return exp.false()
 157        if is_false(expression.this):
 158            return exp.true()
 159        if isinstance(expression.this, exp.Not):
 160            # double negation
 161            # NOT NOT x -> x
 162            return expression.this.this
 163    return expression
 164
 165
 166def flatten(expression):
 167    """
 168    A AND (B AND C) -> A AND B AND C
 169    A OR (B OR C) -> A OR B OR C
 170    """
 171    if isinstance(expression, exp.Connector):
 172        for node in expression.args.values():
 173            child = node.unnest()
 174            if isinstance(child, expression.__class__):
 175                node.replace(child)
 176    return expression
 177
 178
 179def simplify_connectors(expression, root=True):
 180    def _simplify_connectors(expression, left, right):
 181        if left == right:
 182            return left
 183        if isinstance(expression, exp.And):
 184            if is_false(left) or is_false(right):
 185                return exp.false()
 186            if is_null(left) or is_null(right):
 187                return exp.null()
 188            if always_true(left) and always_true(right):
 189                return exp.true()
 190            if always_true(left):
 191                return right
 192            if always_true(right):
 193                return left
 194            return _simplify_comparison(expression, left, right)
 195        elif isinstance(expression, exp.Or):
 196            if always_true(left) or always_true(right):
 197                return exp.true()
 198            if is_false(left) and is_false(right):
 199                return exp.false()
 200            if (
 201                (is_null(left) and is_null(right))
 202                or (is_null(left) and is_false(right))
 203                or (is_false(left) and is_null(right))
 204            ):
 205                return exp.null()
 206            if is_false(left):
 207                return right
 208            if is_false(right):
 209                return left
 210            return _simplify_comparison(expression, left, right, or_=True)
 211
 212    if isinstance(expression, exp.Connector):
 213        return _flat_simplify(expression, _simplify_connectors, root)
 214    return expression
 215
 216
 217LT_LTE = (exp.LT, exp.LTE)
 218GT_GTE = (exp.GT, exp.GTE)
 219
 220COMPARISONS = (
 221    *LT_LTE,
 222    *GT_GTE,
 223    exp.EQ,
 224    exp.NEQ,
 225    exp.Is,
 226)
 227
 228INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 229    exp.LT: exp.GT,
 230    exp.GT: exp.LT,
 231    exp.LTE: exp.GTE,
 232    exp.GTE: exp.LTE,
 233}
 234
 235
 236def _simplify_comparison(expression, left, right, or_=False):
 237    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 238        ll, lr = left.args.values()
 239        rl, rr = right.args.values()
 240
 241        largs = {ll, lr}
 242        rargs = {rl, rr}
 243
 244        matching = largs & rargs
 245        columns = {m for m in matching if isinstance(m, exp.Column)}
 246
 247        if matching and columns:
 248            try:
 249                l = first(largs - columns)
 250                r = first(rargs - columns)
 251            except StopIteration:
 252                return expression
 253
 254            # make sure the comparison is always of the form x > 1 instead of 1 < x
 255            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 256                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 257            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 258                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 259
 260            if l.is_number and r.is_number:
 261                l = float(l.name)
 262                r = float(r.name)
 263            elif l.is_string and r.is_string:
 264                l = l.name
 265                r = r.name
 266            else:
 267                return None
 268
 269            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 270                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 271                    return left if (av > bv if or_ else av <= bv) else right
 272                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 273                    return left if (av < bv if or_ else av >= bv) else right
 274
 275                # we can't ever shortcut to true because the column could be null
 276                if not or_:
 277                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 278                        if av <= bv:
 279                            return exp.false()
 280                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 281                        if av >= bv:
 282                            return exp.false()
 283                    elif isinstance(a, exp.EQ):
 284                        if isinstance(b, exp.LT):
 285                            return exp.false() if av >= bv else a
 286                        if isinstance(b, exp.LTE):
 287                            return exp.false() if av > bv else a
 288                        if isinstance(b, exp.GT):
 289                            return exp.false() if av <= bv else a
 290                        if isinstance(b, exp.GTE):
 291                            return exp.false() if av < bv else a
 292                        if isinstance(b, exp.NEQ):
 293                            return exp.false() if av == bv else a
 294    return None
 295
 296
 297def remove_complements(expression, root=True):
 298    """
 299    Removing complements.
 300
 301    A AND NOT A -> FALSE
 302    A OR NOT A -> TRUE
 303    """
 304    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 305        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 306
 307        for a, b in itertools.permutations(expression.flatten(), 2):
 308            if is_complement(a, b):
 309                return complement
 310    return expression
 311
 312
 313def uniq_sort(expression, generate, root=True):
 314    """
 315    Uniq and sort a connector.
 316
 317    C AND A AND B AND B -> A AND B AND C
 318    """
 319    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 320        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 321        flattened = tuple(expression.flatten())
 322        deduped = {generate(e): e for e in flattened}
 323        arr = tuple(deduped.items())
 324
 325        # check if the operands are already sorted, if not sort them
 326        # A AND C AND B -> A AND B AND C
 327        for i, (sql, e) in enumerate(arr[1:]):
 328            if sql < arr[i][0]:
 329                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 330                break
 331        else:
 332            # we didn't have to sort but maybe we need to dedup
 333            if len(deduped) < len(flattened):
 334                expression = result_func(*deduped.values(), copy=False)
 335
 336    return expression
 337
 338
 339def absorb_and_eliminate(expression, root=True):
 340    """
 341    absorption:
 342        A AND (A OR B) -> A
 343        A OR (A AND B) -> A
 344        A AND (NOT A OR B) -> A AND B
 345        A OR (NOT A AND B) -> A OR B
 346    elimination:
 347        (A AND B) OR (A AND NOT B) -> A
 348        (A OR B) AND (A OR NOT B) -> A
 349    """
 350    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 351        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 352
 353        for a, b in itertools.permutations(expression.flatten(), 2):
 354            if isinstance(a, kind):
 355                aa, ab = a.unnest_operands()
 356
 357                # absorb
 358                if is_complement(b, aa):
 359                    aa.replace(exp.true() if kind == exp.And else exp.false())
 360                elif is_complement(b, ab):
 361                    ab.replace(exp.true() if kind == exp.And else exp.false())
 362                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 363                    a.replace(exp.false() if kind == exp.And else exp.true())
 364                elif isinstance(b, kind):
 365                    # eliminate
 366                    rhs = b.unnest_operands()
 367                    ba, bb = rhs
 368
 369                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 370                        a.replace(aa)
 371                        b.replace(aa)
 372                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 373                        a.replace(ab)
 374                        b.replace(ab)
 375
 376    return expression
 377
 378
 379def propagate_constants(expression, root=True):
 380    """
 381    Propagate constants for conjunctions in DNF:
 382
 383    SELECT * FROM t WHERE a = b AND b = 5 becomes
 384    SELECT * FROM t WHERE a = 5 AND b = 5
 385
 386    Reference: https://www.sqlite.org/optoverview.html
 387    """
 388
 389    if (
 390        isinstance(expression, exp.And)
 391        and (root or not expression.same_parent)
 392        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 393    ):
 394        constant_mapping = {}
 395        for eq in find_all_in_scope(expression, exp.EQ):
 396            l, r = eq.left, eq.right
 397
 398            # TODO: create a helper that can be used to detect nested literal expressions such
 399            # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 400            if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 401                pass
 402            elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 403                l, r = r, l
 404            else:
 405                continue
 406
 407            constant_mapping[l] = (id(l), r)
 408
 409        if constant_mapping:
 410            for column in find_all_in_scope(expression, exp.Column):
 411                parent = column.parent
 412                column_id, constant = constant_mapping.get(column) or (None, None)
 413                if (
 414                    column_id is not None
 415                    and id(column) != column_id
 416                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 417                ):
 418                    column.replace(constant.copy())
 419
 420    return expression
 421
 422
 423INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 424    exp.DateAdd: exp.Sub,
 425    exp.DateSub: exp.Add,
 426    exp.DatetimeAdd: exp.Sub,
 427    exp.DatetimeSub: exp.Add,
 428}
 429
 430INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 431    **INVERSE_DATE_OPS,
 432    exp.Add: exp.Sub,
 433    exp.Sub: exp.Add,
 434}
 435
 436
 437def _is_number(expression: exp.Expression) -> bool:
 438    return expression.is_number
 439
 440
 441def _is_interval(expression: exp.Expression) -> bool:
 442    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 443
 444
 445@catch(ModuleNotFoundError, UnsupportedUnit)
 446def simplify_equality(expression: exp.Expression) -> exp.Expression:
 447    """
 448    Use the subtraction and addition properties of equality to simplify expressions:
 449
 450        x + 1 = 3 becomes x = 2
 451
 452    There are two binary operations in the above expression: + and =
 453    Here's how we reference all the operands in the code below:
 454
 455          l     r
 456        x + 1 = 3
 457        a   b
 458    """
 459    if isinstance(expression, COMPARISONS):
 460        l, r = expression.left, expression.right
 461
 462        if l.__class__ in INVERSE_OPS:
 463            pass
 464        elif r.__class__ in INVERSE_OPS:
 465            l, r = r, l
 466        else:
 467            return expression
 468
 469        if r.is_number:
 470            a_predicate = _is_number
 471            b_predicate = _is_number
 472        elif _is_date_literal(r):
 473            a_predicate = _is_date_literal
 474            b_predicate = _is_interval
 475        else:
 476            return expression
 477
 478        if l.__class__ in INVERSE_DATE_OPS:
 479            a = l.this
 480            b = l.interval()
 481        else:
 482            a, b = l.left, l.right
 483
 484        if not a_predicate(a) and b_predicate(b):
 485            pass
 486        elif not a_predicate(b) and b_predicate(a):
 487            a, b = b, a
 488        else:
 489            return expression
 490
 491        return expression.__class__(
 492            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 493        )
 494    return expression
 495
 496
 497def simplify_literals(expression, root=True):
 498    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 499        return _flat_simplify(expression, _simplify_binary, root)
 500
 501    if isinstance(expression, exp.Neg):
 502        this = expression.this
 503        if this.is_number:
 504            value = this.name
 505            if value[0] == "-":
 506                return exp.Literal.number(value[1:])
 507            return exp.Literal.number(f"-{value}")
 508
 509    return expression
 510
 511
 512def _simplify_binary(expression, a, b):
 513    if isinstance(expression, exp.Is):
 514        if isinstance(b, exp.Not):
 515            c = b.this
 516            not_ = True
 517        else:
 518            c = b
 519            not_ = False
 520
 521        if is_null(c):
 522            if isinstance(a, exp.Literal):
 523                return exp.true() if not_ else exp.false()
 524            if is_null(a):
 525                return exp.false() if not_ else exp.true()
 526    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 527        return None
 528    elif is_null(a) or is_null(b):
 529        return exp.null()
 530
 531    if a.is_number and b.is_number:
 532        a = int(a.name) if a.is_int else Decimal(a.name)
 533        b = int(b.name) if b.is_int else Decimal(b.name)
 534
 535        if isinstance(expression, exp.Add):
 536            return exp.Literal.number(a + b)
 537        if isinstance(expression, exp.Sub):
 538            return exp.Literal.number(a - b)
 539        if isinstance(expression, exp.Mul):
 540            return exp.Literal.number(a * b)
 541        if isinstance(expression, exp.Div):
 542            # engines have differing int div behavior so intdiv is not safe
 543            if isinstance(a, int) and isinstance(b, int):
 544                return None
 545            return exp.Literal.number(a / b)
 546
 547        boolean = eval_boolean(expression, a, b)
 548
 549        if boolean:
 550            return boolean
 551    elif a.is_string and b.is_string:
 552        boolean = eval_boolean(expression, a.this, b.this)
 553
 554        if boolean:
 555            return boolean
 556    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 557        a, b = extract_date(a), extract_interval(b)
 558        if a and b:
 559            if isinstance(expression, exp.Add):
 560                return date_literal(a + b)
 561            if isinstance(expression, exp.Sub):
 562                return date_literal(a - b)
 563    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 564        a, b = extract_interval(a), extract_date(b)
 565        # you cannot subtract a date from an interval
 566        if a and b and isinstance(expression, exp.Add):
 567            return date_literal(a + b)
 568
 569    return None
 570
 571
 572def simplify_parens(expression):
 573    if not isinstance(expression, exp.Paren):
 574        return expression
 575
 576    this = expression.this
 577    parent = expression.parent
 578
 579    if not isinstance(this, exp.Select) and (
 580        not isinstance(parent, (exp.Condition, exp.Binary))
 581        or isinstance(parent, exp.Paren)
 582        or not isinstance(this, exp.Binary)
 583        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 584        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 585        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 586        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 587    ):
 588        return this
 589    return expression
 590
 591
 592CONSTANTS = (
 593    exp.Literal,
 594    exp.Boolean,
 595    exp.Null,
 596)
 597
 598
 599def simplify_coalesce(expression):
 600    # COALESCE(x) -> x
 601    if (
 602        isinstance(expression, exp.Coalesce)
 603        and not expression.expressions
 604        # COALESCE is also used as a Spark partitioning hint
 605        and not isinstance(expression.parent, exp.Hint)
 606    ):
 607        return expression.this
 608
 609    if not isinstance(expression, COMPARISONS):
 610        return expression
 611
 612    if isinstance(expression.left, exp.Coalesce):
 613        coalesce = expression.left
 614        other = expression.right
 615    elif isinstance(expression.right, exp.Coalesce):
 616        coalesce = expression.right
 617        other = expression.left
 618    else:
 619        return expression
 620
 621    # This transformation is valid for non-constants,
 622    # but it really only does anything if they are both constants.
 623    if not isinstance(other, CONSTANTS):
 624        return expression
 625
 626    # Find the first constant arg
 627    for arg_index, arg in enumerate(coalesce.expressions):
 628        if isinstance(arg, CONSTANTS):
 629            break
 630    else:
 631        return expression
 632
 633    coalesce.set("expressions", coalesce.expressions[:arg_index])
 634
 635    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 636    # since we already remove COALESCE at the top of this function.
 637    coalesce = coalesce if coalesce.expressions else coalesce.this
 638
 639    # This expression is more complex than when we started, but it will get simplified further
 640    return exp.paren(
 641        exp.or_(
 642            exp.and_(
 643                coalesce.is_(exp.null()).not_(copy=False),
 644                expression.copy(),
 645                copy=False,
 646            ),
 647            exp.and_(
 648                coalesce.is_(exp.null()),
 649                type(expression)(this=arg.copy(), expression=other.copy()),
 650                copy=False,
 651            ),
 652            copy=False,
 653        )
 654    )
 655
 656
 657CONCATS = (exp.Concat, exp.DPipe)
 658SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
 659
 660
 661def simplify_concat(expression):
 662    """Reduces all groups that contain string literals by concatenating them."""
 663    if not isinstance(expression, CONCATS) or (
 664        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 665        isinstance(expression, exp.ConcatWs)
 666        and not expression.expressions[0].is_string
 667    ):
 668        return expression
 669
 670    if isinstance(expression, exp.ConcatWs):
 671        sep_expr, *expressions = expression.expressions
 672        sep = sep_expr.name
 673        concat_type = exp.ConcatWs
 674    else:
 675        expressions = expression.expressions
 676        sep = ""
 677        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
 678
 679    new_args = []
 680    for is_string_group, group in itertools.groupby(
 681        expressions or expression.flatten(), lambda e: e.is_string
 682    ):
 683        if is_string_group:
 684            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 685        else:
 686            new_args.extend(group)
 687
 688    if len(new_args) == 1 and new_args[0].is_string:
 689        return new_args[0]
 690
 691    if concat_type is exp.ConcatWs:
 692        new_args = [sep_expr] + new_args
 693
 694    return concat_type(expressions=new_args)
 695
 696
 697DateRange = t.Tuple[datetime.date, datetime.date]
 698
 699
 700def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 701    """
 702    Get the date range for a DATE_TRUNC equality comparison:
 703
 704    Example:
 705        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 706    Returns:
 707        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 708    """
 709    floor = date_floor(date, unit)
 710
 711    if date != floor:
 712        # This will always be False, except for NULL values.
 713        return None
 714
 715    return floor, floor + interval(unit)
 716
 717
 718def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 719    """Get the logical expression for a date range"""
 720    return exp.and_(
 721        left >= date_literal(drange[0]),
 722        left < date_literal(drange[1]),
 723        copy=False,
 724    )
 725
 726
 727def _datetrunc_eq(
 728    left: exp.Expression, date: datetime.date, unit: str
 729) -> t.Optional[exp.Expression]:
 730    drange = _datetrunc_range(date, unit)
 731    if not drange:
 732        return None
 733
 734    return _datetrunc_eq_expression(left, drange)
 735
 736
 737def _datetrunc_neq(
 738    left: exp.Expression, date: datetime.date, unit: str
 739) -> t.Optional[exp.Expression]:
 740    drange = _datetrunc_range(date, unit)
 741    if not drange:
 742        return None
 743
 744    return exp.and_(
 745        left < date_literal(drange[0]),
 746        left >= date_literal(drange[1]),
 747        copy=False,
 748    )
 749
 750
 751DateTruncBinaryTransform = t.Callable[
 752    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 753]
 754DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 755    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 756    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 757    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 758    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 759    exp.EQ: _datetrunc_eq,
 760    exp.NEQ: _datetrunc_neq,
 761}
 762DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 763
 764
 765def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 766    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 767
 768
 769@catch(ModuleNotFoundError, UnsupportedUnit)
 770def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 771    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 772    comparison = expression.__class__
 773
 774    if comparison not in DATETRUNC_COMPARISONS:
 775        return expression
 776
 777    if isinstance(expression, exp.Binary):
 778        l, r = expression.left, expression.right
 779
 780        if _is_datetrunc_predicate(l, r):
 781            pass
 782        elif _is_datetrunc_predicate(r, l):
 783            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 784            l, r = r, l
 785        else:
 786            return expression
 787
 788        unit = l.unit.name.lower()
 789        date = extract_date(r)
 790
 791        if not date:
 792            return expression
 793
 794        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 795    elif isinstance(expression, exp.In):
 796        l = expression.this
 797        rs = expression.expressions
 798
 799        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 800            unit = l.unit.name.lower()
 801
 802            ranges = []
 803            for r in rs:
 804                date = extract_date(r)
 805                if not date:
 806                    return expression
 807                drange = _datetrunc_range(date, unit)
 808                if drange:
 809                    ranges.append(drange)
 810
 811            if not ranges:
 812                return expression
 813
 814            ranges = merge_ranges(ranges)
 815
 816            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 817
 818    return expression
 819
 820
 821# CROSS joins result in an empty table if the right table is empty.
 822# So we can only simplify certain types of joins to CROSS.
 823# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 824JOINS = {
 825    ("", ""),
 826    ("", "INNER"),
 827    ("RIGHT", ""),
 828    ("RIGHT", "OUTER"),
 829}
 830
 831
 832def remove_where_true(expression):
 833    for where in expression.find_all(exp.Where):
 834        if always_true(where.this):
 835            where.parent.set("where", None)
 836    for join in expression.find_all(exp.Join):
 837        if (
 838            always_true(join.args.get("on"))
 839            and not join.args.get("using")
 840            and not join.args.get("method")
 841            and (join.side, join.kind) in JOINS
 842        ):
 843            join.set("on", None)
 844            join.set("side", None)
 845            join.set("kind", "CROSS")
 846
 847
 848def always_true(expression):
 849    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 850        expression, exp.Literal
 851    )
 852
 853
 854def is_complement(a, b):
 855    return isinstance(b, exp.Not) and b.this == a
 856
 857
 858def is_false(a: exp.Expression) -> bool:
 859    return type(a) is exp.Boolean and not a.this
 860
 861
 862def is_null(a: exp.Expression) -> bool:
 863    return type(a) is exp.Null
 864
 865
 866def eval_boolean(expression, a, b):
 867    if isinstance(expression, (exp.EQ, exp.Is)):
 868        return boolean_literal(a == b)
 869    if isinstance(expression, exp.NEQ):
 870        return boolean_literal(a != b)
 871    if isinstance(expression, exp.GT):
 872        return boolean_literal(a > b)
 873    if isinstance(expression, exp.GTE):
 874        return boolean_literal(a >= b)
 875    if isinstance(expression, exp.LT):
 876        return boolean_literal(a < b)
 877    if isinstance(expression, exp.LTE):
 878        return boolean_literal(a <= b)
 879    return None
 880
 881
 882def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 883    if isinstance(value, datetime.datetime):
 884        return value.date()
 885    if isinstance(value, datetime.date):
 886        return value
 887    try:
 888        return datetime.datetime.fromisoformat(value).date()
 889    except ValueError:
 890        return None
 891
 892
 893def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 894    if isinstance(value, datetime.datetime):
 895        return value
 896    if isinstance(value, datetime.date):
 897        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 898    try:
 899        return datetime.datetime.fromisoformat(value)
 900    except ValueError:
 901        return None
 902
 903
 904def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 905    if not value:
 906        return None
 907    if to.is_type(exp.DataType.Type.DATE):
 908        return cast_as_date(value)
 909    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 910        return cast_as_datetime(value)
 911    return None
 912
 913
 914def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 915    if isinstance(cast, exp.Cast):
 916        to = cast.to
 917    elif isinstance(cast, exp.TsOrDsToDate):
 918        to = exp.DataType.build(exp.DataType.Type.DATE)
 919    else:
 920        return None
 921
 922    if isinstance(cast.this, exp.Literal):
 923        value: t.Any = cast.this.name
 924    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 925        value = extract_date(cast.this)
 926    else:
 927        return None
 928    return cast_value(value, to)
 929
 930
 931def _is_date_literal(expression: exp.Expression) -> bool:
 932    return extract_date(expression) is not None
 933
 934
 935def extract_interval(expression):
 936    n = int(expression.name)
 937    unit = expression.text("unit").lower()
 938
 939    try:
 940        return interval(unit, n)
 941    except (UnsupportedUnit, ModuleNotFoundError):
 942        return None
 943
 944
 945def date_literal(date):
 946    return exp.cast(
 947        exp.Literal.string(date),
 948        exp.DataType.Type.DATETIME
 949        if isinstance(date, datetime.datetime)
 950        else exp.DataType.Type.DATE,
 951    )
 952
 953
 954def interval(unit: str, n: int = 1):
 955    from dateutil.relativedelta import relativedelta
 956
 957    if unit == "year":
 958        return relativedelta(years=1 * n)
 959    if unit == "quarter":
 960        return relativedelta(months=3 * n)
 961    if unit == "month":
 962        return relativedelta(months=1 * n)
 963    if unit == "week":
 964        return relativedelta(weeks=1 * n)
 965    if unit == "day":
 966        return relativedelta(days=1 * n)
 967    if unit == "hour":
 968        return relativedelta(hours=1 * n)
 969    if unit == "minute":
 970        return relativedelta(minutes=1 * n)
 971    if unit == "second":
 972        return relativedelta(seconds=1 * n)
 973
 974    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 975
 976
 977def date_floor(d: datetime.date, unit: str) -> datetime.date:
 978    if unit == "year":
 979        return d.replace(month=1, day=1)
 980    if unit == "quarter":
 981        if d.month <= 3:
 982            return d.replace(month=1, day=1)
 983        elif d.month <= 6:
 984            return d.replace(month=4, day=1)
 985        elif d.month <= 9:
 986            return d.replace(month=7, day=1)
 987        else:
 988            return d.replace(month=10, day=1)
 989    if unit == "month":
 990        return d.replace(month=d.month, day=1)
 991    if unit == "week":
 992        # Assuming week starts on Monday (0) and ends on Sunday (6)
 993        return d - datetime.timedelta(days=d.weekday())
 994    if unit == "day":
 995        return d
 996
 997    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 998
 999
1000def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1001    floor = date_floor(d, unit)
1002
1003    if floor == d:
1004        return d
1005
1006    return floor + interval(unit)
1007
1008
1009def boolean_literal(condition):
1010    return exp.true() if condition else exp.false()
1011
1012
1013def _flat_simplify(expression, simplifier, root=True):
1014    if root or not expression.same_parent:
1015        operands = []
1016        queue = deque(expression.flatten(unnest=False))
1017        size = len(queue)
1018
1019        while queue:
1020            a = queue.popleft()
1021
1022            for b in queue:
1023                result = simplifier(expression, a, b)
1024
1025                if result and result is not expression:
1026                    queue.remove(b)
1027                    queue.appendleft(result)
1028                    break
1029            else:
1030                operands.append(a)
1031
1032        if len(operands) < size:
1033            return functools.reduce(
1034                lambda a, b: expression.__class__(this=a, expression=b), operands
1035            )
1036    return expression
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
19class UnsupportedUnit(Exception):
20    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
23def simplify(expression, constant_propagation=False):
24    """
25    Rewrite sqlglot AST to simplify expressions.
26
27    Example:
28        >>> import sqlglot
29        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
30        >>> simplify(expression).sql()
31        'TRUE'
32
33    Args:
34        expression (sqlglot.Expression): expression to simplify
35        constant_propagation: whether or not the constant propagation rule should be used
36
37    Returns:
38        sqlglot.Expression: simplified expression
39    """
40
41    generate = cached_generator()
42
43    # group by expressions cannot be simplified, for example
44    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
45    # the projection must exactly match the group by key
46    for group in expression.find_all(exp.Group):
47        select = group.parent
48        groups = set(group.expressions)
49        group.meta[FINAL] = True
50
51        for e in select.selects:
52            for node, *_ in e.walk():
53                if node in groups:
54                    e.meta[FINAL] = True
55                    break
56
57        having = select.args.get("having")
58        if having:
59            for node, *_ in having.walk():
60                if node in groups:
61                    having.meta[FINAL] = True
62                    break
63
64    def _simplify(expression, root=True):
65        if expression.meta.get(FINAL):
66            return expression
67
68        # Pre-order transformations
69        node = expression
70        node = rewrite_between(node)
71        node = uniq_sort(node, generate, root)
72        node = absorb_and_eliminate(node, root)
73        node = simplify_concat(node)
74
75        if constant_propagation:
76            node = propagate_constants(node, root)
77
78        exp.replace_children(node, lambda e: _simplify(e, False))
79
80        # Post-order transformations
81        node = simplify_not(node)
82        node = flatten(node)
83        node = simplify_connectors(node, root)
84        node = remove_complements(node, root)
85        node = simplify_coalesce(node)
86        node.parent = expression.parent
87        node = simplify_literals(node, root)
88        node = simplify_equality(node)
89        node = simplify_parens(node)
90        node = simplify_datetrunc_predicate(node)
91
92        if root:
93            expression.replace(node)
94
95        return node
96
97    expression = while_changing(expression, _simplify)
98    remove_where_true(expression)
99    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
102def catch(*exceptions):
103    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
104
105    def decorator(func):
106        def wrapped(expression, *args, **kwargs):
107            try:
108                return func(expression, *args, **kwargs)
109            except exceptions:
110                return expression
111
112        return wrapped
113
114    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
117def rewrite_between(expression: exp.Expression) -> exp.Expression:
118    """Rewrite x between y and z to x >= y AND x <= z.
119
120    This is done because comparison simplification is only done on lt/lte/gt/gte.
121    """
122    if isinstance(expression, exp.Between):
123        return exp.and_(
124            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
125            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
126            copy=False,
127        )
128    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
131def simplify_not(expression):
132    """
133    Demorgan's Law
134    NOT (x OR y) -> NOT x AND NOT y
135    NOT (x AND y) -> NOT x OR NOT y
136    """
137    if isinstance(expression, exp.Not):
138        if is_null(expression.this):
139            return exp.null()
140        if isinstance(expression.this, exp.Paren):
141            condition = expression.this.unnest()
142            if isinstance(condition, exp.And):
143                return exp.or_(
144                    exp.not_(condition.left, copy=False),
145                    exp.not_(condition.right, copy=False),
146                    copy=False,
147                )
148            if isinstance(condition, exp.Or):
149                return exp.and_(
150                    exp.not_(condition.left, copy=False),
151                    exp.not_(condition.right, copy=False),
152                    copy=False,
153                )
154            if is_null(condition):
155                return exp.null()
156        if always_true(expression.this):
157            return exp.false()
158        if is_false(expression.this):
159            return exp.true()
160        if isinstance(expression.this, exp.Not):
161            # double negation
162            # NOT NOT x -> x
163            return expression.this.this
164    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
167def flatten(expression):
168    """
169    A AND (B AND C) -> A AND B AND C
170    A OR (B OR C) -> A OR B OR C
171    """
172    if isinstance(expression, exp.Connector):
173        for node in expression.args.values():
174            child = node.unnest()
175            if isinstance(child, expression.__class__):
176                node.replace(child)
177    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
180def simplify_connectors(expression, root=True):
181    def _simplify_connectors(expression, left, right):
182        if left == right:
183            return left
184        if isinstance(expression, exp.And):
185            if is_false(left) or is_false(right):
186                return exp.false()
187            if is_null(left) or is_null(right):
188                return exp.null()
189            if always_true(left) and always_true(right):
190                return exp.true()
191            if always_true(left):
192                return right
193            if always_true(right):
194                return left
195            return _simplify_comparison(expression, left, right)
196        elif isinstance(expression, exp.Or):
197            if always_true(left) or always_true(right):
198                return exp.true()
199            if is_false(left) and is_false(right):
200                return exp.false()
201            if (
202                (is_null(left) and is_null(right))
203                or (is_null(left) and is_false(right))
204                or (is_false(left) and is_null(right))
205            ):
206                return exp.null()
207            if is_false(left):
208                return right
209            if is_false(right):
210                return left
211            return _simplify_comparison(expression, left, right, or_=True)
212
213    if isinstance(expression, exp.Connector):
214        return _flat_simplify(expression, _simplify_connectors, root)
215    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
298def remove_complements(expression, root=True):
299    """
300    Removing complements.
301
302    A AND NOT A -> FALSE
303    A OR NOT A -> TRUE
304    """
305    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
306        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
307
308        for a, b in itertools.permutations(expression.flatten(), 2):
309            if is_complement(a, b):
310                return complement
311    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
314def uniq_sort(expression, generate, root=True):
315    """
316    Uniq and sort a connector.
317
318    C AND A AND B AND B -> A AND B AND C
319    """
320    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
321        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
322        flattened = tuple(expression.flatten())
323        deduped = {generate(e): e for e in flattened}
324        arr = tuple(deduped.items())
325
326        # check if the operands are already sorted, if not sort them
327        # A AND C AND B -> A AND B AND C
328        for i, (sql, e) in enumerate(arr[1:]):
329            if sql < arr[i][0]:
330                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
331                break
332        else:
333            # we didn't have to sort but maybe we need to dedup
334            if len(deduped) < len(flattened):
335                expression = result_func(*deduped.values(), copy=False)
336
337    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
340def absorb_and_eliminate(expression, root=True):
341    """
342    absorption:
343        A AND (A OR B) -> A
344        A OR (A AND B) -> A
345        A AND (NOT A OR B) -> A AND B
346        A OR (NOT A AND B) -> A OR B
347    elimination:
348        (A AND B) OR (A AND NOT B) -> A
349        (A OR B) AND (A OR NOT B) -> A
350    """
351    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
352        kind = exp.Or if isinstance(expression, exp.And) else exp.And
353
354        for a, b in itertools.permutations(expression.flatten(), 2):
355            if isinstance(a, kind):
356                aa, ab = a.unnest_operands()
357
358                # absorb
359                if is_complement(b, aa):
360                    aa.replace(exp.true() if kind == exp.And else exp.false())
361                elif is_complement(b, ab):
362                    ab.replace(exp.true() if kind == exp.And else exp.false())
363                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
364                    a.replace(exp.false() if kind == exp.And else exp.true())
365                elif isinstance(b, kind):
366                    # eliminate
367                    rhs = b.unnest_operands()
368                    ba, bb = rhs
369
370                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
371                        a.replace(aa)
372                        b.replace(aa)
373                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
374                        a.replace(ab)
375                        b.replace(ab)
376
377    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
380def propagate_constants(expression, root=True):
381    """
382    Propagate constants for conjunctions in DNF:
383
384    SELECT * FROM t WHERE a = b AND b = 5 becomes
385    SELECT * FROM t WHERE a = 5 AND b = 5
386
387    Reference: https://www.sqlite.org/optoverview.html
388    """
389
390    if (
391        isinstance(expression, exp.And)
392        and (root or not expression.same_parent)
393        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
394    ):
395        constant_mapping = {}
396        for eq in find_all_in_scope(expression, exp.EQ):
397            l, r = eq.left, eq.right
398
399            # TODO: create a helper that can be used to detect nested literal expressions such
400            # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
401            if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
402                pass
403            elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
404                l, r = r, l
405            else:
406                continue
407
408            constant_mapping[l] = (id(l), r)
409
410        if constant_mapping:
411            for column in find_all_in_scope(expression, exp.Column):
412                parent = column.parent
413                column_id, constant = constant_mapping.get(column) or (None, None)
414                if (
415                    column_id is not None
416                    and id(column) != column_id
417                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
418                ):
419                    column.replace(constant.copy())
420
421    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
106        def wrapped(expression, *args, **kwargs):
107            try:
108                return func(expression, *args, **kwargs)
109            except exceptions:
110                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
498def simplify_literals(expression, root=True):
499    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
500        return _flat_simplify(expression, _simplify_binary, root)
501
502    if isinstance(expression, exp.Neg):
503        this = expression.this
504        if this.is_number:
505            value = this.name
506            if value[0] == "-":
507                return exp.Literal.number(value[1:])
508            return exp.Literal.number(f"-{value}")
509
510    return expression
def simplify_parens(expression):
573def simplify_parens(expression):
574    if not isinstance(expression, exp.Paren):
575        return expression
576
577    this = expression.this
578    parent = expression.parent
579
580    if not isinstance(this, exp.Select) and (
581        not isinstance(parent, (exp.Condition, exp.Binary))
582        or isinstance(parent, exp.Paren)
583        or not isinstance(this, exp.Binary)
584        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
585        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
586        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
587        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
588    ):
589        return this
590    return expression
def simplify_coalesce(expression):
600def simplify_coalesce(expression):
601    # COALESCE(x) -> x
602    if (
603        isinstance(expression, exp.Coalesce)
604        and not expression.expressions
605        # COALESCE is also used as a Spark partitioning hint
606        and not isinstance(expression.parent, exp.Hint)
607    ):
608        return expression.this
609
610    if not isinstance(expression, COMPARISONS):
611        return expression
612
613    if isinstance(expression.left, exp.Coalesce):
614        coalesce = expression.left
615        other = expression.right
616    elif isinstance(expression.right, exp.Coalesce):
617        coalesce = expression.right
618        other = expression.left
619    else:
620        return expression
621
622    # This transformation is valid for non-constants,
623    # but it really only does anything if they are both constants.
624    if not isinstance(other, CONSTANTS):
625        return expression
626
627    # Find the first constant arg
628    for arg_index, arg in enumerate(coalesce.expressions):
629        if isinstance(arg, CONSTANTS):
630            break
631    else:
632        return expression
633
634    coalesce.set("expressions", coalesce.expressions[:arg_index])
635
636    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
637    # since we already remove COALESCE at the top of this function.
638    coalesce = coalesce if coalesce.expressions else coalesce.this
639
640    # This expression is more complex than when we started, but it will get simplified further
641    return exp.paren(
642        exp.or_(
643            exp.and_(
644                coalesce.is_(exp.null()).not_(copy=False),
645                expression.copy(),
646                copy=False,
647            ),
648            exp.and_(
649                coalesce.is_(exp.null()),
650                type(expression)(this=arg.copy(), expression=other.copy()),
651                copy=False,
652            ),
653            copy=False,
654        )
655    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
662def simplify_concat(expression):
663    """Reduces all groups that contain string literals by concatenating them."""
664    if not isinstance(expression, CONCATS) or (
665        # We can't reduce a CONCAT_WS call if we don't statically know the separator
666        isinstance(expression, exp.ConcatWs)
667        and not expression.expressions[0].is_string
668    ):
669        return expression
670
671    if isinstance(expression, exp.ConcatWs):
672        sep_expr, *expressions = expression.expressions
673        sep = sep_expr.name
674        concat_type = exp.ConcatWs
675    else:
676        expressions = expression.expressions
677        sep = ""
678        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
679
680    new_args = []
681    for is_string_group, group in itertools.groupby(
682        expressions or expression.flatten(), lambda e: e.is_string
683    ):
684        if is_string_group:
685            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
686        else:
687            new_args.extend(group)
688
689    if len(new_args) == 1 and new_args[0].is_string:
690        return new_args[0]
691
692    if concat_type is exp.ConcatWs:
693        new_args = [sep_expr] + new_args
694
695    return concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
106        def wrapped(expression, *args, **kwargs):
107            try:
108                return func(expression, *args, **kwargs)
109            except exceptions:
110                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('', ''), ('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER')}
def remove_where_true(expression):
833def remove_where_true(expression):
834    for where in expression.find_all(exp.Where):
835        if always_true(where.this):
836            where.parent.set("where", None)
837    for join in expression.find_all(exp.Join):
838        if (
839            always_true(join.args.get("on"))
840            and not join.args.get("using")
841            and not join.args.get("method")
842            and (join.side, join.kind) in JOINS
843        ):
844            join.set("on", None)
845            join.set("side", None)
846            join.set("kind", "CROSS")
def always_true(expression):
849def always_true(expression):
850    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
851        expression, exp.Literal
852    )
def is_complement(a, b):
855def is_complement(a, b):
856    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
859def is_false(a: exp.Expression) -> bool:
860    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
863def is_null(a: exp.Expression) -> bool:
864    return type(a) is exp.Null
def eval_boolean(expression, a, b):
867def eval_boolean(expression, a, b):
868    if isinstance(expression, (exp.EQ, exp.Is)):
869        return boolean_literal(a == b)
870    if isinstance(expression, exp.NEQ):
871        return boolean_literal(a != b)
872    if isinstance(expression, exp.GT):
873        return boolean_literal(a > b)
874    if isinstance(expression, exp.GTE):
875        return boolean_literal(a >= b)
876    if isinstance(expression, exp.LT):
877        return boolean_literal(a < b)
878    if isinstance(expression, exp.LTE):
879        return boolean_literal(a <= b)
880    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
883def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
884    if isinstance(value, datetime.datetime):
885        return value.date()
886    if isinstance(value, datetime.date):
887        return value
888    try:
889        return datetime.datetime.fromisoformat(value).date()
890    except ValueError:
891        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
894def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
895    if isinstance(value, datetime.datetime):
896        return value
897    if isinstance(value, datetime.date):
898        return datetime.datetime(year=value.year, month=value.month, day=value.day)
899    try:
900        return datetime.datetime.fromisoformat(value)
901    except ValueError:
902        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
905def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
906    if not value:
907        return None
908    if to.is_type(exp.DataType.Type.DATE):
909        return cast_as_date(value)
910    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
911        return cast_as_datetime(value)
912    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
915def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
916    if isinstance(cast, exp.Cast):
917        to = cast.to
918    elif isinstance(cast, exp.TsOrDsToDate):
919        to = exp.DataType.build(exp.DataType.Type.DATE)
920    else:
921        return None
922
923    if isinstance(cast.this, exp.Literal):
924        value: t.Any = cast.this.name
925    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
926        value = extract_date(cast.this)
927    else:
928        return None
929    return cast_value(value, to)
def extract_interval(expression):
936def extract_interval(expression):
937    n = int(expression.name)
938    unit = expression.text("unit").lower()
939
940    try:
941        return interval(unit, n)
942    except (UnsupportedUnit, ModuleNotFoundError):
943        return None
def date_literal(date):
946def date_literal(date):
947    return exp.cast(
948        exp.Literal.string(date),
949        exp.DataType.Type.DATETIME
950        if isinstance(date, datetime.datetime)
951        else exp.DataType.Type.DATE,
952    )
def interval(unit: str, n: int = 1):
955def interval(unit: str, n: int = 1):
956    from dateutil.relativedelta import relativedelta
957
958    if unit == "year":
959        return relativedelta(years=1 * n)
960    if unit == "quarter":
961        return relativedelta(months=3 * n)
962    if unit == "month":
963        return relativedelta(months=1 * n)
964    if unit == "week":
965        return relativedelta(weeks=1 * n)
966    if unit == "day":
967        return relativedelta(days=1 * n)
968    if unit == "hour":
969        return relativedelta(hours=1 * n)
970    if unit == "minute":
971        return relativedelta(minutes=1 * n)
972    if unit == "second":
973        return relativedelta(seconds=1 * n)
974
975    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
978def date_floor(d: datetime.date, unit: str) -> datetime.date:
979    if unit == "year":
980        return d.replace(month=1, day=1)
981    if unit == "quarter":
982        if d.month <= 3:
983            return d.replace(month=1, day=1)
984        elif d.month <= 6:
985            return d.replace(month=4, day=1)
986        elif d.month <= 9:
987            return d.replace(month=7, day=1)
988        else:
989            return d.replace(month=10, day=1)
990    if unit == "month":
991        return d.replace(month=d.month, day=1)
992    if unit == "week":
993        # Assuming week starts on Monday (0) and ends on Sunday (6)
994        return d - datetime.timedelta(days=d.weekday())
995    if unit == "day":
996        return d
997
998    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1001def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1002    floor = date_floor(d, unit)
1003
1004    if floor == d:
1005        return d
1006
1007    return floor + interval(unit)
def boolean_literal(condition):
1010def boolean_literal(condition):
1011    return exp.true() if condition else exp.false()