Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9
  10import sqlglot
  11from sqlglot import Dialect, exp
  12from sqlglot.helper import first, merge_ranges, while_changing
  13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot.dialects.dialect import DialectType
  17
  18    DateTruncBinaryTransform = t.Callable[
  19        [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
  20    ]
  21
  22# Final means that an expression should not be simplified
  23FINAL = "final"
  24
  25# Value ranges for byte-sized signed/unsigned integers
  26TINYINT_MIN = -128
  27TINYINT_MAX = 127
  28UTINYINT_MIN = 0
  29UTINYINT_MAX = 255
  30
  31
  32class UnsupportedUnit(Exception):
  33    pass
  34
  35
  36def simplify(
  37    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  38):
  39    """
  40    Rewrite sqlglot AST to simplify expressions.
  41
  42    Example:
  43        >>> import sqlglot
  44        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  45        >>> simplify(expression).sql()
  46        'TRUE'
  47
  48    Args:
  49        expression (sqlglot.Expression): expression to simplify
  50        constant_propagation: whether the constant propagation rule should be used
  51
  52    Returns:
  53        sqlglot.Expression: simplified expression
  54    """
  55
  56    dialect = Dialect.get_or_raise(dialect)
  57
  58    def _simplify(expression, root=True):
  59        if expression.meta.get(FINAL):
  60            return expression
  61
  62        # group by expressions cannot be simplified, for example
  63        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  64        # the projection must exactly match the group by key
  65        group = expression.args.get("group")
  66
  67        if group and hasattr(expression, "selects"):
  68            groups = set(group.expressions)
  69            group.meta[FINAL] = True
  70
  71            for e in expression.selects:
  72                for node in e.walk():
  73                    if node in groups:
  74                        e.meta[FINAL] = True
  75                        break
  76
  77            having = expression.args.get("having")
  78            if having:
  79                for node in having.walk():
  80                    if node in groups:
  81                        having.meta[FINAL] = True
  82                        break
  83
  84        # Pre-order transformations
  85        node = expression
  86        node = rewrite_between(node)
  87        node = uniq_sort(node, root)
  88        node = absorb_and_eliminate(node, root)
  89        node = simplify_concat(node)
  90        node = simplify_conditionals(node)
  91
  92        if constant_propagation:
  93            node = propagate_constants(node, root)
  94
  95        exp.replace_children(node, lambda e: _simplify(e, False))
  96
  97        # Post-order transformations
  98        node = simplify_not(node)
  99        node = flatten(node)
 100        node = simplify_connectors(node, root)
 101        node = remove_complements(node, root)
 102        node = simplify_coalesce(node)
 103        node.parent = expression.parent
 104        node = simplify_literals(node, root)
 105        node = simplify_equality(node)
 106        node = simplify_parens(node)
 107        node = simplify_datetrunc(node, dialect)
 108        node = sort_comparison(node)
 109        node = simplify_startswith(node)
 110
 111        if root:
 112            expression.replace(node)
 113        return node
 114
 115    expression = while_changing(expression, _simplify)
 116    remove_where_true(expression)
 117    return expression
 118
 119
 120def catch(*exceptions):
 121    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 122
 123    def decorator(func):
 124        def wrapped(expression, *args, **kwargs):
 125            try:
 126                return func(expression, *args, **kwargs)
 127            except exceptions:
 128                return expression
 129
 130        return wrapped
 131
 132    return decorator
 133
 134
 135def rewrite_between(expression: exp.Expression) -> exp.Expression:
 136    """Rewrite x between y and z to x >= y AND x <= z.
 137
 138    This is done because comparison simplification is only done on lt/lte/gt/gte.
 139    """
 140    if isinstance(expression, exp.Between):
 141        negate = isinstance(expression.parent, exp.Not)
 142
 143        expression = exp.and_(
 144            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 145            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 146            copy=False,
 147        )
 148
 149        if negate:
 150            expression = exp.paren(expression, copy=False)
 151
 152    return expression
 153
 154
 155COMPLEMENT_COMPARISONS = {
 156    exp.LT: exp.GTE,
 157    exp.GT: exp.LTE,
 158    exp.LTE: exp.GT,
 159    exp.GTE: exp.LT,
 160    exp.EQ: exp.NEQ,
 161    exp.NEQ: exp.EQ,
 162}
 163
 164
 165def simplify_not(expression):
 166    """
 167    Demorgan's Law
 168    NOT (x OR y) -> NOT x AND NOT y
 169    NOT (x AND y) -> NOT x OR NOT y
 170    """
 171    if isinstance(expression, exp.Not):
 172        this = expression.this
 173        if is_null(this):
 174            return exp.null()
 175        if this.__class__ in COMPLEMENT_COMPARISONS:
 176            return COMPLEMENT_COMPARISONS[this.__class__](
 177                this=this.this, expression=this.expression
 178            )
 179        if isinstance(this, exp.Paren):
 180            condition = this.unnest()
 181            if isinstance(condition, exp.And):
 182                return exp.paren(
 183                    exp.or_(
 184                        exp.not_(condition.left, copy=False),
 185                        exp.not_(condition.right, copy=False),
 186                        copy=False,
 187                    )
 188                )
 189            if isinstance(condition, exp.Or):
 190                return exp.paren(
 191                    exp.and_(
 192                        exp.not_(condition.left, copy=False),
 193                        exp.not_(condition.right, copy=False),
 194                        copy=False,
 195                    )
 196                )
 197            if is_null(condition):
 198                return exp.null()
 199        if always_true(this):
 200            return exp.false()
 201        if is_false(this):
 202            return exp.true()
 203        if isinstance(this, exp.Not):
 204            # double negation
 205            # NOT NOT x -> x
 206            return this.this
 207    return expression
 208
 209
 210def flatten(expression):
 211    """
 212    A AND (B AND C) -> A AND B AND C
 213    A OR (B OR C) -> A OR B OR C
 214    """
 215    if isinstance(expression, exp.Connector):
 216        for node in expression.args.values():
 217            child = node.unnest()
 218            if isinstance(child, expression.__class__):
 219                node.replace(child)
 220    return expression
 221
 222
 223def simplify_connectors(expression, root=True):
 224    def _simplify_connectors(expression, left, right):
 225        if left == right:
 226            return left
 227        if isinstance(expression, exp.And):
 228            if is_false(left) or is_false(right):
 229                return exp.false()
 230            if is_null(left) or is_null(right):
 231                return exp.null()
 232            if always_true(left) and always_true(right):
 233                return exp.true()
 234            if always_true(left):
 235                return right
 236            if always_true(right):
 237                return left
 238            return _simplify_comparison(expression, left, right)
 239        elif isinstance(expression, exp.Or):
 240            if always_true(left) or always_true(right):
 241                return exp.true()
 242            if is_false(left) and is_false(right):
 243                return exp.false()
 244            if (
 245                (is_null(left) and is_null(right))
 246                or (is_null(left) and is_false(right))
 247                or (is_false(left) and is_null(right))
 248            ):
 249                return exp.null()
 250            if is_false(left):
 251                return right
 252            if is_false(right):
 253                return left
 254            return _simplify_comparison(expression, left, right, or_=True)
 255
 256    if isinstance(expression, exp.Connector):
 257        return _flat_simplify(expression, _simplify_connectors, root)
 258    return expression
 259
 260
 261LT_LTE = (exp.LT, exp.LTE)
 262GT_GTE = (exp.GT, exp.GTE)
 263
 264COMPARISONS = (
 265    *LT_LTE,
 266    *GT_GTE,
 267    exp.EQ,
 268    exp.NEQ,
 269    exp.Is,
 270)
 271
 272INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 273    exp.LT: exp.GT,
 274    exp.GT: exp.LT,
 275    exp.LTE: exp.GTE,
 276    exp.GTE: exp.LTE,
 277}
 278
 279NONDETERMINISTIC = (exp.Rand, exp.Randn)
 280
 281
 282def _simplify_comparison(expression, left, right, or_=False):
 283    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 284        ll, lr = left.args.values()
 285        rl, rr = right.args.values()
 286
 287        largs = {ll, lr}
 288        rargs = {rl, rr}
 289
 290        matching = largs & rargs
 291        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 292
 293        if matching and columns:
 294            try:
 295                l = first(largs - columns)
 296                r = first(rargs - columns)
 297            except StopIteration:
 298                return expression
 299
 300            if l.is_number and r.is_number:
 301                l = float(l.name)
 302                r = float(r.name)
 303            elif l.is_string and r.is_string:
 304                l = l.name
 305                r = r.name
 306            else:
 307                l = extract_date(l)
 308                if not l:
 309                    return None
 310                r = extract_date(r)
 311                if not r:
 312                    return None
 313                # python won't compare date and datetime, but many engines will upcast
 314                l, r = cast_as_datetime(l), cast_as_datetime(r)
 315
 316            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 317                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 318                    return left if (av > bv if or_ else av <= bv) else right
 319                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 320                    return left if (av < bv if or_ else av >= bv) else right
 321
 322                # we can't ever shortcut to true because the column could be null
 323                if not or_:
 324                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 325                        if av <= bv:
 326                            return exp.false()
 327                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 328                        if av >= bv:
 329                            return exp.false()
 330                    elif isinstance(a, exp.EQ):
 331                        if isinstance(b, exp.LT):
 332                            return exp.false() if av >= bv else a
 333                        if isinstance(b, exp.LTE):
 334                            return exp.false() if av > bv else a
 335                        if isinstance(b, exp.GT):
 336                            return exp.false() if av <= bv else a
 337                        if isinstance(b, exp.GTE):
 338                            return exp.false() if av < bv else a
 339                        if isinstance(b, exp.NEQ):
 340                            return exp.false() if av == bv else a
 341    return None
 342
 343
 344def remove_complements(expression, root=True):
 345    """
 346    Removing complements.
 347
 348    A AND NOT A -> FALSE
 349    A OR NOT A -> TRUE
 350    """
 351    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 352        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 353
 354        for a, b in itertools.permutations(expression.flatten(), 2):
 355            if is_complement(a, b):
 356                return complement
 357    return expression
 358
 359
 360def uniq_sort(expression, root=True):
 361    """
 362    Uniq and sort a connector.
 363
 364    C AND A AND B AND B -> A AND B AND C
 365    """
 366    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 367        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 368        flattened = tuple(expression.flatten())
 369        deduped = {gen(e): e for e in flattened}
 370        arr = tuple(deduped.items())
 371
 372        # check if the operands are already sorted, if not sort them
 373        # A AND C AND B -> A AND B AND C
 374        for i, (sql, e) in enumerate(arr[1:]):
 375            if sql < arr[i][0]:
 376                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 377                break
 378        else:
 379            # we didn't have to sort but maybe we need to dedup
 380            if len(deduped) < len(flattened):
 381                expression = result_func(*deduped.values(), copy=False)
 382
 383    return expression
 384
 385
 386def absorb_and_eliminate(expression, root=True):
 387    """
 388    absorption:
 389        A AND (A OR B) -> A
 390        A OR (A AND B) -> A
 391        A AND (NOT A OR B) -> A AND B
 392        A OR (NOT A AND B) -> A OR B
 393    elimination:
 394        (A AND B) OR (A AND NOT B) -> A
 395        (A OR B) AND (A OR NOT B) -> A
 396    """
 397    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 398        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 399
 400        for a, b in itertools.permutations(expression.flatten(), 2):
 401            if isinstance(a, kind):
 402                aa, ab = a.unnest_operands()
 403
 404                # absorb
 405                if is_complement(b, aa):
 406                    aa.replace(exp.true() if kind == exp.And else exp.false())
 407                elif is_complement(b, ab):
 408                    ab.replace(exp.true() if kind == exp.And else exp.false())
 409                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 410                    a.replace(exp.false() if kind == exp.And else exp.true())
 411                elif isinstance(b, kind):
 412                    # eliminate
 413                    rhs = b.unnest_operands()
 414                    ba, bb = rhs
 415
 416                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 417                        a.replace(aa)
 418                        b.replace(aa)
 419                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 420                        a.replace(ab)
 421                        b.replace(ab)
 422
 423    return expression
 424
 425
 426def propagate_constants(expression, root=True):
 427    """
 428    Propagate constants for conjunctions in DNF:
 429
 430    SELECT * FROM t WHERE a = b AND b = 5 becomes
 431    SELECT * FROM t WHERE a = 5 AND b = 5
 432
 433    Reference: https://www.sqlite.org/optoverview.html
 434    """
 435
 436    if (
 437        isinstance(expression, exp.And)
 438        and (root or not expression.same_parent)
 439        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 440    ):
 441        constant_mapping = {}
 442        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 443            if isinstance(expr, exp.EQ):
 444                l, r = expr.left, expr.right
 445
 446                # TODO: create a helper that can be used to detect nested literal expressions such
 447                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 448                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 449                    constant_mapping[l] = (id(l), r)
 450
 451        if constant_mapping:
 452            for column in find_all_in_scope(expression, exp.Column):
 453                parent = column.parent
 454                column_id, constant = constant_mapping.get(column) or (None, None)
 455                if (
 456                    column_id is not None
 457                    and id(column) != column_id
 458                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 459                ):
 460                    column.replace(constant.copy())
 461
 462    return expression
 463
 464
 465INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 466    exp.DateAdd: exp.Sub,
 467    exp.DateSub: exp.Add,
 468    exp.DatetimeAdd: exp.Sub,
 469    exp.DatetimeSub: exp.Add,
 470}
 471
 472INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 473    **INVERSE_DATE_OPS,
 474    exp.Add: exp.Sub,
 475    exp.Sub: exp.Add,
 476}
 477
 478
 479def _is_number(expression: exp.Expression) -> bool:
 480    return expression.is_number
 481
 482
 483def _is_interval(expression: exp.Expression) -> bool:
 484    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 485
 486
 487@catch(ModuleNotFoundError, UnsupportedUnit)
 488def simplify_equality(expression: exp.Expression) -> exp.Expression:
 489    """
 490    Use the subtraction and addition properties of equality to simplify expressions:
 491
 492        x + 1 = 3 becomes x = 2
 493
 494    There are two binary operations in the above expression: + and =
 495    Here's how we reference all the operands in the code below:
 496
 497          l     r
 498        x + 1 = 3
 499        a   b
 500    """
 501    if isinstance(expression, COMPARISONS):
 502        l, r = expression.left, expression.right
 503
 504        if l.__class__ not in INVERSE_OPS:
 505            return expression
 506
 507        if r.is_number:
 508            a_predicate = _is_number
 509            b_predicate = _is_number
 510        elif _is_date_literal(r):
 511            a_predicate = _is_date_literal
 512            b_predicate = _is_interval
 513        else:
 514            return expression
 515
 516        if l.__class__ in INVERSE_DATE_OPS:
 517            l = t.cast(exp.IntervalOp, l)
 518            a = l.this
 519            b = l.interval()
 520        else:
 521            l = t.cast(exp.Binary, l)
 522            a, b = l.left, l.right
 523
 524        if not a_predicate(a) and b_predicate(b):
 525            pass
 526        elif not a_predicate(b) and b_predicate(a):
 527            a, b = b, a
 528        else:
 529            return expression
 530
 531        return expression.__class__(
 532            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 533        )
 534    return expression
 535
 536
 537def simplify_literals(expression, root=True):
 538    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 539        return _flat_simplify(expression, _simplify_binary, root)
 540
 541    if isinstance(expression, exp.Neg):
 542        this = expression.this
 543        if this.is_number:
 544            value = this.name
 545            if value[0] == "-":
 546                return exp.Literal.number(value[1:])
 547            return exp.Literal.number(f"-{value}")
 548
 549    if type(expression) in INVERSE_DATE_OPS:
 550        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 551
 552    return expression
 553
 554
 555NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 556
 557
 558def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 559    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 560        this = _simplify_integer_cast(expr.this)
 561    else:
 562        this = expr.this
 563
 564    if isinstance(expr, exp.Cast) and this.is_int:
 565        num = int(this.name)
 566
 567        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 568        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 569        # engine-dependent
 570        if (
 571            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 572        ) or (
 573            UTINYINT_MIN <= num <= UTINYINT_MAX
 574            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 575        ):
 576            return this
 577
 578    return expr
 579
 580
 581def _simplify_binary(expression, a, b):
 582    if isinstance(expression, COMPARISONS):
 583        a = _simplify_integer_cast(a)
 584        b = _simplify_integer_cast(b)
 585
 586    if isinstance(expression, exp.Is):
 587        if isinstance(b, exp.Not):
 588            c = b.this
 589            not_ = True
 590        else:
 591            c = b
 592            not_ = False
 593
 594        if is_null(c):
 595            if isinstance(a, exp.Literal):
 596                return exp.true() if not_ else exp.false()
 597            if is_null(a):
 598                return exp.false() if not_ else exp.true()
 599    elif isinstance(expression, NULL_OK):
 600        return None
 601    elif is_null(a) or is_null(b):
 602        return exp.null()
 603
 604    if a.is_number and b.is_number:
 605        num_a = int(a.name) if a.is_int else Decimal(a.name)
 606        num_b = int(b.name) if b.is_int else Decimal(b.name)
 607
 608        if isinstance(expression, exp.Add):
 609            return exp.Literal.number(num_a + num_b)
 610        if isinstance(expression, exp.Mul):
 611            return exp.Literal.number(num_a * num_b)
 612
 613        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 614        if isinstance(expression, exp.Sub):
 615            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 616        if isinstance(expression, exp.Div):
 617            # engines have differing int div behavior so intdiv is not safe
 618            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 619                return None
 620            return exp.Literal.number(num_a / num_b)
 621
 622        boolean = eval_boolean(expression, num_a, num_b)
 623
 624        if boolean:
 625            return boolean
 626    elif a.is_string and b.is_string:
 627        boolean = eval_boolean(expression, a.this, b.this)
 628
 629        if boolean:
 630            return boolean
 631    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 632        a, b = extract_date(a), extract_interval(b)
 633        if a and b:
 634            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 635                return date_literal(a + b)
 636            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 637                return date_literal(a - b)
 638    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 639        a, b = extract_interval(a), extract_date(b)
 640        # you cannot subtract a date from an interval
 641        if a and b and isinstance(expression, exp.Add):
 642            return date_literal(a + b)
 643    elif _is_date_literal(a) and _is_date_literal(b):
 644        if isinstance(expression, exp.Predicate):
 645            a, b = extract_date(a), extract_date(b)
 646            boolean = eval_boolean(expression, a, b)
 647            if boolean:
 648                return boolean
 649
 650    return None
 651
 652
 653def simplify_parens(expression):
 654    if not isinstance(expression, exp.Paren):
 655        return expression
 656
 657    this = expression.this
 658    parent = expression.parent
 659    parent_is_predicate = isinstance(parent, exp.Predicate)
 660
 661    if not isinstance(this, exp.Select) and (
 662        not isinstance(parent, (exp.Condition, exp.Binary))
 663        or isinstance(parent, exp.Paren)
 664        or (
 665            not isinstance(this, exp.Binary)
 666            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 667        )
 668        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 669        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 670        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 671        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 672    ):
 673        return this
 674    return expression
 675
 676
 677def _is_nonnull_constant(expression: exp.Expression) -> bool:
 678    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 679
 680
 681def _is_constant(expression: exp.Expression) -> bool:
 682    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 683
 684
 685def simplify_coalesce(expression):
 686    # COALESCE(x) -> x
 687    if (
 688        isinstance(expression, exp.Coalesce)
 689        and (not expression.expressions or _is_nonnull_constant(expression.this))
 690        # COALESCE is also used as a Spark partitioning hint
 691        and not isinstance(expression.parent, exp.Hint)
 692    ):
 693        return expression.this
 694
 695    if not isinstance(expression, COMPARISONS):
 696        return expression
 697
 698    if isinstance(expression.left, exp.Coalesce):
 699        coalesce = expression.left
 700        other = expression.right
 701    elif isinstance(expression.right, exp.Coalesce):
 702        coalesce = expression.right
 703        other = expression.left
 704    else:
 705        return expression
 706
 707    # This transformation is valid for non-constants,
 708    # but it really only does anything if they are both constants.
 709    if not _is_constant(other):
 710        return expression
 711
 712    # Find the first constant arg
 713    for arg_index, arg in enumerate(coalesce.expressions):
 714        if _is_constant(arg):
 715            break
 716    else:
 717        return expression
 718
 719    coalesce.set("expressions", coalesce.expressions[:arg_index])
 720
 721    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 722    # since we already remove COALESCE at the top of this function.
 723    coalesce = coalesce if coalesce.expressions else coalesce.this
 724
 725    # This expression is more complex than when we started, but it will get simplified further
 726    return exp.paren(
 727        exp.or_(
 728            exp.and_(
 729                coalesce.is_(exp.null()).not_(copy=False),
 730                expression.copy(),
 731                copy=False,
 732            ),
 733            exp.and_(
 734                coalesce.is_(exp.null()),
 735                type(expression)(this=arg.copy(), expression=other.copy()),
 736                copy=False,
 737            ),
 738            copy=False,
 739        )
 740    )
 741
 742
 743CONCATS = (exp.Concat, exp.DPipe)
 744
 745
 746def simplify_concat(expression):
 747    """Reduces all groups that contain string literals by concatenating them."""
 748    if not isinstance(expression, CONCATS) or (
 749        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 750        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 751    ):
 752        return expression
 753
 754    if isinstance(expression, exp.ConcatWs):
 755        sep_expr, *expressions = expression.expressions
 756        sep = sep_expr.name
 757        concat_type = exp.ConcatWs
 758        args = {}
 759    else:
 760        expressions = expression.expressions
 761        sep = ""
 762        concat_type = exp.Concat
 763        args = {
 764            "safe": expression.args.get("safe"),
 765            "coalesce": expression.args.get("coalesce"),
 766        }
 767
 768    new_args = []
 769    for is_string_group, group in itertools.groupby(
 770        expressions or expression.flatten(), lambda e: e.is_string
 771    ):
 772        if is_string_group:
 773            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 774        else:
 775            new_args.extend(group)
 776
 777    if len(new_args) == 1 and new_args[0].is_string:
 778        return new_args[0]
 779
 780    if concat_type is exp.ConcatWs:
 781        new_args = [sep_expr] + new_args
 782
 783    return concat_type(expressions=new_args, **args)
 784
 785
 786def simplify_conditionals(expression):
 787    """Simplifies expressions like IF, CASE if their condition is statically known."""
 788    if isinstance(expression, exp.Case):
 789        this = expression.this
 790        for case in expression.args["ifs"]:
 791            cond = case.this
 792            if this:
 793                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 794                cond = cond.replace(this.pop().eq(cond))
 795
 796            if always_true(cond):
 797                return case.args["true"]
 798
 799            if always_false(cond):
 800                case.pop()
 801                if not expression.args["ifs"]:
 802                    return expression.args.get("default") or exp.null()
 803    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 804        if always_true(expression.this):
 805            return expression.args["true"]
 806        if always_false(expression.this):
 807            return expression.args.get("false") or exp.null()
 808
 809    return expression
 810
 811
 812def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 813    """
 814    Reduces a prefix check to either TRUE or FALSE if both the string and the
 815    prefix are statically known.
 816
 817    Example:
 818        >>> from sqlglot import parse_one
 819        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 820        'TRUE'
 821    """
 822    if (
 823        isinstance(expression, exp.StartsWith)
 824        and expression.this.is_string
 825        and expression.expression.is_string
 826    ):
 827        return exp.convert(expression.name.startswith(expression.expression.name))
 828
 829    return expression
 830
 831
 832DateRange = t.Tuple[datetime.date, datetime.date]
 833
 834
 835def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 836    """
 837    Get the date range for a DATE_TRUNC equality comparison:
 838
 839    Example:
 840        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 841    Returns:
 842        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 843    """
 844    floor = date_floor(date, unit, dialect)
 845
 846    if date != floor:
 847        # This will always be False, except for NULL values.
 848        return None
 849
 850    return floor, floor + interval(unit)
 851
 852
 853def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 854    """Get the logical expression for a date range"""
 855    return exp.and_(
 856        left >= date_literal(drange[0]),
 857        left < date_literal(drange[1]),
 858        copy=False,
 859    )
 860
 861
 862def _datetrunc_eq(
 863    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 864) -> t.Optional[exp.Expression]:
 865    drange = _datetrunc_range(date, unit, dialect)
 866    if not drange:
 867        return None
 868
 869    return _datetrunc_eq_expression(left, drange)
 870
 871
 872def _datetrunc_neq(
 873    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 874) -> t.Optional[exp.Expression]:
 875    drange = _datetrunc_range(date, unit, dialect)
 876    if not drange:
 877        return None
 878
 879    return exp.and_(
 880        left < date_literal(drange[0]),
 881        left >= date_literal(drange[1]),
 882        copy=False,
 883    )
 884
 885
 886DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 887    exp.LT: lambda l, dt, u, d: l
 888    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
 889    exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
 890    exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
 891    exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
 892    exp.EQ: _datetrunc_eq,
 893    exp.NEQ: _datetrunc_neq,
 894}
 895DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 896DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 897
 898
 899def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 900    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 901
 902
 903@catch(ModuleNotFoundError, UnsupportedUnit)
 904def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 905    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 906    comparison = expression.__class__
 907
 908    if isinstance(expression, DATETRUNCS):
 909        date = extract_date(expression.this)
 910        if date and expression.unit:
 911            return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
 912    elif comparison not in DATETRUNC_COMPARISONS:
 913        return expression
 914
 915    if isinstance(expression, exp.Binary):
 916        l, r = expression.left, expression.right
 917
 918        if not _is_datetrunc_predicate(l, r):
 919            return expression
 920
 921        l = t.cast(exp.DateTrunc, l)
 922        unit = l.unit.name.lower()
 923        date = extract_date(r)
 924
 925        if not date:
 926            return expression
 927
 928        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
 929    elif isinstance(expression, exp.In):
 930        l = expression.this
 931        rs = expression.expressions
 932
 933        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 934            l = t.cast(exp.DateTrunc, l)
 935            unit = l.unit.name.lower()
 936
 937            ranges = []
 938            for r in rs:
 939                date = extract_date(r)
 940                if not date:
 941                    return expression
 942                drange = _datetrunc_range(date, unit, dialect)
 943                if drange:
 944                    ranges.append(drange)
 945
 946            if not ranges:
 947                return expression
 948
 949            ranges = merge_ranges(ranges)
 950
 951            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 952
 953    return expression
 954
 955
 956def sort_comparison(expression: exp.Expression) -> exp.Expression:
 957    if expression.__class__ in COMPLEMENT_COMPARISONS:
 958        l, r = expression.this, expression.expression
 959        l_column = isinstance(l, exp.Column)
 960        r_column = isinstance(r, exp.Column)
 961        l_const = _is_constant(l)
 962        r_const = _is_constant(r)
 963
 964        if (l_column and not r_column) or (r_const and not l_const):
 965            return expression
 966        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 967            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 968                this=r, expression=l
 969            )
 970    return expression
 971
 972
 973# CROSS joins result in an empty table if the right table is empty.
 974# So we can only simplify certain types of joins to CROSS.
 975# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 976JOINS = {
 977    ("", ""),
 978    ("", "INNER"),
 979    ("RIGHT", ""),
 980    ("RIGHT", "OUTER"),
 981}
 982
 983
 984def remove_where_true(expression):
 985    for where in expression.find_all(exp.Where):
 986        if always_true(where.this):
 987            where.pop()
 988    for join in expression.find_all(exp.Join):
 989        if (
 990            always_true(join.args.get("on"))
 991            and not join.args.get("using")
 992            and not join.args.get("method")
 993            and (join.side, join.kind) in JOINS
 994        ):
 995            join.args["on"].pop()
 996            join.set("side", None)
 997            join.set("kind", "CROSS")
 998
 999
1000def always_true(expression):
1001    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1002        expression, exp.Literal
1003    )
1004
1005
1006def always_false(expression):
1007    return is_false(expression) or is_null(expression)
1008
1009
1010def is_complement(a, b):
1011    return isinstance(b, exp.Not) and b.this == a
1012
1013
1014def is_false(a: exp.Expression) -> bool:
1015    return type(a) is exp.Boolean and not a.this
1016
1017
1018def is_null(a: exp.Expression) -> bool:
1019    return type(a) is exp.Null
1020
1021
1022def eval_boolean(expression, a, b):
1023    if isinstance(expression, (exp.EQ, exp.Is)):
1024        return boolean_literal(a == b)
1025    if isinstance(expression, exp.NEQ):
1026        return boolean_literal(a != b)
1027    if isinstance(expression, exp.GT):
1028        return boolean_literal(a > b)
1029    if isinstance(expression, exp.GTE):
1030        return boolean_literal(a >= b)
1031    if isinstance(expression, exp.LT):
1032        return boolean_literal(a < b)
1033    if isinstance(expression, exp.LTE):
1034        return boolean_literal(a <= b)
1035    return None
1036
1037
1038def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1039    if isinstance(value, datetime.datetime):
1040        return value.date()
1041    if isinstance(value, datetime.date):
1042        return value
1043    try:
1044        return datetime.datetime.fromisoformat(value).date()
1045    except ValueError:
1046        return None
1047
1048
1049def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1050    if isinstance(value, datetime.datetime):
1051        return value
1052    if isinstance(value, datetime.date):
1053        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1054    try:
1055        return datetime.datetime.fromisoformat(value)
1056    except ValueError:
1057        return None
1058
1059
1060def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1061    if not value:
1062        return None
1063    if to.is_type(exp.DataType.Type.DATE):
1064        return cast_as_date(value)
1065    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1066        return cast_as_datetime(value)
1067    return None
1068
1069
1070def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1071    if isinstance(cast, exp.Cast):
1072        to = cast.to
1073    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1074        to = exp.DataType.build(exp.DataType.Type.DATE)
1075    else:
1076        return None
1077
1078    if isinstance(cast.this, exp.Literal):
1079        value: t.Any = cast.this.name
1080    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1081        value = extract_date(cast.this)
1082    else:
1083        return None
1084    return cast_value(value, to)
1085
1086
1087def _is_date_literal(expression: exp.Expression) -> bool:
1088    return extract_date(expression) is not None
1089
1090
1091def extract_interval(expression):
1092    try:
1093        n = int(expression.name)
1094        unit = expression.text("unit").lower()
1095        return interval(unit, n)
1096    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1097        return None
1098
1099
1100def date_literal(date):
1101    return exp.cast(
1102        exp.Literal.string(date),
1103        (
1104            exp.DataType.Type.DATETIME
1105            if isinstance(date, datetime.datetime)
1106            else exp.DataType.Type.DATE
1107        ),
1108    )
1109
1110
1111def interval(unit: str, n: int = 1):
1112    from dateutil.relativedelta import relativedelta
1113
1114    if unit == "year":
1115        return relativedelta(years=1 * n)
1116    if unit == "quarter":
1117        return relativedelta(months=3 * n)
1118    if unit == "month":
1119        return relativedelta(months=1 * n)
1120    if unit == "week":
1121        return relativedelta(weeks=1 * n)
1122    if unit == "day":
1123        return relativedelta(days=1 * n)
1124    if unit == "hour":
1125        return relativedelta(hours=1 * n)
1126    if unit == "minute":
1127        return relativedelta(minutes=1 * n)
1128    if unit == "second":
1129        return relativedelta(seconds=1 * n)
1130
1131    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1132
1133
1134def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1135    if unit == "year":
1136        return d.replace(month=1, day=1)
1137    if unit == "quarter":
1138        if d.month <= 3:
1139            return d.replace(month=1, day=1)
1140        elif d.month <= 6:
1141            return d.replace(month=4, day=1)
1142        elif d.month <= 9:
1143            return d.replace(month=7, day=1)
1144        else:
1145            return d.replace(month=10, day=1)
1146    if unit == "month":
1147        return d.replace(month=d.month, day=1)
1148    if unit == "week":
1149        # Assuming week starts on Monday (0) and ends on Sunday (6)
1150        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1151    if unit == "day":
1152        return d
1153
1154    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1155
1156
1157def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1158    floor = date_floor(d, unit, dialect)
1159
1160    if floor == d:
1161        return d
1162
1163    return floor + interval(unit)
1164
1165
1166def boolean_literal(condition):
1167    return exp.true() if condition else exp.false()
1168
1169
1170def _flat_simplify(expression, simplifier, root=True):
1171    if root or not expression.same_parent:
1172        operands = []
1173        queue = deque(expression.flatten(unnest=False))
1174        size = len(queue)
1175
1176        while queue:
1177            a = queue.popleft()
1178
1179            for b in queue:
1180                result = simplifier(expression, a, b)
1181
1182                if result and result is not expression:
1183                    queue.remove(b)
1184                    queue.appendleft(result)
1185                    break
1186            else:
1187                operands.append(a)
1188
1189        if len(operands) < size:
1190            return functools.reduce(
1191                lambda a, b: expression.__class__(this=a, expression=b), operands
1192            )
1193    return expression
1194
1195
1196def gen(expression: t.Any) -> str:
1197    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1198
1199    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1200    generator is expensive so we have a bare minimum sql generator here.
1201    """
1202    return Gen().gen(expression)
1203
1204
1205class Gen:
1206    def __init__(self):
1207        self.stack = []
1208        self.sqls = []
1209
1210    def gen(self, expression: exp.Expression) -> str:
1211        self.stack = [expression]
1212        self.sqls.clear()
1213
1214        while self.stack:
1215            node = self.stack.pop()
1216
1217            if isinstance(node, exp.Expression):
1218                exp_handler_name = f"{node.key}_sql"
1219
1220                if hasattr(self, exp_handler_name):
1221                    getattr(self, exp_handler_name)(node)
1222                elif isinstance(node, exp.Func):
1223                    self._function(node)
1224                else:
1225                    key = node.key.upper()
1226                    self.stack.append(f"{key} " if self._args(node) else key)
1227            elif type(node) is list:
1228                for n in reversed(node):
1229                    if n is not None:
1230                        self.stack.extend((n, ","))
1231                if node:
1232                    self.stack.pop()
1233            else:
1234                if node is not None:
1235                    self.sqls.append(str(node))
1236
1237        return "".join(self.sqls)
1238
1239    def add_sql(self, e: exp.Add) -> None:
1240        self._binary(e, " + ")
1241
1242    def alias_sql(self, e: exp.Alias) -> None:
1243        self.stack.extend(
1244            (
1245                e.args.get("alias"),
1246                " AS ",
1247                e.args.get("this"),
1248            )
1249        )
1250
1251    def and_sql(self, e: exp.And) -> None:
1252        self._binary(e, " AND ")
1253
1254    def anonymous_sql(self, e: exp.Anonymous) -> None:
1255        this = e.this
1256        if isinstance(this, str):
1257            name = this.upper()
1258        elif isinstance(this, exp.Identifier):
1259            name = this.this
1260            name = f'"{name}"' if this.quoted else name.upper()
1261        else:
1262            raise ValueError(
1263                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1264            )
1265
1266        self.stack.extend(
1267            (
1268                ")",
1269                e.expressions,
1270                "(",
1271                name,
1272            )
1273        )
1274
1275    def between_sql(self, e: exp.Between) -> None:
1276        self.stack.extend(
1277            (
1278                e.args.get("high"),
1279                " AND ",
1280                e.args.get("low"),
1281                " BETWEEN ",
1282                e.this,
1283            )
1284        )
1285
1286    def boolean_sql(self, e: exp.Boolean) -> None:
1287        self.stack.append("TRUE" if e.this else "FALSE")
1288
1289    def bracket_sql(self, e: exp.Bracket) -> None:
1290        self.stack.extend(
1291            (
1292                "]",
1293                e.expressions,
1294                "[",
1295                e.this,
1296            )
1297        )
1298
1299    def column_sql(self, e: exp.Column) -> None:
1300        for p in reversed(e.parts):
1301            self.stack.extend((p, "."))
1302        self.stack.pop()
1303
1304    def datatype_sql(self, e: exp.DataType) -> None:
1305        self._args(e, 1)
1306        self.stack.append(f"{e.this.name} ")
1307
1308    def div_sql(self, e: exp.Div) -> None:
1309        self._binary(e, " / ")
1310
1311    def dot_sql(self, e: exp.Dot) -> None:
1312        self._binary(e, ".")
1313
1314    def eq_sql(self, e: exp.EQ) -> None:
1315        self._binary(e, " = ")
1316
1317    def from_sql(self, e: exp.From) -> None:
1318        self.stack.extend((e.this, "FROM "))
1319
1320    def gt_sql(self, e: exp.GT) -> None:
1321        self._binary(e, " > ")
1322
1323    def gte_sql(self, e: exp.GTE) -> None:
1324        self._binary(e, " >= ")
1325
1326    def identifier_sql(self, e: exp.Identifier) -> None:
1327        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1328
1329    def ilike_sql(self, e: exp.ILike) -> None:
1330        self._binary(e, " ILIKE ")
1331
1332    def in_sql(self, e: exp.In) -> None:
1333        self.stack.append(")")
1334        self._args(e, 1)
1335        self.stack.extend(
1336            (
1337                "(",
1338                " IN ",
1339                e.this,
1340            )
1341        )
1342
1343    def intdiv_sql(self, e: exp.IntDiv) -> None:
1344        self._binary(e, " DIV ")
1345
1346    def is_sql(self, e: exp.Is) -> None:
1347        self._binary(e, " IS ")
1348
1349    def like_sql(self, e: exp.Like) -> None:
1350        self._binary(e, " Like ")
1351
1352    def literal_sql(self, e: exp.Literal) -> None:
1353        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1354
1355    def lt_sql(self, e: exp.LT) -> None:
1356        self._binary(e, " < ")
1357
1358    def lte_sql(self, e: exp.LTE) -> None:
1359        self._binary(e, " <= ")
1360
1361    def mod_sql(self, e: exp.Mod) -> None:
1362        self._binary(e, " % ")
1363
1364    def mul_sql(self, e: exp.Mul) -> None:
1365        self._binary(e, " * ")
1366
1367    def neg_sql(self, e: exp.Neg) -> None:
1368        self._unary(e, "-")
1369
1370    def neq_sql(self, e: exp.NEQ) -> None:
1371        self._binary(e, " <> ")
1372
1373    def not_sql(self, e: exp.Not) -> None:
1374        self._unary(e, "NOT ")
1375
1376    def null_sql(self, e: exp.Null) -> None:
1377        self.stack.append("NULL")
1378
1379    def or_sql(self, e: exp.Or) -> None:
1380        self._binary(e, " OR ")
1381
1382    def paren_sql(self, e: exp.Paren) -> None:
1383        self.stack.extend(
1384            (
1385                ")",
1386                e.this,
1387                "(",
1388            )
1389        )
1390
1391    def sub_sql(self, e: exp.Sub) -> None:
1392        self._binary(e, " - ")
1393
1394    def subquery_sql(self, e: exp.Subquery) -> None:
1395        self._args(e, 2)
1396        alias = e.args.get("alias")
1397        if alias:
1398            self.stack.append(alias)
1399        self.stack.extend((")", e.this, "("))
1400
1401    def table_sql(self, e: exp.Table) -> None:
1402        self._args(e, 4)
1403        alias = e.args.get("alias")
1404        if alias:
1405            self.stack.append(alias)
1406        for p in reversed(e.parts):
1407            self.stack.extend((p, "."))
1408        self.stack.pop()
1409
1410    def tablealias_sql(self, e: exp.TableAlias) -> None:
1411        columns = e.columns
1412
1413        if columns:
1414            self.stack.extend((")", columns, "("))
1415
1416        self.stack.extend((e.this, " AS "))
1417
1418    def var_sql(self, e: exp.Var) -> None:
1419        self.stack.append(e.this)
1420
1421    def _binary(self, e: exp.Binary, op: str) -> None:
1422        self.stack.extend((e.expression, op, e.this))
1423
1424    def _unary(self, e: exp.Unary, op: str) -> None:
1425        self.stack.extend((e.this, op))
1426
1427    def _function(self, e: exp.Func) -> None:
1428        self.stack.extend(
1429            (
1430                ")",
1431                list(e.args.values()),
1432                "(",
1433                e.sql_name(),
1434            )
1435        )
1436
1437    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1438        kvs = []
1439        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1440
1441        for k in arg_types or arg_types:
1442            v = node.args.get(k)
1443
1444            if v is not None:
1445                kvs.append([f":{k}", v])
1446        if kvs:
1447            self.stack.append(kvs)
1448            return True
1449        return False
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
33class UnsupportedUnit(Exception):
34    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 37def simplify(
 38    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 39):
 40    """
 41    Rewrite sqlglot AST to simplify expressions.
 42
 43    Example:
 44        >>> import sqlglot
 45        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 46        >>> simplify(expression).sql()
 47        'TRUE'
 48
 49    Args:
 50        expression (sqlglot.Expression): expression to simplify
 51        constant_propagation: whether the constant propagation rule should be used
 52
 53    Returns:
 54        sqlglot.Expression: simplified expression
 55    """
 56
 57    dialect = Dialect.get_or_raise(dialect)
 58
 59    def _simplify(expression, root=True):
 60        if expression.meta.get(FINAL):
 61            return expression
 62
 63        # group by expressions cannot be simplified, for example
 64        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 65        # the projection must exactly match the group by key
 66        group = expression.args.get("group")
 67
 68        if group and hasattr(expression, "selects"):
 69            groups = set(group.expressions)
 70            group.meta[FINAL] = True
 71
 72            for e in expression.selects:
 73                for node in e.walk():
 74                    if node in groups:
 75                        e.meta[FINAL] = True
 76                        break
 77
 78            having = expression.args.get("having")
 79            if having:
 80                for node in having.walk():
 81                    if node in groups:
 82                        having.meta[FINAL] = True
 83                        break
 84
 85        # Pre-order transformations
 86        node = expression
 87        node = rewrite_between(node)
 88        node = uniq_sort(node, root)
 89        node = absorb_and_eliminate(node, root)
 90        node = simplify_concat(node)
 91        node = simplify_conditionals(node)
 92
 93        if constant_propagation:
 94            node = propagate_constants(node, root)
 95
 96        exp.replace_children(node, lambda e: _simplify(e, False))
 97
 98        # Post-order transformations
 99        node = simplify_not(node)
100        node = flatten(node)
101        node = simplify_connectors(node, root)
102        node = remove_complements(node, root)
103        node = simplify_coalesce(node)
104        node.parent = expression.parent
105        node = simplify_literals(node, root)
106        node = simplify_equality(node)
107        node = simplify_parens(node)
108        node = simplify_datetrunc(node, dialect)
109        node = sort_comparison(node)
110        node = simplify_startswith(node)
111
112        if root:
113            expression.replace(node)
114        return node
115
116    expression = while_changing(expression, _simplify)
117    remove_where_true(expression)
118    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
121def catch(*exceptions):
122    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
123
124    def decorator(func):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression
130
131        return wrapped
132
133    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
136def rewrite_between(expression: exp.Expression) -> exp.Expression:
137    """Rewrite x between y and z to x >= y AND x <= z.
138
139    This is done because comparison simplification is only done on lt/lte/gt/gte.
140    """
141    if isinstance(expression, exp.Between):
142        negate = isinstance(expression.parent, exp.Not)
143
144        expression = exp.and_(
145            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
146            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
147            copy=False,
148        )
149
150        if negate:
151            expression = exp.paren(expression, copy=False)
152
153    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
166def simplify_not(expression):
167    """
168    Demorgan's Law
169    NOT (x OR y) -> NOT x AND NOT y
170    NOT (x AND y) -> NOT x OR NOT y
171    """
172    if isinstance(expression, exp.Not):
173        this = expression.this
174        if is_null(this):
175            return exp.null()
176        if this.__class__ in COMPLEMENT_COMPARISONS:
177            return COMPLEMENT_COMPARISONS[this.__class__](
178                this=this.this, expression=this.expression
179            )
180        if isinstance(this, exp.Paren):
181            condition = this.unnest()
182            if isinstance(condition, exp.And):
183                return exp.paren(
184                    exp.or_(
185                        exp.not_(condition.left, copy=False),
186                        exp.not_(condition.right, copy=False),
187                        copy=False,
188                    )
189                )
190            if isinstance(condition, exp.Or):
191                return exp.paren(
192                    exp.and_(
193                        exp.not_(condition.left, copy=False),
194                        exp.not_(condition.right, copy=False),
195                        copy=False,
196                    )
197                )
198            if is_null(condition):
199                return exp.null()
200        if always_true(this):
201            return exp.false()
202        if is_false(this):
203            return exp.true()
204        if isinstance(this, exp.Not):
205            # double negation
206            # NOT NOT x -> x
207            return this.this
208    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
211def flatten(expression):
212    """
213    A AND (B AND C) -> A AND B AND C
214    A OR (B OR C) -> A OR B OR C
215    """
216    if isinstance(expression, exp.Connector):
217        for node in expression.args.values():
218            child = node.unnest()
219            if isinstance(child, expression.__class__):
220                node.replace(child)
221    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
224def simplify_connectors(expression, root=True):
225    def _simplify_connectors(expression, left, right):
226        if left == right:
227            return left
228        if isinstance(expression, exp.And):
229            if is_false(left) or is_false(right):
230                return exp.false()
231            if is_null(left) or is_null(right):
232                return exp.null()
233            if always_true(left) and always_true(right):
234                return exp.true()
235            if always_true(left):
236                return right
237            if always_true(right):
238                return left
239            return _simplify_comparison(expression, left, right)
240        elif isinstance(expression, exp.Or):
241            if always_true(left) or always_true(right):
242                return exp.true()
243            if is_false(left) and is_false(right):
244                return exp.false()
245            if (
246                (is_null(left) and is_null(right))
247                or (is_null(left) and is_false(right))
248                or (is_false(left) and is_null(right))
249            ):
250                return exp.null()
251            if is_false(left):
252                return right
253            if is_false(right):
254                return left
255            return _simplify_comparison(expression, left, right, or_=True)
256
257    if isinstance(expression, exp.Connector):
258        return _flat_simplify(expression, _simplify_connectors, root)
259    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
345def remove_complements(expression, root=True):
346    """
347    Removing complements.
348
349    A AND NOT A -> FALSE
350    A OR NOT A -> TRUE
351    """
352    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
353        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
354
355        for a, b in itertools.permutations(expression.flatten(), 2):
356            if is_complement(a, b):
357                return complement
358    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
361def uniq_sort(expression, root=True):
362    """
363    Uniq and sort a connector.
364
365    C AND A AND B AND B -> A AND B AND C
366    """
367    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
368        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
369        flattened = tuple(expression.flatten())
370        deduped = {gen(e): e for e in flattened}
371        arr = tuple(deduped.items())
372
373        # check if the operands are already sorted, if not sort them
374        # A AND C AND B -> A AND B AND C
375        for i, (sql, e) in enumerate(arr[1:]):
376            if sql < arr[i][0]:
377                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
378                break
379        else:
380            # we didn't have to sort but maybe we need to dedup
381            if len(deduped) < len(flattened):
382                expression = result_func(*deduped.values(), copy=False)
383
384    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
387def absorb_and_eliminate(expression, root=True):
388    """
389    absorption:
390        A AND (A OR B) -> A
391        A OR (A AND B) -> A
392        A AND (NOT A OR B) -> A AND B
393        A OR (NOT A AND B) -> A OR B
394    elimination:
395        (A AND B) OR (A AND NOT B) -> A
396        (A OR B) AND (A OR NOT B) -> A
397    """
398    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
399        kind = exp.Or if isinstance(expression, exp.And) else exp.And
400
401        for a, b in itertools.permutations(expression.flatten(), 2):
402            if isinstance(a, kind):
403                aa, ab = a.unnest_operands()
404
405                # absorb
406                if is_complement(b, aa):
407                    aa.replace(exp.true() if kind == exp.And else exp.false())
408                elif is_complement(b, ab):
409                    ab.replace(exp.true() if kind == exp.And else exp.false())
410                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
411                    a.replace(exp.false() if kind == exp.And else exp.true())
412                elif isinstance(b, kind):
413                    # eliminate
414                    rhs = b.unnest_operands()
415                    ba, bb = rhs
416
417                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
418                        a.replace(aa)
419                        b.replace(aa)
420                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
421                        a.replace(ab)
422                        b.replace(ab)
423
424    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
427def propagate_constants(expression, root=True):
428    """
429    Propagate constants for conjunctions in DNF:
430
431    SELECT * FROM t WHERE a = b AND b = 5 becomes
432    SELECT * FROM t WHERE a = 5 AND b = 5
433
434    Reference: https://www.sqlite.org/optoverview.html
435    """
436
437    if (
438        isinstance(expression, exp.And)
439        and (root or not expression.same_parent)
440        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
441    ):
442        constant_mapping = {}
443        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
444            if isinstance(expr, exp.EQ):
445                l, r = expr.left, expr.right
446
447                # TODO: create a helper that can be used to detect nested literal expressions such
448                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
449                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
450                    constant_mapping[l] = (id(l), r)
451
452        if constant_mapping:
453            for column in find_all_in_scope(expression, exp.Column):
454                parent = column.parent
455                column_id, constant = constant_mapping.get(column) or (None, None)
456                if (
457                    column_id is not None
458                    and id(column) != column_id
459                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
460                ):
461                    column.replace(constant.copy())
462
463    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
538def simplify_literals(expression, root=True):
539    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
540        return _flat_simplify(expression, _simplify_binary, root)
541
542    if isinstance(expression, exp.Neg):
543        this = expression.this
544        if this.is_number:
545            value = this.name
546            if value[0] == "-":
547                return exp.Literal.number(value[1:])
548            return exp.Literal.number(f"-{value}")
549
550    if type(expression) in INVERSE_DATE_OPS:
551        return _simplify_binary(expression, expression.this, expression.interval()) or expression
552
553    return expression
def simplify_parens(expression):
654def simplify_parens(expression):
655    if not isinstance(expression, exp.Paren):
656        return expression
657
658    this = expression.this
659    parent = expression.parent
660    parent_is_predicate = isinstance(parent, exp.Predicate)
661
662    if not isinstance(this, exp.Select) and (
663        not isinstance(parent, (exp.Condition, exp.Binary))
664        or isinstance(parent, exp.Paren)
665        or (
666            not isinstance(this, exp.Binary)
667            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
668        )
669        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
670        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
671        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
672        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
673    ):
674        return this
675    return expression
def simplify_coalesce(expression):
686def simplify_coalesce(expression):
687    # COALESCE(x) -> x
688    if (
689        isinstance(expression, exp.Coalesce)
690        and (not expression.expressions or _is_nonnull_constant(expression.this))
691        # COALESCE is also used as a Spark partitioning hint
692        and not isinstance(expression.parent, exp.Hint)
693    ):
694        return expression.this
695
696    if not isinstance(expression, COMPARISONS):
697        return expression
698
699    if isinstance(expression.left, exp.Coalesce):
700        coalesce = expression.left
701        other = expression.right
702    elif isinstance(expression.right, exp.Coalesce):
703        coalesce = expression.right
704        other = expression.left
705    else:
706        return expression
707
708    # This transformation is valid for non-constants,
709    # but it really only does anything if they are both constants.
710    if not _is_constant(other):
711        return expression
712
713    # Find the first constant arg
714    for arg_index, arg in enumerate(coalesce.expressions):
715        if _is_constant(arg):
716            break
717    else:
718        return expression
719
720    coalesce.set("expressions", coalesce.expressions[:arg_index])
721
722    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
723    # since we already remove COALESCE at the top of this function.
724    coalesce = coalesce if coalesce.expressions else coalesce.this
725
726    # This expression is more complex than when we started, but it will get simplified further
727    return exp.paren(
728        exp.or_(
729            exp.and_(
730                coalesce.is_(exp.null()).not_(copy=False),
731                expression.copy(),
732                copy=False,
733            ),
734            exp.and_(
735                coalesce.is_(exp.null()),
736                type(expression)(this=arg.copy(), expression=other.copy()),
737                copy=False,
738            ),
739            copy=False,
740        )
741    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
747def simplify_concat(expression):
748    """Reduces all groups that contain string literals by concatenating them."""
749    if not isinstance(expression, CONCATS) or (
750        # We can't reduce a CONCAT_WS call if we don't statically know the separator
751        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
752    ):
753        return expression
754
755    if isinstance(expression, exp.ConcatWs):
756        sep_expr, *expressions = expression.expressions
757        sep = sep_expr.name
758        concat_type = exp.ConcatWs
759        args = {}
760    else:
761        expressions = expression.expressions
762        sep = ""
763        concat_type = exp.Concat
764        args = {
765            "safe": expression.args.get("safe"),
766            "coalesce": expression.args.get("coalesce"),
767        }
768
769    new_args = []
770    for is_string_group, group in itertools.groupby(
771        expressions or expression.flatten(), lambda e: e.is_string
772    ):
773        if is_string_group:
774            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
775        else:
776            new_args.extend(group)
777
778    if len(new_args) == 1 and new_args[0].is_string:
779        return new_args[0]
780
781    if concat_type is exp.ConcatWs:
782        new_args = [sep_expr] + new_args
783
784    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
787def simplify_conditionals(expression):
788    """Simplifies expressions like IF, CASE if their condition is statically known."""
789    if isinstance(expression, exp.Case):
790        this = expression.this
791        for case in expression.args["ifs"]:
792            cond = case.this
793            if this:
794                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
795                cond = cond.replace(this.pop().eq(cond))
796
797            if always_true(cond):
798                return case.args["true"]
799
800            if always_false(cond):
801                case.pop()
802                if not expression.args["ifs"]:
803                    return expression.args.get("default") or exp.null()
804    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
805        if always_true(expression.this):
806            return expression.args["true"]
807        if always_false(expression.this):
808            return expression.args.get("false") or exp.null()
809
810    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
813def simplify_startswith(expression: exp.Expression) -> exp.Expression:
814    """
815    Reduces a prefix check to either TRUE or FALSE if both the string and the
816    prefix are statically known.
817
818    Example:
819        >>> from sqlglot import parse_one
820        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
821        'TRUE'
822    """
823    if (
824        isinstance(expression, exp.StartsWith)
825        and expression.this.is_string
826        and expression.expression.is_string
827    ):
828        return exp.convert(expression.name.startswith(expression.expression.name))
829
830    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GT'>}
def simplify_datetrunc(expression, *args, **kwargs):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
957def sort_comparison(expression: exp.Expression) -> exp.Expression:
958    if expression.__class__ in COMPLEMENT_COMPARISONS:
959        l, r = expression.this, expression.expression
960        l_column = isinstance(l, exp.Column)
961        r_column = isinstance(r, exp.Column)
962        l_const = _is_constant(l)
963        r_const = _is_constant(r)
964
965        if (l_column and not r_column) or (r_const and not l_const):
966            return expression
967        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
968            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
969                this=r, expression=l
970            )
971    return expression
JOINS = {('', 'INNER'), ('RIGHT', ''), ('', ''), ('RIGHT', 'OUTER')}
def remove_where_true(expression):
985def remove_where_true(expression):
986    for where in expression.find_all(exp.Where):
987        if always_true(where.this):
988            where.pop()
989    for join in expression.find_all(exp.Join):
990        if (
991            always_true(join.args.get("on"))
992            and not join.args.get("using")
993            and not join.args.get("method")
994            and (join.side, join.kind) in JOINS
995        ):
996            join.args["on"].pop()
997            join.set("side", None)
998            join.set("kind", "CROSS")
def always_true(expression):
1001def always_true(expression):
1002    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1003        expression, exp.Literal
1004    )
def always_false(expression):
1007def always_false(expression):
1008    return is_false(expression) or is_null(expression)
def is_complement(a, b):
1011def is_complement(a, b):
1012    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1015def is_false(a: exp.Expression) -> bool:
1016    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1019def is_null(a: exp.Expression) -> bool:
1020    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1023def eval_boolean(expression, a, b):
1024    if isinstance(expression, (exp.EQ, exp.Is)):
1025        return boolean_literal(a == b)
1026    if isinstance(expression, exp.NEQ):
1027        return boolean_literal(a != b)
1028    if isinstance(expression, exp.GT):
1029        return boolean_literal(a > b)
1030    if isinstance(expression, exp.GTE):
1031        return boolean_literal(a >= b)
1032    if isinstance(expression, exp.LT):
1033        return boolean_literal(a < b)
1034    if isinstance(expression, exp.LTE):
1035        return boolean_literal(a <= b)
1036    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1039def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1040    if isinstance(value, datetime.datetime):
1041        return value.date()
1042    if isinstance(value, datetime.date):
1043        return value
1044    try:
1045        return datetime.datetime.fromisoformat(value).date()
1046    except ValueError:
1047        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1050def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1051    if isinstance(value, datetime.datetime):
1052        return value
1053    if isinstance(value, datetime.date):
1054        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1055    try:
1056        return datetime.datetime.fromisoformat(value)
1057    except ValueError:
1058        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1061def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1062    if not value:
1063        return None
1064    if to.is_type(exp.DataType.Type.DATE):
1065        return cast_as_date(value)
1066    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1067        return cast_as_datetime(value)
1068    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1071def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1072    if isinstance(cast, exp.Cast):
1073        to = cast.to
1074    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1075        to = exp.DataType.build(exp.DataType.Type.DATE)
1076    else:
1077        return None
1078
1079    if isinstance(cast.this, exp.Literal):
1080        value: t.Any = cast.this.name
1081    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1082        value = extract_date(cast.this)
1083    else:
1084        return None
1085    return cast_value(value, to)
def extract_interval(expression):
1092def extract_interval(expression):
1093    try:
1094        n = int(expression.name)
1095        unit = expression.text("unit").lower()
1096        return interval(unit, n)
1097    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1098        return None
def date_literal(date):
1101def date_literal(date):
1102    return exp.cast(
1103        exp.Literal.string(date),
1104        (
1105            exp.DataType.Type.DATETIME
1106            if isinstance(date, datetime.datetime)
1107            else exp.DataType.Type.DATE
1108        ),
1109    )
def interval(unit: str, n: int = 1):
1112def interval(unit: str, n: int = 1):
1113    from dateutil.relativedelta import relativedelta
1114
1115    if unit == "year":
1116        return relativedelta(years=1 * n)
1117    if unit == "quarter":
1118        return relativedelta(months=3 * n)
1119    if unit == "month":
1120        return relativedelta(months=1 * n)
1121    if unit == "week":
1122        return relativedelta(weeks=1 * n)
1123    if unit == "day":
1124        return relativedelta(days=1 * n)
1125    if unit == "hour":
1126        return relativedelta(hours=1 * n)
1127    if unit == "minute":
1128        return relativedelta(minutes=1 * n)
1129    if unit == "second":
1130        return relativedelta(seconds=1 * n)
1131
1132    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1135def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1136    if unit == "year":
1137        return d.replace(month=1, day=1)
1138    if unit == "quarter":
1139        if d.month <= 3:
1140            return d.replace(month=1, day=1)
1141        elif d.month <= 6:
1142            return d.replace(month=4, day=1)
1143        elif d.month <= 9:
1144            return d.replace(month=7, day=1)
1145        else:
1146            return d.replace(month=10, day=1)
1147    if unit == "month":
1148        return d.replace(month=d.month, day=1)
1149    if unit == "week":
1150        # Assuming week starts on Monday (0) and ends on Sunday (6)
1151        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1152    if unit == "day":
1153        return d
1154
1155    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1158def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1159    floor = date_floor(d, unit, dialect)
1160
1161    if floor == d:
1162        return d
1163
1164    return floor + interval(unit)
def boolean_literal(condition):
1167def boolean_literal(condition):
1168    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1197def gen(expression: t.Any) -> str:
1198    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1199
1200    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1201    generator is expensive so we have a bare minimum sql generator here.
1202    """
1203    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1206class Gen:
1207    def __init__(self):
1208        self.stack = []
1209        self.sqls = []
1210
1211    def gen(self, expression: exp.Expression) -> str:
1212        self.stack = [expression]
1213        self.sqls.clear()
1214
1215        while self.stack:
1216            node = self.stack.pop()
1217
1218            if isinstance(node, exp.Expression):
1219                exp_handler_name = f"{node.key}_sql"
1220
1221                if hasattr(self, exp_handler_name):
1222                    getattr(self, exp_handler_name)(node)
1223                elif isinstance(node, exp.Func):
1224                    self._function(node)
1225                else:
1226                    key = node.key.upper()
1227                    self.stack.append(f"{key} " if self._args(node) else key)
1228            elif type(node) is list:
1229                for n in reversed(node):
1230                    if n is not None:
1231                        self.stack.extend((n, ","))
1232                if node:
1233                    self.stack.pop()
1234            else:
1235                if node is not None:
1236                    self.sqls.append(str(node))
1237
1238        return "".join(self.sqls)
1239
1240    def add_sql(self, e: exp.Add) -> None:
1241        self._binary(e, " + ")
1242
1243    def alias_sql(self, e: exp.Alias) -> None:
1244        self.stack.extend(
1245            (
1246                e.args.get("alias"),
1247                " AS ",
1248                e.args.get("this"),
1249            )
1250        )
1251
1252    def and_sql(self, e: exp.And) -> None:
1253        self._binary(e, " AND ")
1254
1255    def anonymous_sql(self, e: exp.Anonymous) -> None:
1256        this = e.this
1257        if isinstance(this, str):
1258            name = this.upper()
1259        elif isinstance(this, exp.Identifier):
1260            name = this.this
1261            name = f'"{name}"' if this.quoted else name.upper()
1262        else:
1263            raise ValueError(
1264                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1265            )
1266
1267        self.stack.extend(
1268            (
1269                ")",
1270                e.expressions,
1271                "(",
1272                name,
1273            )
1274        )
1275
1276    def between_sql(self, e: exp.Between) -> None:
1277        self.stack.extend(
1278            (
1279                e.args.get("high"),
1280                " AND ",
1281                e.args.get("low"),
1282                " BETWEEN ",
1283                e.this,
1284            )
1285        )
1286
1287    def boolean_sql(self, e: exp.Boolean) -> None:
1288        self.stack.append("TRUE" if e.this else "FALSE")
1289
1290    def bracket_sql(self, e: exp.Bracket) -> None:
1291        self.stack.extend(
1292            (
1293                "]",
1294                e.expressions,
1295                "[",
1296                e.this,
1297            )
1298        )
1299
1300    def column_sql(self, e: exp.Column) -> None:
1301        for p in reversed(e.parts):
1302            self.stack.extend((p, "."))
1303        self.stack.pop()
1304
1305    def datatype_sql(self, e: exp.DataType) -> None:
1306        self._args(e, 1)
1307        self.stack.append(f"{e.this.name} ")
1308
1309    def div_sql(self, e: exp.Div) -> None:
1310        self._binary(e, " / ")
1311
1312    def dot_sql(self, e: exp.Dot) -> None:
1313        self._binary(e, ".")
1314
1315    def eq_sql(self, e: exp.EQ) -> None:
1316        self._binary(e, " = ")
1317
1318    def from_sql(self, e: exp.From) -> None:
1319        self.stack.extend((e.this, "FROM "))
1320
1321    def gt_sql(self, e: exp.GT) -> None:
1322        self._binary(e, " > ")
1323
1324    def gte_sql(self, e: exp.GTE) -> None:
1325        self._binary(e, " >= ")
1326
1327    def identifier_sql(self, e: exp.Identifier) -> None:
1328        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1329
1330    def ilike_sql(self, e: exp.ILike) -> None:
1331        self._binary(e, " ILIKE ")
1332
1333    def in_sql(self, e: exp.In) -> None:
1334        self.stack.append(")")
1335        self._args(e, 1)
1336        self.stack.extend(
1337            (
1338                "(",
1339                " IN ",
1340                e.this,
1341            )
1342        )
1343
1344    def intdiv_sql(self, e: exp.IntDiv) -> None:
1345        self._binary(e, " DIV ")
1346
1347    def is_sql(self, e: exp.Is) -> None:
1348        self._binary(e, " IS ")
1349
1350    def like_sql(self, e: exp.Like) -> None:
1351        self._binary(e, " Like ")
1352
1353    def literal_sql(self, e: exp.Literal) -> None:
1354        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1355
1356    def lt_sql(self, e: exp.LT) -> None:
1357        self._binary(e, " < ")
1358
1359    def lte_sql(self, e: exp.LTE) -> None:
1360        self._binary(e, " <= ")
1361
1362    def mod_sql(self, e: exp.Mod) -> None:
1363        self._binary(e, " % ")
1364
1365    def mul_sql(self, e: exp.Mul) -> None:
1366        self._binary(e, " * ")
1367
1368    def neg_sql(self, e: exp.Neg) -> None:
1369        self._unary(e, "-")
1370
1371    def neq_sql(self, e: exp.NEQ) -> None:
1372        self._binary(e, " <> ")
1373
1374    def not_sql(self, e: exp.Not) -> None:
1375        self._unary(e, "NOT ")
1376
1377    def null_sql(self, e: exp.Null) -> None:
1378        self.stack.append("NULL")
1379
1380    def or_sql(self, e: exp.Or) -> None:
1381        self._binary(e, " OR ")
1382
1383    def paren_sql(self, e: exp.Paren) -> None:
1384        self.stack.extend(
1385            (
1386                ")",
1387                e.this,
1388                "(",
1389            )
1390        )
1391
1392    def sub_sql(self, e: exp.Sub) -> None:
1393        self._binary(e, " - ")
1394
1395    def subquery_sql(self, e: exp.Subquery) -> None:
1396        self._args(e, 2)
1397        alias = e.args.get("alias")
1398        if alias:
1399            self.stack.append(alias)
1400        self.stack.extend((")", e.this, "("))
1401
1402    def table_sql(self, e: exp.Table) -> None:
1403        self._args(e, 4)
1404        alias = e.args.get("alias")
1405        if alias:
1406            self.stack.append(alias)
1407        for p in reversed(e.parts):
1408            self.stack.extend((p, "."))
1409        self.stack.pop()
1410
1411    def tablealias_sql(self, e: exp.TableAlias) -> None:
1412        columns = e.columns
1413
1414        if columns:
1415            self.stack.extend((")", columns, "("))
1416
1417        self.stack.extend((e.this, " AS "))
1418
1419    def var_sql(self, e: exp.Var) -> None:
1420        self.stack.append(e.this)
1421
1422    def _binary(self, e: exp.Binary, op: str) -> None:
1423        self.stack.extend((e.expression, op, e.this))
1424
1425    def _unary(self, e: exp.Unary, op: str) -> None:
1426        self.stack.extend((e.this, op))
1427
1428    def _function(self, e: exp.Func) -> None:
1429        self.stack.extend(
1430            (
1431                ")",
1432                list(e.args.values()),
1433                "(",
1434                e.sql_name(),
1435            )
1436        )
1437
1438    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1439        kvs = []
1440        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1441
1442        for k in arg_types or arg_types:
1443            v = node.args.get(k)
1444
1445            if v is not None:
1446                kvs.append([f":{k}", v])
1447        if kvs:
1448            self.stack.append(kvs)
1449            return True
1450        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1211    def gen(self, expression: exp.Expression) -> str:
1212        self.stack = [expression]
1213        self.sqls.clear()
1214
1215        while self.stack:
1216            node = self.stack.pop()
1217
1218            if isinstance(node, exp.Expression):
1219                exp_handler_name = f"{node.key}_sql"
1220
1221                if hasattr(self, exp_handler_name):
1222                    getattr(self, exp_handler_name)(node)
1223                elif isinstance(node, exp.Func):
1224                    self._function(node)
1225                else:
1226                    key = node.key.upper()
1227                    self.stack.append(f"{key} " if self._args(node) else key)
1228            elif type(node) is list:
1229                for n in reversed(node):
1230                    if n is not None:
1231                        self.stack.extend((n, ","))
1232                if node:
1233                    self.stack.pop()
1234            else:
1235                if node is not None:
1236                    self.sqls.append(str(node))
1237
1238        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1240    def add_sql(self, e: exp.Add) -> None:
1241        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1243    def alias_sql(self, e: exp.Alias) -> None:
1244        self.stack.extend(
1245            (
1246                e.args.get("alias"),
1247                " AS ",
1248                e.args.get("this"),
1249            )
1250        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1252    def and_sql(self, e: exp.And) -> None:
1253        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1255    def anonymous_sql(self, e: exp.Anonymous) -> None:
1256        this = e.this
1257        if isinstance(this, str):
1258            name = this.upper()
1259        elif isinstance(this, exp.Identifier):
1260            name = this.this
1261            name = f'"{name}"' if this.quoted else name.upper()
1262        else:
1263            raise ValueError(
1264                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1265            )
1266
1267        self.stack.extend(
1268            (
1269                ")",
1270                e.expressions,
1271                "(",
1272                name,
1273            )
1274        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1276    def between_sql(self, e: exp.Between) -> None:
1277        self.stack.extend(
1278            (
1279                e.args.get("high"),
1280                " AND ",
1281                e.args.get("low"),
1282                " BETWEEN ",
1283                e.this,
1284            )
1285        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1287    def boolean_sql(self, e: exp.Boolean) -> None:
1288        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1290    def bracket_sql(self, e: exp.Bracket) -> None:
1291        self.stack.extend(
1292            (
1293                "]",
1294                e.expressions,
1295                "[",
1296                e.this,
1297            )
1298        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1300    def column_sql(self, e: exp.Column) -> None:
1301        for p in reversed(e.parts):
1302            self.stack.extend((p, "."))
1303        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1305    def datatype_sql(self, e: exp.DataType) -> None:
1306        self._args(e, 1)
1307        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1309    def div_sql(self, e: exp.Div) -> None:
1310        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1312    def dot_sql(self, e: exp.Dot) -> None:
1313        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1315    def eq_sql(self, e: exp.EQ) -> None:
1316        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1318    def from_sql(self, e: exp.From) -> None:
1319        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1321    def gt_sql(self, e: exp.GT) -> None:
1322        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1324    def gte_sql(self, e: exp.GTE) -> None:
1325        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1327    def identifier_sql(self, e: exp.Identifier) -> None:
1328        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1330    def ilike_sql(self, e: exp.ILike) -> None:
1331        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1333    def in_sql(self, e: exp.In) -> None:
1334        self.stack.append(")")
1335        self._args(e, 1)
1336        self.stack.extend(
1337            (
1338                "(",
1339                " IN ",
1340                e.this,
1341            )
1342        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1344    def intdiv_sql(self, e: exp.IntDiv) -> None:
1345        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1347    def is_sql(self, e: exp.Is) -> None:
1348        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1350    def like_sql(self, e: exp.Like) -> None:
1351        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1353    def literal_sql(self, e: exp.Literal) -> None:
1354        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1356    def lt_sql(self, e: exp.LT) -> None:
1357        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1359    def lte_sql(self, e: exp.LTE) -> None:
1360        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1362    def mod_sql(self, e: exp.Mod) -> None:
1363        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1365    def mul_sql(self, e: exp.Mul) -> None:
1366        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1368    def neg_sql(self, e: exp.Neg) -> None:
1369        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1371    def neq_sql(self, e: exp.NEQ) -> None:
1372        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1374    def not_sql(self, e: exp.Not) -> None:
1375        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1377    def null_sql(self, e: exp.Null) -> None:
1378        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1380    def or_sql(self, e: exp.Or) -> None:
1381        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1383    def paren_sql(self, e: exp.Paren) -> None:
1384        self.stack.extend(
1385            (
1386                ")",
1387                e.this,
1388                "(",
1389            )
1390        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1392    def sub_sql(self, e: exp.Sub) -> None:
1393        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1395    def subquery_sql(self, e: exp.Subquery) -> None:
1396        self._args(e, 2)
1397        alias = e.args.get("alias")
1398        if alias:
1399            self.stack.append(alias)
1400        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1402    def table_sql(self, e: exp.Table) -> None:
1403        self._args(e, 4)
1404        alias = e.args.get("alias")
1405        if alias:
1406            self.stack.append(alias)
1407        for p in reversed(e.parts):
1408            self.stack.extend((p, "."))
1409        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1411    def tablealias_sql(self, e: exp.TableAlias) -> None:
1412        columns = e.columns
1413
1414        if columns:
1415            self.stack.extend((")", columns, "("))
1416
1417        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1419    def var_sql(self, e: exp.Var) -> None:
1420        self.stack.append(e.this)