Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
  11from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  12
  13# Final means that an expression should not be simplified
  14FINAL = "final"
  15
  16
  17class UnsupportedUnit(Exception):
  18    pass
  19
  20
  21def simplify(expression, constant_propagation=False):
  22    """
  23    Rewrite sqlglot AST to simplify expressions.
  24
  25    Example:
  26        >>> import sqlglot
  27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  28        >>> simplify(expression).sql()
  29        'TRUE'
  30
  31    Args:
  32        expression (sqlglot.Expression): expression to simplify
  33        constant_propagation: whether or not the constant propagation rule should be used
  34
  35    Returns:
  36        sqlglot.Expression: simplified expression
  37    """
  38
  39    # group by expressions cannot be simplified, for example
  40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  41    # the projection must exactly match the group by key
  42    for group in expression.find_all(exp.Group):
  43        select = group.parent
  44        groups = set(group.expressions)
  45        group.meta[FINAL] = True
  46
  47        for e in select.selects:
  48            for node, *_ in e.walk():
  49                if node in groups:
  50                    e.meta[FINAL] = True
  51                    break
  52
  53        having = select.args.get("having")
  54        if having:
  55            for node, *_ in having.walk():
  56                if node in groups:
  57                    having.meta[FINAL] = True
  58                    break
  59
  60    def _simplify(expression, root=True):
  61        if expression.meta.get(FINAL):
  62            return expression
  63
  64        # Pre-order transformations
  65        node = expression
  66        node = rewrite_between(node)
  67        node = uniq_sort(node, root)
  68        node = absorb_and_eliminate(node, root)
  69        node = simplify_concat(node)
  70        node = simplify_conditionals(node)
  71
  72        if constant_propagation:
  73            node = propagate_constants(node, root)
  74
  75        exp.replace_children(node, lambda e: _simplify(e, False))
  76
  77        # Post-order transformations
  78        node = simplify_not(node)
  79        node = flatten(node)
  80        node = simplify_connectors(node, root)
  81        node = remove_complements(node, root)
  82        node = simplify_coalesce(node)
  83        node.parent = expression.parent
  84        node = simplify_literals(node, root)
  85        node = simplify_equality(node)
  86        node = simplify_parens(node)
  87        node = simplify_datetrunc_predicate(node)
  88
  89        if root:
  90            expression.replace(node)
  91
  92        return node
  93
  94    expression = while_changing(expression, _simplify)
  95    remove_where_true(expression)
  96    return expression
  97
  98
  99def catch(*exceptions):
 100    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 101
 102    def decorator(func):
 103        def wrapped(expression, *args, **kwargs):
 104            try:
 105                return func(expression, *args, **kwargs)
 106            except exceptions:
 107                return expression
 108
 109        return wrapped
 110
 111    return decorator
 112
 113
 114def rewrite_between(expression: exp.Expression) -> exp.Expression:
 115    """Rewrite x between y and z to x >= y AND x <= z.
 116
 117    This is done because comparison simplification is only done on lt/lte/gt/gte.
 118    """
 119    if isinstance(expression, exp.Between):
 120        return exp.and_(
 121            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 122            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 123            copy=False,
 124        )
 125    return expression
 126
 127
 128def simplify_not(expression):
 129    """
 130    Demorgan's Law
 131    NOT (x OR y) -> NOT x AND NOT y
 132    NOT (x AND y) -> NOT x OR NOT y
 133    """
 134    if isinstance(expression, exp.Not):
 135        if is_null(expression.this):
 136            return exp.null()
 137        if isinstance(expression.this, exp.Paren):
 138            condition = expression.this.unnest()
 139            if isinstance(condition, exp.And):
 140                return exp.or_(
 141                    exp.not_(condition.left, copy=False),
 142                    exp.not_(condition.right, copy=False),
 143                    copy=False,
 144                )
 145            if isinstance(condition, exp.Or):
 146                return exp.and_(
 147                    exp.not_(condition.left, copy=False),
 148                    exp.not_(condition.right, copy=False),
 149                    copy=False,
 150                )
 151            if is_null(condition):
 152                return exp.null()
 153        if always_true(expression.this):
 154            return exp.false()
 155        if is_false(expression.this):
 156            return exp.true()
 157        if isinstance(expression.this, exp.Not):
 158            # double negation
 159            # NOT NOT x -> x
 160            return expression.this.this
 161    return expression
 162
 163
 164def flatten(expression):
 165    """
 166    A AND (B AND C) -> A AND B AND C
 167    A OR (B OR C) -> A OR B OR C
 168    """
 169    if isinstance(expression, exp.Connector):
 170        for node in expression.args.values():
 171            child = node.unnest()
 172            if isinstance(child, expression.__class__):
 173                node.replace(child)
 174    return expression
 175
 176
 177def simplify_connectors(expression, root=True):
 178    def _simplify_connectors(expression, left, right):
 179        if left == right:
 180            return left
 181        if isinstance(expression, exp.And):
 182            if is_false(left) or is_false(right):
 183                return exp.false()
 184            if is_null(left) or is_null(right):
 185                return exp.null()
 186            if always_true(left) and always_true(right):
 187                return exp.true()
 188            if always_true(left):
 189                return right
 190            if always_true(right):
 191                return left
 192            return _simplify_comparison(expression, left, right)
 193        elif isinstance(expression, exp.Or):
 194            if always_true(left) or always_true(right):
 195                return exp.true()
 196            if is_false(left) and is_false(right):
 197                return exp.false()
 198            if (
 199                (is_null(left) and is_null(right))
 200                or (is_null(left) and is_false(right))
 201                or (is_false(left) and is_null(right))
 202            ):
 203                return exp.null()
 204            if is_false(left):
 205                return right
 206            if is_false(right):
 207                return left
 208            return _simplify_comparison(expression, left, right, or_=True)
 209
 210    if isinstance(expression, exp.Connector):
 211        return _flat_simplify(expression, _simplify_connectors, root)
 212    return expression
 213
 214
 215LT_LTE = (exp.LT, exp.LTE)
 216GT_GTE = (exp.GT, exp.GTE)
 217
 218COMPARISONS = (
 219    *LT_LTE,
 220    *GT_GTE,
 221    exp.EQ,
 222    exp.NEQ,
 223    exp.Is,
 224)
 225
 226INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 227    exp.LT: exp.GT,
 228    exp.GT: exp.LT,
 229    exp.LTE: exp.GTE,
 230    exp.GTE: exp.LTE,
 231}
 232
 233
 234def _simplify_comparison(expression, left, right, or_=False):
 235    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 236        ll, lr = left.args.values()
 237        rl, rr = right.args.values()
 238
 239        largs = {ll, lr}
 240        rargs = {rl, rr}
 241
 242        matching = largs & rargs
 243        columns = {m for m in matching if isinstance(m, exp.Column)}
 244
 245        if matching and columns:
 246            try:
 247                l = first(largs - columns)
 248                r = first(rargs - columns)
 249            except StopIteration:
 250                return expression
 251
 252            # make sure the comparison is always of the form x > 1 instead of 1 < x
 253            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 254                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 255            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 256                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 257
 258            if l.is_number and r.is_number:
 259                l = float(l.name)
 260                r = float(r.name)
 261            elif l.is_string and r.is_string:
 262                l = l.name
 263                r = r.name
 264            else:
 265                return None
 266
 267            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 268                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 269                    return left if (av > bv if or_ else av <= bv) else right
 270                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 271                    return left if (av < bv if or_ else av >= bv) else right
 272
 273                # we can't ever shortcut to true because the column could be null
 274                if not or_:
 275                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 276                        if av <= bv:
 277                            return exp.false()
 278                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 279                        if av >= bv:
 280                            return exp.false()
 281                    elif isinstance(a, exp.EQ):
 282                        if isinstance(b, exp.LT):
 283                            return exp.false() if av >= bv else a
 284                        if isinstance(b, exp.LTE):
 285                            return exp.false() if av > bv else a
 286                        if isinstance(b, exp.GT):
 287                            return exp.false() if av <= bv else a
 288                        if isinstance(b, exp.GTE):
 289                            return exp.false() if av < bv else a
 290                        if isinstance(b, exp.NEQ):
 291                            return exp.false() if av == bv else a
 292    return None
 293
 294
 295def remove_complements(expression, root=True):
 296    """
 297    Removing complements.
 298
 299    A AND NOT A -> FALSE
 300    A OR NOT A -> TRUE
 301    """
 302    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 303        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 304
 305        for a, b in itertools.permutations(expression.flatten(), 2):
 306            if is_complement(a, b):
 307                return complement
 308    return expression
 309
 310
 311def uniq_sort(expression, root=True):
 312    """
 313    Uniq and sort a connector.
 314
 315    C AND A AND B AND B -> A AND B AND C
 316    """
 317    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 318        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 319        flattened = tuple(expression.flatten())
 320        deduped = {gen(e): e for e in flattened}
 321        arr = tuple(deduped.items())
 322
 323        # check if the operands are already sorted, if not sort them
 324        # A AND C AND B -> A AND B AND C
 325        for i, (sql, e) in enumerate(arr[1:]):
 326            if sql < arr[i][0]:
 327                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 328                break
 329        else:
 330            # we didn't have to sort but maybe we need to dedup
 331            if len(deduped) < len(flattened):
 332                expression = result_func(*deduped.values(), copy=False)
 333
 334    return expression
 335
 336
 337def absorb_and_eliminate(expression, root=True):
 338    """
 339    absorption:
 340        A AND (A OR B) -> A
 341        A OR (A AND B) -> A
 342        A AND (NOT A OR B) -> A AND B
 343        A OR (NOT A AND B) -> A OR B
 344    elimination:
 345        (A AND B) OR (A AND NOT B) -> A
 346        (A OR B) AND (A OR NOT B) -> A
 347    """
 348    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 349        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 350
 351        for a, b in itertools.permutations(expression.flatten(), 2):
 352            if isinstance(a, kind):
 353                aa, ab = a.unnest_operands()
 354
 355                # absorb
 356                if is_complement(b, aa):
 357                    aa.replace(exp.true() if kind == exp.And else exp.false())
 358                elif is_complement(b, ab):
 359                    ab.replace(exp.true() if kind == exp.And else exp.false())
 360                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 361                    a.replace(exp.false() if kind == exp.And else exp.true())
 362                elif isinstance(b, kind):
 363                    # eliminate
 364                    rhs = b.unnest_operands()
 365                    ba, bb = rhs
 366
 367                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 368                        a.replace(aa)
 369                        b.replace(aa)
 370                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 371                        a.replace(ab)
 372                        b.replace(ab)
 373
 374    return expression
 375
 376
 377def propagate_constants(expression, root=True):
 378    """
 379    Propagate constants for conjunctions in DNF:
 380
 381    SELECT * FROM t WHERE a = b AND b = 5 becomes
 382    SELECT * FROM t WHERE a = 5 AND b = 5
 383
 384    Reference: https://www.sqlite.org/optoverview.html
 385    """
 386
 387    if (
 388        isinstance(expression, exp.And)
 389        and (root or not expression.same_parent)
 390        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 391    ):
 392        constant_mapping = {}
 393        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 394            if isinstance(expr, exp.EQ):
 395                l, r = expr.left, expr.right
 396
 397                # TODO: create a helper that can be used to detect nested literal expressions such
 398                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 399                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 400                    pass
 401                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 402                    l, r = r, l
 403                else:
 404                    continue
 405
 406                constant_mapping[l] = (id(l), r)
 407
 408        if constant_mapping:
 409            for column in find_all_in_scope(expression, exp.Column):
 410                parent = column.parent
 411                column_id, constant = constant_mapping.get(column) or (None, None)
 412                if (
 413                    column_id is not None
 414                    and id(column) != column_id
 415                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 416                ):
 417                    column.replace(constant.copy())
 418
 419    return expression
 420
 421
 422INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 423    exp.DateAdd: exp.Sub,
 424    exp.DateSub: exp.Add,
 425    exp.DatetimeAdd: exp.Sub,
 426    exp.DatetimeSub: exp.Add,
 427}
 428
 429INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 430    **INVERSE_DATE_OPS,
 431    exp.Add: exp.Sub,
 432    exp.Sub: exp.Add,
 433}
 434
 435
 436def _is_number(expression: exp.Expression) -> bool:
 437    return expression.is_number
 438
 439
 440def _is_interval(expression: exp.Expression) -> bool:
 441    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 442
 443
 444@catch(ModuleNotFoundError, UnsupportedUnit)
 445def simplify_equality(expression: exp.Expression) -> exp.Expression:
 446    """
 447    Use the subtraction and addition properties of equality to simplify expressions:
 448
 449        x + 1 = 3 becomes x = 2
 450
 451    There are two binary operations in the above expression: + and =
 452    Here's how we reference all the operands in the code below:
 453
 454          l     r
 455        x + 1 = 3
 456        a   b
 457    """
 458    if isinstance(expression, COMPARISONS):
 459        l, r = expression.left, expression.right
 460
 461        if l.__class__ in INVERSE_OPS:
 462            pass
 463        elif r.__class__ in INVERSE_OPS:
 464            l, r = r, l
 465        else:
 466            return expression
 467
 468        if r.is_number:
 469            a_predicate = _is_number
 470            b_predicate = _is_number
 471        elif _is_date_literal(r):
 472            a_predicate = _is_date_literal
 473            b_predicate = _is_interval
 474        else:
 475            return expression
 476
 477        if l.__class__ in INVERSE_DATE_OPS:
 478            l = t.cast(exp.IntervalOp, l)
 479            a = l.this
 480            b = l.interval()
 481        else:
 482            l = t.cast(exp.Binary, l)
 483            a, b = l.left, l.right
 484
 485        if not a_predicate(a) and b_predicate(b):
 486            pass
 487        elif not a_predicate(b) and b_predicate(a):
 488            a, b = b, a
 489        else:
 490            return expression
 491
 492        return expression.__class__(
 493            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 494        )
 495    return expression
 496
 497
 498def simplify_literals(expression, root=True):
 499    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 500        return _flat_simplify(expression, _simplify_binary, root)
 501
 502    if isinstance(expression, exp.Neg):
 503        this = expression.this
 504        if this.is_number:
 505            value = this.name
 506            if value[0] == "-":
 507                return exp.Literal.number(value[1:])
 508            return exp.Literal.number(f"-{value}")
 509
 510    if type(expression) in INVERSE_DATE_OPS:
 511        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 512
 513    return expression
 514
 515
 516def _simplify_binary(expression, a, b):
 517    if isinstance(expression, exp.Is):
 518        if isinstance(b, exp.Not):
 519            c = b.this
 520            not_ = True
 521        else:
 522            c = b
 523            not_ = False
 524
 525        if is_null(c):
 526            if isinstance(a, exp.Literal):
 527                return exp.true() if not_ else exp.false()
 528            if is_null(a):
 529                return exp.false() if not_ else exp.true()
 530    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 531        return None
 532    elif is_null(a) or is_null(b):
 533        return exp.null()
 534
 535    if a.is_number and b.is_number:
 536        num_a = int(a.name) if a.is_int else Decimal(a.name)
 537        num_b = int(b.name) if b.is_int else Decimal(b.name)
 538
 539        if isinstance(expression, exp.Add):
 540            return exp.Literal.number(num_a + num_b)
 541        if isinstance(expression, exp.Mul):
 542            return exp.Literal.number(num_a * num_b)
 543
 544        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 545        if isinstance(expression, exp.Sub):
 546            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 547        if isinstance(expression, exp.Div):
 548            # engines have differing int div behavior so intdiv is not safe
 549            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 550                return None
 551            return exp.Literal.number(num_a / num_b)
 552
 553        boolean = eval_boolean(expression, num_a, num_b)
 554
 555        if boolean:
 556            return boolean
 557    elif a.is_string and b.is_string:
 558        boolean = eval_boolean(expression, a.this, b.this)
 559
 560        if boolean:
 561            return boolean
 562    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 563        a, b = extract_date(a), extract_interval(b)
 564        if a and b:
 565            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 566                return date_literal(a + b)
 567            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 568                return date_literal(a - b)
 569    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 570        a, b = extract_interval(a), extract_date(b)
 571        # you cannot subtract a date from an interval
 572        if a and b and isinstance(expression, exp.Add):
 573            return date_literal(a + b)
 574    elif _is_date_literal(a) and _is_date_literal(b):
 575        if isinstance(expression, exp.Predicate):
 576            a, b = extract_date(a), extract_date(b)
 577            boolean = eval_boolean(expression, a, b)
 578            if boolean:
 579                return boolean
 580
 581    return None
 582
 583
 584def simplify_parens(expression):
 585    if not isinstance(expression, exp.Paren):
 586        return expression
 587
 588    this = expression.this
 589    parent = expression.parent
 590
 591    if not isinstance(this, exp.Select) and (
 592        not isinstance(parent, (exp.Condition, exp.Binary))
 593        or isinstance(parent, exp.Paren)
 594        or not isinstance(this, exp.Binary)
 595        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 596        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 597        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 598        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 599    ):
 600        return this
 601    return expression
 602
 603
 604NONNULL_CONSTANTS = (
 605    exp.Literal,
 606    exp.Boolean,
 607)
 608
 609CONSTANTS = (
 610    exp.Literal,
 611    exp.Boolean,
 612    exp.Null,
 613)
 614
 615
 616def _is_nonnull_constant(expression: exp.Expression) -> bool:
 617    return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
 618
 619
 620def _is_constant(expression: exp.Expression) -> bool:
 621    return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
 622
 623
 624def simplify_coalesce(expression):
 625    # COALESCE(x) -> x
 626    if (
 627        isinstance(expression, exp.Coalesce)
 628        and (not expression.expressions or _is_nonnull_constant(expression.this))
 629        # COALESCE is also used as a Spark partitioning hint
 630        and not isinstance(expression.parent, exp.Hint)
 631    ):
 632        return expression.this
 633
 634    if not isinstance(expression, COMPARISONS):
 635        return expression
 636
 637    if isinstance(expression.left, exp.Coalesce):
 638        coalesce = expression.left
 639        other = expression.right
 640    elif isinstance(expression.right, exp.Coalesce):
 641        coalesce = expression.right
 642        other = expression.left
 643    else:
 644        return expression
 645
 646    # This transformation is valid for non-constants,
 647    # but it really only does anything if they are both constants.
 648    if not _is_constant(other):
 649        return expression
 650
 651    # Find the first constant arg
 652    for arg_index, arg in enumerate(coalesce.expressions):
 653        if _is_constant(other):
 654            break
 655    else:
 656        return expression
 657
 658    coalesce.set("expressions", coalesce.expressions[:arg_index])
 659
 660    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 661    # since we already remove COALESCE at the top of this function.
 662    coalesce = coalesce if coalesce.expressions else coalesce.this
 663
 664    # This expression is more complex than when we started, but it will get simplified further
 665    return exp.paren(
 666        exp.or_(
 667            exp.and_(
 668                coalesce.is_(exp.null()).not_(copy=False),
 669                expression.copy(),
 670                copy=False,
 671            ),
 672            exp.and_(
 673                coalesce.is_(exp.null()),
 674                type(expression)(this=arg.copy(), expression=other.copy()),
 675                copy=False,
 676            ),
 677            copy=False,
 678        )
 679    )
 680
 681
 682CONCATS = (exp.Concat, exp.DPipe)
 683
 684
 685def simplify_concat(expression):
 686    """Reduces all groups that contain string literals by concatenating them."""
 687    if not isinstance(expression, CONCATS) or (
 688        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 689        isinstance(expression, exp.ConcatWs)
 690        and not expression.expressions[0].is_string
 691    ):
 692        return expression
 693
 694    if isinstance(expression, exp.ConcatWs):
 695        sep_expr, *expressions = expression.expressions
 696        sep = sep_expr.name
 697        concat_type = exp.ConcatWs
 698        args = {}
 699    else:
 700        expressions = expression.expressions
 701        sep = ""
 702        concat_type = exp.Concat
 703        args = {
 704            "safe": expression.args.get("safe"),
 705            "coalesce": expression.args.get("coalesce"),
 706        }
 707
 708    new_args = []
 709    for is_string_group, group in itertools.groupby(
 710        expressions or expression.flatten(), lambda e: e.is_string
 711    ):
 712        if is_string_group:
 713            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 714        else:
 715            new_args.extend(group)
 716
 717    if len(new_args) == 1 and new_args[0].is_string:
 718        return new_args[0]
 719
 720    if concat_type is exp.ConcatWs:
 721        new_args = [sep_expr] + new_args
 722
 723    return concat_type(expressions=new_args, **args)
 724
 725
 726def simplify_conditionals(expression):
 727    """Simplifies expressions like IF, CASE if their condition is statically known."""
 728    if isinstance(expression, exp.Case):
 729        this = expression.this
 730        for case in expression.args["ifs"]:
 731            cond = case.this
 732            if this:
 733                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 734                cond = cond.replace(this.pop().eq(cond))
 735
 736            if always_true(cond):
 737                return case.args["true"]
 738
 739            if always_false(cond):
 740                case.pop()
 741                if not expression.args["ifs"]:
 742                    return expression.args.get("default") or exp.null()
 743    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 744        if always_true(expression.this):
 745            return expression.args["true"]
 746        if always_false(expression.this):
 747            return expression.args.get("false") or exp.null()
 748
 749    return expression
 750
 751
 752DateRange = t.Tuple[datetime.date, datetime.date]
 753
 754
 755def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 756    """
 757    Get the date range for a DATE_TRUNC equality comparison:
 758
 759    Example:
 760        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 761    Returns:
 762        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 763    """
 764    floor = date_floor(date, unit)
 765
 766    if date != floor:
 767        # This will always be False, except for NULL values.
 768        return None
 769
 770    return floor, floor + interval(unit)
 771
 772
 773def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 774    """Get the logical expression for a date range"""
 775    return exp.and_(
 776        left >= date_literal(drange[0]),
 777        left < date_literal(drange[1]),
 778        copy=False,
 779    )
 780
 781
 782def _datetrunc_eq(
 783    left: exp.Expression, date: datetime.date, unit: str
 784) -> t.Optional[exp.Expression]:
 785    drange = _datetrunc_range(date, unit)
 786    if not drange:
 787        return None
 788
 789    return _datetrunc_eq_expression(left, drange)
 790
 791
 792def _datetrunc_neq(
 793    left: exp.Expression, date: datetime.date, unit: str
 794) -> t.Optional[exp.Expression]:
 795    drange = _datetrunc_range(date, unit)
 796    if not drange:
 797        return None
 798
 799    return exp.and_(
 800        left < date_literal(drange[0]),
 801        left >= date_literal(drange[1]),
 802        copy=False,
 803    )
 804
 805
 806DateTruncBinaryTransform = t.Callable[
 807    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 808]
 809DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 810    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 811    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 812    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 813    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 814    exp.EQ: _datetrunc_eq,
 815    exp.NEQ: _datetrunc_neq,
 816}
 817DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 818
 819
 820def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 821    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 822
 823
 824@catch(ModuleNotFoundError, UnsupportedUnit)
 825def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 826    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 827    comparison = expression.__class__
 828
 829    if comparison not in DATETRUNC_COMPARISONS:
 830        return expression
 831
 832    if isinstance(expression, exp.Binary):
 833        l, r = expression.left, expression.right
 834
 835        if _is_datetrunc_predicate(l, r):
 836            pass
 837        elif _is_datetrunc_predicate(r, l):
 838            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 839            l, r = r, l
 840        else:
 841            return expression
 842
 843        l = t.cast(exp.DateTrunc, l)
 844        unit = l.unit.name.lower()
 845        date = extract_date(r)
 846
 847        if not date:
 848            return expression
 849
 850        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 851    elif isinstance(expression, exp.In):
 852        l = expression.this
 853        rs = expression.expressions
 854
 855        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 856            l = t.cast(exp.DateTrunc, l)
 857            unit = l.unit.name.lower()
 858
 859            ranges = []
 860            for r in rs:
 861                date = extract_date(r)
 862                if not date:
 863                    return expression
 864                drange = _datetrunc_range(date, unit)
 865                if drange:
 866                    ranges.append(drange)
 867
 868            if not ranges:
 869                return expression
 870
 871            ranges = merge_ranges(ranges)
 872
 873            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 874
 875    return expression
 876
 877
 878# CROSS joins result in an empty table if the right table is empty.
 879# So we can only simplify certain types of joins to CROSS.
 880# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 881JOINS = {
 882    ("", ""),
 883    ("", "INNER"),
 884    ("RIGHT", ""),
 885    ("RIGHT", "OUTER"),
 886}
 887
 888
 889def remove_where_true(expression):
 890    for where in expression.find_all(exp.Where):
 891        if always_true(where.this):
 892            where.parent.set("where", None)
 893    for join in expression.find_all(exp.Join):
 894        if (
 895            always_true(join.args.get("on"))
 896            and not join.args.get("using")
 897            and not join.args.get("method")
 898            and (join.side, join.kind) in JOINS
 899        ):
 900            join.set("on", None)
 901            join.set("side", None)
 902            join.set("kind", "CROSS")
 903
 904
 905def always_true(expression):
 906    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 907        expression, exp.Literal
 908    )
 909
 910
 911def always_false(expression):
 912    return is_false(expression) or is_null(expression)
 913
 914
 915def is_complement(a, b):
 916    return isinstance(b, exp.Not) and b.this == a
 917
 918
 919def is_false(a: exp.Expression) -> bool:
 920    return type(a) is exp.Boolean and not a.this
 921
 922
 923def is_null(a: exp.Expression) -> bool:
 924    return type(a) is exp.Null
 925
 926
 927def eval_boolean(expression, a, b):
 928    if isinstance(expression, (exp.EQ, exp.Is)):
 929        return boolean_literal(a == b)
 930    if isinstance(expression, exp.NEQ):
 931        return boolean_literal(a != b)
 932    if isinstance(expression, exp.GT):
 933        return boolean_literal(a > b)
 934    if isinstance(expression, exp.GTE):
 935        return boolean_literal(a >= b)
 936    if isinstance(expression, exp.LT):
 937        return boolean_literal(a < b)
 938    if isinstance(expression, exp.LTE):
 939        return boolean_literal(a <= b)
 940    return None
 941
 942
 943def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 944    if isinstance(value, datetime.datetime):
 945        return value.date()
 946    if isinstance(value, datetime.date):
 947        return value
 948    try:
 949        return datetime.datetime.fromisoformat(value).date()
 950    except ValueError:
 951        return None
 952
 953
 954def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 955    if isinstance(value, datetime.datetime):
 956        return value
 957    if isinstance(value, datetime.date):
 958        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 959    try:
 960        return datetime.datetime.fromisoformat(value)
 961    except ValueError:
 962        return None
 963
 964
 965def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 966    if not value:
 967        return None
 968    if to.is_type(exp.DataType.Type.DATE):
 969        return cast_as_date(value)
 970    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 971        return cast_as_datetime(value)
 972    return None
 973
 974
 975def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 976    if isinstance(cast, exp.Cast):
 977        to = cast.to
 978    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
 979        to = exp.DataType.build(exp.DataType.Type.DATE)
 980    else:
 981        return None
 982
 983    if isinstance(cast.this, exp.Literal):
 984        value: t.Any = cast.this.name
 985    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 986        value = extract_date(cast.this)
 987    else:
 988        return None
 989    return cast_value(value, to)
 990
 991
 992def _is_date_literal(expression: exp.Expression) -> bool:
 993    return extract_date(expression) is not None
 994
 995
 996def extract_interval(expression):
 997    try:
 998        n = int(expression.name)
 999        unit = expression.text("unit").lower()
1000        return interval(unit, n)
1001    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1002        return None
1003
1004
1005def date_literal(date):
1006    return exp.cast(
1007        exp.Literal.string(date),
1008        exp.DataType.Type.DATETIME
1009        if isinstance(date, datetime.datetime)
1010        else exp.DataType.Type.DATE,
1011    )
1012
1013
1014def interval(unit: str, n: int = 1):
1015    from dateutil.relativedelta import relativedelta
1016
1017    if unit == "year":
1018        return relativedelta(years=1 * n)
1019    if unit == "quarter":
1020        return relativedelta(months=3 * n)
1021    if unit == "month":
1022        return relativedelta(months=1 * n)
1023    if unit == "week":
1024        return relativedelta(weeks=1 * n)
1025    if unit == "day":
1026        return relativedelta(days=1 * n)
1027    if unit == "hour":
1028        return relativedelta(hours=1 * n)
1029    if unit == "minute":
1030        return relativedelta(minutes=1 * n)
1031    if unit == "second":
1032        return relativedelta(seconds=1 * n)
1033
1034    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1035
1036
1037def date_floor(d: datetime.date, unit: str) -> datetime.date:
1038    if unit == "year":
1039        return d.replace(month=1, day=1)
1040    if unit == "quarter":
1041        if d.month <= 3:
1042            return d.replace(month=1, day=1)
1043        elif d.month <= 6:
1044            return d.replace(month=4, day=1)
1045        elif d.month <= 9:
1046            return d.replace(month=7, day=1)
1047        else:
1048            return d.replace(month=10, day=1)
1049    if unit == "month":
1050        return d.replace(month=d.month, day=1)
1051    if unit == "week":
1052        # Assuming week starts on Monday (0) and ends on Sunday (6)
1053        return d - datetime.timedelta(days=d.weekday())
1054    if unit == "day":
1055        return d
1056
1057    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1058
1059
1060def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1061    floor = date_floor(d, unit)
1062
1063    if floor == d:
1064        return d
1065
1066    return floor + interval(unit)
1067
1068
1069def boolean_literal(condition):
1070    return exp.true() if condition else exp.false()
1071
1072
1073def _flat_simplify(expression, simplifier, root=True):
1074    if root or not expression.same_parent:
1075        operands = []
1076        queue = deque(expression.flatten(unnest=False))
1077        size = len(queue)
1078
1079        while queue:
1080            a = queue.popleft()
1081
1082            for b in queue:
1083                result = simplifier(expression, a, b)
1084
1085                if result and result is not expression:
1086                    queue.remove(b)
1087                    queue.appendleft(result)
1088                    break
1089            else:
1090                operands.append(a)
1091
1092        if len(operands) < size:
1093            return functools.reduce(
1094                lambda a, b: expression.__class__(this=a, expression=b), operands
1095            )
1096    return expression
1097
1098
1099def gen(expression: t.Any) -> str:
1100    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1101
1102    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1103    generator is expensive so we have a bare minimum sql generator here.
1104    """
1105    if expression is None:
1106        return "_"
1107    if is_iterable(expression):
1108        return ",".join(gen(e) for e in expression)
1109    if not isinstance(expression, exp.Expression):
1110        return str(expression)
1111
1112    etype = type(expression)
1113    if etype in GEN_MAP:
1114        return GEN_MAP[etype](expression)
1115    return f"{expression.key} {gen(expression.args.values())}"
1116
1117
1118GEN_MAP = {
1119    exp.Add: lambda e: _binary(e, "+"),
1120    exp.And: lambda e: _binary(e, "AND"),
1121    exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
1122    exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
1123    exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
1124    exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
1125    exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
1126    exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
1127    exp.Div: lambda e: _binary(e, "/"),
1128    exp.Dot: lambda e: _binary(e, "."),
1129    exp.EQ: lambda e: _binary(e, "="),
1130    exp.GT: lambda e: _binary(e, ">"),
1131    exp.GTE: lambda e: _binary(e, ">="),
1132    exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
1133    exp.ILike: lambda e: _binary(e, "ILIKE"),
1134    exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
1135    exp.Is: lambda e: _binary(e, "IS"),
1136    exp.Like: lambda e: _binary(e, "LIKE"),
1137    exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
1138    exp.LT: lambda e: _binary(e, "<"),
1139    exp.LTE: lambda e: _binary(e, "<="),
1140    exp.Mod: lambda e: _binary(e, "%"),
1141    exp.Mul: lambda e: _binary(e, "*"),
1142    exp.Neg: lambda e: _unary(e, "-"),
1143    exp.NEQ: lambda e: _binary(e, "<>"),
1144    exp.Not: lambda e: _unary(e, "NOT"),
1145    exp.Null: lambda e: "NULL",
1146    exp.Or: lambda e: _binary(e, "OR"),
1147    exp.Paren: lambda e: f"({gen(e.this)})",
1148    exp.Sub: lambda e: _binary(e, "-"),
1149    exp.Subquery: lambda e: f"({gen(e.args.values())})",
1150    exp.Table: lambda e: gen(e.args.values()),
1151    exp.Var: lambda e: e.name,
1152}
1153
1154
1155def _binary(e: exp.Binary, op: str) -> str:
1156    return f"{gen(e.left)} {op} {gen(e.right)}"
1157
1158
1159def _unary(e: exp.Unary, op: str) -> str:
1160    return f"{op} {gen(e.this)}"
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
18class UnsupportedUnit(Exception):
19    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
22def simplify(expression, constant_propagation=False):
23    """
24    Rewrite sqlglot AST to simplify expressions.
25
26    Example:
27        >>> import sqlglot
28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
29        >>> simplify(expression).sql()
30        'TRUE'
31
32    Args:
33        expression (sqlglot.Expression): expression to simplify
34        constant_propagation: whether or not the constant propagation rule should be used
35
36    Returns:
37        sqlglot.Expression: simplified expression
38    """
39
40    # group by expressions cannot be simplified, for example
41    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
42    # the projection must exactly match the group by key
43    for group in expression.find_all(exp.Group):
44        select = group.parent
45        groups = set(group.expressions)
46        group.meta[FINAL] = True
47
48        for e in select.selects:
49            for node, *_ in e.walk():
50                if node in groups:
51                    e.meta[FINAL] = True
52                    break
53
54        having = select.args.get("having")
55        if having:
56            for node, *_ in having.walk():
57                if node in groups:
58                    having.meta[FINAL] = True
59                    break
60
61    def _simplify(expression, root=True):
62        if expression.meta.get(FINAL):
63            return expression
64
65        # Pre-order transformations
66        node = expression
67        node = rewrite_between(node)
68        node = uniq_sort(node, root)
69        node = absorb_and_eliminate(node, root)
70        node = simplify_concat(node)
71        node = simplify_conditionals(node)
72
73        if constant_propagation:
74            node = propagate_constants(node, root)
75
76        exp.replace_children(node, lambda e: _simplify(e, False))
77
78        # Post-order transformations
79        node = simplify_not(node)
80        node = flatten(node)
81        node = simplify_connectors(node, root)
82        node = remove_complements(node, root)
83        node = simplify_coalesce(node)
84        node.parent = expression.parent
85        node = simplify_literals(node, root)
86        node = simplify_equality(node)
87        node = simplify_parens(node)
88        node = simplify_datetrunc_predicate(node)
89
90        if root:
91            expression.replace(node)
92
93        return node
94
95    expression = while_changing(expression, _simplify)
96    remove_where_true(expression)
97    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
100def catch(*exceptions):
101    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
102
103    def decorator(func):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
109
110        return wrapped
111
112    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
115def rewrite_between(expression: exp.Expression) -> exp.Expression:
116    """Rewrite x between y and z to x >= y AND x <= z.
117
118    This is done because comparison simplification is only done on lt/lte/gt/gte.
119    """
120    if isinstance(expression, exp.Between):
121        return exp.and_(
122            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
123            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
124            copy=False,
125        )
126    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
129def simplify_not(expression):
130    """
131    Demorgan's Law
132    NOT (x OR y) -> NOT x AND NOT y
133    NOT (x AND y) -> NOT x OR NOT y
134    """
135    if isinstance(expression, exp.Not):
136        if is_null(expression.this):
137            return exp.null()
138        if isinstance(expression.this, exp.Paren):
139            condition = expression.this.unnest()
140            if isinstance(condition, exp.And):
141                return exp.or_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if isinstance(condition, exp.Or):
147                return exp.and_(
148                    exp.not_(condition.left, copy=False),
149                    exp.not_(condition.right, copy=False),
150                    copy=False,
151                )
152            if is_null(condition):
153                return exp.null()
154        if always_true(expression.this):
155            return exp.false()
156        if is_false(expression.this):
157            return exp.true()
158        if isinstance(expression.this, exp.Not):
159            # double negation
160            # NOT NOT x -> x
161            return expression.this.this
162    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
165def flatten(expression):
166    """
167    A AND (B AND C) -> A AND B AND C
168    A OR (B OR C) -> A OR B OR C
169    """
170    if isinstance(expression, exp.Connector):
171        for node in expression.args.values():
172            child = node.unnest()
173            if isinstance(child, expression.__class__):
174                node.replace(child)
175    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
178def simplify_connectors(expression, root=True):
179    def _simplify_connectors(expression, left, right):
180        if left == right:
181            return left
182        if isinstance(expression, exp.And):
183            if is_false(left) or is_false(right):
184                return exp.false()
185            if is_null(left) or is_null(right):
186                return exp.null()
187            if always_true(left) and always_true(right):
188                return exp.true()
189            if always_true(left):
190                return right
191            if always_true(right):
192                return left
193            return _simplify_comparison(expression, left, right)
194        elif isinstance(expression, exp.Or):
195            if always_true(left) or always_true(right):
196                return exp.true()
197            if is_false(left) and is_false(right):
198                return exp.false()
199            if (
200                (is_null(left) and is_null(right))
201                or (is_null(left) and is_false(right))
202                or (is_false(left) and is_null(right))
203            ):
204                return exp.null()
205            if is_false(left):
206                return right
207            if is_false(right):
208                return left
209            return _simplify_comparison(expression, left, right, or_=True)
210
211    if isinstance(expression, exp.Connector):
212        return _flat_simplify(expression, _simplify_connectors, root)
213    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
296def remove_complements(expression, root=True):
297    """
298    Removing complements.
299
300    A AND NOT A -> FALSE
301    A OR NOT A -> TRUE
302    """
303    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
304        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
305
306        for a, b in itertools.permutations(expression.flatten(), 2):
307            if is_complement(a, b):
308                return complement
309    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
312def uniq_sort(expression, root=True):
313    """
314    Uniq and sort a connector.
315
316    C AND A AND B AND B -> A AND B AND C
317    """
318    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
319        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
320        flattened = tuple(expression.flatten())
321        deduped = {gen(e): e for e in flattened}
322        arr = tuple(deduped.items())
323
324        # check if the operands are already sorted, if not sort them
325        # A AND C AND B -> A AND B AND C
326        for i, (sql, e) in enumerate(arr[1:]):
327            if sql < arr[i][0]:
328                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
329                break
330        else:
331            # we didn't have to sort but maybe we need to dedup
332            if len(deduped) < len(flattened):
333                expression = result_func(*deduped.values(), copy=False)
334
335    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
338def absorb_and_eliminate(expression, root=True):
339    """
340    absorption:
341        A AND (A OR B) -> A
342        A OR (A AND B) -> A
343        A AND (NOT A OR B) -> A AND B
344        A OR (NOT A AND B) -> A OR B
345    elimination:
346        (A AND B) OR (A AND NOT B) -> A
347        (A OR B) AND (A OR NOT B) -> A
348    """
349    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
350        kind = exp.Or if isinstance(expression, exp.And) else exp.And
351
352        for a, b in itertools.permutations(expression.flatten(), 2):
353            if isinstance(a, kind):
354                aa, ab = a.unnest_operands()
355
356                # absorb
357                if is_complement(b, aa):
358                    aa.replace(exp.true() if kind == exp.And else exp.false())
359                elif is_complement(b, ab):
360                    ab.replace(exp.true() if kind == exp.And else exp.false())
361                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
362                    a.replace(exp.false() if kind == exp.And else exp.true())
363                elif isinstance(b, kind):
364                    # eliminate
365                    rhs = b.unnest_operands()
366                    ba, bb = rhs
367
368                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
369                        a.replace(aa)
370                        b.replace(aa)
371                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
372                        a.replace(ab)
373                        b.replace(ab)
374
375    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
378def propagate_constants(expression, root=True):
379    """
380    Propagate constants for conjunctions in DNF:
381
382    SELECT * FROM t WHERE a = b AND b = 5 becomes
383    SELECT * FROM t WHERE a = 5 AND b = 5
384
385    Reference: https://www.sqlite.org/optoverview.html
386    """
387
388    if (
389        isinstance(expression, exp.And)
390        and (root or not expression.same_parent)
391        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
392    ):
393        constant_mapping = {}
394        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
395            if isinstance(expr, exp.EQ):
396                l, r = expr.left, expr.right
397
398                # TODO: create a helper that can be used to detect nested literal expressions such
399                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
400                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
401                    pass
402                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
403                    l, r = r, l
404                else:
405                    continue
406
407                constant_mapping[l] = (id(l), r)
408
409        if constant_mapping:
410            for column in find_all_in_scope(expression, exp.Column):
411                parent = column.parent
412                column_id, constant = constant_mapping.get(column) or (None, None)
413                if (
414                    column_id is not None
415                    and id(column) != column_id
416                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
417                ):
418                    column.replace(constant.copy())
419
420    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
499def simplify_literals(expression, root=True):
500    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
501        return _flat_simplify(expression, _simplify_binary, root)
502
503    if isinstance(expression, exp.Neg):
504        this = expression.this
505        if this.is_number:
506            value = this.name
507            if value[0] == "-":
508                return exp.Literal.number(value[1:])
509            return exp.Literal.number(f"-{value}")
510
511    if type(expression) in INVERSE_DATE_OPS:
512        return _simplify_binary(expression, expression.this, expression.interval()) or expression
513
514    return expression
def simplify_parens(expression):
585def simplify_parens(expression):
586    if not isinstance(expression, exp.Paren):
587        return expression
588
589    this = expression.this
590    parent = expression.parent
591
592    if not isinstance(this, exp.Select) and (
593        not isinstance(parent, (exp.Condition, exp.Binary))
594        or isinstance(parent, exp.Paren)
595        or not isinstance(this, exp.Binary)
596        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
597        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
598        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
599        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
600    ):
601        return this
602    return expression
NONNULL_CONSTANTS = (<class 'sqlglot.expressions.Literal'>, <class 'sqlglot.expressions.Boolean'>)
def simplify_coalesce(expression):
625def simplify_coalesce(expression):
626    # COALESCE(x) -> x
627    if (
628        isinstance(expression, exp.Coalesce)
629        and (not expression.expressions or _is_nonnull_constant(expression.this))
630        # COALESCE is also used as a Spark partitioning hint
631        and not isinstance(expression.parent, exp.Hint)
632    ):
633        return expression.this
634
635    if not isinstance(expression, COMPARISONS):
636        return expression
637
638    if isinstance(expression.left, exp.Coalesce):
639        coalesce = expression.left
640        other = expression.right
641    elif isinstance(expression.right, exp.Coalesce):
642        coalesce = expression.right
643        other = expression.left
644    else:
645        return expression
646
647    # This transformation is valid for non-constants,
648    # but it really only does anything if they are both constants.
649    if not _is_constant(other):
650        return expression
651
652    # Find the first constant arg
653    for arg_index, arg in enumerate(coalesce.expressions):
654        if _is_constant(other):
655            break
656    else:
657        return expression
658
659    coalesce.set("expressions", coalesce.expressions[:arg_index])
660
661    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
662    # since we already remove COALESCE at the top of this function.
663    coalesce = coalesce if coalesce.expressions else coalesce.this
664
665    # This expression is more complex than when we started, but it will get simplified further
666    return exp.paren(
667        exp.or_(
668            exp.and_(
669                coalesce.is_(exp.null()).not_(copy=False),
670                expression.copy(),
671                copy=False,
672            ),
673            exp.and_(
674                coalesce.is_(exp.null()),
675                type(expression)(this=arg.copy(), expression=other.copy()),
676                copy=False,
677            ),
678            copy=False,
679        )
680    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
686def simplify_concat(expression):
687    """Reduces all groups that contain string literals by concatenating them."""
688    if not isinstance(expression, CONCATS) or (
689        # We can't reduce a CONCAT_WS call if we don't statically know the separator
690        isinstance(expression, exp.ConcatWs)
691        and not expression.expressions[0].is_string
692    ):
693        return expression
694
695    if isinstance(expression, exp.ConcatWs):
696        sep_expr, *expressions = expression.expressions
697        sep = sep_expr.name
698        concat_type = exp.ConcatWs
699        args = {}
700    else:
701        expressions = expression.expressions
702        sep = ""
703        concat_type = exp.Concat
704        args = {
705            "safe": expression.args.get("safe"),
706            "coalesce": expression.args.get("coalesce"),
707        }
708
709    new_args = []
710    for is_string_group, group in itertools.groupby(
711        expressions or expression.flatten(), lambda e: e.is_string
712    ):
713        if is_string_group:
714            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
715        else:
716            new_args.extend(group)
717
718    if len(new_args) == 1 and new_args[0].is_string:
719        return new_args[0]
720
721    if concat_type is exp.ConcatWs:
722        new_args = [sep_expr] + new_args
723
724    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
727def simplify_conditionals(expression):
728    """Simplifies expressions like IF, CASE if their condition is statically known."""
729    if isinstance(expression, exp.Case):
730        this = expression.this
731        for case in expression.args["ifs"]:
732            cond = case.this
733            if this:
734                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
735                cond = cond.replace(this.pop().eq(cond))
736
737            if always_true(cond):
738                return case.args["true"]
739
740            if always_false(cond):
741                case.pop()
742                if not expression.args["ifs"]:
743                    return expression.args.get("default") or exp.null()
744    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
745        if always_true(expression.this):
746            return expression.args["true"]
747        if always_false(expression.this):
748            return expression.args.get("false") or exp.null()
749
750    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', '')}
def remove_where_true(expression):
890def remove_where_true(expression):
891    for where in expression.find_all(exp.Where):
892        if always_true(where.this):
893            where.parent.set("where", None)
894    for join in expression.find_all(exp.Join):
895        if (
896            always_true(join.args.get("on"))
897            and not join.args.get("using")
898            and not join.args.get("method")
899            and (join.side, join.kind) in JOINS
900        ):
901            join.set("on", None)
902            join.set("side", None)
903            join.set("kind", "CROSS")
def always_true(expression):
906def always_true(expression):
907    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
908        expression, exp.Literal
909    )
def always_false(expression):
912def always_false(expression):
913    return is_false(expression) or is_null(expression)
def is_complement(a, b):
916def is_complement(a, b):
917    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
920def is_false(a: exp.Expression) -> bool:
921    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
924def is_null(a: exp.Expression) -> bool:
925    return type(a) is exp.Null
def eval_boolean(expression, a, b):
928def eval_boolean(expression, a, b):
929    if isinstance(expression, (exp.EQ, exp.Is)):
930        return boolean_literal(a == b)
931    if isinstance(expression, exp.NEQ):
932        return boolean_literal(a != b)
933    if isinstance(expression, exp.GT):
934        return boolean_literal(a > b)
935    if isinstance(expression, exp.GTE):
936        return boolean_literal(a >= b)
937    if isinstance(expression, exp.LT):
938        return boolean_literal(a < b)
939    if isinstance(expression, exp.LTE):
940        return boolean_literal(a <= b)
941    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
944def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
945    if isinstance(value, datetime.datetime):
946        return value.date()
947    if isinstance(value, datetime.date):
948        return value
949    try:
950        return datetime.datetime.fromisoformat(value).date()
951    except ValueError:
952        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
955def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
956    if isinstance(value, datetime.datetime):
957        return value
958    if isinstance(value, datetime.date):
959        return datetime.datetime(year=value.year, month=value.month, day=value.day)
960    try:
961        return datetime.datetime.fromisoformat(value)
962    except ValueError:
963        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
966def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
967    if not value:
968        return None
969    if to.is_type(exp.DataType.Type.DATE):
970        return cast_as_date(value)
971    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
972        return cast_as_datetime(value)
973    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
976def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
977    if isinstance(cast, exp.Cast):
978        to = cast.to
979    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
980        to = exp.DataType.build(exp.DataType.Type.DATE)
981    else:
982        return None
983
984    if isinstance(cast.this, exp.Literal):
985        value: t.Any = cast.this.name
986    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
987        value = extract_date(cast.this)
988    else:
989        return None
990    return cast_value(value, to)
def extract_interval(expression):
 997def extract_interval(expression):
 998    try:
 999        n = int(expression.name)
1000        unit = expression.text("unit").lower()
1001        return interval(unit, n)
1002    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1003        return None
def date_literal(date):
1006def date_literal(date):
1007    return exp.cast(
1008        exp.Literal.string(date),
1009        exp.DataType.Type.DATETIME
1010        if isinstance(date, datetime.datetime)
1011        else exp.DataType.Type.DATE,
1012    )
def interval(unit: str, n: int = 1):
1015def interval(unit: str, n: int = 1):
1016    from dateutil.relativedelta import relativedelta
1017
1018    if unit == "year":
1019        return relativedelta(years=1 * n)
1020    if unit == "quarter":
1021        return relativedelta(months=3 * n)
1022    if unit == "month":
1023        return relativedelta(months=1 * n)
1024    if unit == "week":
1025        return relativedelta(weeks=1 * n)
1026    if unit == "day":
1027        return relativedelta(days=1 * n)
1028    if unit == "hour":
1029        return relativedelta(hours=1 * n)
1030    if unit == "minute":
1031        return relativedelta(minutes=1 * n)
1032    if unit == "second":
1033        return relativedelta(seconds=1 * n)
1034
1035    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
1038def date_floor(d: datetime.date, unit: str) -> datetime.date:
1039    if unit == "year":
1040        return d.replace(month=1, day=1)
1041    if unit == "quarter":
1042        if d.month <= 3:
1043            return d.replace(month=1, day=1)
1044        elif d.month <= 6:
1045            return d.replace(month=4, day=1)
1046        elif d.month <= 9:
1047            return d.replace(month=7, day=1)
1048        else:
1049            return d.replace(month=10, day=1)
1050    if unit == "month":
1051        return d.replace(month=d.month, day=1)
1052    if unit == "week":
1053        # Assuming week starts on Monday (0) and ends on Sunday (6)
1054        return d - datetime.timedelta(days=d.weekday())
1055    if unit == "day":
1056        return d
1057
1058    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1061def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1062    floor = date_floor(d, unit)
1063
1064    if floor == d:
1065        return d
1066
1067    return floor + interval(unit)
def boolean_literal(condition):
1070def boolean_literal(condition):
1071    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1100def gen(expression: t.Any) -> str:
1101    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1102
1103    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1104    generator is expensive so we have a bare minimum sql generator here.
1105    """
1106    if expression is None:
1107        return "_"
1108    if is_iterable(expression):
1109        return ",".join(gen(e) for e in expression)
1110    if not isinstance(expression, exp.Expression):
1111        return str(expression)
1112
1113    etype = type(expression)
1114    if etype in GEN_MAP:
1115        return GEN_MAP[etype](expression)
1116    return f"{expression.key} {gen(expression.args.values())}"

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

GEN_MAP = {<class 'sqlglot.expressions.Add'>: <function <lambda>>, <class 'sqlglot.expressions.And'>: <function <lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function <lambda>>, <class 'sqlglot.expressions.Between'>: <function <lambda>>, <class 'sqlglot.expressions.Boolean'>: <function <lambda>>, <class 'sqlglot.expressions.Bracket'>: <function <lambda>>, <class 'sqlglot.expressions.Column'>: <function <lambda>>, <class 'sqlglot.expressions.DataType'>: <function <lambda>>, <class 'sqlglot.expressions.Div'>: <function <lambda>>, <class 'sqlglot.expressions.Dot'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.Identifier'>: <function <lambda>>, <class 'sqlglot.expressions.ILike'>: <function <lambda>>, <class 'sqlglot.expressions.In'>: <function <lambda>>, <class 'sqlglot.expressions.Is'>: <function <lambda>>, <class 'sqlglot.expressions.Like'>: <function <lambda>>, <class 'sqlglot.expressions.Literal'>: <function <lambda>>, <class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.Mod'>: <function <lambda>>, <class 'sqlglot.expressions.Mul'>: <function <lambda>>, <class 'sqlglot.expressions.Neg'>: <function <lambda>>, <class 'sqlglot.expressions.NEQ'>: <function <lambda>>, <class 'sqlglot.expressions.Not'>: <function <lambda>>, <class 'sqlglot.expressions.Null'>: <function <lambda>>, <class 'sqlglot.expressions.Or'>: <function <lambda>>, <class 'sqlglot.expressions.Paren'>: <function <lambda>>, <class 'sqlglot.expressions.Sub'>: <function <lambda>>, <class 'sqlglot.expressions.Subquery'>: <function <lambda>>, <class 'sqlglot.expressions.Table'>: <function <lambda>>, <class 'sqlglot.expressions.Var'>: <function <lambda>>}