Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
  11from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  12
  13# Final means that an expression should not be simplified
  14FINAL = "final"
  15
  16
  17class UnsupportedUnit(Exception):
  18    pass
  19
  20
  21def simplify(expression, constant_propagation=False):
  22    """
  23    Rewrite sqlglot AST to simplify expressions.
  24
  25    Example:
  26        >>> import sqlglot
  27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  28        >>> simplify(expression).sql()
  29        'TRUE'
  30
  31    Args:
  32        expression (sqlglot.Expression): expression to simplify
  33        constant_propagation: whether or not the constant propagation rule should be used
  34
  35    Returns:
  36        sqlglot.Expression: simplified expression
  37    """
  38
  39    # group by expressions cannot be simplified, for example
  40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  41    # the projection must exactly match the group by key
  42    for group in expression.find_all(exp.Group):
  43        select = group.parent
  44        groups = set(group.expressions)
  45        group.meta[FINAL] = True
  46
  47        for e in select.selects:
  48            for node, *_ in e.walk():
  49                if node in groups:
  50                    e.meta[FINAL] = True
  51                    break
  52
  53        having = select.args.get("having")
  54        if having:
  55            for node, *_ in having.walk():
  56                if node in groups:
  57                    having.meta[FINAL] = True
  58                    break
  59
  60    def _simplify(expression, root=True):
  61        if expression.meta.get(FINAL):
  62            return expression
  63
  64        # Pre-order transformations
  65        node = expression
  66        node = rewrite_between(node)
  67        node = uniq_sort(node, root)
  68        node = absorb_and_eliminate(node, root)
  69        node = simplify_concat(node)
  70        node = simplify_conditionals(node)
  71
  72        if constant_propagation:
  73            node = propagate_constants(node, root)
  74
  75        exp.replace_children(node, lambda e: _simplify(e, False))
  76
  77        # Post-order transformations
  78        node = simplify_not(node)
  79        node = flatten(node)
  80        node = simplify_connectors(node, root)
  81        node = remove_complements(node, root)
  82        node = simplify_coalesce(node)
  83        node.parent = expression.parent
  84        node = simplify_literals(node, root)
  85        node = simplify_equality(node)
  86        node = simplify_parens(node)
  87        node = simplify_datetrunc_predicate(node)
  88
  89        if root:
  90            expression.replace(node)
  91
  92        return node
  93
  94    expression = while_changing(expression, _simplify)
  95    remove_where_true(expression)
  96    return expression
  97
  98
  99def catch(*exceptions):
 100    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 101
 102    def decorator(func):
 103        def wrapped(expression, *args, **kwargs):
 104            try:
 105                return func(expression, *args, **kwargs)
 106            except exceptions:
 107                return expression
 108
 109        return wrapped
 110
 111    return decorator
 112
 113
 114def rewrite_between(expression: exp.Expression) -> exp.Expression:
 115    """Rewrite x between y and z to x >= y AND x <= z.
 116
 117    This is done because comparison simplification is only done on lt/lte/gt/gte.
 118    """
 119    if isinstance(expression, exp.Between):
 120        return exp.and_(
 121            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 122            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 123            copy=False,
 124        )
 125    return expression
 126
 127
 128def simplify_not(expression):
 129    """
 130    Demorgan's Law
 131    NOT (x OR y) -> NOT x AND NOT y
 132    NOT (x AND y) -> NOT x OR NOT y
 133    """
 134    if isinstance(expression, exp.Not):
 135        if is_null(expression.this):
 136            return exp.null()
 137        if isinstance(expression.this, exp.Paren):
 138            condition = expression.this.unnest()
 139            if isinstance(condition, exp.And):
 140                return exp.or_(
 141                    exp.not_(condition.left, copy=False),
 142                    exp.not_(condition.right, copy=False),
 143                    copy=False,
 144                )
 145            if isinstance(condition, exp.Or):
 146                return exp.and_(
 147                    exp.not_(condition.left, copy=False),
 148                    exp.not_(condition.right, copy=False),
 149                    copy=False,
 150                )
 151            if is_null(condition):
 152                return exp.null()
 153        if always_true(expression.this):
 154            return exp.false()
 155        if is_false(expression.this):
 156            return exp.true()
 157        if isinstance(expression.this, exp.Not):
 158            # double negation
 159            # NOT NOT x -> x
 160            return expression.this.this
 161    return expression
 162
 163
 164def flatten(expression):
 165    """
 166    A AND (B AND C) -> A AND B AND C
 167    A OR (B OR C) -> A OR B OR C
 168    """
 169    if isinstance(expression, exp.Connector):
 170        for node in expression.args.values():
 171            child = node.unnest()
 172            if isinstance(child, expression.__class__):
 173                node.replace(child)
 174    return expression
 175
 176
 177def simplify_connectors(expression, root=True):
 178    def _simplify_connectors(expression, left, right):
 179        if left == right:
 180            return left
 181        if isinstance(expression, exp.And):
 182            if is_false(left) or is_false(right):
 183                return exp.false()
 184            if is_null(left) or is_null(right):
 185                return exp.null()
 186            if always_true(left) and always_true(right):
 187                return exp.true()
 188            if always_true(left):
 189                return right
 190            if always_true(right):
 191                return left
 192            return _simplify_comparison(expression, left, right)
 193        elif isinstance(expression, exp.Or):
 194            if always_true(left) or always_true(right):
 195                return exp.true()
 196            if is_false(left) and is_false(right):
 197                return exp.false()
 198            if (
 199                (is_null(left) and is_null(right))
 200                or (is_null(left) and is_false(right))
 201                or (is_false(left) and is_null(right))
 202            ):
 203                return exp.null()
 204            if is_false(left):
 205                return right
 206            if is_false(right):
 207                return left
 208            return _simplify_comparison(expression, left, right, or_=True)
 209
 210    if isinstance(expression, exp.Connector):
 211        return _flat_simplify(expression, _simplify_connectors, root)
 212    return expression
 213
 214
 215LT_LTE = (exp.LT, exp.LTE)
 216GT_GTE = (exp.GT, exp.GTE)
 217
 218COMPARISONS = (
 219    *LT_LTE,
 220    *GT_GTE,
 221    exp.EQ,
 222    exp.NEQ,
 223    exp.Is,
 224)
 225
 226INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 227    exp.LT: exp.GT,
 228    exp.GT: exp.LT,
 229    exp.LTE: exp.GTE,
 230    exp.GTE: exp.LTE,
 231}
 232
 233
 234def _simplify_comparison(expression, left, right, or_=False):
 235    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 236        ll, lr = left.args.values()
 237        rl, rr = right.args.values()
 238
 239        largs = {ll, lr}
 240        rargs = {rl, rr}
 241
 242        matching = largs & rargs
 243        columns = {m for m in matching if isinstance(m, exp.Column)}
 244
 245        if matching and columns:
 246            try:
 247                l = first(largs - columns)
 248                r = first(rargs - columns)
 249            except StopIteration:
 250                return expression
 251
 252            # make sure the comparison is always of the form x > 1 instead of 1 < x
 253            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 254                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 255            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 256                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 257
 258            if l.is_number and r.is_number:
 259                l = float(l.name)
 260                r = float(r.name)
 261            elif l.is_string and r.is_string:
 262                l = l.name
 263                r = r.name
 264            else:
 265                return None
 266
 267            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 268                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 269                    return left if (av > bv if or_ else av <= bv) else right
 270                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 271                    return left if (av < bv if or_ else av >= bv) else right
 272
 273                # we can't ever shortcut to true because the column could be null
 274                if not or_:
 275                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 276                        if av <= bv:
 277                            return exp.false()
 278                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 279                        if av >= bv:
 280                            return exp.false()
 281                    elif isinstance(a, exp.EQ):
 282                        if isinstance(b, exp.LT):
 283                            return exp.false() if av >= bv else a
 284                        if isinstance(b, exp.LTE):
 285                            return exp.false() if av > bv else a
 286                        if isinstance(b, exp.GT):
 287                            return exp.false() if av <= bv else a
 288                        if isinstance(b, exp.GTE):
 289                            return exp.false() if av < bv else a
 290                        if isinstance(b, exp.NEQ):
 291                            return exp.false() if av == bv else a
 292    return None
 293
 294
 295def remove_complements(expression, root=True):
 296    """
 297    Removing complements.
 298
 299    A AND NOT A -> FALSE
 300    A OR NOT A -> TRUE
 301    """
 302    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 303        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 304
 305        for a, b in itertools.permutations(expression.flatten(), 2):
 306            if is_complement(a, b):
 307                return complement
 308    return expression
 309
 310
 311def uniq_sort(expression, root=True):
 312    """
 313    Uniq and sort a connector.
 314
 315    C AND A AND B AND B -> A AND B AND C
 316    """
 317    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 318        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 319        flattened = tuple(expression.flatten())
 320        deduped = {gen(e): e for e in flattened}
 321        arr = tuple(deduped.items())
 322
 323        # check if the operands are already sorted, if not sort them
 324        # A AND C AND B -> A AND B AND C
 325        for i, (sql, e) in enumerate(arr[1:]):
 326            if sql < arr[i][0]:
 327                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 328                break
 329        else:
 330            # we didn't have to sort but maybe we need to dedup
 331            if len(deduped) < len(flattened):
 332                expression = result_func(*deduped.values(), copy=False)
 333
 334    return expression
 335
 336
 337def absorb_and_eliminate(expression, root=True):
 338    """
 339    absorption:
 340        A AND (A OR B) -> A
 341        A OR (A AND B) -> A
 342        A AND (NOT A OR B) -> A AND B
 343        A OR (NOT A AND B) -> A OR B
 344    elimination:
 345        (A AND B) OR (A AND NOT B) -> A
 346        (A OR B) AND (A OR NOT B) -> A
 347    """
 348    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 349        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 350
 351        for a, b in itertools.permutations(expression.flatten(), 2):
 352            if isinstance(a, kind):
 353                aa, ab = a.unnest_operands()
 354
 355                # absorb
 356                if is_complement(b, aa):
 357                    aa.replace(exp.true() if kind == exp.And else exp.false())
 358                elif is_complement(b, ab):
 359                    ab.replace(exp.true() if kind == exp.And else exp.false())
 360                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 361                    a.replace(exp.false() if kind == exp.And else exp.true())
 362                elif isinstance(b, kind):
 363                    # eliminate
 364                    rhs = b.unnest_operands()
 365                    ba, bb = rhs
 366
 367                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 368                        a.replace(aa)
 369                        b.replace(aa)
 370                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 371                        a.replace(ab)
 372                        b.replace(ab)
 373
 374    return expression
 375
 376
 377def propagate_constants(expression, root=True):
 378    """
 379    Propagate constants for conjunctions in DNF:
 380
 381    SELECT * FROM t WHERE a = b AND b = 5 becomes
 382    SELECT * FROM t WHERE a = 5 AND b = 5
 383
 384    Reference: https://www.sqlite.org/optoverview.html
 385    """
 386
 387    if (
 388        isinstance(expression, exp.And)
 389        and (root or not expression.same_parent)
 390        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 391    ):
 392        constant_mapping = {}
 393        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 394            if isinstance(expr, exp.EQ):
 395                l, r = expr.left, expr.right
 396
 397                # TODO: create a helper that can be used to detect nested literal expressions such
 398                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 399                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 400                    pass
 401                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 402                    l, r = r, l
 403                else:
 404                    continue
 405
 406                constant_mapping[l] = (id(l), r)
 407
 408        if constant_mapping:
 409            for column in find_all_in_scope(expression, exp.Column):
 410                parent = column.parent
 411                column_id, constant = constant_mapping.get(column) or (None, None)
 412                if (
 413                    column_id is not None
 414                    and id(column) != column_id
 415                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 416                ):
 417                    column.replace(constant.copy())
 418
 419    return expression
 420
 421
 422INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 423    exp.DateAdd: exp.Sub,
 424    exp.DateSub: exp.Add,
 425    exp.DatetimeAdd: exp.Sub,
 426    exp.DatetimeSub: exp.Add,
 427}
 428
 429INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 430    **INVERSE_DATE_OPS,
 431    exp.Add: exp.Sub,
 432    exp.Sub: exp.Add,
 433}
 434
 435
 436def _is_number(expression: exp.Expression) -> bool:
 437    return expression.is_number
 438
 439
 440def _is_interval(expression: exp.Expression) -> bool:
 441    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 442
 443
 444@catch(ModuleNotFoundError, UnsupportedUnit)
 445def simplify_equality(expression: exp.Expression) -> exp.Expression:
 446    """
 447    Use the subtraction and addition properties of equality to simplify expressions:
 448
 449        x + 1 = 3 becomes x = 2
 450
 451    There are two binary operations in the above expression: + and =
 452    Here's how we reference all the operands in the code below:
 453
 454          l     r
 455        x + 1 = 3
 456        a   b
 457    """
 458    if isinstance(expression, COMPARISONS):
 459        l, r = expression.left, expression.right
 460
 461        if l.__class__ in INVERSE_OPS:
 462            pass
 463        elif r.__class__ in INVERSE_OPS:
 464            l, r = r, l
 465        else:
 466            return expression
 467
 468        if r.is_number:
 469            a_predicate = _is_number
 470            b_predicate = _is_number
 471        elif _is_date_literal(r):
 472            a_predicate = _is_date_literal
 473            b_predicate = _is_interval
 474        else:
 475            return expression
 476
 477        if l.__class__ in INVERSE_DATE_OPS:
 478            l = t.cast(exp.IntervalOp, l)
 479            a = l.this
 480            b = l.interval()
 481        else:
 482            l = t.cast(exp.Binary, l)
 483            a, b = l.left, l.right
 484
 485        if not a_predicate(a) and b_predicate(b):
 486            pass
 487        elif not a_predicate(b) and b_predicate(a):
 488            a, b = b, a
 489        else:
 490            return expression
 491
 492        return expression.__class__(
 493            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 494        )
 495    return expression
 496
 497
 498def simplify_literals(expression, root=True):
 499    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 500        return _flat_simplify(expression, _simplify_binary, root)
 501
 502    if isinstance(expression, exp.Neg):
 503        this = expression.this
 504        if this.is_number:
 505            value = this.name
 506            if value[0] == "-":
 507                return exp.Literal.number(value[1:])
 508            return exp.Literal.number(f"-{value}")
 509
 510    if type(expression) in INVERSE_DATE_OPS:
 511        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 512
 513    return expression
 514
 515
 516def _simplify_binary(expression, a, b):
 517    if isinstance(expression, exp.Is):
 518        if isinstance(b, exp.Not):
 519            c = b.this
 520            not_ = True
 521        else:
 522            c = b
 523            not_ = False
 524
 525        if is_null(c):
 526            if isinstance(a, exp.Literal):
 527                return exp.true() if not_ else exp.false()
 528            if is_null(a):
 529                return exp.false() if not_ else exp.true()
 530    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 531        return None
 532    elif is_null(a) or is_null(b):
 533        return exp.null()
 534
 535    if a.is_number and b.is_number:
 536        num_a = int(a.name) if a.is_int else Decimal(a.name)
 537        num_b = int(b.name) if b.is_int else Decimal(b.name)
 538
 539        if isinstance(expression, exp.Add):
 540            return exp.Literal.number(num_a + num_b)
 541        if isinstance(expression, exp.Mul):
 542            return exp.Literal.number(num_a * num_b)
 543
 544        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 545        if isinstance(expression, exp.Sub):
 546            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 547        if isinstance(expression, exp.Div):
 548            # engines have differing int div behavior so intdiv is not safe
 549            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 550                return None
 551            return exp.Literal.number(num_a / num_b)
 552
 553        boolean = eval_boolean(expression, num_a, num_b)
 554
 555        if boolean:
 556            return boolean
 557    elif a.is_string and b.is_string:
 558        boolean = eval_boolean(expression, a.this, b.this)
 559
 560        if boolean:
 561            return boolean
 562    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 563        a, b = extract_date(a), extract_interval(b)
 564        if a and b:
 565            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 566                return date_literal(a + b)
 567            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 568                return date_literal(a - b)
 569    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 570        a, b = extract_interval(a), extract_date(b)
 571        # you cannot subtract a date from an interval
 572        if a and b and isinstance(expression, exp.Add):
 573            return date_literal(a + b)
 574    elif _is_date_literal(a) and _is_date_literal(b):
 575        if isinstance(expression, exp.Predicate):
 576            a, b = extract_date(a), extract_date(b)
 577            boolean = eval_boolean(expression, a, b)
 578            if boolean:
 579                return boolean
 580
 581    return None
 582
 583
 584def simplify_parens(expression):
 585    if not isinstance(expression, exp.Paren):
 586        return expression
 587
 588    this = expression.this
 589    parent = expression.parent
 590
 591    if not isinstance(this, exp.Select) and (
 592        not isinstance(parent, (exp.Condition, exp.Binary))
 593        or isinstance(parent, exp.Paren)
 594        or not isinstance(this, exp.Binary)
 595        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 596        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 597        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 598        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 599    ):
 600        return this
 601    return expression
 602
 603
 604NONNULL_CONSTANTS = (
 605    exp.Literal,
 606    exp.Boolean,
 607)
 608
 609CONSTANTS = (
 610    exp.Literal,
 611    exp.Boolean,
 612    exp.Null,
 613)
 614
 615
 616def _is_nonnull_constant(expression: exp.Expression) -> bool:
 617    return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
 618
 619
 620def _is_constant(expression: exp.Expression) -> bool:
 621    return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
 622
 623
 624def simplify_coalesce(expression):
 625    # COALESCE(x) -> x
 626    if (
 627        isinstance(expression, exp.Coalesce)
 628        and (not expression.expressions or _is_nonnull_constant(expression.this))
 629        # COALESCE is also used as a Spark partitioning hint
 630        and not isinstance(expression.parent, exp.Hint)
 631    ):
 632        return expression.this
 633
 634    if not isinstance(expression, COMPARISONS):
 635        return expression
 636
 637    if isinstance(expression.left, exp.Coalesce):
 638        coalesce = expression.left
 639        other = expression.right
 640    elif isinstance(expression.right, exp.Coalesce):
 641        coalesce = expression.right
 642        other = expression.left
 643    else:
 644        return expression
 645
 646    # This transformation is valid for non-constants,
 647    # but it really only does anything if they are both constants.
 648    if not _is_constant(other):
 649        return expression
 650
 651    # Find the first constant arg
 652    for arg_index, arg in enumerate(coalesce.expressions):
 653        if _is_constant(other):
 654            break
 655    else:
 656        return expression
 657
 658    coalesce.set("expressions", coalesce.expressions[:arg_index])
 659
 660    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 661    # since we already remove COALESCE at the top of this function.
 662    coalesce = coalesce if coalesce.expressions else coalesce.this
 663
 664    # This expression is more complex than when we started, but it will get simplified further
 665    return exp.paren(
 666        exp.or_(
 667            exp.and_(
 668                coalesce.is_(exp.null()).not_(copy=False),
 669                expression.copy(),
 670                copy=False,
 671            ),
 672            exp.and_(
 673                coalesce.is_(exp.null()),
 674                type(expression)(this=arg.copy(), expression=other.copy()),
 675                copy=False,
 676            ),
 677            copy=False,
 678        )
 679    )
 680
 681
 682CONCATS = (exp.Concat, exp.DPipe)
 683
 684
 685def simplify_concat(expression):
 686    """Reduces all groups that contain string literals by concatenating them."""
 687    if not isinstance(expression, CONCATS) or (
 688        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 689        isinstance(expression, exp.ConcatWs)
 690        and not expression.expressions[0].is_string
 691    ):
 692        return expression
 693
 694    if isinstance(expression, exp.ConcatWs):
 695        sep_expr, *expressions = expression.expressions
 696        sep = sep_expr.name
 697        concat_type = exp.ConcatWs
 698        args = {}
 699    else:
 700        expressions = expression.expressions
 701        sep = ""
 702        concat_type = exp.Concat
 703        args = {"safe": expression.args.get("safe")}
 704
 705    new_args = []
 706    for is_string_group, group in itertools.groupby(
 707        expressions or expression.flatten(), lambda e: e.is_string
 708    ):
 709        if is_string_group:
 710            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 711        else:
 712            new_args.extend(group)
 713
 714    if len(new_args) == 1 and new_args[0].is_string:
 715        return new_args[0]
 716
 717    if concat_type is exp.ConcatWs:
 718        new_args = [sep_expr] + new_args
 719
 720    return concat_type(expressions=new_args, **args)
 721
 722
 723def simplify_conditionals(expression):
 724    """Simplifies expressions like IF, CASE if their condition is statically known."""
 725    if isinstance(expression, exp.Case):
 726        this = expression.this
 727        for case in expression.args["ifs"]:
 728            cond = case.this
 729            if this:
 730                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 731                cond = cond.replace(this.pop().eq(cond))
 732
 733            if always_true(cond):
 734                return case.args["true"]
 735
 736            if always_false(cond):
 737                case.pop()
 738                if not expression.args["ifs"]:
 739                    return expression.args.get("default") or exp.null()
 740    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 741        if always_true(expression.this):
 742            return expression.args["true"]
 743        if always_false(expression.this):
 744            return expression.args.get("false") or exp.null()
 745
 746    return expression
 747
 748
 749DateRange = t.Tuple[datetime.date, datetime.date]
 750
 751
 752def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 753    """
 754    Get the date range for a DATE_TRUNC equality comparison:
 755
 756    Example:
 757        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 758    Returns:
 759        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 760    """
 761    floor = date_floor(date, unit)
 762
 763    if date != floor:
 764        # This will always be False, except for NULL values.
 765        return None
 766
 767    return floor, floor + interval(unit)
 768
 769
 770def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 771    """Get the logical expression for a date range"""
 772    return exp.and_(
 773        left >= date_literal(drange[0]),
 774        left < date_literal(drange[1]),
 775        copy=False,
 776    )
 777
 778
 779def _datetrunc_eq(
 780    left: exp.Expression, date: datetime.date, unit: str
 781) -> t.Optional[exp.Expression]:
 782    drange = _datetrunc_range(date, unit)
 783    if not drange:
 784        return None
 785
 786    return _datetrunc_eq_expression(left, drange)
 787
 788
 789def _datetrunc_neq(
 790    left: exp.Expression, date: datetime.date, unit: str
 791) -> t.Optional[exp.Expression]:
 792    drange = _datetrunc_range(date, unit)
 793    if not drange:
 794        return None
 795
 796    return exp.and_(
 797        left < date_literal(drange[0]),
 798        left >= date_literal(drange[1]),
 799        copy=False,
 800    )
 801
 802
 803DateTruncBinaryTransform = t.Callable[
 804    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 805]
 806DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 807    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 808    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 809    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 810    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 811    exp.EQ: _datetrunc_eq,
 812    exp.NEQ: _datetrunc_neq,
 813}
 814DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 815
 816
 817def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 818    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 819
 820
 821@catch(ModuleNotFoundError, UnsupportedUnit)
 822def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 823    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 824    comparison = expression.__class__
 825
 826    if comparison not in DATETRUNC_COMPARISONS:
 827        return expression
 828
 829    if isinstance(expression, exp.Binary):
 830        l, r = expression.left, expression.right
 831
 832        if _is_datetrunc_predicate(l, r):
 833            pass
 834        elif _is_datetrunc_predicate(r, l):
 835            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 836            l, r = r, l
 837        else:
 838            return expression
 839
 840        l = t.cast(exp.DateTrunc, l)
 841        unit = l.unit.name.lower()
 842        date = extract_date(r)
 843
 844        if not date:
 845            return expression
 846
 847        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 848    elif isinstance(expression, exp.In):
 849        l = expression.this
 850        rs = expression.expressions
 851
 852        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 853            l = t.cast(exp.DateTrunc, l)
 854            unit = l.unit.name.lower()
 855
 856            ranges = []
 857            for r in rs:
 858                date = extract_date(r)
 859                if not date:
 860                    return expression
 861                drange = _datetrunc_range(date, unit)
 862                if drange:
 863                    ranges.append(drange)
 864
 865            if not ranges:
 866                return expression
 867
 868            ranges = merge_ranges(ranges)
 869
 870            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 871
 872    return expression
 873
 874
 875# CROSS joins result in an empty table if the right table is empty.
 876# So we can only simplify certain types of joins to CROSS.
 877# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 878JOINS = {
 879    ("", ""),
 880    ("", "INNER"),
 881    ("RIGHT", ""),
 882    ("RIGHT", "OUTER"),
 883}
 884
 885
 886def remove_where_true(expression):
 887    for where in expression.find_all(exp.Where):
 888        if always_true(where.this):
 889            where.parent.set("where", None)
 890    for join in expression.find_all(exp.Join):
 891        if (
 892            always_true(join.args.get("on"))
 893            and not join.args.get("using")
 894            and not join.args.get("method")
 895            and (join.side, join.kind) in JOINS
 896        ):
 897            join.set("on", None)
 898            join.set("side", None)
 899            join.set("kind", "CROSS")
 900
 901
 902def always_true(expression):
 903    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 904        expression, exp.Literal
 905    )
 906
 907
 908def always_false(expression):
 909    return is_false(expression) or is_null(expression)
 910
 911
 912def is_complement(a, b):
 913    return isinstance(b, exp.Not) and b.this == a
 914
 915
 916def is_false(a: exp.Expression) -> bool:
 917    return type(a) is exp.Boolean and not a.this
 918
 919
 920def is_null(a: exp.Expression) -> bool:
 921    return type(a) is exp.Null
 922
 923
 924def eval_boolean(expression, a, b):
 925    if isinstance(expression, (exp.EQ, exp.Is)):
 926        return boolean_literal(a == b)
 927    if isinstance(expression, exp.NEQ):
 928        return boolean_literal(a != b)
 929    if isinstance(expression, exp.GT):
 930        return boolean_literal(a > b)
 931    if isinstance(expression, exp.GTE):
 932        return boolean_literal(a >= b)
 933    if isinstance(expression, exp.LT):
 934        return boolean_literal(a < b)
 935    if isinstance(expression, exp.LTE):
 936        return boolean_literal(a <= b)
 937    return None
 938
 939
 940def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 941    if isinstance(value, datetime.datetime):
 942        return value.date()
 943    if isinstance(value, datetime.date):
 944        return value
 945    try:
 946        return datetime.datetime.fromisoformat(value).date()
 947    except ValueError:
 948        return None
 949
 950
 951def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 952    if isinstance(value, datetime.datetime):
 953        return value
 954    if isinstance(value, datetime.date):
 955        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 956    try:
 957        return datetime.datetime.fromisoformat(value)
 958    except ValueError:
 959        return None
 960
 961
 962def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 963    if not value:
 964        return None
 965    if to.is_type(exp.DataType.Type.DATE):
 966        return cast_as_date(value)
 967    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 968        return cast_as_datetime(value)
 969    return None
 970
 971
 972def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 973    if isinstance(cast, exp.Cast):
 974        to = cast.to
 975    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
 976        to = exp.DataType.build(exp.DataType.Type.DATE)
 977    else:
 978        return None
 979
 980    if isinstance(cast.this, exp.Literal):
 981        value: t.Any = cast.this.name
 982    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 983        value = extract_date(cast.this)
 984    else:
 985        return None
 986    return cast_value(value, to)
 987
 988
 989def _is_date_literal(expression: exp.Expression) -> bool:
 990    return extract_date(expression) is not None
 991
 992
 993def extract_interval(expression):
 994    try:
 995        n = int(expression.name)
 996        unit = expression.text("unit").lower()
 997        return interval(unit, n)
 998    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
 999        return None
1000
1001
1002def date_literal(date):
1003    return exp.cast(
1004        exp.Literal.string(date),
1005        exp.DataType.Type.DATETIME
1006        if isinstance(date, datetime.datetime)
1007        else exp.DataType.Type.DATE,
1008    )
1009
1010
1011def interval(unit: str, n: int = 1):
1012    from dateutil.relativedelta import relativedelta
1013
1014    if unit == "year":
1015        return relativedelta(years=1 * n)
1016    if unit == "quarter":
1017        return relativedelta(months=3 * n)
1018    if unit == "month":
1019        return relativedelta(months=1 * n)
1020    if unit == "week":
1021        return relativedelta(weeks=1 * n)
1022    if unit == "day":
1023        return relativedelta(days=1 * n)
1024    if unit == "hour":
1025        return relativedelta(hours=1 * n)
1026    if unit == "minute":
1027        return relativedelta(minutes=1 * n)
1028    if unit == "second":
1029        return relativedelta(seconds=1 * n)
1030
1031    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1032
1033
1034def date_floor(d: datetime.date, unit: str) -> datetime.date:
1035    if unit == "year":
1036        return d.replace(month=1, day=1)
1037    if unit == "quarter":
1038        if d.month <= 3:
1039            return d.replace(month=1, day=1)
1040        elif d.month <= 6:
1041            return d.replace(month=4, day=1)
1042        elif d.month <= 9:
1043            return d.replace(month=7, day=1)
1044        else:
1045            return d.replace(month=10, day=1)
1046    if unit == "month":
1047        return d.replace(month=d.month, day=1)
1048    if unit == "week":
1049        # Assuming week starts on Monday (0) and ends on Sunday (6)
1050        return d - datetime.timedelta(days=d.weekday())
1051    if unit == "day":
1052        return d
1053
1054    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1055
1056
1057def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1058    floor = date_floor(d, unit)
1059
1060    if floor == d:
1061        return d
1062
1063    return floor + interval(unit)
1064
1065
1066def boolean_literal(condition):
1067    return exp.true() if condition else exp.false()
1068
1069
1070def _flat_simplify(expression, simplifier, root=True):
1071    if root or not expression.same_parent:
1072        operands = []
1073        queue = deque(expression.flatten(unnest=False))
1074        size = len(queue)
1075
1076        while queue:
1077            a = queue.popleft()
1078
1079            for b in queue:
1080                result = simplifier(expression, a, b)
1081
1082                if result and result is not expression:
1083                    queue.remove(b)
1084                    queue.appendleft(result)
1085                    break
1086            else:
1087                operands.append(a)
1088
1089        if len(operands) < size:
1090            return functools.reduce(
1091                lambda a, b: expression.__class__(this=a, expression=b), operands
1092            )
1093    return expression
1094
1095
1096def gen(expression: t.Any) -> str:
1097    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1098
1099    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1100    generator is expensive so we have a bare minimum sql generator here.
1101    """
1102    if expression is None:
1103        return "_"
1104    if is_iterable(expression):
1105        return ",".join(gen(e) for e in expression)
1106    if not isinstance(expression, exp.Expression):
1107        return str(expression)
1108
1109    etype = type(expression)
1110    if etype in GEN_MAP:
1111        return GEN_MAP[etype](expression)
1112    return f"{expression.key} {gen(expression.args.values())}"
1113
1114
1115GEN_MAP = {
1116    exp.Add: lambda e: _binary(e, "+"),
1117    exp.And: lambda e: _binary(e, "AND"),
1118    exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
1119    exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
1120    exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
1121    exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
1122    exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
1123    exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
1124    exp.Div: lambda e: _binary(e, "/"),
1125    exp.Dot: lambda e: _binary(e, "."),
1126    exp.EQ: lambda e: _binary(e, "="),
1127    exp.GT: lambda e: _binary(e, ">"),
1128    exp.GTE: lambda e: _binary(e, ">="),
1129    exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
1130    exp.ILike: lambda e: _binary(e, "ILIKE"),
1131    exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
1132    exp.Is: lambda e: _binary(e, "IS"),
1133    exp.Like: lambda e: _binary(e, "LIKE"),
1134    exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
1135    exp.LT: lambda e: _binary(e, "<"),
1136    exp.LTE: lambda e: _binary(e, "<="),
1137    exp.Mod: lambda e: _binary(e, "%"),
1138    exp.Mul: lambda e: _binary(e, "*"),
1139    exp.Neg: lambda e: _unary(e, "-"),
1140    exp.NEQ: lambda e: _binary(e, "<>"),
1141    exp.Not: lambda e: _unary(e, "NOT"),
1142    exp.Null: lambda e: "NULL",
1143    exp.Or: lambda e: _binary(e, "OR"),
1144    exp.Paren: lambda e: f"({gen(e.this)})",
1145    exp.Sub: lambda e: _binary(e, "-"),
1146    exp.Subquery: lambda e: f"({gen(e.args.values())})",
1147    exp.Table: lambda e: gen(e.args.values()),
1148    exp.Var: lambda e: e.name,
1149}
1150
1151
1152def _binary(e: exp.Binary, op: str) -> str:
1153    return f"{gen(e.left)} {op} {gen(e.right)}"
1154
1155
1156def _unary(e: exp.Unary, op: str) -> str:
1157    return f"{op} {gen(e.this)}"
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
18class UnsupportedUnit(Exception):
19    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
22def simplify(expression, constant_propagation=False):
23    """
24    Rewrite sqlglot AST to simplify expressions.
25
26    Example:
27        >>> import sqlglot
28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
29        >>> simplify(expression).sql()
30        'TRUE'
31
32    Args:
33        expression (sqlglot.Expression): expression to simplify
34        constant_propagation: whether or not the constant propagation rule should be used
35
36    Returns:
37        sqlglot.Expression: simplified expression
38    """
39
40    # group by expressions cannot be simplified, for example
41    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
42    # the projection must exactly match the group by key
43    for group in expression.find_all(exp.Group):
44        select = group.parent
45        groups = set(group.expressions)
46        group.meta[FINAL] = True
47
48        for e in select.selects:
49            for node, *_ in e.walk():
50                if node in groups:
51                    e.meta[FINAL] = True
52                    break
53
54        having = select.args.get("having")
55        if having:
56            for node, *_ in having.walk():
57                if node in groups:
58                    having.meta[FINAL] = True
59                    break
60
61    def _simplify(expression, root=True):
62        if expression.meta.get(FINAL):
63            return expression
64
65        # Pre-order transformations
66        node = expression
67        node = rewrite_between(node)
68        node = uniq_sort(node, root)
69        node = absorb_and_eliminate(node, root)
70        node = simplify_concat(node)
71        node = simplify_conditionals(node)
72
73        if constant_propagation:
74            node = propagate_constants(node, root)
75
76        exp.replace_children(node, lambda e: _simplify(e, False))
77
78        # Post-order transformations
79        node = simplify_not(node)
80        node = flatten(node)
81        node = simplify_connectors(node, root)
82        node = remove_complements(node, root)
83        node = simplify_coalesce(node)
84        node.parent = expression.parent
85        node = simplify_literals(node, root)
86        node = simplify_equality(node)
87        node = simplify_parens(node)
88        node = simplify_datetrunc_predicate(node)
89
90        if root:
91            expression.replace(node)
92
93        return node
94
95    expression = while_changing(expression, _simplify)
96    remove_where_true(expression)
97    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
100def catch(*exceptions):
101    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
102
103    def decorator(func):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
109
110        return wrapped
111
112    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
115def rewrite_between(expression: exp.Expression) -> exp.Expression:
116    """Rewrite x between y and z to x >= y AND x <= z.
117
118    This is done because comparison simplification is only done on lt/lte/gt/gte.
119    """
120    if isinstance(expression, exp.Between):
121        return exp.and_(
122            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
123            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
124            copy=False,
125        )
126    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
129def simplify_not(expression):
130    """
131    Demorgan's Law
132    NOT (x OR y) -> NOT x AND NOT y
133    NOT (x AND y) -> NOT x OR NOT y
134    """
135    if isinstance(expression, exp.Not):
136        if is_null(expression.this):
137            return exp.null()
138        if isinstance(expression.this, exp.Paren):
139            condition = expression.this.unnest()
140            if isinstance(condition, exp.And):
141                return exp.or_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if isinstance(condition, exp.Or):
147                return exp.and_(
148                    exp.not_(condition.left, copy=False),
149                    exp.not_(condition.right, copy=False),
150                    copy=False,
151                )
152            if is_null(condition):
153                return exp.null()
154        if always_true(expression.this):
155            return exp.false()
156        if is_false(expression.this):
157            return exp.true()
158        if isinstance(expression.this, exp.Not):
159            # double negation
160            # NOT NOT x -> x
161            return expression.this.this
162    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
165def flatten(expression):
166    """
167    A AND (B AND C) -> A AND B AND C
168    A OR (B OR C) -> A OR B OR C
169    """
170    if isinstance(expression, exp.Connector):
171        for node in expression.args.values():
172            child = node.unnest()
173            if isinstance(child, expression.__class__):
174                node.replace(child)
175    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
178def simplify_connectors(expression, root=True):
179    def _simplify_connectors(expression, left, right):
180        if left == right:
181            return left
182        if isinstance(expression, exp.And):
183            if is_false(left) or is_false(right):
184                return exp.false()
185            if is_null(left) or is_null(right):
186                return exp.null()
187            if always_true(left) and always_true(right):
188                return exp.true()
189            if always_true(left):
190                return right
191            if always_true(right):
192                return left
193            return _simplify_comparison(expression, left, right)
194        elif isinstance(expression, exp.Or):
195            if always_true(left) or always_true(right):
196                return exp.true()
197            if is_false(left) and is_false(right):
198                return exp.false()
199            if (
200                (is_null(left) and is_null(right))
201                or (is_null(left) and is_false(right))
202                or (is_false(left) and is_null(right))
203            ):
204                return exp.null()
205            if is_false(left):
206                return right
207            if is_false(right):
208                return left
209            return _simplify_comparison(expression, left, right, or_=True)
210
211    if isinstance(expression, exp.Connector):
212        return _flat_simplify(expression, _simplify_connectors, root)
213    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
296def remove_complements(expression, root=True):
297    """
298    Removing complements.
299
300    A AND NOT A -> FALSE
301    A OR NOT A -> TRUE
302    """
303    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
304        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
305
306        for a, b in itertools.permutations(expression.flatten(), 2):
307            if is_complement(a, b):
308                return complement
309    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
312def uniq_sort(expression, root=True):
313    """
314    Uniq and sort a connector.
315
316    C AND A AND B AND B -> A AND B AND C
317    """
318    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
319        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
320        flattened = tuple(expression.flatten())
321        deduped = {gen(e): e for e in flattened}
322        arr = tuple(deduped.items())
323
324        # check if the operands are already sorted, if not sort them
325        # A AND C AND B -> A AND B AND C
326        for i, (sql, e) in enumerate(arr[1:]):
327            if sql < arr[i][0]:
328                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
329                break
330        else:
331            # we didn't have to sort but maybe we need to dedup
332            if len(deduped) < len(flattened):
333                expression = result_func(*deduped.values(), copy=False)
334
335    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
338def absorb_and_eliminate(expression, root=True):
339    """
340    absorption:
341        A AND (A OR B) -> A
342        A OR (A AND B) -> A
343        A AND (NOT A OR B) -> A AND B
344        A OR (NOT A AND B) -> A OR B
345    elimination:
346        (A AND B) OR (A AND NOT B) -> A
347        (A OR B) AND (A OR NOT B) -> A
348    """
349    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
350        kind = exp.Or if isinstance(expression, exp.And) else exp.And
351
352        for a, b in itertools.permutations(expression.flatten(), 2):
353            if isinstance(a, kind):
354                aa, ab = a.unnest_operands()
355
356                # absorb
357                if is_complement(b, aa):
358                    aa.replace(exp.true() if kind == exp.And else exp.false())
359                elif is_complement(b, ab):
360                    ab.replace(exp.true() if kind == exp.And else exp.false())
361                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
362                    a.replace(exp.false() if kind == exp.And else exp.true())
363                elif isinstance(b, kind):
364                    # eliminate
365                    rhs = b.unnest_operands()
366                    ba, bb = rhs
367
368                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
369                        a.replace(aa)
370                        b.replace(aa)
371                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
372                        a.replace(ab)
373                        b.replace(ab)
374
375    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
378def propagate_constants(expression, root=True):
379    """
380    Propagate constants for conjunctions in DNF:
381
382    SELECT * FROM t WHERE a = b AND b = 5 becomes
383    SELECT * FROM t WHERE a = 5 AND b = 5
384
385    Reference: https://www.sqlite.org/optoverview.html
386    """
387
388    if (
389        isinstance(expression, exp.And)
390        and (root or not expression.same_parent)
391        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
392    ):
393        constant_mapping = {}
394        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
395            if isinstance(expr, exp.EQ):
396                l, r = expr.left, expr.right
397
398                # TODO: create a helper that can be used to detect nested literal expressions such
399                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
400                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
401                    pass
402                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
403                    l, r = r, l
404                else:
405                    continue
406
407                constant_mapping[l] = (id(l), r)
408
409        if constant_mapping:
410            for column in find_all_in_scope(expression, exp.Column):
411                parent = column.parent
412                column_id, constant = constant_mapping.get(column) or (None, None)
413                if (
414                    column_id is not None
415                    and id(column) != column_id
416                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
417                ):
418                    column.replace(constant.copy())
419
420    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
499def simplify_literals(expression, root=True):
500    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
501        return _flat_simplify(expression, _simplify_binary, root)
502
503    if isinstance(expression, exp.Neg):
504        this = expression.this
505        if this.is_number:
506            value = this.name
507            if value[0] == "-":
508                return exp.Literal.number(value[1:])
509            return exp.Literal.number(f"-{value}")
510
511    if type(expression) in INVERSE_DATE_OPS:
512        return _simplify_binary(expression, expression.this, expression.interval()) or expression
513
514    return expression
def simplify_parens(expression):
585def simplify_parens(expression):
586    if not isinstance(expression, exp.Paren):
587        return expression
588
589    this = expression.this
590    parent = expression.parent
591
592    if not isinstance(this, exp.Select) and (
593        not isinstance(parent, (exp.Condition, exp.Binary))
594        or isinstance(parent, exp.Paren)
595        or not isinstance(this, exp.Binary)
596        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
597        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
598        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
599        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
600    ):
601        return this
602    return expression
NONNULL_CONSTANTS = (<class 'sqlglot.expressions.Literal'>, <class 'sqlglot.expressions.Boolean'>)
def simplify_coalesce(expression):
625def simplify_coalesce(expression):
626    # COALESCE(x) -> x
627    if (
628        isinstance(expression, exp.Coalesce)
629        and (not expression.expressions or _is_nonnull_constant(expression.this))
630        # COALESCE is also used as a Spark partitioning hint
631        and not isinstance(expression.parent, exp.Hint)
632    ):
633        return expression.this
634
635    if not isinstance(expression, COMPARISONS):
636        return expression
637
638    if isinstance(expression.left, exp.Coalesce):
639        coalesce = expression.left
640        other = expression.right
641    elif isinstance(expression.right, exp.Coalesce):
642        coalesce = expression.right
643        other = expression.left
644    else:
645        return expression
646
647    # This transformation is valid for non-constants,
648    # but it really only does anything if they are both constants.
649    if not _is_constant(other):
650        return expression
651
652    # Find the first constant arg
653    for arg_index, arg in enumerate(coalesce.expressions):
654        if _is_constant(other):
655            break
656    else:
657        return expression
658
659    coalesce.set("expressions", coalesce.expressions[:arg_index])
660
661    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
662    # since we already remove COALESCE at the top of this function.
663    coalesce = coalesce if coalesce.expressions else coalesce.this
664
665    # This expression is more complex than when we started, but it will get simplified further
666    return exp.paren(
667        exp.or_(
668            exp.and_(
669                coalesce.is_(exp.null()).not_(copy=False),
670                expression.copy(),
671                copy=False,
672            ),
673            exp.and_(
674                coalesce.is_(exp.null()),
675                type(expression)(this=arg.copy(), expression=other.copy()),
676                copy=False,
677            ),
678            copy=False,
679        )
680    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
686def simplify_concat(expression):
687    """Reduces all groups that contain string literals by concatenating them."""
688    if not isinstance(expression, CONCATS) or (
689        # We can't reduce a CONCAT_WS call if we don't statically know the separator
690        isinstance(expression, exp.ConcatWs)
691        and not expression.expressions[0].is_string
692    ):
693        return expression
694
695    if isinstance(expression, exp.ConcatWs):
696        sep_expr, *expressions = expression.expressions
697        sep = sep_expr.name
698        concat_type = exp.ConcatWs
699        args = {}
700    else:
701        expressions = expression.expressions
702        sep = ""
703        concat_type = exp.Concat
704        args = {"safe": expression.args.get("safe")}
705
706    new_args = []
707    for is_string_group, group in itertools.groupby(
708        expressions or expression.flatten(), lambda e: e.is_string
709    ):
710        if is_string_group:
711            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
712        else:
713            new_args.extend(group)
714
715    if len(new_args) == 1 and new_args[0].is_string:
716        return new_args[0]
717
718    if concat_type is exp.ConcatWs:
719        new_args = [sep_expr] + new_args
720
721    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
724def simplify_conditionals(expression):
725    """Simplifies expressions like IF, CASE if their condition is statically known."""
726    if isinstance(expression, exp.Case):
727        this = expression.this
728        for case in expression.args["ifs"]:
729            cond = case.this
730            if this:
731                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
732                cond = cond.replace(this.pop().eq(cond))
733
734            if always_true(cond):
735                return case.args["true"]
736
737            if always_false(cond):
738                case.pop()
739                if not expression.args["ifs"]:
740                    return expression.args.get("default") or exp.null()
741    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
742        if always_true(expression.this):
743            return expression.args["true"]
744        if always_false(expression.this):
745            return expression.args.get("false") or exp.null()
746
747    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('RIGHT', 'OUTER'), ('', 'INNER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
887def remove_where_true(expression):
888    for where in expression.find_all(exp.Where):
889        if always_true(where.this):
890            where.parent.set("where", None)
891    for join in expression.find_all(exp.Join):
892        if (
893            always_true(join.args.get("on"))
894            and not join.args.get("using")
895            and not join.args.get("method")
896            and (join.side, join.kind) in JOINS
897        ):
898            join.set("on", None)
899            join.set("side", None)
900            join.set("kind", "CROSS")
def always_true(expression):
903def always_true(expression):
904    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
905        expression, exp.Literal
906    )
def always_false(expression):
909def always_false(expression):
910    return is_false(expression) or is_null(expression)
def is_complement(a, b):
913def is_complement(a, b):
914    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
917def is_false(a: exp.Expression) -> bool:
918    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
921def is_null(a: exp.Expression) -> bool:
922    return type(a) is exp.Null
def eval_boolean(expression, a, b):
925def eval_boolean(expression, a, b):
926    if isinstance(expression, (exp.EQ, exp.Is)):
927        return boolean_literal(a == b)
928    if isinstance(expression, exp.NEQ):
929        return boolean_literal(a != b)
930    if isinstance(expression, exp.GT):
931        return boolean_literal(a > b)
932    if isinstance(expression, exp.GTE):
933        return boolean_literal(a >= b)
934    if isinstance(expression, exp.LT):
935        return boolean_literal(a < b)
936    if isinstance(expression, exp.LTE):
937        return boolean_literal(a <= b)
938    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
941def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
942    if isinstance(value, datetime.datetime):
943        return value.date()
944    if isinstance(value, datetime.date):
945        return value
946    try:
947        return datetime.datetime.fromisoformat(value).date()
948    except ValueError:
949        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
952def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
953    if isinstance(value, datetime.datetime):
954        return value
955    if isinstance(value, datetime.date):
956        return datetime.datetime(year=value.year, month=value.month, day=value.day)
957    try:
958        return datetime.datetime.fromisoformat(value)
959    except ValueError:
960        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
963def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
964    if not value:
965        return None
966    if to.is_type(exp.DataType.Type.DATE):
967        return cast_as_date(value)
968    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
969        return cast_as_datetime(value)
970    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
973def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
974    if isinstance(cast, exp.Cast):
975        to = cast.to
976    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
977        to = exp.DataType.build(exp.DataType.Type.DATE)
978    else:
979        return None
980
981    if isinstance(cast.this, exp.Literal):
982        value: t.Any = cast.this.name
983    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
984        value = extract_date(cast.this)
985    else:
986        return None
987    return cast_value(value, to)
def extract_interval(expression):
 994def extract_interval(expression):
 995    try:
 996        n = int(expression.name)
 997        unit = expression.text("unit").lower()
 998        return interval(unit, n)
 999    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1000        return None
def date_literal(date):
1003def date_literal(date):
1004    return exp.cast(
1005        exp.Literal.string(date),
1006        exp.DataType.Type.DATETIME
1007        if isinstance(date, datetime.datetime)
1008        else exp.DataType.Type.DATE,
1009    )
def interval(unit: str, n: int = 1):
1012def interval(unit: str, n: int = 1):
1013    from dateutil.relativedelta import relativedelta
1014
1015    if unit == "year":
1016        return relativedelta(years=1 * n)
1017    if unit == "quarter":
1018        return relativedelta(months=3 * n)
1019    if unit == "month":
1020        return relativedelta(months=1 * n)
1021    if unit == "week":
1022        return relativedelta(weeks=1 * n)
1023    if unit == "day":
1024        return relativedelta(days=1 * n)
1025    if unit == "hour":
1026        return relativedelta(hours=1 * n)
1027    if unit == "minute":
1028        return relativedelta(minutes=1 * n)
1029    if unit == "second":
1030        return relativedelta(seconds=1 * n)
1031
1032    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
1035def date_floor(d: datetime.date, unit: str) -> datetime.date:
1036    if unit == "year":
1037        return d.replace(month=1, day=1)
1038    if unit == "quarter":
1039        if d.month <= 3:
1040            return d.replace(month=1, day=1)
1041        elif d.month <= 6:
1042            return d.replace(month=4, day=1)
1043        elif d.month <= 9:
1044            return d.replace(month=7, day=1)
1045        else:
1046            return d.replace(month=10, day=1)
1047    if unit == "month":
1048        return d.replace(month=d.month, day=1)
1049    if unit == "week":
1050        # Assuming week starts on Monday (0) and ends on Sunday (6)
1051        return d - datetime.timedelta(days=d.weekday())
1052    if unit == "day":
1053        return d
1054
1055    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1058def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1059    floor = date_floor(d, unit)
1060
1061    if floor == d:
1062        return d
1063
1064    return floor + interval(unit)
def boolean_literal(condition):
1067def boolean_literal(condition):
1068    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1097def gen(expression: t.Any) -> str:
1098    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1099
1100    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1101    generator is expensive so we have a bare minimum sql generator here.
1102    """
1103    if expression is None:
1104        return "_"
1105    if is_iterable(expression):
1106        return ",".join(gen(e) for e in expression)
1107    if not isinstance(expression, exp.Expression):
1108        return str(expression)
1109
1110    etype = type(expression)
1111    if etype in GEN_MAP:
1112        return GEN_MAP[etype](expression)
1113    return f"{expression.key} {gen(expression.args.values())}"

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

GEN_MAP = {<class 'sqlglot.expressions.Add'>: <function <lambda>>, <class 'sqlglot.expressions.And'>: <function <lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function <lambda>>, <class 'sqlglot.expressions.Between'>: <function <lambda>>, <class 'sqlglot.expressions.Boolean'>: <function <lambda>>, <class 'sqlglot.expressions.Bracket'>: <function <lambda>>, <class 'sqlglot.expressions.Column'>: <function <lambda>>, <class 'sqlglot.expressions.DataType'>: <function <lambda>>, <class 'sqlglot.expressions.Div'>: <function <lambda>>, <class 'sqlglot.expressions.Dot'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.Identifier'>: <function <lambda>>, <class 'sqlglot.expressions.ILike'>: <function <lambda>>, <class 'sqlglot.expressions.In'>: <function <lambda>>, <class 'sqlglot.expressions.Is'>: <function <lambda>>, <class 'sqlglot.expressions.Like'>: <function <lambda>>, <class 'sqlglot.expressions.Literal'>: <function <lambda>>, <class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.Mod'>: <function <lambda>>, <class 'sqlglot.expressions.Mul'>: <function <lambda>>, <class 'sqlglot.expressions.Neg'>: <function <lambda>>, <class 'sqlglot.expressions.NEQ'>: <function <lambda>>, <class 'sqlglot.expressions.Not'>: <function <lambda>>, <class 'sqlglot.expressions.Null'>: <function <lambda>>, <class 'sqlglot.expressions.Or'>: <function <lambda>>, <class 'sqlglot.expressions.Paren'>: <function <lambda>>, <class 'sqlglot.expressions.Sub'>: <function <lambda>>, <class 'sqlglot.expressions.Subquery'>: <function <lambda>>, <class 'sqlglot.expressions.Table'>: <function <lambda>>, <class 'sqlglot.expressions.Var'>: <function <lambda>>}