Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9
  10import sqlglot
  11from sqlglot import Dialect, exp
  12from sqlglot.helper import first, merge_ranges, while_changing
  13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot.dialects.dialect import DialectType
  17
  18    DateTruncBinaryTransform = t.Callable[
  19        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  20    ]
  21
  22# Final means that an expression should not be simplified
  23FINAL = "final"
  24
  25# Value ranges for byte-sized signed/unsigned integers
  26TINYINT_MIN = -128
  27TINYINT_MAX = 127
  28UTINYINT_MIN = 0
  29UTINYINT_MAX = 255
  30
  31
  32class UnsupportedUnit(Exception):
  33    pass
  34
  35
  36def simplify(
  37    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  38):
  39    """
  40    Rewrite sqlglot AST to simplify expressions.
  41
  42    Example:
  43        >>> import sqlglot
  44        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  45        >>> simplify(expression).sql()
  46        'TRUE'
  47
  48    Args:
  49        expression (sqlglot.Expression): expression to simplify
  50        constant_propagation: whether the constant propagation rule should be used
  51
  52    Returns:
  53        sqlglot.Expression: simplified expression
  54    """
  55
  56    dialect = Dialect.get_or_raise(dialect)
  57
  58    def _simplify(expression, root=True):
  59        if expression.meta.get(FINAL):
  60            return expression
  61
  62        # group by expressions cannot be simplified, for example
  63        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  64        # the projection must exactly match the group by key
  65        group = expression.args.get("group")
  66
  67        if group and hasattr(expression, "selects"):
  68            groups = set(group.expressions)
  69            group.meta[FINAL] = True
  70
  71            for e in expression.selects:
  72                for node in e.walk():
  73                    if node in groups:
  74                        e.meta[FINAL] = True
  75                        break
  76
  77            having = expression.args.get("having")
  78            if having:
  79                for node in having.walk():
  80                    if node in groups:
  81                        having.meta[FINAL] = True
  82                        break
  83
  84        # Pre-order transformations
  85        node = expression
  86        node = rewrite_between(node)
  87        node = uniq_sort(node, root)
  88        node = absorb_and_eliminate(node, root)
  89        node = simplify_concat(node)
  90        node = simplify_conditionals(node)
  91
  92        if constant_propagation:
  93            node = propagate_constants(node, root)
  94
  95        exp.replace_children(node, lambda e: _simplify(e, False))
  96
  97        # Post-order transformations
  98        node = simplify_not(node)
  99        node = flatten(node)
 100        node = simplify_connectors(node, root)
 101        node = remove_complements(node, root)
 102        node = simplify_coalesce(node)
 103        node.parent = expression.parent
 104        node = simplify_literals(node, root)
 105        node = simplify_equality(node)
 106        node = simplify_parens(node)
 107        node = simplify_datetrunc(node, dialect)
 108        node = sort_comparison(node)
 109        node = simplify_startswith(node)
 110
 111        if root:
 112            expression.replace(node)
 113        return node
 114
 115    expression = while_changing(expression, _simplify)
 116    remove_where_true(expression)
 117    return expression
 118
 119
 120def catch(*exceptions):
 121    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 122
 123    def decorator(func):
 124        def wrapped(expression, *args, **kwargs):
 125            try:
 126                return func(expression, *args, **kwargs)
 127            except exceptions:
 128                return expression
 129
 130        return wrapped
 131
 132    return decorator
 133
 134
 135def rewrite_between(expression: exp.Expression) -> exp.Expression:
 136    """Rewrite x between y and z to x >= y AND x <= z.
 137
 138    This is done because comparison simplification is only done on lt/lte/gt/gte.
 139    """
 140    if isinstance(expression, exp.Between):
 141        negate = isinstance(expression.parent, exp.Not)
 142
 143        expression = exp.and_(
 144            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 145            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 146            copy=False,
 147        )
 148
 149        if negate:
 150            expression = exp.paren(expression, copy=False)
 151
 152    return expression
 153
 154
 155COMPLEMENT_COMPARISONS = {
 156    exp.LT: exp.GTE,
 157    exp.GT: exp.LTE,
 158    exp.LTE: exp.GT,
 159    exp.GTE: exp.LT,
 160    exp.EQ: exp.NEQ,
 161    exp.NEQ: exp.EQ,
 162}
 163
 164
 165def simplify_not(expression):
 166    """
 167    Demorgan's Law
 168    NOT (x OR y) -> NOT x AND NOT y
 169    NOT (x AND y) -> NOT x OR NOT y
 170    """
 171    if isinstance(expression, exp.Not):
 172        this = expression.this
 173        if is_null(this):
 174            return exp.null()
 175        if this.__class__ in COMPLEMENT_COMPARISONS:
 176            return COMPLEMENT_COMPARISONS[this.__class__](
 177                this=this.this, expression=this.expression
 178            )
 179        if isinstance(this, exp.Paren):
 180            condition = this.unnest()
 181            if isinstance(condition, exp.And):
 182                return exp.paren(
 183                    exp.or_(
 184                        exp.not_(condition.left, copy=False),
 185                        exp.not_(condition.right, copy=False),
 186                        copy=False,
 187                    )
 188                )
 189            if isinstance(condition, exp.Or):
 190                return exp.paren(
 191                    exp.and_(
 192                        exp.not_(condition.left, copy=False),
 193                        exp.not_(condition.right, copy=False),
 194                        copy=False,
 195                    )
 196                )
 197            if is_null(condition):
 198                return exp.null()
 199        if always_true(this):
 200            return exp.false()
 201        if is_false(this):
 202            return exp.true()
 203        if isinstance(this, exp.Not):
 204            # double negation
 205            # NOT NOT x -> x
 206            return this.this
 207    return expression
 208
 209
 210def flatten(expression):
 211    """
 212    A AND (B AND C) -> A AND B AND C
 213    A OR (B OR C) -> A OR B OR C
 214    """
 215    if isinstance(expression, exp.Connector):
 216        for node in expression.args.values():
 217            child = node.unnest()
 218            if isinstance(child, expression.__class__):
 219                node.replace(child)
 220    return expression
 221
 222
 223def simplify_connectors(expression, root=True):
 224    def _simplify_connectors(expression, left, right):
 225        if left == right:
 226            return left
 227        if isinstance(expression, exp.And):
 228            if is_false(left) or is_false(right):
 229                return exp.false()
 230            if is_null(left) or is_null(right):
 231                return exp.null()
 232            if always_true(left) and always_true(right):
 233                return exp.true()
 234            if always_true(left):
 235                return right
 236            if always_true(right):
 237                return left
 238            return _simplify_comparison(expression, left, right)
 239        elif isinstance(expression, exp.Or):
 240            if always_true(left) or always_true(right):
 241                return exp.true()
 242            if is_false(left) and is_false(right):
 243                return exp.false()
 244            if (
 245                (is_null(left) and is_null(right))
 246                or (is_null(left) and is_false(right))
 247                or (is_false(left) and is_null(right))
 248            ):
 249                return exp.null()
 250            if is_false(left):
 251                return right
 252            if is_false(right):
 253                return left
 254            return _simplify_comparison(expression, left, right, or_=True)
 255
 256    if isinstance(expression, exp.Connector):
 257        return _flat_simplify(expression, _simplify_connectors, root)
 258    return expression
 259
 260
 261LT_LTE = (exp.LT, exp.LTE)
 262GT_GTE = (exp.GT, exp.GTE)
 263
 264COMPARISONS = (
 265    *LT_LTE,
 266    *GT_GTE,
 267    exp.EQ,
 268    exp.NEQ,
 269    exp.Is,
 270)
 271
 272INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 273    exp.LT: exp.GT,
 274    exp.GT: exp.LT,
 275    exp.LTE: exp.GTE,
 276    exp.GTE: exp.LTE,
 277}
 278
 279NONDETERMINISTIC = (exp.Rand, exp.Randn)
 280
 281
 282def _simplify_comparison(expression, left, right, or_=False):
 283    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 284        ll, lr = left.args.values()
 285        rl, rr = right.args.values()
 286
 287        largs = {ll, lr}
 288        rargs = {rl, rr}
 289
 290        matching = largs & rargs
 291        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 292
 293        if matching and columns:
 294            try:
 295                l = first(largs - columns)
 296                r = first(rargs - columns)
 297            except StopIteration:
 298                return expression
 299
 300            if l.is_number and r.is_number:
 301                l = float(l.name)
 302                r = float(r.name)
 303            elif l.is_string and r.is_string:
 304                l = l.name
 305                r = r.name
 306            else:
 307                l = extract_date(l)
 308                if not l:
 309                    return None
 310                r = extract_date(r)
 311                if not r:
 312                    return None
 313                # python won't compare date and datetime, but many engines will upcast
 314                l, r = cast_as_datetime(l), cast_as_datetime(r)
 315
 316            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 317                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 318                    return left if (av > bv if or_ else av <= bv) else right
 319                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 320                    return left if (av < bv if or_ else av >= bv) else right
 321
 322                # we can't ever shortcut to true because the column could be null
 323                if not or_:
 324                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 325                        if av <= bv:
 326                            return exp.false()
 327                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 328                        if av >= bv:
 329                            return exp.false()
 330                    elif isinstance(a, exp.EQ):
 331                        if isinstance(b, exp.LT):
 332                            return exp.false() if av >= bv else a
 333                        if isinstance(b, exp.LTE):
 334                            return exp.false() if av > bv else a
 335                        if isinstance(b, exp.GT):
 336                            return exp.false() if av <= bv else a
 337                        if isinstance(b, exp.GTE):
 338                            return exp.false() if av < bv else a
 339                        if isinstance(b, exp.NEQ):
 340                            return exp.false() if av == bv else a
 341    return None
 342
 343
 344def remove_complements(expression, root=True):
 345    """
 346    Removing complements.
 347
 348    A AND NOT A -> FALSE
 349    A OR NOT A -> TRUE
 350    """
 351    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 352        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 353
 354        for a, b in itertools.permutations(expression.flatten(), 2):
 355            if is_complement(a, b):
 356                return complement
 357    return expression
 358
 359
 360def uniq_sort(expression, root=True):
 361    """
 362    Uniq and sort a connector.
 363
 364    C AND A AND B AND B -> A AND B AND C
 365    """
 366    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 367        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 368        flattened = tuple(expression.flatten())
 369        deduped = {gen(e): e for e in flattened}
 370        arr = tuple(deduped.items())
 371
 372        # check if the operands are already sorted, if not sort them
 373        # A AND C AND B -> A AND B AND C
 374        for i, (sql, e) in enumerate(arr[1:]):
 375            if sql < arr[i][0]:
 376                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 377                break
 378        else:
 379            # we didn't have to sort but maybe we need to dedup
 380            if len(deduped) < len(flattened):
 381                expression = result_func(*deduped.values(), copy=False)
 382
 383    return expression
 384
 385
 386def absorb_and_eliminate(expression, root=True):
 387    """
 388    absorption:
 389        A AND (A OR B) -> A
 390        A OR (A AND B) -> A
 391        A AND (NOT A OR B) -> A AND B
 392        A OR (NOT A AND B) -> A OR B
 393    elimination:
 394        (A AND B) OR (A AND NOT B) -> A
 395        (A OR B) AND (A OR NOT B) -> A
 396    """
 397    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 398        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 399
 400        for a, b in itertools.permutations(expression.flatten(), 2):
 401            if isinstance(a, kind):
 402                aa, ab = a.unnest_operands()
 403
 404                # absorb
 405                if is_complement(b, aa):
 406                    aa.replace(exp.true() if kind == exp.And else exp.false())
 407                elif is_complement(b, ab):
 408                    ab.replace(exp.true() if kind == exp.And else exp.false())
 409                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 410                    a.replace(exp.false() if kind == exp.And else exp.true())
 411                elif isinstance(b, kind):
 412                    # eliminate
 413                    rhs = b.unnest_operands()
 414                    ba, bb = rhs
 415
 416                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 417                        a.replace(aa)
 418                        b.replace(aa)
 419                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 420                        a.replace(ab)
 421                        b.replace(ab)
 422
 423    return expression
 424
 425
 426def propagate_constants(expression, root=True):
 427    """
 428    Propagate constants for conjunctions in DNF:
 429
 430    SELECT * FROM t WHERE a = b AND b = 5 becomes
 431    SELECT * FROM t WHERE a = 5 AND b = 5
 432
 433    Reference: https://www.sqlite.org/optoverview.html
 434    """
 435
 436    if (
 437        isinstance(expression, exp.And)
 438        and (root or not expression.same_parent)
 439        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 440    ):
 441        constant_mapping = {}
 442        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 443            if isinstance(expr, exp.EQ):
 444                l, r = expr.left, expr.right
 445
 446                # TODO: create a helper that can be used to detect nested literal expressions such
 447                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 448                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 449                    constant_mapping[l] = (id(l), r)
 450
 451        if constant_mapping:
 452            for column in find_all_in_scope(expression, exp.Column):
 453                parent = column.parent
 454                column_id, constant = constant_mapping.get(column) or (None, None)
 455                if (
 456                    column_id is not None
 457                    and id(column) != column_id
 458                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 459                ):
 460                    column.replace(constant.copy())
 461
 462    return expression
 463
 464
 465INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 466    exp.DateAdd: exp.Sub,
 467    exp.DateSub: exp.Add,
 468    exp.DatetimeAdd: exp.Sub,
 469    exp.DatetimeSub: exp.Add,
 470}
 471
 472INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 473    **INVERSE_DATE_OPS,
 474    exp.Add: exp.Sub,
 475    exp.Sub: exp.Add,
 476}
 477
 478
 479def _is_number(expression: exp.Expression) -> bool:
 480    return expression.is_number
 481
 482
 483def _is_interval(expression: exp.Expression) -> bool:
 484    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 485
 486
 487@catch(ModuleNotFoundError, UnsupportedUnit)
 488def simplify_equality(expression: exp.Expression) -> exp.Expression:
 489    """
 490    Use the subtraction and addition properties of equality to simplify expressions:
 491
 492        x + 1 = 3 becomes x = 2
 493
 494    There are two binary operations in the above expression: + and =
 495    Here's how we reference all the operands in the code below:
 496
 497          l     r
 498        x + 1 = 3
 499        a   b
 500    """
 501    if isinstance(expression, COMPARISONS):
 502        l, r = expression.left, expression.right
 503
 504        if l.__class__ not in INVERSE_OPS:
 505            return expression
 506
 507        if r.is_number:
 508            a_predicate = _is_number
 509            b_predicate = _is_number
 510        elif _is_date_literal(r):
 511            a_predicate = _is_date_literal
 512            b_predicate = _is_interval
 513        else:
 514            return expression
 515
 516        if l.__class__ in INVERSE_DATE_OPS:
 517            l = t.cast(exp.IntervalOp, l)
 518            a = l.this
 519            b = l.interval()
 520        else:
 521            l = t.cast(exp.Binary, l)
 522            a, b = l.left, l.right
 523
 524        if not a_predicate(a) and b_predicate(b):
 525            pass
 526        elif not a_predicate(b) and b_predicate(a):
 527            a, b = b, a
 528        else:
 529            return expression
 530
 531        return expression.__class__(
 532            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 533        )
 534    return expression
 535
 536
 537def simplify_literals(expression, root=True):
 538    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 539        return _flat_simplify(expression, _simplify_binary, root)
 540
 541    if isinstance(expression, exp.Neg):
 542        this = expression.this
 543        if this.is_number:
 544            value = this.name
 545            if value[0] == "-":
 546                return exp.Literal.number(value[1:])
 547            return exp.Literal.number(f"-{value}")
 548
 549    if type(expression) in INVERSE_DATE_OPS:
 550        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 551
 552    return expression
 553
 554
 555NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 556
 557
 558def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 559    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 560        this = _simplify_integer_cast(expr.this)
 561    else:
 562        this = expr.this
 563
 564    if isinstance(expr, exp.Cast) and this.is_int:
 565        num = int(this.name)
 566
 567        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 568        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 569        # engine-dependent
 570        if (
 571            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 572        ) or (
 573            UTINYINT_MIN <= num <= UTINYINT_MAX
 574            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 575        ):
 576            return this
 577
 578    return expr
 579
 580
 581def _simplify_binary(expression, a, b):
 582    if isinstance(expression, COMPARISONS):
 583        a = _simplify_integer_cast(a)
 584        b = _simplify_integer_cast(b)
 585
 586    if isinstance(expression, exp.Is):
 587        if isinstance(b, exp.Not):
 588            c = b.this
 589            not_ = True
 590        else:
 591            c = b
 592            not_ = False
 593
 594        if is_null(c):
 595            if isinstance(a, exp.Literal):
 596                return exp.true() if not_ else exp.false()
 597            if is_null(a):
 598                return exp.false() if not_ else exp.true()
 599    elif isinstance(expression, NULL_OK):
 600        return None
 601    elif is_null(a) or is_null(b):
 602        return exp.null()
 603
 604    if a.is_number and b.is_number:
 605        num_a = int(a.name) if a.is_int else Decimal(a.name)
 606        num_b = int(b.name) if b.is_int else Decimal(b.name)
 607
 608        if isinstance(expression, exp.Add):
 609            return exp.Literal.number(num_a + num_b)
 610        if isinstance(expression, exp.Mul):
 611            return exp.Literal.number(num_a * num_b)
 612
 613        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 614        if isinstance(expression, exp.Sub):
 615            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 616        if isinstance(expression, exp.Div):
 617            # engines have differing int div behavior so intdiv is not safe
 618            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 619                return None
 620            return exp.Literal.number(num_a / num_b)
 621
 622        boolean = eval_boolean(expression, num_a, num_b)
 623
 624        if boolean:
 625            return boolean
 626    elif a.is_string and b.is_string:
 627        boolean = eval_boolean(expression, a.this, b.this)
 628
 629        if boolean:
 630            return boolean
 631    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 632        date, b = extract_date(a), extract_interval(b)
 633        if date and b:
 634            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 635                return date_literal(date + b, extract_type(a))
 636            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 637                return date_literal(date - b, extract_type(a))
 638    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 639        a, date = extract_interval(a), extract_date(b)
 640        # you cannot subtract a date from an interval
 641        if a and b and isinstance(expression, exp.Add):
 642            return date_literal(a + date, extract_type(b))
 643    elif _is_date_literal(a) and _is_date_literal(b):
 644        if isinstance(expression, exp.Predicate):
 645            a, b = extract_date(a), extract_date(b)
 646            boolean = eval_boolean(expression, a, b)
 647            if boolean:
 648                return boolean
 649
 650    return None
 651
 652
 653def simplify_parens(expression):
 654    if not isinstance(expression, exp.Paren):
 655        return expression
 656
 657    this = expression.this
 658    parent = expression.parent
 659    parent_is_predicate = isinstance(parent, exp.Predicate)
 660
 661    if not isinstance(this, exp.Select) and (
 662        not isinstance(parent, (exp.Condition, exp.Binary))
 663        or isinstance(parent, exp.Paren)
 664        or (
 665            not isinstance(this, exp.Binary)
 666            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 667        )
 668        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 669        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 670        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 671        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 672    ):
 673        return this
 674    return expression
 675
 676
 677def _is_nonnull_constant(expression: exp.Expression) -> bool:
 678    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 679
 680
 681def _is_constant(expression: exp.Expression) -> bool:
 682    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 683
 684
 685def simplify_coalesce(expression):
 686    # COALESCE(x) -> x
 687    if (
 688        isinstance(expression, exp.Coalesce)
 689        and (not expression.expressions or _is_nonnull_constant(expression.this))
 690        # COALESCE is also used as a Spark partitioning hint
 691        and not isinstance(expression.parent, exp.Hint)
 692    ):
 693        return expression.this
 694
 695    if not isinstance(expression, COMPARISONS):
 696        return expression
 697
 698    if isinstance(expression.left, exp.Coalesce):
 699        coalesce = expression.left
 700        other = expression.right
 701    elif isinstance(expression.right, exp.Coalesce):
 702        coalesce = expression.right
 703        other = expression.left
 704    else:
 705        return expression
 706
 707    # This transformation is valid for non-constants,
 708    # but it really only does anything if they are both constants.
 709    if not _is_constant(other):
 710        return expression
 711
 712    # Find the first constant arg
 713    for arg_index, arg in enumerate(coalesce.expressions):
 714        if _is_constant(arg):
 715            break
 716    else:
 717        return expression
 718
 719    coalesce.set("expressions", coalesce.expressions[:arg_index])
 720
 721    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 722    # since we already remove COALESCE at the top of this function.
 723    coalesce = coalesce if coalesce.expressions else coalesce.this
 724
 725    # This expression is more complex than when we started, but it will get simplified further
 726    return exp.paren(
 727        exp.or_(
 728            exp.and_(
 729                coalesce.is_(exp.null()).not_(copy=False),
 730                expression.copy(),
 731                copy=False,
 732            ),
 733            exp.and_(
 734                coalesce.is_(exp.null()),
 735                type(expression)(this=arg.copy(), expression=other.copy()),
 736                copy=False,
 737            ),
 738            copy=False,
 739        )
 740    )
 741
 742
 743CONCATS = (exp.Concat, exp.DPipe)
 744
 745
 746def simplify_concat(expression):
 747    """Reduces all groups that contain string literals by concatenating them."""
 748    if not isinstance(expression, CONCATS) or (
 749        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 750        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 751    ):
 752        return expression
 753
 754    if isinstance(expression, exp.ConcatWs):
 755        sep_expr, *expressions = expression.expressions
 756        sep = sep_expr.name
 757        concat_type = exp.ConcatWs
 758        args = {}
 759    else:
 760        expressions = expression.expressions
 761        sep = ""
 762        concat_type = exp.Concat
 763        args = {
 764            "safe": expression.args.get("safe"),
 765            "coalesce": expression.args.get("coalesce"),
 766        }
 767
 768    new_args = []
 769    for is_string_group, group in itertools.groupby(
 770        expressions or expression.flatten(), lambda e: e.is_string
 771    ):
 772        if is_string_group:
 773            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 774        else:
 775            new_args.extend(group)
 776
 777    if len(new_args) == 1 and new_args[0].is_string:
 778        return new_args[0]
 779
 780    if concat_type is exp.ConcatWs:
 781        new_args = [sep_expr] + new_args
 782
 783    return concat_type(expressions=new_args, **args)
 784
 785
 786def simplify_conditionals(expression):
 787    """Simplifies expressions like IF, CASE if their condition is statically known."""
 788    if isinstance(expression, exp.Case):
 789        this = expression.this
 790        for case in expression.args["ifs"]:
 791            cond = case.this
 792            if this:
 793                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 794                cond = cond.replace(this.pop().eq(cond))
 795
 796            if always_true(cond):
 797                return case.args["true"]
 798
 799            if always_false(cond):
 800                case.pop()
 801                if not expression.args["ifs"]:
 802                    return expression.args.get("default") or exp.null()
 803    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 804        if always_true(expression.this):
 805            return expression.args["true"]
 806        if always_false(expression.this):
 807            return expression.args.get("false") or exp.null()
 808
 809    return expression
 810
 811
 812def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 813    """
 814    Reduces a prefix check to either TRUE or FALSE if both the string and the
 815    prefix are statically known.
 816
 817    Example:
 818        >>> from sqlglot import parse_one
 819        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 820        'TRUE'
 821    """
 822    if (
 823        isinstance(expression, exp.StartsWith)
 824        and expression.this.is_string
 825        and expression.expression.is_string
 826    ):
 827        return exp.convert(expression.name.startswith(expression.expression.name))
 828
 829    return expression
 830
 831
 832DateRange = t.Tuple[datetime.date, datetime.date]
 833
 834
 835def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 836    """
 837    Get the date range for a DATE_TRUNC equality comparison:
 838
 839    Example:
 840        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 841    Returns:
 842        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 843    """
 844    floor = date_floor(date, unit, dialect)
 845
 846    if date != floor:
 847        # This will always be False, except for NULL values.
 848        return None
 849
 850    return floor, floor + interval(unit)
 851
 852
 853def _datetrunc_eq_expression(
 854    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 855) -> exp.Expression:
 856    """Get the logical expression for a date range"""
 857    return exp.and_(
 858        left >= date_literal(drange[0], target_type),
 859        left < date_literal(drange[1], target_type),
 860        copy=False,
 861    )
 862
 863
 864def _datetrunc_eq(
 865    left: exp.Expression,
 866    date: datetime.date,
 867    unit: str,
 868    dialect: Dialect,
 869    target_type: t.Optional[exp.DataType],
 870) -> t.Optional[exp.Expression]:
 871    drange = _datetrunc_range(date, unit, dialect)
 872    if not drange:
 873        return None
 874
 875    return _datetrunc_eq_expression(left, drange, target_type)
 876
 877
 878def _datetrunc_neq(
 879    left: exp.Expression,
 880    date: datetime.date,
 881    unit: str,
 882    dialect: Dialect,
 883    target_type: t.Optional[exp.DataType],
 884) -> t.Optional[exp.Expression]:
 885    drange = _datetrunc_range(date, unit, dialect)
 886    if not drange:
 887        return None
 888
 889    return exp.and_(
 890        left < date_literal(drange[0], target_type),
 891        left >= date_literal(drange[1], target_type),
 892        copy=False,
 893    )
 894
 895
 896DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 897    exp.LT: lambda l, dt, u, d, t: l
 898    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 899    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 900    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 901    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 902    exp.EQ: _datetrunc_eq,
 903    exp.NEQ: _datetrunc_neq,
 904}
 905DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 906DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 907
 908
 909def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 910    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 911
 912
 913@catch(ModuleNotFoundError, UnsupportedUnit)
 914def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 915    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 916    comparison = expression.__class__
 917
 918    if isinstance(expression, DATETRUNCS):
 919        this = expression.this
 920        trunc_type = extract_type(this)
 921        date = extract_date(this)
 922        if date and expression.unit:
 923            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
 924    elif comparison not in DATETRUNC_COMPARISONS:
 925        return expression
 926
 927    if isinstance(expression, exp.Binary):
 928        l, r = expression.left, expression.right
 929
 930        if not _is_datetrunc_predicate(l, r):
 931            return expression
 932
 933        l = t.cast(exp.DateTrunc, l)
 934        trunc_arg = l.this
 935        unit = l.unit.name.lower()
 936        date = extract_date(r)
 937
 938        if not date:
 939            return expression
 940
 941        return (
 942            DATETRUNC_BINARY_COMPARISONS[comparison](
 943                trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
 944            )
 945            or expression
 946        )
 947
 948    if isinstance(expression, exp.In):
 949        l = expression.this
 950        rs = expression.expressions
 951
 952        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 953            l = t.cast(exp.DateTrunc, l)
 954            unit = l.unit.name.lower()
 955
 956            ranges = []
 957            for r in rs:
 958                date = extract_date(r)
 959                if not date:
 960                    return expression
 961                drange = _datetrunc_range(date, unit, dialect)
 962                if drange:
 963                    ranges.append(drange)
 964
 965            if not ranges:
 966                return expression
 967
 968            ranges = merge_ranges(ranges)
 969            target_type = extract_type(l, *rs)
 970
 971            return exp.or_(
 972                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
 973            )
 974
 975    return expression
 976
 977
 978def sort_comparison(expression: exp.Expression) -> exp.Expression:
 979    if expression.__class__ in COMPLEMENT_COMPARISONS:
 980        l, r = expression.this, expression.expression
 981        l_column = isinstance(l, exp.Column)
 982        r_column = isinstance(r, exp.Column)
 983        l_const = _is_constant(l)
 984        r_const = _is_constant(r)
 985
 986        if (l_column and not r_column) or (r_const and not l_const):
 987            return expression
 988        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 989            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 990                this=r, expression=l
 991            )
 992    return expression
 993
 994
 995# CROSS joins result in an empty table if the right table is empty.
 996# So we can only simplify certain types of joins to CROSS.
 997# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 998JOINS = {
 999    ("", ""),
1000    ("", "INNER"),
1001    ("RIGHT", ""),
1002    ("RIGHT", "OUTER"),
1003}
1004
1005
1006def remove_where_true(expression):
1007    for where in expression.find_all(exp.Where):
1008        if always_true(where.this):
1009            where.pop()
1010    for join in expression.find_all(exp.Join):
1011        if (
1012            always_true(join.args.get("on"))
1013            and not join.args.get("using")
1014            and not join.args.get("method")
1015            and (join.side, join.kind) in JOINS
1016        ):
1017            join.args["on"].pop()
1018            join.set("side", None)
1019            join.set("kind", "CROSS")
1020
1021
1022def always_true(expression):
1023    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1024        expression, exp.Literal
1025    )
1026
1027
1028def always_false(expression):
1029    return is_false(expression) or is_null(expression)
1030
1031
1032def is_complement(a, b):
1033    return isinstance(b, exp.Not) and b.this == a
1034
1035
1036def is_false(a: exp.Expression) -> bool:
1037    return type(a) is exp.Boolean and not a.this
1038
1039
1040def is_null(a: exp.Expression) -> bool:
1041    return type(a) is exp.Null
1042
1043
1044def eval_boolean(expression, a, b):
1045    if isinstance(expression, (exp.EQ, exp.Is)):
1046        return boolean_literal(a == b)
1047    if isinstance(expression, exp.NEQ):
1048        return boolean_literal(a != b)
1049    if isinstance(expression, exp.GT):
1050        return boolean_literal(a > b)
1051    if isinstance(expression, exp.GTE):
1052        return boolean_literal(a >= b)
1053    if isinstance(expression, exp.LT):
1054        return boolean_literal(a < b)
1055    if isinstance(expression, exp.LTE):
1056        return boolean_literal(a <= b)
1057    return None
1058
1059
1060def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1061    if isinstance(value, datetime.datetime):
1062        return value.date()
1063    if isinstance(value, datetime.date):
1064        return value
1065    try:
1066        return datetime.datetime.fromisoformat(value).date()
1067    except ValueError:
1068        return None
1069
1070
1071def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1072    if isinstance(value, datetime.datetime):
1073        return value
1074    if isinstance(value, datetime.date):
1075        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1076    try:
1077        return datetime.datetime.fromisoformat(value)
1078    except ValueError:
1079        return None
1080
1081
1082def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1083    if not value:
1084        return None
1085    if to.is_type(exp.DataType.Type.DATE):
1086        return cast_as_date(value)
1087    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1088        return cast_as_datetime(value)
1089    return None
1090
1091
1092def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1093    if isinstance(cast, exp.Cast):
1094        to = cast.to
1095    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1096        to = exp.DataType.build(exp.DataType.Type.DATE)
1097    else:
1098        return None
1099
1100    if isinstance(cast.this, exp.Literal):
1101        value: t.Any = cast.this.name
1102    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1103        value = extract_date(cast.this)
1104    else:
1105        return None
1106    return cast_value(value, to)
1107
1108
1109def _is_date_literal(expression: exp.Expression) -> bool:
1110    return extract_date(expression) is not None
1111
1112
1113def extract_interval(expression):
1114    try:
1115        n = int(expression.name)
1116        unit = expression.text("unit").lower()
1117        return interval(unit, n)
1118    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1119        return None
1120
1121
1122def extract_type(*expressions):
1123    target_type = None
1124    for expression in expressions:
1125        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1126        if target_type:
1127            break
1128
1129    return target_type
1130
1131
1132def date_literal(date, target_type=None):
1133    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1134        target_type = (
1135            exp.DataType.Type.DATETIME
1136            if isinstance(date, datetime.datetime)
1137            else exp.DataType.Type.DATE
1138        )
1139
1140    return exp.cast(exp.Literal.string(date), target_type)
1141
1142
1143def interval(unit: str, n: int = 1):
1144    from dateutil.relativedelta import relativedelta
1145
1146    if unit == "year":
1147        return relativedelta(years=1 * n)
1148    if unit == "quarter":
1149        return relativedelta(months=3 * n)
1150    if unit == "month":
1151        return relativedelta(months=1 * n)
1152    if unit == "week":
1153        return relativedelta(weeks=1 * n)
1154    if unit == "day":
1155        return relativedelta(days=1 * n)
1156    if unit == "hour":
1157        return relativedelta(hours=1 * n)
1158    if unit == "minute":
1159        return relativedelta(minutes=1 * n)
1160    if unit == "second":
1161        return relativedelta(seconds=1 * n)
1162
1163    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1164
1165
1166def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1167    if unit == "year":
1168        return d.replace(month=1, day=1)
1169    if unit == "quarter":
1170        if d.month <= 3:
1171            return d.replace(month=1, day=1)
1172        elif d.month <= 6:
1173            return d.replace(month=4, day=1)
1174        elif d.month <= 9:
1175            return d.replace(month=7, day=1)
1176        else:
1177            return d.replace(month=10, day=1)
1178    if unit == "month":
1179        return d.replace(month=d.month, day=1)
1180    if unit == "week":
1181        # Assuming week starts on Monday (0) and ends on Sunday (6)
1182        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1183    if unit == "day":
1184        return d
1185
1186    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1187
1188
1189def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1190    floor = date_floor(d, unit, dialect)
1191
1192    if floor == d:
1193        return d
1194
1195    return floor + interval(unit)
1196
1197
1198def boolean_literal(condition):
1199    return exp.true() if condition else exp.false()
1200
1201
1202def _flat_simplify(expression, simplifier, root=True):
1203    if root or not expression.same_parent:
1204        operands = []
1205        queue = deque(expression.flatten(unnest=False))
1206        size = len(queue)
1207
1208        while queue:
1209            a = queue.popleft()
1210
1211            for b in queue:
1212                result = simplifier(expression, a, b)
1213
1214                if result and result is not expression:
1215                    queue.remove(b)
1216                    queue.appendleft(result)
1217                    break
1218            else:
1219                operands.append(a)
1220
1221        if len(operands) < size:
1222            return functools.reduce(
1223                lambda a, b: expression.__class__(this=a, expression=b), operands
1224            )
1225    return expression
1226
1227
1228def gen(expression: t.Any) -> str:
1229    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1230
1231    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1232    generator is expensive so we have a bare minimum sql generator here.
1233    """
1234    return Gen().gen(expression)
1235
1236
1237class Gen:
1238    def __init__(self):
1239        self.stack = []
1240        self.sqls = []
1241
1242    def gen(self, expression: exp.Expression) -> str:
1243        self.stack = [expression]
1244        self.sqls.clear()
1245
1246        while self.stack:
1247            node = self.stack.pop()
1248
1249            if isinstance(node, exp.Expression):
1250                exp_handler_name = f"{node.key}_sql"
1251
1252                if hasattr(self, exp_handler_name):
1253                    getattr(self, exp_handler_name)(node)
1254                elif isinstance(node, exp.Func):
1255                    self._function(node)
1256                else:
1257                    key = node.key.upper()
1258                    self.stack.append(f"{key} " if self._args(node) else key)
1259            elif type(node) is list:
1260                for n in reversed(node):
1261                    if n is not None:
1262                        self.stack.extend((n, ","))
1263                if node:
1264                    self.stack.pop()
1265            else:
1266                if node is not None:
1267                    self.sqls.append(str(node))
1268
1269        return "".join(self.sqls)
1270
1271    def add_sql(self, e: exp.Add) -> None:
1272        self._binary(e, " + ")
1273
1274    def alias_sql(self, e: exp.Alias) -> None:
1275        self.stack.extend(
1276            (
1277                e.args.get("alias"),
1278                " AS ",
1279                e.args.get("this"),
1280            )
1281        )
1282
1283    def and_sql(self, e: exp.And) -> None:
1284        self._binary(e, " AND ")
1285
1286    def anonymous_sql(self, e: exp.Anonymous) -> None:
1287        this = e.this
1288        if isinstance(this, str):
1289            name = this.upper()
1290        elif isinstance(this, exp.Identifier):
1291            name = this.this
1292            name = f'"{name}"' if this.quoted else name.upper()
1293        else:
1294            raise ValueError(
1295                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1296            )
1297
1298        self.stack.extend(
1299            (
1300                ")",
1301                e.expressions,
1302                "(",
1303                name,
1304            )
1305        )
1306
1307    def between_sql(self, e: exp.Between) -> None:
1308        self.stack.extend(
1309            (
1310                e.args.get("high"),
1311                " AND ",
1312                e.args.get("low"),
1313                " BETWEEN ",
1314                e.this,
1315            )
1316        )
1317
1318    def boolean_sql(self, e: exp.Boolean) -> None:
1319        self.stack.append("TRUE" if e.this else "FALSE")
1320
1321    def bracket_sql(self, e: exp.Bracket) -> None:
1322        self.stack.extend(
1323            (
1324                "]",
1325                e.expressions,
1326                "[",
1327                e.this,
1328            )
1329        )
1330
1331    def column_sql(self, e: exp.Column) -> None:
1332        for p in reversed(e.parts):
1333            self.stack.extend((p, "."))
1334        self.stack.pop()
1335
1336    def datatype_sql(self, e: exp.DataType) -> None:
1337        self._args(e, 1)
1338        self.stack.append(f"{e.this.name} ")
1339
1340    def div_sql(self, e: exp.Div) -> None:
1341        self._binary(e, " / ")
1342
1343    def dot_sql(self, e: exp.Dot) -> None:
1344        self._binary(e, ".")
1345
1346    def eq_sql(self, e: exp.EQ) -> None:
1347        self._binary(e, " = ")
1348
1349    def from_sql(self, e: exp.From) -> None:
1350        self.stack.extend((e.this, "FROM "))
1351
1352    def gt_sql(self, e: exp.GT) -> None:
1353        self._binary(e, " > ")
1354
1355    def gte_sql(self, e: exp.GTE) -> None:
1356        self._binary(e, " >= ")
1357
1358    def identifier_sql(self, e: exp.Identifier) -> None:
1359        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1360
1361    def ilike_sql(self, e: exp.ILike) -> None:
1362        self._binary(e, " ILIKE ")
1363
1364    def in_sql(self, e: exp.In) -> None:
1365        self.stack.append(")")
1366        self._args(e, 1)
1367        self.stack.extend(
1368            (
1369                "(",
1370                " IN ",
1371                e.this,
1372            )
1373        )
1374
1375    def intdiv_sql(self, e: exp.IntDiv) -> None:
1376        self._binary(e, " DIV ")
1377
1378    def is_sql(self, e: exp.Is) -> None:
1379        self._binary(e, " IS ")
1380
1381    def like_sql(self, e: exp.Like) -> None:
1382        self._binary(e, " Like ")
1383
1384    def literal_sql(self, e: exp.Literal) -> None:
1385        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1386
1387    def lt_sql(self, e: exp.LT) -> None:
1388        self._binary(e, " < ")
1389
1390    def lte_sql(self, e: exp.LTE) -> None:
1391        self._binary(e, " <= ")
1392
1393    def mod_sql(self, e: exp.Mod) -> None:
1394        self._binary(e, " % ")
1395
1396    def mul_sql(self, e: exp.Mul) -> None:
1397        self._binary(e, " * ")
1398
1399    def neg_sql(self, e: exp.Neg) -> None:
1400        self._unary(e, "-")
1401
1402    def neq_sql(self, e: exp.NEQ) -> None:
1403        self._binary(e, " <> ")
1404
1405    def not_sql(self, e: exp.Not) -> None:
1406        self._unary(e, "NOT ")
1407
1408    def null_sql(self, e: exp.Null) -> None:
1409        self.stack.append("NULL")
1410
1411    def or_sql(self, e: exp.Or) -> None:
1412        self._binary(e, " OR ")
1413
1414    def paren_sql(self, e: exp.Paren) -> None:
1415        self.stack.extend(
1416            (
1417                ")",
1418                e.this,
1419                "(",
1420            )
1421        )
1422
1423    def sub_sql(self, e: exp.Sub) -> None:
1424        self._binary(e, " - ")
1425
1426    def subquery_sql(self, e: exp.Subquery) -> None:
1427        self._args(e, 2)
1428        alias = e.args.get("alias")
1429        if alias:
1430            self.stack.append(alias)
1431        self.stack.extend((")", e.this, "("))
1432
1433    def table_sql(self, e: exp.Table) -> None:
1434        self._args(e, 4)
1435        alias = e.args.get("alias")
1436        if alias:
1437            self.stack.append(alias)
1438        for p in reversed(e.parts):
1439            self.stack.extend((p, "."))
1440        self.stack.pop()
1441
1442    def tablealias_sql(self, e: exp.TableAlias) -> None:
1443        columns = e.columns
1444
1445        if columns:
1446            self.stack.extend((")", columns, "("))
1447
1448        self.stack.extend((e.this, " AS "))
1449
1450    def var_sql(self, e: exp.Var) -> None:
1451        self.stack.append(e.this)
1452
1453    def _binary(self, e: exp.Binary, op: str) -> None:
1454        self.stack.extend((e.expression, op, e.this))
1455
1456    def _unary(self, e: exp.Unary, op: str) -> None:
1457        self.stack.extend((e.this, op))
1458
1459    def _function(self, e: exp.Func) -> None:
1460        self.stack.extend(
1461            (
1462                ")",
1463                list(e.args.values()),
1464                "(",
1465                e.sql_name(),
1466            )
1467        )
1468
1469    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1470        kvs = []
1471        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1472
1473        for k in arg_types or arg_types:
1474            v = node.args.get(k)
1475
1476            if v is not None:
1477                kvs.append([f":{k}", v])
1478        if kvs:
1479            self.stack.append(kvs)
1480            return True
1481        return False
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
33class UnsupportedUnit(Exception):
34    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 37def simplify(
 38    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 39):
 40    """
 41    Rewrite sqlglot AST to simplify expressions.
 42
 43    Example:
 44        >>> import sqlglot
 45        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 46        >>> simplify(expression).sql()
 47        'TRUE'
 48
 49    Args:
 50        expression (sqlglot.Expression): expression to simplify
 51        constant_propagation: whether the constant propagation rule should be used
 52
 53    Returns:
 54        sqlglot.Expression: simplified expression
 55    """
 56
 57    dialect = Dialect.get_or_raise(dialect)
 58
 59    def _simplify(expression, root=True):
 60        if expression.meta.get(FINAL):
 61            return expression
 62
 63        # group by expressions cannot be simplified, for example
 64        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 65        # the projection must exactly match the group by key
 66        group = expression.args.get("group")
 67
 68        if group and hasattr(expression, "selects"):
 69            groups = set(group.expressions)
 70            group.meta[FINAL] = True
 71
 72            for e in expression.selects:
 73                for node in e.walk():
 74                    if node in groups:
 75                        e.meta[FINAL] = True
 76                        break
 77
 78            having = expression.args.get("having")
 79            if having:
 80                for node in having.walk():
 81                    if node in groups:
 82                        having.meta[FINAL] = True
 83                        break
 84
 85        # Pre-order transformations
 86        node = expression
 87        node = rewrite_between(node)
 88        node = uniq_sort(node, root)
 89        node = absorb_and_eliminate(node, root)
 90        node = simplify_concat(node)
 91        node = simplify_conditionals(node)
 92
 93        if constant_propagation:
 94            node = propagate_constants(node, root)
 95
 96        exp.replace_children(node, lambda e: _simplify(e, False))
 97
 98        # Post-order transformations
 99        node = simplify_not(node)
100        node = flatten(node)
101        node = simplify_connectors(node, root)
102        node = remove_complements(node, root)
103        node = simplify_coalesce(node)
104        node.parent = expression.parent
105        node = simplify_literals(node, root)
106        node = simplify_equality(node)
107        node = simplify_parens(node)
108        node = simplify_datetrunc(node, dialect)
109        node = sort_comparison(node)
110        node = simplify_startswith(node)
111
112        if root:
113            expression.replace(node)
114        return node
115
116    expression = while_changing(expression, _simplify)
117    remove_where_true(expression)
118    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
121def catch(*exceptions):
122    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
123
124    def decorator(func):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression
130
131        return wrapped
132
133    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
136def rewrite_between(expression: exp.Expression) -> exp.Expression:
137    """Rewrite x between y and z to x >= y AND x <= z.
138
139    This is done because comparison simplification is only done on lt/lte/gt/gte.
140    """
141    if isinstance(expression, exp.Between):
142        negate = isinstance(expression.parent, exp.Not)
143
144        expression = exp.and_(
145            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
146            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
147            copy=False,
148        )
149
150        if negate:
151            expression = exp.paren(expression, copy=False)
152
153    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
166def simplify_not(expression):
167    """
168    Demorgan's Law
169    NOT (x OR y) -> NOT x AND NOT y
170    NOT (x AND y) -> NOT x OR NOT y
171    """
172    if isinstance(expression, exp.Not):
173        this = expression.this
174        if is_null(this):
175            return exp.null()
176        if this.__class__ in COMPLEMENT_COMPARISONS:
177            return COMPLEMENT_COMPARISONS[this.__class__](
178                this=this.this, expression=this.expression
179            )
180        if isinstance(this, exp.Paren):
181            condition = this.unnest()
182            if isinstance(condition, exp.And):
183                return exp.paren(
184                    exp.or_(
185                        exp.not_(condition.left, copy=False),
186                        exp.not_(condition.right, copy=False),
187                        copy=False,
188                    )
189                )
190            if isinstance(condition, exp.Or):
191                return exp.paren(
192                    exp.and_(
193                        exp.not_(condition.left, copy=False),
194                        exp.not_(condition.right, copy=False),
195                        copy=False,
196                    )
197                )
198            if is_null(condition):
199                return exp.null()
200        if always_true(this):
201            return exp.false()
202        if is_false(this):
203            return exp.true()
204        if isinstance(this, exp.Not):
205            # double negation
206            # NOT NOT x -> x
207            return this.this
208    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
211def flatten(expression):
212    """
213    A AND (B AND C) -> A AND B AND C
214    A OR (B OR C) -> A OR B OR C
215    """
216    if isinstance(expression, exp.Connector):
217        for node in expression.args.values():
218            child = node.unnest()
219            if isinstance(child, expression.__class__):
220                node.replace(child)
221    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
224def simplify_connectors(expression, root=True):
225    def _simplify_connectors(expression, left, right):
226        if left == right:
227            return left
228        if isinstance(expression, exp.And):
229            if is_false(left) or is_false(right):
230                return exp.false()
231            if is_null(left) or is_null(right):
232                return exp.null()
233            if always_true(left) and always_true(right):
234                return exp.true()
235            if always_true(left):
236                return right
237            if always_true(right):
238                return left
239            return _simplify_comparison(expression, left, right)
240        elif isinstance(expression, exp.Or):
241            if always_true(left) or always_true(right):
242                return exp.true()
243            if is_false(left) and is_false(right):
244                return exp.false()
245            if (
246                (is_null(left) and is_null(right))
247                or (is_null(left) and is_false(right))
248                or (is_false(left) and is_null(right))
249            ):
250                return exp.null()
251            if is_false(left):
252                return right
253            if is_false(right):
254                return left
255            return _simplify_comparison(expression, left, right, or_=True)
256
257    if isinstance(expression, exp.Connector):
258        return _flat_simplify(expression, _simplify_connectors, root)
259    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
345def remove_complements(expression, root=True):
346    """
347    Removing complements.
348
349    A AND NOT A -> FALSE
350    A OR NOT A -> TRUE
351    """
352    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
353        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
354
355        for a, b in itertools.permutations(expression.flatten(), 2):
356            if is_complement(a, b):
357                return complement
358    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
361def uniq_sort(expression, root=True):
362    """
363    Uniq and sort a connector.
364
365    C AND A AND B AND B -> A AND B AND C
366    """
367    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
368        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
369        flattened = tuple(expression.flatten())
370        deduped = {gen(e): e for e in flattened}
371        arr = tuple(deduped.items())
372
373        # check if the operands are already sorted, if not sort them
374        # A AND C AND B -> A AND B AND C
375        for i, (sql, e) in enumerate(arr[1:]):
376            if sql < arr[i][0]:
377                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
378                break
379        else:
380            # we didn't have to sort but maybe we need to dedup
381            if len(deduped) < len(flattened):
382                expression = result_func(*deduped.values(), copy=False)
383
384    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
387def absorb_and_eliminate(expression, root=True):
388    """
389    absorption:
390        A AND (A OR B) -> A
391        A OR (A AND B) -> A
392        A AND (NOT A OR B) -> A AND B
393        A OR (NOT A AND B) -> A OR B
394    elimination:
395        (A AND B) OR (A AND NOT B) -> A
396        (A OR B) AND (A OR NOT B) -> A
397    """
398    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
399        kind = exp.Or if isinstance(expression, exp.And) else exp.And
400
401        for a, b in itertools.permutations(expression.flatten(), 2):
402            if isinstance(a, kind):
403                aa, ab = a.unnest_operands()
404
405                # absorb
406                if is_complement(b, aa):
407                    aa.replace(exp.true() if kind == exp.And else exp.false())
408                elif is_complement(b, ab):
409                    ab.replace(exp.true() if kind == exp.And else exp.false())
410                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
411                    a.replace(exp.false() if kind == exp.And else exp.true())
412                elif isinstance(b, kind):
413                    # eliminate
414                    rhs = b.unnest_operands()
415                    ba, bb = rhs
416
417                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
418                        a.replace(aa)
419                        b.replace(aa)
420                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
421                        a.replace(ab)
422                        b.replace(ab)
423
424    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
427def propagate_constants(expression, root=True):
428    """
429    Propagate constants for conjunctions in DNF:
430
431    SELECT * FROM t WHERE a = b AND b = 5 becomes
432    SELECT * FROM t WHERE a = 5 AND b = 5
433
434    Reference: https://www.sqlite.org/optoverview.html
435    """
436
437    if (
438        isinstance(expression, exp.And)
439        and (root or not expression.same_parent)
440        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
441    ):
442        constant_mapping = {}
443        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
444            if isinstance(expr, exp.EQ):
445                l, r = expr.left, expr.right
446
447                # TODO: create a helper that can be used to detect nested literal expressions such
448                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
449                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
450                    constant_mapping[l] = (id(l), r)
451
452        if constant_mapping:
453            for column in find_all_in_scope(expression, exp.Column):
454                parent = column.parent
455                column_id, constant = constant_mapping.get(column) or (None, None)
456                if (
457                    column_id is not None
458                    and id(column) != column_id
459                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
460                ):
461                    column.replace(constant.copy())
462
463    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
538def simplify_literals(expression, root=True):
539    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
540        return _flat_simplify(expression, _simplify_binary, root)
541
542    if isinstance(expression, exp.Neg):
543        this = expression.this
544        if this.is_number:
545            value = this.name
546            if value[0] == "-":
547                return exp.Literal.number(value[1:])
548            return exp.Literal.number(f"-{value}")
549
550    if type(expression) in INVERSE_DATE_OPS:
551        return _simplify_binary(expression, expression.this, expression.interval()) or expression
552
553    return expression
def simplify_parens(expression):
654def simplify_parens(expression):
655    if not isinstance(expression, exp.Paren):
656        return expression
657
658    this = expression.this
659    parent = expression.parent
660    parent_is_predicate = isinstance(parent, exp.Predicate)
661
662    if not isinstance(this, exp.Select) and (
663        not isinstance(parent, (exp.Condition, exp.Binary))
664        or isinstance(parent, exp.Paren)
665        or (
666            not isinstance(this, exp.Binary)
667            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
668        )
669        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
670        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
671        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
672        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
673    ):
674        return this
675    return expression
def simplify_coalesce(expression):
686def simplify_coalesce(expression):
687    # COALESCE(x) -> x
688    if (
689        isinstance(expression, exp.Coalesce)
690        and (not expression.expressions or _is_nonnull_constant(expression.this))
691        # COALESCE is also used as a Spark partitioning hint
692        and not isinstance(expression.parent, exp.Hint)
693    ):
694        return expression.this
695
696    if not isinstance(expression, COMPARISONS):
697        return expression
698
699    if isinstance(expression.left, exp.Coalesce):
700        coalesce = expression.left
701        other = expression.right
702    elif isinstance(expression.right, exp.Coalesce):
703        coalesce = expression.right
704        other = expression.left
705    else:
706        return expression
707
708    # This transformation is valid for non-constants,
709    # but it really only does anything if they are both constants.
710    if not _is_constant(other):
711        return expression
712
713    # Find the first constant arg
714    for arg_index, arg in enumerate(coalesce.expressions):
715        if _is_constant(arg):
716            break
717    else:
718        return expression
719
720    coalesce.set("expressions", coalesce.expressions[:arg_index])
721
722    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
723    # since we already remove COALESCE at the top of this function.
724    coalesce = coalesce if coalesce.expressions else coalesce.this
725
726    # This expression is more complex than when we started, but it will get simplified further
727    return exp.paren(
728        exp.or_(
729            exp.and_(
730                coalesce.is_(exp.null()).not_(copy=False),
731                expression.copy(),
732                copy=False,
733            ),
734            exp.and_(
735                coalesce.is_(exp.null()),
736                type(expression)(this=arg.copy(), expression=other.copy()),
737                copy=False,
738            ),
739            copy=False,
740        )
741    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
747def simplify_concat(expression):
748    """Reduces all groups that contain string literals by concatenating them."""
749    if not isinstance(expression, CONCATS) or (
750        # We can't reduce a CONCAT_WS call if we don't statically know the separator
751        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
752    ):
753        return expression
754
755    if isinstance(expression, exp.ConcatWs):
756        sep_expr, *expressions = expression.expressions
757        sep = sep_expr.name
758        concat_type = exp.ConcatWs
759        args = {}
760    else:
761        expressions = expression.expressions
762        sep = ""
763        concat_type = exp.Concat
764        args = {
765            "safe": expression.args.get("safe"),
766            "coalesce": expression.args.get("coalesce"),
767        }
768
769    new_args = []
770    for is_string_group, group in itertools.groupby(
771        expressions or expression.flatten(), lambda e: e.is_string
772    ):
773        if is_string_group:
774            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
775        else:
776            new_args.extend(group)
777
778    if len(new_args) == 1 and new_args[0].is_string:
779        return new_args[0]
780
781    if concat_type is exp.ConcatWs:
782        new_args = [sep_expr] + new_args
783
784    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
787def simplify_conditionals(expression):
788    """Simplifies expressions like IF, CASE if their condition is statically known."""
789    if isinstance(expression, exp.Case):
790        this = expression.this
791        for case in expression.args["ifs"]:
792            cond = case.this
793            if this:
794                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
795                cond = cond.replace(this.pop().eq(cond))
796
797            if always_true(cond):
798                return case.args["true"]
799
800            if always_false(cond):
801                case.pop()
802                if not expression.args["ifs"]:
803                    return expression.args.get("default") or exp.null()
804    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
805        if always_true(expression.this):
806            return expression.args["true"]
807        if always_false(expression.this):
808            return expression.args.get("false") or exp.null()
809
810    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
813def simplify_startswith(expression: exp.Expression) -> exp.Expression:
814    """
815    Reduces a prefix check to either TRUE or FALSE if both the string and the
816    prefix are statically known.
817
818    Example:
819        >>> from sqlglot import parse_one
820        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
821        'TRUE'
822    """
823    if (
824        isinstance(expression, exp.StartsWith)
825        and expression.this.is_string
826        and expression.expression.is_string
827    ):
828        return exp.convert(expression.name.startswith(expression.expression.name))
829
830    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>}
def simplify_datetrunc(expression, *args, **kwargs):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
979def sort_comparison(expression: exp.Expression) -> exp.Expression:
980    if expression.__class__ in COMPLEMENT_COMPARISONS:
981        l, r = expression.this, expression.expression
982        l_column = isinstance(l, exp.Column)
983        r_column = isinstance(r, exp.Column)
984        l_const = _is_constant(l)
985        r_const = _is_constant(r)
986
987        if (l_column and not r_column) or (r_const and not l_const):
988            return expression
989        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
990            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
991                this=r, expression=l
992            )
993    return expression
JOINS = {('RIGHT', 'OUTER'), ('', 'INNER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
1007def remove_where_true(expression):
1008    for where in expression.find_all(exp.Where):
1009        if always_true(where.this):
1010            where.pop()
1011    for join in expression.find_all(exp.Join):
1012        if (
1013            always_true(join.args.get("on"))
1014            and not join.args.get("using")
1015            and not join.args.get("method")
1016            and (join.side, join.kind) in JOINS
1017        ):
1018            join.args["on"].pop()
1019            join.set("side", None)
1020            join.set("kind", "CROSS")
def always_true(expression):
1023def always_true(expression):
1024    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1025        expression, exp.Literal
1026    )
def always_false(expression):
1029def always_false(expression):
1030    return is_false(expression) or is_null(expression)
def is_complement(a, b):
1033def is_complement(a, b):
1034    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1037def is_false(a: exp.Expression) -> bool:
1038    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1041def is_null(a: exp.Expression) -> bool:
1042    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1045def eval_boolean(expression, a, b):
1046    if isinstance(expression, (exp.EQ, exp.Is)):
1047        return boolean_literal(a == b)
1048    if isinstance(expression, exp.NEQ):
1049        return boolean_literal(a != b)
1050    if isinstance(expression, exp.GT):
1051        return boolean_literal(a > b)
1052    if isinstance(expression, exp.GTE):
1053        return boolean_literal(a >= b)
1054    if isinstance(expression, exp.LT):
1055        return boolean_literal(a < b)
1056    if isinstance(expression, exp.LTE):
1057        return boolean_literal(a <= b)
1058    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1061def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1062    if isinstance(value, datetime.datetime):
1063        return value.date()
1064    if isinstance(value, datetime.date):
1065        return value
1066    try:
1067        return datetime.datetime.fromisoformat(value).date()
1068    except ValueError:
1069        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1072def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1073    if isinstance(value, datetime.datetime):
1074        return value
1075    if isinstance(value, datetime.date):
1076        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1077    try:
1078        return datetime.datetime.fromisoformat(value)
1079    except ValueError:
1080        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1083def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1084    if not value:
1085        return None
1086    if to.is_type(exp.DataType.Type.DATE):
1087        return cast_as_date(value)
1088    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1089        return cast_as_datetime(value)
1090    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1093def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1094    if isinstance(cast, exp.Cast):
1095        to = cast.to
1096    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1097        to = exp.DataType.build(exp.DataType.Type.DATE)
1098    else:
1099        return None
1100
1101    if isinstance(cast.this, exp.Literal):
1102        value: t.Any = cast.this.name
1103    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1104        value = extract_date(cast.this)
1105    else:
1106        return None
1107    return cast_value(value, to)
def extract_interval(expression):
1114def extract_interval(expression):
1115    try:
1116        n = int(expression.name)
1117        unit = expression.text("unit").lower()
1118        return interval(unit, n)
1119    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1120        return None
def extract_type(*expressions):
1123def extract_type(*expressions):
1124    target_type = None
1125    for expression in expressions:
1126        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1127        if target_type:
1128            break
1129
1130    return target_type
def date_literal(date, target_type=None):
1133def date_literal(date, target_type=None):
1134    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1135        target_type = (
1136            exp.DataType.Type.DATETIME
1137            if isinstance(date, datetime.datetime)
1138            else exp.DataType.Type.DATE
1139        )
1140
1141    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1144def interval(unit: str, n: int = 1):
1145    from dateutil.relativedelta import relativedelta
1146
1147    if unit == "year":
1148        return relativedelta(years=1 * n)
1149    if unit == "quarter":
1150        return relativedelta(months=3 * n)
1151    if unit == "month":
1152        return relativedelta(months=1 * n)
1153    if unit == "week":
1154        return relativedelta(weeks=1 * n)
1155    if unit == "day":
1156        return relativedelta(days=1 * n)
1157    if unit == "hour":
1158        return relativedelta(hours=1 * n)
1159    if unit == "minute":
1160        return relativedelta(minutes=1 * n)
1161    if unit == "second":
1162        return relativedelta(seconds=1 * n)
1163
1164    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1167def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1168    if unit == "year":
1169        return d.replace(month=1, day=1)
1170    if unit == "quarter":
1171        if d.month <= 3:
1172            return d.replace(month=1, day=1)
1173        elif d.month <= 6:
1174            return d.replace(month=4, day=1)
1175        elif d.month <= 9:
1176            return d.replace(month=7, day=1)
1177        else:
1178            return d.replace(month=10, day=1)
1179    if unit == "month":
1180        return d.replace(month=d.month, day=1)
1181    if unit == "week":
1182        # Assuming week starts on Monday (0) and ends on Sunday (6)
1183        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1184    if unit == "day":
1185        return d
1186
1187    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1190def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1191    floor = date_floor(d, unit, dialect)
1192
1193    if floor == d:
1194        return d
1195
1196    return floor + interval(unit)
def boolean_literal(condition):
1199def boolean_literal(condition):
1200    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1229def gen(expression: t.Any) -> str:
1230    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1231
1232    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1233    generator is expensive so we have a bare minimum sql generator here.
1234    """
1235    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1238class Gen:
1239    def __init__(self):
1240        self.stack = []
1241        self.sqls = []
1242
1243    def gen(self, expression: exp.Expression) -> str:
1244        self.stack = [expression]
1245        self.sqls.clear()
1246
1247        while self.stack:
1248            node = self.stack.pop()
1249
1250            if isinstance(node, exp.Expression):
1251                exp_handler_name = f"{node.key}_sql"
1252
1253                if hasattr(self, exp_handler_name):
1254                    getattr(self, exp_handler_name)(node)
1255                elif isinstance(node, exp.Func):
1256                    self._function(node)
1257                else:
1258                    key = node.key.upper()
1259                    self.stack.append(f"{key} " if self._args(node) else key)
1260            elif type(node) is list:
1261                for n in reversed(node):
1262                    if n is not None:
1263                        self.stack.extend((n, ","))
1264                if node:
1265                    self.stack.pop()
1266            else:
1267                if node is not None:
1268                    self.sqls.append(str(node))
1269
1270        return "".join(self.sqls)
1271
1272    def add_sql(self, e: exp.Add) -> None:
1273        self._binary(e, " + ")
1274
1275    def alias_sql(self, e: exp.Alias) -> None:
1276        self.stack.extend(
1277            (
1278                e.args.get("alias"),
1279                " AS ",
1280                e.args.get("this"),
1281            )
1282        )
1283
1284    def and_sql(self, e: exp.And) -> None:
1285        self._binary(e, " AND ")
1286
1287    def anonymous_sql(self, e: exp.Anonymous) -> None:
1288        this = e.this
1289        if isinstance(this, str):
1290            name = this.upper()
1291        elif isinstance(this, exp.Identifier):
1292            name = this.this
1293            name = f'"{name}"' if this.quoted else name.upper()
1294        else:
1295            raise ValueError(
1296                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1297            )
1298
1299        self.stack.extend(
1300            (
1301                ")",
1302                e.expressions,
1303                "(",
1304                name,
1305            )
1306        )
1307
1308    def between_sql(self, e: exp.Between) -> None:
1309        self.stack.extend(
1310            (
1311                e.args.get("high"),
1312                " AND ",
1313                e.args.get("low"),
1314                " BETWEEN ",
1315                e.this,
1316            )
1317        )
1318
1319    def boolean_sql(self, e: exp.Boolean) -> None:
1320        self.stack.append("TRUE" if e.this else "FALSE")
1321
1322    def bracket_sql(self, e: exp.Bracket) -> None:
1323        self.stack.extend(
1324            (
1325                "]",
1326                e.expressions,
1327                "[",
1328                e.this,
1329            )
1330        )
1331
1332    def column_sql(self, e: exp.Column) -> None:
1333        for p in reversed(e.parts):
1334            self.stack.extend((p, "."))
1335        self.stack.pop()
1336
1337    def datatype_sql(self, e: exp.DataType) -> None:
1338        self._args(e, 1)
1339        self.stack.append(f"{e.this.name} ")
1340
1341    def div_sql(self, e: exp.Div) -> None:
1342        self._binary(e, " / ")
1343
1344    def dot_sql(self, e: exp.Dot) -> None:
1345        self._binary(e, ".")
1346
1347    def eq_sql(self, e: exp.EQ) -> None:
1348        self._binary(e, " = ")
1349
1350    def from_sql(self, e: exp.From) -> None:
1351        self.stack.extend((e.this, "FROM "))
1352
1353    def gt_sql(self, e: exp.GT) -> None:
1354        self._binary(e, " > ")
1355
1356    def gte_sql(self, e: exp.GTE) -> None:
1357        self._binary(e, " >= ")
1358
1359    def identifier_sql(self, e: exp.Identifier) -> None:
1360        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1361
1362    def ilike_sql(self, e: exp.ILike) -> None:
1363        self._binary(e, " ILIKE ")
1364
1365    def in_sql(self, e: exp.In) -> None:
1366        self.stack.append(")")
1367        self._args(e, 1)
1368        self.stack.extend(
1369            (
1370                "(",
1371                " IN ",
1372                e.this,
1373            )
1374        )
1375
1376    def intdiv_sql(self, e: exp.IntDiv) -> None:
1377        self._binary(e, " DIV ")
1378
1379    def is_sql(self, e: exp.Is) -> None:
1380        self._binary(e, " IS ")
1381
1382    def like_sql(self, e: exp.Like) -> None:
1383        self._binary(e, " Like ")
1384
1385    def literal_sql(self, e: exp.Literal) -> None:
1386        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1387
1388    def lt_sql(self, e: exp.LT) -> None:
1389        self._binary(e, " < ")
1390
1391    def lte_sql(self, e: exp.LTE) -> None:
1392        self._binary(e, " <= ")
1393
1394    def mod_sql(self, e: exp.Mod) -> None:
1395        self._binary(e, " % ")
1396
1397    def mul_sql(self, e: exp.Mul) -> None:
1398        self._binary(e, " * ")
1399
1400    def neg_sql(self, e: exp.Neg) -> None:
1401        self._unary(e, "-")
1402
1403    def neq_sql(self, e: exp.NEQ) -> None:
1404        self._binary(e, " <> ")
1405
1406    def not_sql(self, e: exp.Not) -> None:
1407        self._unary(e, "NOT ")
1408
1409    def null_sql(self, e: exp.Null) -> None:
1410        self.stack.append("NULL")
1411
1412    def or_sql(self, e: exp.Or) -> None:
1413        self._binary(e, " OR ")
1414
1415    def paren_sql(self, e: exp.Paren) -> None:
1416        self.stack.extend(
1417            (
1418                ")",
1419                e.this,
1420                "(",
1421            )
1422        )
1423
1424    def sub_sql(self, e: exp.Sub) -> None:
1425        self._binary(e, " - ")
1426
1427    def subquery_sql(self, e: exp.Subquery) -> None:
1428        self._args(e, 2)
1429        alias = e.args.get("alias")
1430        if alias:
1431            self.stack.append(alias)
1432        self.stack.extend((")", e.this, "("))
1433
1434    def table_sql(self, e: exp.Table) -> None:
1435        self._args(e, 4)
1436        alias = e.args.get("alias")
1437        if alias:
1438            self.stack.append(alias)
1439        for p in reversed(e.parts):
1440            self.stack.extend((p, "."))
1441        self.stack.pop()
1442
1443    def tablealias_sql(self, e: exp.TableAlias) -> None:
1444        columns = e.columns
1445
1446        if columns:
1447            self.stack.extend((")", columns, "("))
1448
1449        self.stack.extend((e.this, " AS "))
1450
1451    def var_sql(self, e: exp.Var) -> None:
1452        self.stack.append(e.this)
1453
1454    def _binary(self, e: exp.Binary, op: str) -> None:
1455        self.stack.extend((e.expression, op, e.this))
1456
1457    def _unary(self, e: exp.Unary, op: str) -> None:
1458        self.stack.extend((e.this, op))
1459
1460    def _function(self, e: exp.Func) -> None:
1461        self.stack.extend(
1462            (
1463                ")",
1464                list(e.args.values()),
1465                "(",
1466                e.sql_name(),
1467            )
1468        )
1469
1470    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1471        kvs = []
1472        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1473
1474        for k in arg_types or arg_types:
1475            v = node.args.get(k)
1476
1477            if v is not None:
1478                kvs.append([f":{k}", v])
1479        if kvs:
1480            self.stack.append(kvs)
1481            return True
1482        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1243    def gen(self, expression: exp.Expression) -> str:
1244        self.stack = [expression]
1245        self.sqls.clear()
1246
1247        while self.stack:
1248            node = self.stack.pop()
1249
1250            if isinstance(node, exp.Expression):
1251                exp_handler_name = f"{node.key}_sql"
1252
1253                if hasattr(self, exp_handler_name):
1254                    getattr(self, exp_handler_name)(node)
1255                elif isinstance(node, exp.Func):
1256                    self._function(node)
1257                else:
1258                    key = node.key.upper()
1259                    self.stack.append(f"{key} " if self._args(node) else key)
1260            elif type(node) is list:
1261                for n in reversed(node):
1262                    if n is not None:
1263                        self.stack.extend((n, ","))
1264                if node:
1265                    self.stack.pop()
1266            else:
1267                if node is not None:
1268                    self.sqls.append(str(node))
1269
1270        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1272    def add_sql(self, e: exp.Add) -> None:
1273        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1275    def alias_sql(self, e: exp.Alias) -> None:
1276        self.stack.extend(
1277            (
1278                e.args.get("alias"),
1279                " AS ",
1280                e.args.get("this"),
1281            )
1282        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1284    def and_sql(self, e: exp.And) -> None:
1285        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1287    def anonymous_sql(self, e: exp.Anonymous) -> None:
1288        this = e.this
1289        if isinstance(this, str):
1290            name = this.upper()
1291        elif isinstance(this, exp.Identifier):
1292            name = this.this
1293            name = f'"{name}"' if this.quoted else name.upper()
1294        else:
1295            raise ValueError(
1296                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1297            )
1298
1299        self.stack.extend(
1300            (
1301                ")",
1302                e.expressions,
1303                "(",
1304                name,
1305            )
1306        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1308    def between_sql(self, e: exp.Between) -> None:
1309        self.stack.extend(
1310            (
1311                e.args.get("high"),
1312                " AND ",
1313                e.args.get("low"),
1314                " BETWEEN ",
1315                e.this,
1316            )
1317        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1319    def boolean_sql(self, e: exp.Boolean) -> None:
1320        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1322    def bracket_sql(self, e: exp.Bracket) -> None:
1323        self.stack.extend(
1324            (
1325                "]",
1326                e.expressions,
1327                "[",
1328                e.this,
1329            )
1330        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1332    def column_sql(self, e: exp.Column) -> None:
1333        for p in reversed(e.parts):
1334            self.stack.extend((p, "."))
1335        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1337    def datatype_sql(self, e: exp.DataType) -> None:
1338        self._args(e, 1)
1339        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1341    def div_sql(self, e: exp.Div) -> None:
1342        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1344    def dot_sql(self, e: exp.Dot) -> None:
1345        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1347    def eq_sql(self, e: exp.EQ) -> None:
1348        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1350    def from_sql(self, e: exp.From) -> None:
1351        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1353    def gt_sql(self, e: exp.GT) -> None:
1354        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1356    def gte_sql(self, e: exp.GTE) -> None:
1357        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1359    def identifier_sql(self, e: exp.Identifier) -> None:
1360        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1362    def ilike_sql(self, e: exp.ILike) -> None:
1363        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1365    def in_sql(self, e: exp.In) -> None:
1366        self.stack.append(")")
1367        self._args(e, 1)
1368        self.stack.extend(
1369            (
1370                "(",
1371                " IN ",
1372                e.this,
1373            )
1374        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1376    def intdiv_sql(self, e: exp.IntDiv) -> None:
1377        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1379    def is_sql(self, e: exp.Is) -> None:
1380        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1382    def like_sql(self, e: exp.Like) -> None:
1383        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1385    def literal_sql(self, e: exp.Literal) -> None:
1386        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1388    def lt_sql(self, e: exp.LT) -> None:
1389        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1391    def lte_sql(self, e: exp.LTE) -> None:
1392        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1394    def mod_sql(self, e: exp.Mod) -> None:
1395        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1397    def mul_sql(self, e: exp.Mul) -> None:
1398        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1400    def neg_sql(self, e: exp.Neg) -> None:
1401        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1403    def neq_sql(self, e: exp.NEQ) -> None:
1404        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1406    def not_sql(self, e: exp.Not) -> None:
1407        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1409    def null_sql(self, e: exp.Null) -> None:
1410        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1412    def or_sql(self, e: exp.Or) -> None:
1413        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1415    def paren_sql(self, e: exp.Paren) -> None:
1416        self.stack.extend(
1417            (
1418                ")",
1419                e.this,
1420                "(",
1421            )
1422        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1424    def sub_sql(self, e: exp.Sub) -> None:
1425        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1427    def subquery_sql(self, e: exp.Subquery) -> None:
1428        self._args(e, 2)
1429        alias = e.args.get("alias")
1430        if alias:
1431            self.stack.append(alias)
1432        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1434    def table_sql(self, e: exp.Table) -> None:
1435        self._args(e, 4)
1436        alias = e.args.get("alias")
1437        if alias:
1438            self.stack.append(alias)
1439        for p in reversed(e.parts):
1440            self.stack.extend((p, "."))
1441        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1443    def tablealias_sql(self, e: exp.TableAlias) -> None:
1444        columns = e.columns
1445
1446        if columns:
1447            self.stack.extend((")", columns, "("))
1448
1449        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1451    def var_sql(self, e: exp.Var) -> None:
1452        self.stack.append(e.this)