Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23logger = logging.getLogger("sqlglot")
  24
  25# Final means that an expression should not be simplified
  26FINAL = "final"
  27
  28# Value ranges for byte-sized signed/unsigned integers
  29TINYINT_MIN = -128
  30TINYINT_MAX = 127
  31UTINYINT_MIN = 0
  32UTINYINT_MAX = 255
  33
  34
  35class UnsupportedUnit(Exception):
  36    pass
  37
  38
  39def simplify(
  40    expression: exp.Expression,
  41    constant_propagation: bool = False,
  42    dialect: DialectType = None,
  43    max_depth: t.Optional[int] = None,
  44):
  45    """
  46    Rewrite sqlglot AST to simplify expressions.
  47
  48    Example:
  49        >>> import sqlglot
  50        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  51        >>> simplify(expression).sql()
  52        'TRUE'
  53
  54    Args:
  55        expression: expression to simplify
  56        constant_propagation: whether the constant propagation rule should be used
  57        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
  58    Returns:
  59        sqlglot.Expression: simplified expression
  60    """
  61
  62    dialect = Dialect.get_or_raise(dialect)
  63
  64    def _simplify(expression, root=True):
  65        if (
  66            max_depth
  67            and isinstance(expression, exp.Connector)
  68            and not isinstance(expression.parent, exp.Connector)
  69        ):
  70            depth = connector_depth(expression)
  71            if depth > max_depth:
  72                logger.info(
  73                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
  74                )
  75                return expression
  76
  77        if expression.meta.get(FINAL):
  78            return expression
  79
  80        # group by expressions cannot be simplified, for example
  81        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  82        # the projection must exactly match the group by key
  83        group = expression.args.get("group")
  84
  85        if group and hasattr(expression, "selects"):
  86            groups = set(group.expressions)
  87            group.meta[FINAL] = True
  88
  89            for e in expression.selects:
  90                for node in e.walk():
  91                    if node in groups:
  92                        e.meta[FINAL] = True
  93                        break
  94
  95            having = expression.args.get("having")
  96            if having:
  97                for node in having.walk():
  98                    if node in groups:
  99                        having.meta[FINAL] = True
 100                        break
 101
 102        # Pre-order transformations
 103        node = expression
 104        node = rewrite_between(node)
 105        node = uniq_sort(node, root)
 106        node = absorb_and_eliminate(node, root)
 107        node = simplify_concat(node)
 108        node = simplify_conditionals(node)
 109
 110        if constant_propagation:
 111            node = propagate_constants(node, root)
 112
 113        exp.replace_children(node, lambda e: _simplify(e, False))
 114
 115        # Post-order transformations
 116        node = simplify_not(node)
 117        node = flatten(node)
 118        node = simplify_connectors(node, root)
 119        node = remove_complements(node, root)
 120        node = simplify_coalesce(node)
 121        node.parent = expression.parent
 122        node = simplify_literals(node, root)
 123        node = simplify_equality(node)
 124        node = simplify_parens(node)
 125        node = simplify_datetrunc(node, dialect)
 126        node = sort_comparison(node)
 127        node = simplify_startswith(node)
 128
 129        if root:
 130            expression.replace(node)
 131        return node
 132
 133    expression = while_changing(expression, _simplify)
 134    remove_where_true(expression)
 135    return expression
 136
 137
 138def connector_depth(expression: exp.Expression) -> int:
 139    """
 140    Determine the maximum depth of a tree of Connectors.
 141
 142    For example:
 143        >>> from sqlglot import parse_one
 144        >>> connector_depth(parse_one("a AND b AND c AND d"))
 145        3
 146    """
 147    stack = deque([(expression, 0)])
 148    max_depth = 0
 149
 150    while stack:
 151        expression, depth = stack.pop()
 152
 153        if not isinstance(expression, exp.Connector):
 154            continue
 155
 156        depth += 1
 157        max_depth = max(depth, max_depth)
 158
 159        stack.append((expression.left, depth))
 160        stack.append((expression.right, depth))
 161
 162    return max_depth
 163
 164
 165def catch(*exceptions):
 166    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 167
 168    def decorator(func):
 169        def wrapped(expression, *args, **kwargs):
 170            try:
 171                return func(expression, *args, **kwargs)
 172            except exceptions:
 173                return expression
 174
 175        return wrapped
 176
 177    return decorator
 178
 179
 180def rewrite_between(expression: exp.Expression) -> exp.Expression:
 181    """Rewrite x between y and z to x >= y AND x <= z.
 182
 183    This is done because comparison simplification is only done on lt/lte/gt/gte.
 184    """
 185    if isinstance(expression, exp.Between):
 186        negate = isinstance(expression.parent, exp.Not)
 187
 188        expression = exp.and_(
 189            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 190            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 191            copy=False,
 192        )
 193
 194        if negate:
 195            expression = exp.paren(expression, copy=False)
 196
 197    return expression
 198
 199
 200COMPLEMENT_COMPARISONS = {
 201    exp.LT: exp.GTE,
 202    exp.GT: exp.LTE,
 203    exp.LTE: exp.GT,
 204    exp.GTE: exp.LT,
 205    exp.EQ: exp.NEQ,
 206    exp.NEQ: exp.EQ,
 207}
 208
 209
 210def simplify_not(expression):
 211    """
 212    Demorgan's Law
 213    NOT (x OR y) -> NOT x AND NOT y
 214    NOT (x AND y) -> NOT x OR NOT y
 215    """
 216    if isinstance(expression, exp.Not):
 217        this = expression.this
 218        if is_null(this):
 219            return exp.null()
 220        if this.__class__ in COMPLEMENT_COMPARISONS:
 221            return COMPLEMENT_COMPARISONS[this.__class__](
 222                this=this.this, expression=this.expression
 223            )
 224        if isinstance(this, exp.Paren):
 225            condition = this.unnest()
 226            if isinstance(condition, exp.And):
 227                return exp.paren(
 228                    exp.or_(
 229                        exp.not_(condition.left, copy=False),
 230                        exp.not_(condition.right, copy=False),
 231                        copy=False,
 232                    )
 233                )
 234            if isinstance(condition, exp.Or):
 235                return exp.paren(
 236                    exp.and_(
 237                        exp.not_(condition.left, copy=False),
 238                        exp.not_(condition.right, copy=False),
 239                        copy=False,
 240                    )
 241                )
 242            if is_null(condition):
 243                return exp.null()
 244        if always_true(this):
 245            return exp.false()
 246        if is_false(this):
 247            return exp.true()
 248        if isinstance(this, exp.Not):
 249            # double negation
 250            # NOT NOT x -> x
 251            return this.this
 252    return expression
 253
 254
 255def flatten(expression):
 256    """
 257    A AND (B AND C) -> A AND B AND C
 258    A OR (B OR C) -> A OR B OR C
 259    """
 260    if isinstance(expression, exp.Connector):
 261        for node in expression.args.values():
 262            child = node.unnest()
 263            if isinstance(child, expression.__class__):
 264                node.replace(child)
 265    return expression
 266
 267
 268def simplify_connectors(expression, root=True):
 269    def _simplify_connectors(expression, left, right):
 270        if isinstance(expression, exp.And):
 271            if is_false(left) or is_false(right):
 272                return exp.false()
 273            if is_zero(left) or is_zero(right):
 274                return exp.false()
 275            if is_null(left) or is_null(right):
 276                return exp.null()
 277            if always_true(left) and always_true(right):
 278                return exp.true()
 279            if always_true(left):
 280                return right
 281            if always_true(right):
 282                return left
 283            return _simplify_comparison(expression, left, right)
 284        elif isinstance(expression, exp.Or):
 285            if always_true(left) or always_true(right):
 286                return exp.true()
 287            if (
 288                (is_null(left) and is_null(right))
 289                or (is_null(left) and always_false(right))
 290                or (always_false(left) and is_null(right))
 291            ):
 292                return exp.null()
 293            if is_false(left):
 294                return right
 295            if is_false(right):
 296                return left
 297            return _simplify_comparison(expression, left, right, or_=True)
 298        elif isinstance(expression, exp.Xor):
 299            if left == right:
 300                return exp.false()
 301
 302    if isinstance(expression, exp.Connector):
 303        return _flat_simplify(expression, _simplify_connectors, root)
 304    return expression
 305
 306
 307LT_LTE = (exp.LT, exp.LTE)
 308GT_GTE = (exp.GT, exp.GTE)
 309
 310COMPARISONS = (
 311    *LT_LTE,
 312    *GT_GTE,
 313    exp.EQ,
 314    exp.NEQ,
 315    exp.Is,
 316)
 317
 318INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 319    exp.LT: exp.GT,
 320    exp.GT: exp.LT,
 321    exp.LTE: exp.GTE,
 322    exp.GTE: exp.LTE,
 323}
 324
 325NONDETERMINISTIC = (exp.Rand, exp.Randn)
 326AND_OR = (exp.And, exp.Or)
 327
 328
 329def _simplify_comparison(expression, left, right, or_=False):
 330    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 331        ll, lr = left.args.values()
 332        rl, rr = right.args.values()
 333
 334        largs = {ll, lr}
 335        rargs = {rl, rr}
 336
 337        matching = largs & rargs
 338        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 339
 340        if matching and columns:
 341            try:
 342                l = first(largs - columns)
 343                r = first(rargs - columns)
 344            except StopIteration:
 345                return expression
 346
 347            if l.is_number and r.is_number:
 348                l = l.to_py()
 349                r = r.to_py()
 350            elif l.is_string and r.is_string:
 351                l = l.name
 352                r = r.name
 353            else:
 354                l = extract_date(l)
 355                if not l:
 356                    return None
 357                r = extract_date(r)
 358                if not r:
 359                    return None
 360                # python won't compare date and datetime, but many engines will upcast
 361                l, r = cast_as_datetime(l), cast_as_datetime(r)
 362
 363            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 364                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 365                    return left if (av > bv if or_ else av <= bv) else right
 366                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 367                    return left if (av < bv if or_ else av >= bv) else right
 368
 369                # we can't ever shortcut to true because the column could be null
 370                if not or_:
 371                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 372                        if av <= bv:
 373                            return exp.false()
 374                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 375                        if av >= bv:
 376                            return exp.false()
 377                    elif isinstance(a, exp.EQ):
 378                        if isinstance(b, exp.LT):
 379                            return exp.false() if av >= bv else a
 380                        if isinstance(b, exp.LTE):
 381                            return exp.false() if av > bv else a
 382                        if isinstance(b, exp.GT):
 383                            return exp.false() if av <= bv else a
 384                        if isinstance(b, exp.GTE):
 385                            return exp.false() if av < bv else a
 386                        if isinstance(b, exp.NEQ):
 387                            return exp.false() if av == bv else a
 388    return None
 389
 390
 391def remove_complements(expression, root=True):
 392    """
 393    Removing complements.
 394
 395    A AND NOT A -> FALSE
 396    A OR NOT A -> TRUE
 397    """
 398    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 399        ops = set(expression.flatten())
 400        for op in ops:
 401            if isinstance(op, exp.Not) and op.this in ops:
 402                return exp.false() if isinstance(expression, exp.And) else exp.true()
 403
 404    return expression
 405
 406
 407def uniq_sort(expression, root=True):
 408    """
 409    Uniq and sort a connector.
 410
 411    C AND A AND B AND B -> A AND B AND C
 412    """
 413    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 414        flattened = tuple(expression.flatten())
 415
 416        if isinstance(expression, exp.Xor):
 417            result_func = exp.xor
 418            # Do not deduplicate XOR as A XOR A != A if A == True
 419            deduped = None
 420            arr = tuple((gen(e), e) for e in flattened)
 421        else:
 422            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 423            deduped = {gen(e): e for e in flattened}
 424            arr = tuple(deduped.items())
 425
 426        # check if the operands are already sorted, if not sort them
 427        # A AND C AND B -> A AND B AND C
 428        for i, (sql, e) in enumerate(arr[1:]):
 429            if sql < arr[i][0]:
 430                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 431                break
 432        else:
 433            # we didn't have to sort but maybe we need to dedup
 434            if deduped and len(deduped) < len(flattened):
 435                expression = result_func(*deduped.values(), copy=False)
 436
 437    return expression
 438
 439
 440def absorb_and_eliminate(expression, root=True):
 441    """
 442    absorption:
 443        A AND (A OR B) -> A
 444        A OR (A AND B) -> A
 445        A AND (NOT A OR B) -> A AND B
 446        A OR (NOT A AND B) -> A OR B
 447    elimination:
 448        (A AND B) OR (A AND NOT B) -> A
 449        (A OR B) AND (A OR NOT B) -> A
 450    """
 451    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 452        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 453
 454        ops = tuple(expression.flatten())
 455
 456        # Initialize lookup tables:
 457        # Set of all operands, used to find complements for absorption.
 458        op_set = set()
 459        # Sub-operands, used to find subsets for absorption.
 460        subops = defaultdict(list)
 461        # Pairs of complements, used for elimination.
 462        pairs = defaultdict(list)
 463
 464        # Populate the lookup tables
 465        for op in ops:
 466            op_set.add(op)
 467
 468            if not isinstance(op, kind):
 469                # In cases like: A OR (A AND B)
 470                # Subop will be: ^
 471                subops[op].append({op})
 472                continue
 473
 474            # In cases like: (A AND B) OR (A AND B AND C)
 475            # Subops will be: ^     ^
 476            subset = set(op.flatten())
 477            for i in subset:
 478                subops[i].append(subset)
 479
 480            a, b = op.unnest_operands()
 481            if isinstance(a, exp.Not):
 482                pairs[frozenset((a.this, b))].append((op, b))
 483            if isinstance(b, exp.Not):
 484                pairs[frozenset((a, b.this))].append((op, a))
 485
 486        for op in ops:
 487            if not isinstance(op, kind):
 488                continue
 489
 490            a, b = op.unnest_operands()
 491
 492            # Absorb
 493            if isinstance(a, exp.Not) and a.this in op_set:
 494                a.replace(exp.true() if kind == exp.And else exp.false())
 495                continue
 496            if isinstance(b, exp.Not) and b.this in op_set:
 497                b.replace(exp.true() if kind == exp.And else exp.false())
 498                continue
 499            superset = set(op.flatten())
 500            if any(any(subset < superset for subset in subops[i]) for i in superset):
 501                op.replace(exp.false() if kind == exp.And else exp.true())
 502                continue
 503
 504            # Eliminate
 505            for other, complement in pairs[frozenset((a, b))]:
 506                op.replace(complement)
 507                other.replace(complement)
 508
 509    return expression
 510
 511
 512def propagate_constants(expression, root=True):
 513    """
 514    Propagate constants for conjunctions in DNF:
 515
 516    SELECT * FROM t WHERE a = b AND b = 5 becomes
 517    SELECT * FROM t WHERE a = 5 AND b = 5
 518
 519    Reference: https://www.sqlite.org/optoverview.html
 520    """
 521
 522    if (
 523        isinstance(expression, exp.And)
 524        and (root or not expression.same_parent)
 525        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 526    ):
 527        constant_mapping = {}
 528        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 529            if isinstance(expr, exp.EQ):
 530                l, r = expr.left, expr.right
 531
 532                # TODO: create a helper that can be used to detect nested literal expressions such
 533                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 534                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 535                    constant_mapping[l] = (id(l), r)
 536
 537        if constant_mapping:
 538            for column in find_all_in_scope(expression, exp.Column):
 539                parent = column.parent
 540                column_id, constant = constant_mapping.get(column) or (None, None)
 541                if (
 542                    column_id is not None
 543                    and id(column) != column_id
 544                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 545                ):
 546                    column.replace(constant.copy())
 547
 548    return expression
 549
 550
 551INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 552    exp.DateAdd: exp.Sub,
 553    exp.DateSub: exp.Add,
 554    exp.DatetimeAdd: exp.Sub,
 555    exp.DatetimeSub: exp.Add,
 556}
 557
 558INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 559    **INVERSE_DATE_OPS,
 560    exp.Add: exp.Sub,
 561    exp.Sub: exp.Add,
 562}
 563
 564
 565def _is_number(expression: exp.Expression) -> bool:
 566    return expression.is_number
 567
 568
 569def _is_interval(expression: exp.Expression) -> bool:
 570    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 571
 572
 573@catch(ModuleNotFoundError, UnsupportedUnit)
 574def simplify_equality(expression: exp.Expression) -> exp.Expression:
 575    """
 576    Use the subtraction and addition properties of equality to simplify expressions:
 577
 578        x + 1 = 3 becomes x = 2
 579
 580    There are two binary operations in the above expression: + and =
 581    Here's how we reference all the operands in the code below:
 582
 583          l     r
 584        x + 1 = 3
 585        a   b
 586    """
 587    if isinstance(expression, COMPARISONS):
 588        l, r = expression.left, expression.right
 589
 590        if l.__class__ not in INVERSE_OPS:
 591            return expression
 592
 593        if r.is_number:
 594            a_predicate = _is_number
 595            b_predicate = _is_number
 596        elif _is_date_literal(r):
 597            a_predicate = _is_date_literal
 598            b_predicate = _is_interval
 599        else:
 600            return expression
 601
 602        if l.__class__ in INVERSE_DATE_OPS:
 603            l = t.cast(exp.IntervalOp, l)
 604            a = l.this
 605            b = l.interval()
 606        else:
 607            l = t.cast(exp.Binary, l)
 608            a, b = l.left, l.right
 609
 610        if not a_predicate(a) and b_predicate(b):
 611            pass
 612        elif not a_predicate(b) and b_predicate(a):
 613            a, b = b, a
 614        else:
 615            return expression
 616
 617        return expression.__class__(
 618            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 619        )
 620    return expression
 621
 622
 623def simplify_literals(expression, root=True):
 624    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 625        return _flat_simplify(expression, _simplify_binary, root)
 626
 627    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
 628        return expression.this.this
 629
 630    if type(expression) in INVERSE_DATE_OPS:
 631        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 632
 633    return expression
 634
 635
 636NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 637
 638
 639def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 640    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 641        this = _simplify_integer_cast(expr.this)
 642    else:
 643        this = expr.this
 644
 645    if isinstance(expr, exp.Cast) and this.is_int:
 646        num = this.to_py()
 647
 648        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 649        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 650        # engine-dependent
 651        if (
 652            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 653        ) or (
 654            UTINYINT_MIN <= num <= UTINYINT_MAX
 655            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 656        ):
 657            return this
 658
 659    return expr
 660
 661
 662def _simplify_binary(expression, a, b):
 663    if isinstance(expression, COMPARISONS):
 664        a = _simplify_integer_cast(a)
 665        b = _simplify_integer_cast(b)
 666
 667    if isinstance(expression, exp.Is):
 668        if isinstance(b, exp.Not):
 669            c = b.this
 670            not_ = True
 671        else:
 672            c = b
 673            not_ = False
 674
 675        if is_null(c):
 676            if isinstance(a, exp.Literal):
 677                return exp.true() if not_ else exp.false()
 678            if is_null(a):
 679                return exp.false() if not_ else exp.true()
 680    elif isinstance(expression, NULL_OK):
 681        return None
 682    elif is_null(a) or is_null(b):
 683        return exp.null()
 684
 685    if a.is_number and b.is_number:
 686        num_a = a.to_py()
 687        num_b = b.to_py()
 688
 689        if isinstance(expression, exp.Add):
 690            return exp.Literal.number(num_a + num_b)
 691        if isinstance(expression, exp.Mul):
 692            return exp.Literal.number(num_a * num_b)
 693
 694        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 695        if isinstance(expression, exp.Sub):
 696            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 697        if isinstance(expression, exp.Div):
 698            # engines have differing int div behavior so intdiv is not safe
 699            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 700                return None
 701            return exp.Literal.number(num_a / num_b)
 702
 703        boolean = eval_boolean(expression, num_a, num_b)
 704
 705        if boolean:
 706            return boolean
 707    elif a.is_string and b.is_string:
 708        boolean = eval_boolean(expression, a.this, b.this)
 709
 710        if boolean:
 711            return boolean
 712    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 713        date, b = extract_date(a), extract_interval(b)
 714        if date and b:
 715            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 716                return date_literal(date + b, extract_type(a))
 717            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 718                return date_literal(date - b, extract_type(a))
 719    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 720        a, date = extract_interval(a), extract_date(b)
 721        # you cannot subtract a date from an interval
 722        if a and b and isinstance(expression, exp.Add):
 723            return date_literal(a + date, extract_type(b))
 724    elif _is_date_literal(a) and _is_date_literal(b):
 725        if isinstance(expression, exp.Predicate):
 726            a, b = extract_date(a), extract_date(b)
 727            boolean = eval_boolean(expression, a, b)
 728            if boolean:
 729                return boolean
 730
 731    return None
 732
 733
 734def simplify_parens(expression):
 735    if not isinstance(expression, exp.Paren):
 736        return expression
 737
 738    this = expression.this
 739    parent = expression.parent
 740    parent_is_predicate = isinstance(parent, exp.Predicate)
 741
 742    if (
 743        not isinstance(this, exp.Select)
 744        and not isinstance(parent, exp.SubqueryPredicate)
 745        and (
 746            not isinstance(parent, (exp.Condition, exp.Binary))
 747            or isinstance(parent, exp.Paren)
 748            or (
 749                not isinstance(this, exp.Binary)
 750                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 751            )
 752            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 753            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 754            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 755            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 756        )
 757    ):
 758        return this
 759    return expression
 760
 761
 762def _is_nonnull_constant(expression: exp.Expression) -> bool:
 763    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 764
 765
 766def _is_constant(expression: exp.Expression) -> bool:
 767    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 768
 769
 770def simplify_coalesce(expression):
 771    # COALESCE(x) -> x
 772    if (
 773        isinstance(expression, exp.Coalesce)
 774        and (not expression.expressions or _is_nonnull_constant(expression.this))
 775        # COALESCE is also used as a Spark partitioning hint
 776        and not isinstance(expression.parent, exp.Hint)
 777    ):
 778        return expression.this
 779
 780    if not isinstance(expression, COMPARISONS):
 781        return expression
 782
 783    if isinstance(expression.left, exp.Coalesce):
 784        coalesce = expression.left
 785        other = expression.right
 786    elif isinstance(expression.right, exp.Coalesce):
 787        coalesce = expression.right
 788        other = expression.left
 789    else:
 790        return expression
 791
 792    # This transformation is valid for non-constants,
 793    # but it really only does anything if they are both constants.
 794    if not _is_constant(other):
 795        return expression
 796
 797    # Find the first constant arg
 798    for arg_index, arg in enumerate(coalesce.expressions):
 799        if _is_constant(arg):
 800            break
 801    else:
 802        return expression
 803
 804    coalesce.set("expressions", coalesce.expressions[:arg_index])
 805
 806    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 807    # since we already remove COALESCE at the top of this function.
 808    coalesce = coalesce if coalesce.expressions else coalesce.this
 809
 810    # This expression is more complex than when we started, but it will get simplified further
 811    return exp.paren(
 812        exp.or_(
 813            exp.and_(
 814                coalesce.is_(exp.null()).not_(copy=False),
 815                expression.copy(),
 816                copy=False,
 817            ),
 818            exp.and_(
 819                coalesce.is_(exp.null()),
 820                type(expression)(this=arg.copy(), expression=other.copy()),
 821                copy=False,
 822            ),
 823            copy=False,
 824        )
 825    )
 826
 827
 828CONCATS = (exp.Concat, exp.DPipe)
 829
 830
 831def simplify_concat(expression):
 832    """Reduces all groups that contain string literals by concatenating them."""
 833    if not isinstance(expression, CONCATS) or (
 834        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 835        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 836    ):
 837        return expression
 838
 839    if isinstance(expression, exp.ConcatWs):
 840        sep_expr, *expressions = expression.expressions
 841        sep = sep_expr.name
 842        concat_type = exp.ConcatWs
 843        args = {}
 844    else:
 845        expressions = expression.expressions
 846        sep = ""
 847        concat_type = exp.Concat
 848        args = {
 849            "safe": expression.args.get("safe"),
 850            "coalesce": expression.args.get("coalesce"),
 851        }
 852
 853    new_args = []
 854    for is_string_group, group in itertools.groupby(
 855        expressions or expression.flatten(), lambda e: e.is_string
 856    ):
 857        if is_string_group:
 858            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 859        else:
 860            new_args.extend(group)
 861
 862    if len(new_args) == 1 and new_args[0].is_string:
 863        return new_args[0]
 864
 865    if concat_type is exp.ConcatWs:
 866        new_args = [sep_expr] + new_args
 867    elif isinstance(expression, exp.DPipe):
 868        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 869
 870    return concat_type(expressions=new_args, **args)
 871
 872
 873def simplify_conditionals(expression):
 874    """Simplifies expressions like IF, CASE if their condition is statically known."""
 875    if isinstance(expression, exp.Case):
 876        this = expression.this
 877        for case in expression.args["ifs"]:
 878            cond = case.this
 879            if this:
 880                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 881                cond = cond.replace(this.pop().eq(cond))
 882
 883            if always_true(cond):
 884                return case.args["true"]
 885
 886            if always_false(cond):
 887                case.pop()
 888                if not expression.args["ifs"]:
 889                    return expression.args.get("default") or exp.null()
 890    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 891        if always_true(expression.this):
 892            return expression.args["true"]
 893        if always_false(expression.this):
 894            return expression.args.get("false") or exp.null()
 895
 896    return expression
 897
 898
 899def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 900    """
 901    Reduces a prefix check to either TRUE or FALSE if both the string and the
 902    prefix are statically known.
 903
 904    Example:
 905        >>> from sqlglot import parse_one
 906        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 907        'TRUE'
 908    """
 909    if (
 910        isinstance(expression, exp.StartsWith)
 911        and expression.this.is_string
 912        and expression.expression.is_string
 913    ):
 914        return exp.convert(expression.name.startswith(expression.expression.name))
 915
 916    return expression
 917
 918
 919DateRange = t.Tuple[datetime.date, datetime.date]
 920
 921
 922def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 923    """
 924    Get the date range for a DATE_TRUNC equality comparison:
 925
 926    Example:
 927        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 928    Returns:
 929        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 930    """
 931    floor = date_floor(date, unit, dialect)
 932
 933    if date != floor:
 934        # This will always be False, except for NULL values.
 935        return None
 936
 937    return floor, floor + interval(unit)
 938
 939
 940def _datetrunc_eq_expression(
 941    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 942) -> exp.Expression:
 943    """Get the logical expression for a date range"""
 944    return exp.and_(
 945        left >= date_literal(drange[0], target_type),
 946        left < date_literal(drange[1], target_type),
 947        copy=False,
 948    )
 949
 950
 951def _datetrunc_eq(
 952    left: exp.Expression,
 953    date: datetime.date,
 954    unit: str,
 955    dialect: Dialect,
 956    target_type: t.Optional[exp.DataType],
 957) -> t.Optional[exp.Expression]:
 958    drange = _datetrunc_range(date, unit, dialect)
 959    if not drange:
 960        return None
 961
 962    return _datetrunc_eq_expression(left, drange, target_type)
 963
 964
 965def _datetrunc_neq(
 966    left: exp.Expression,
 967    date: datetime.date,
 968    unit: str,
 969    dialect: Dialect,
 970    target_type: t.Optional[exp.DataType],
 971) -> t.Optional[exp.Expression]:
 972    drange = _datetrunc_range(date, unit, dialect)
 973    if not drange:
 974        return None
 975
 976    return exp.and_(
 977        left < date_literal(drange[0], target_type),
 978        left >= date_literal(drange[1], target_type),
 979        copy=False,
 980    )
 981
 982
 983DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 984    exp.LT: lambda l, dt, u, d, t: l
 985    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 986    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 987    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 988    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 989    exp.EQ: _datetrunc_eq,
 990    exp.NEQ: _datetrunc_neq,
 991}
 992DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 993DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 994
 995
 996def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 997    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 998
 999
1000@catch(ModuleNotFoundError, UnsupportedUnit)
1001def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1002    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1003    comparison = expression.__class__
1004
1005    if isinstance(expression, DATETRUNCS):
1006        this = expression.this
1007        trunc_type = extract_type(this)
1008        date = extract_date(this)
1009        if date and expression.unit:
1010            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1011    elif comparison not in DATETRUNC_COMPARISONS:
1012        return expression
1013
1014    if isinstance(expression, exp.Binary):
1015        l, r = expression.left, expression.right
1016
1017        if not _is_datetrunc_predicate(l, r):
1018            return expression
1019
1020        l = t.cast(exp.DateTrunc, l)
1021        trunc_arg = l.this
1022        unit = l.unit.name.lower()
1023        date = extract_date(r)
1024
1025        if not date:
1026            return expression
1027
1028        return (
1029            DATETRUNC_BINARY_COMPARISONS[comparison](
1030                trunc_arg, date, unit, dialect, extract_type(r)
1031            )
1032            or expression
1033        )
1034
1035    if isinstance(expression, exp.In):
1036        l = expression.this
1037        rs = expression.expressions
1038
1039        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1040            l = t.cast(exp.DateTrunc, l)
1041            unit = l.unit.name.lower()
1042
1043            ranges = []
1044            for r in rs:
1045                date = extract_date(r)
1046                if not date:
1047                    return expression
1048                drange = _datetrunc_range(date, unit, dialect)
1049                if drange:
1050                    ranges.append(drange)
1051
1052            if not ranges:
1053                return expression
1054
1055            ranges = merge_ranges(ranges)
1056            target_type = extract_type(*rs)
1057
1058            return exp.or_(
1059                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1060            )
1061
1062    return expression
1063
1064
1065def sort_comparison(expression: exp.Expression) -> exp.Expression:
1066    if expression.__class__ in COMPLEMENT_COMPARISONS:
1067        l, r = expression.this, expression.expression
1068        l_column = isinstance(l, exp.Column)
1069        r_column = isinstance(r, exp.Column)
1070        l_const = _is_constant(l)
1071        r_const = _is_constant(r)
1072
1073        if (l_column and not r_column) or (r_const and not l_const):
1074            return expression
1075        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1076            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1077                this=r, expression=l
1078            )
1079    return expression
1080
1081
1082# CROSS joins result in an empty table if the right table is empty.
1083# So we can only simplify certain types of joins to CROSS.
1084# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1085JOINS = {
1086    ("", ""),
1087    ("", "INNER"),
1088    ("RIGHT", ""),
1089    ("RIGHT", "OUTER"),
1090}
1091
1092
1093def remove_where_true(expression):
1094    for where in expression.find_all(exp.Where):
1095        if always_true(where.this):
1096            where.pop()
1097    for join in expression.find_all(exp.Join):
1098        if (
1099            always_true(join.args.get("on"))
1100            and not join.args.get("using")
1101            and not join.args.get("method")
1102            and (join.side, join.kind) in JOINS
1103        ):
1104            join.args["on"].pop()
1105            join.set("side", None)
1106            join.set("kind", "CROSS")
1107
1108
1109def always_true(expression):
1110    return (isinstance(expression, exp.Boolean) and expression.this) or (
1111        isinstance(expression, exp.Literal) and not is_zero(expression)
1112    )
1113
1114
1115def always_false(expression):
1116    return is_false(expression) or is_null(expression) or is_zero(expression)
1117
1118
1119def is_zero(expression):
1120    return isinstance(expression, exp.Literal) and expression.to_py() == 0
1121
1122
1123def is_complement(a, b):
1124    return isinstance(b, exp.Not) and b.this == a
1125
1126
1127def is_false(a: exp.Expression) -> bool:
1128    return type(a) is exp.Boolean and not a.this
1129
1130
1131def is_null(a: exp.Expression) -> bool:
1132    return type(a) is exp.Null
1133
1134
1135def eval_boolean(expression, a, b):
1136    if isinstance(expression, (exp.EQ, exp.Is)):
1137        return boolean_literal(a == b)
1138    if isinstance(expression, exp.NEQ):
1139        return boolean_literal(a != b)
1140    if isinstance(expression, exp.GT):
1141        return boolean_literal(a > b)
1142    if isinstance(expression, exp.GTE):
1143        return boolean_literal(a >= b)
1144    if isinstance(expression, exp.LT):
1145        return boolean_literal(a < b)
1146    if isinstance(expression, exp.LTE):
1147        return boolean_literal(a <= b)
1148    return None
1149
1150
1151def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1152    if isinstance(value, datetime.datetime):
1153        return value.date()
1154    if isinstance(value, datetime.date):
1155        return value
1156    try:
1157        return datetime.datetime.fromisoformat(value).date()
1158    except ValueError:
1159        return None
1160
1161
1162def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1163    if isinstance(value, datetime.datetime):
1164        return value
1165    if isinstance(value, datetime.date):
1166        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1167    try:
1168        return datetime.datetime.fromisoformat(value)
1169    except ValueError:
1170        return None
1171
1172
1173def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1174    if not value:
1175        return None
1176    if to.is_type(exp.DataType.Type.DATE):
1177        return cast_as_date(value)
1178    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1179        return cast_as_datetime(value)
1180    return None
1181
1182
1183def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1184    if isinstance(cast, exp.Cast):
1185        to = cast.to
1186    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1187        to = exp.DataType.build(exp.DataType.Type.DATE)
1188    else:
1189        return None
1190
1191    if isinstance(cast.this, exp.Literal):
1192        value: t.Any = cast.this.name
1193    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1194        value = extract_date(cast.this)
1195    else:
1196        return None
1197    return cast_value(value, to)
1198
1199
1200def _is_date_literal(expression: exp.Expression) -> bool:
1201    return extract_date(expression) is not None
1202
1203
1204def extract_interval(expression):
1205    try:
1206        n = int(expression.this.to_py())
1207        unit = expression.text("unit").lower()
1208        return interval(unit, n)
1209    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1210        return None
1211
1212
1213def extract_type(*expressions):
1214    target_type = None
1215    for expression in expressions:
1216        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1217        if target_type:
1218            break
1219
1220    return target_type
1221
1222
1223def date_literal(date, target_type=None):
1224    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1225        target_type = (
1226            exp.DataType.Type.DATETIME
1227            if isinstance(date, datetime.datetime)
1228            else exp.DataType.Type.DATE
1229        )
1230
1231    return exp.cast(exp.Literal.string(date), target_type)
1232
1233
1234def interval(unit: str, n: int = 1):
1235    from dateutil.relativedelta import relativedelta
1236
1237    if unit == "year":
1238        return relativedelta(years=1 * n)
1239    if unit == "quarter":
1240        return relativedelta(months=3 * n)
1241    if unit == "month":
1242        return relativedelta(months=1 * n)
1243    if unit == "week":
1244        return relativedelta(weeks=1 * n)
1245    if unit == "day":
1246        return relativedelta(days=1 * n)
1247    if unit == "hour":
1248        return relativedelta(hours=1 * n)
1249    if unit == "minute":
1250        return relativedelta(minutes=1 * n)
1251    if unit == "second":
1252        return relativedelta(seconds=1 * n)
1253
1254    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1255
1256
1257def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1258    if unit == "year":
1259        return d.replace(month=1, day=1)
1260    if unit == "quarter":
1261        if d.month <= 3:
1262            return d.replace(month=1, day=1)
1263        elif d.month <= 6:
1264            return d.replace(month=4, day=1)
1265        elif d.month <= 9:
1266            return d.replace(month=7, day=1)
1267        else:
1268            return d.replace(month=10, day=1)
1269    if unit == "month":
1270        return d.replace(month=d.month, day=1)
1271    if unit == "week":
1272        # Assuming week starts on Monday (0) and ends on Sunday (6)
1273        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1274    if unit == "day":
1275        return d
1276
1277    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1278
1279
1280def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1281    floor = date_floor(d, unit, dialect)
1282
1283    if floor == d:
1284        return d
1285
1286    return floor + interval(unit)
1287
1288
1289def boolean_literal(condition):
1290    return exp.true() if condition else exp.false()
1291
1292
1293def _flat_simplify(expression, simplifier, root=True):
1294    if root or not expression.same_parent:
1295        operands = []
1296        queue = deque(expression.flatten(unnest=False))
1297        size = len(queue)
1298
1299        while queue:
1300            a = queue.popleft()
1301
1302            for b in queue:
1303                result = simplifier(expression, a, b)
1304
1305                if result and result is not expression:
1306                    queue.remove(b)
1307                    queue.appendleft(result)
1308                    break
1309            else:
1310                operands.append(a)
1311
1312        if len(operands) < size:
1313            return functools.reduce(
1314                lambda a, b: expression.__class__(this=a, expression=b), operands
1315            )
1316    return expression
1317
1318
1319def gen(expression: t.Any) -> str:
1320    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1321
1322    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1323    generator is expensive so we have a bare minimum sql generator here.
1324    """
1325    return Gen().gen(expression)
1326
1327
1328class Gen:
1329    def __init__(self):
1330        self.stack = []
1331        self.sqls = []
1332
1333    def gen(self, expression: exp.Expression) -> str:
1334        self.stack = [expression]
1335        self.sqls.clear()
1336
1337        while self.stack:
1338            node = self.stack.pop()
1339
1340            if isinstance(node, exp.Expression):
1341                exp_handler_name = f"{node.key}_sql"
1342
1343                if hasattr(self, exp_handler_name):
1344                    getattr(self, exp_handler_name)(node)
1345                elif isinstance(node, exp.Func):
1346                    self._function(node)
1347                else:
1348                    key = node.key.upper()
1349                    self.stack.append(f"{key} " if self._args(node) else key)
1350            elif type(node) is list:
1351                for n in reversed(node):
1352                    if n is not None:
1353                        self.stack.extend((n, ","))
1354                if node:
1355                    self.stack.pop()
1356            else:
1357                if node is not None:
1358                    self.sqls.append(str(node))
1359
1360        return "".join(self.sqls)
1361
1362    def add_sql(self, e: exp.Add) -> None:
1363        self._binary(e, " + ")
1364
1365    def alias_sql(self, e: exp.Alias) -> None:
1366        self.stack.extend(
1367            (
1368                e.args.get("alias"),
1369                " AS ",
1370                e.args.get("this"),
1371            )
1372        )
1373
1374    def and_sql(self, e: exp.And) -> None:
1375        self._binary(e, " AND ")
1376
1377    def anonymous_sql(self, e: exp.Anonymous) -> None:
1378        this = e.this
1379        if isinstance(this, str):
1380            name = this.upper()
1381        elif isinstance(this, exp.Identifier):
1382            name = this.this
1383            name = f'"{name}"' if this.quoted else name.upper()
1384        else:
1385            raise ValueError(
1386                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1387            )
1388
1389        self.stack.extend(
1390            (
1391                ")",
1392                e.expressions,
1393                "(",
1394                name,
1395            )
1396        )
1397
1398    def between_sql(self, e: exp.Between) -> None:
1399        self.stack.extend(
1400            (
1401                e.args.get("high"),
1402                " AND ",
1403                e.args.get("low"),
1404                " BETWEEN ",
1405                e.this,
1406            )
1407        )
1408
1409    def boolean_sql(self, e: exp.Boolean) -> None:
1410        self.stack.append("TRUE" if e.this else "FALSE")
1411
1412    def bracket_sql(self, e: exp.Bracket) -> None:
1413        self.stack.extend(
1414            (
1415                "]",
1416                e.expressions,
1417                "[",
1418                e.this,
1419            )
1420        )
1421
1422    def column_sql(self, e: exp.Column) -> None:
1423        for p in reversed(e.parts):
1424            self.stack.extend((p, "."))
1425        self.stack.pop()
1426
1427    def datatype_sql(self, e: exp.DataType) -> None:
1428        self._args(e, 1)
1429        self.stack.append(f"{e.this.name} ")
1430
1431    def div_sql(self, e: exp.Div) -> None:
1432        self._binary(e, " / ")
1433
1434    def dot_sql(self, e: exp.Dot) -> None:
1435        self._binary(e, ".")
1436
1437    def eq_sql(self, e: exp.EQ) -> None:
1438        self._binary(e, " = ")
1439
1440    def from_sql(self, e: exp.From) -> None:
1441        self.stack.extend((e.this, "FROM "))
1442
1443    def gt_sql(self, e: exp.GT) -> None:
1444        self._binary(e, " > ")
1445
1446    def gte_sql(self, e: exp.GTE) -> None:
1447        self._binary(e, " >= ")
1448
1449    def identifier_sql(self, e: exp.Identifier) -> None:
1450        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1451
1452    def ilike_sql(self, e: exp.ILike) -> None:
1453        self._binary(e, " ILIKE ")
1454
1455    def in_sql(self, e: exp.In) -> None:
1456        self.stack.append(")")
1457        self._args(e, 1)
1458        self.stack.extend(
1459            (
1460                "(",
1461                " IN ",
1462                e.this,
1463            )
1464        )
1465
1466    def intdiv_sql(self, e: exp.IntDiv) -> None:
1467        self._binary(e, " DIV ")
1468
1469    def is_sql(self, e: exp.Is) -> None:
1470        self._binary(e, " IS ")
1471
1472    def like_sql(self, e: exp.Like) -> None:
1473        self._binary(e, " Like ")
1474
1475    def literal_sql(self, e: exp.Literal) -> None:
1476        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1477
1478    def lt_sql(self, e: exp.LT) -> None:
1479        self._binary(e, " < ")
1480
1481    def lte_sql(self, e: exp.LTE) -> None:
1482        self._binary(e, " <= ")
1483
1484    def mod_sql(self, e: exp.Mod) -> None:
1485        self._binary(e, " % ")
1486
1487    def mul_sql(self, e: exp.Mul) -> None:
1488        self._binary(e, " * ")
1489
1490    def neg_sql(self, e: exp.Neg) -> None:
1491        self._unary(e, "-")
1492
1493    def neq_sql(self, e: exp.NEQ) -> None:
1494        self._binary(e, " <> ")
1495
1496    def not_sql(self, e: exp.Not) -> None:
1497        self._unary(e, "NOT ")
1498
1499    def null_sql(self, e: exp.Null) -> None:
1500        self.stack.append("NULL")
1501
1502    def or_sql(self, e: exp.Or) -> None:
1503        self._binary(e, " OR ")
1504
1505    def paren_sql(self, e: exp.Paren) -> None:
1506        self.stack.extend(
1507            (
1508                ")",
1509                e.this,
1510                "(",
1511            )
1512        )
1513
1514    def sub_sql(self, e: exp.Sub) -> None:
1515        self._binary(e, " - ")
1516
1517    def subquery_sql(self, e: exp.Subquery) -> None:
1518        self._args(e, 2)
1519        alias = e.args.get("alias")
1520        if alias:
1521            self.stack.append(alias)
1522        self.stack.extend((")", e.this, "("))
1523
1524    def table_sql(self, e: exp.Table) -> None:
1525        self._args(e, 4)
1526        alias = e.args.get("alias")
1527        if alias:
1528            self.stack.append(alias)
1529        for p in reversed(e.parts):
1530            self.stack.extend((p, "."))
1531        self.stack.pop()
1532
1533    def tablealias_sql(self, e: exp.TableAlias) -> None:
1534        columns = e.columns
1535
1536        if columns:
1537            self.stack.extend((")", columns, "("))
1538
1539        self.stack.extend((e.this, " AS "))
1540
1541    def var_sql(self, e: exp.Var) -> None:
1542        self.stack.append(e.this)
1543
1544    def _binary(self, e: exp.Binary, op: str) -> None:
1545        self.stack.extend((e.expression, op, e.this))
1546
1547    def _unary(self, e: exp.Unary, op: str) -> None:
1548        self.stack.extend((e.this, op))
1549
1550    def _function(self, e: exp.Func) -> None:
1551        self.stack.extend(
1552            (
1553                ")",
1554                list(e.args.values()),
1555                "(",
1556                e.sql_name(),
1557            )
1558        )
1559
1560    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1561        kvs = []
1562        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1563
1564        for k in arg_types or arg_types:
1565            v = node.args.get(k)
1566
1567            if v is not None:
1568                kvs.append([f":{k}", v])
1569        if kvs:
1570            self.stack.append(kvs)
1571            return True
1572        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
36class UnsupportedUnit(Exception):
37    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, max_depth: Optional[int] = None):
 40def simplify(
 41    expression: exp.Expression,
 42    constant_propagation: bool = False,
 43    dialect: DialectType = None,
 44    max_depth: t.Optional[int] = None,
 45):
 46    """
 47    Rewrite sqlglot AST to simplify expressions.
 48
 49    Example:
 50        >>> import sqlglot
 51        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 52        >>> simplify(expression).sql()
 53        'TRUE'
 54
 55    Args:
 56        expression: expression to simplify
 57        constant_propagation: whether the constant propagation rule should be used
 58        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
 59    Returns:
 60        sqlglot.Expression: simplified expression
 61    """
 62
 63    dialect = Dialect.get_or_raise(dialect)
 64
 65    def _simplify(expression, root=True):
 66        if (
 67            max_depth
 68            and isinstance(expression, exp.Connector)
 69            and not isinstance(expression.parent, exp.Connector)
 70        ):
 71            depth = connector_depth(expression)
 72            if depth > max_depth:
 73                logger.info(
 74                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
 75                )
 76                return expression
 77
 78        if expression.meta.get(FINAL):
 79            return expression
 80
 81        # group by expressions cannot be simplified, for example
 82        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 83        # the projection must exactly match the group by key
 84        group = expression.args.get("group")
 85
 86        if group and hasattr(expression, "selects"):
 87            groups = set(group.expressions)
 88            group.meta[FINAL] = True
 89
 90            for e in expression.selects:
 91                for node in e.walk():
 92                    if node in groups:
 93                        e.meta[FINAL] = True
 94                        break
 95
 96            having = expression.args.get("having")
 97            if having:
 98                for node in having.walk():
 99                    if node in groups:
100                        having.meta[FINAL] = True
101                        break
102
103        # Pre-order transformations
104        node = expression
105        node = rewrite_between(node)
106        node = uniq_sort(node, root)
107        node = absorb_and_eliminate(node, root)
108        node = simplify_concat(node)
109        node = simplify_conditionals(node)
110
111        if constant_propagation:
112            node = propagate_constants(node, root)
113
114        exp.replace_children(node, lambda e: _simplify(e, False))
115
116        # Post-order transformations
117        node = simplify_not(node)
118        node = flatten(node)
119        node = simplify_connectors(node, root)
120        node = remove_complements(node, root)
121        node = simplify_coalesce(node)
122        node.parent = expression.parent
123        node = simplify_literals(node, root)
124        node = simplify_equality(node)
125        node = simplify_parens(node)
126        node = simplify_datetrunc(node, dialect)
127        node = sort_comparison(node)
128        node = simplify_startswith(node)
129
130        if root:
131            expression.replace(node)
132        return node
133
134    expression = while_changing(expression, _simplify)
135    remove_where_true(expression)
136    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • max_depth: Chains of Connectors (AND, OR, etc) exceeding max_depth will be skipped
Returns:

sqlglot.Expression: simplified expression

def connector_depth(expression: sqlglot.expressions.Expression) -> int:
139def connector_depth(expression: exp.Expression) -> int:
140    """
141    Determine the maximum depth of a tree of Connectors.
142
143    For example:
144        >>> from sqlglot import parse_one
145        >>> connector_depth(parse_one("a AND b AND c AND d"))
146        3
147    """
148    stack = deque([(expression, 0)])
149    max_depth = 0
150
151    while stack:
152        expression, depth = stack.pop()
153
154        if not isinstance(expression, exp.Connector):
155            continue
156
157        depth += 1
158        max_depth = max(depth, max_depth)
159
160        stack.append((expression.left, depth))
161        stack.append((expression.right, depth))
162
163    return max_depth

Determine the maximum depth of a tree of Connectors.

For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
def catch(*exceptions):
166def catch(*exceptions):
167    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
168
169    def decorator(func):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
175
176        return wrapped
177
178    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
181def rewrite_between(expression: exp.Expression) -> exp.Expression:
182    """Rewrite x between y and z to x >= y AND x <= z.
183
184    This is done because comparison simplification is only done on lt/lte/gt/gte.
185    """
186    if isinstance(expression, exp.Between):
187        negate = isinstance(expression.parent, exp.Not)
188
189        expression = exp.and_(
190            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
191            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
192            copy=False,
193        )
194
195        if negate:
196            expression = exp.paren(expression, copy=False)
197
198    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
211def simplify_not(expression):
212    """
213    Demorgan's Law
214    NOT (x OR y) -> NOT x AND NOT y
215    NOT (x AND y) -> NOT x OR NOT y
216    """
217    if isinstance(expression, exp.Not):
218        this = expression.this
219        if is_null(this):
220            return exp.null()
221        if this.__class__ in COMPLEMENT_COMPARISONS:
222            return COMPLEMENT_COMPARISONS[this.__class__](
223                this=this.this, expression=this.expression
224            )
225        if isinstance(this, exp.Paren):
226            condition = this.unnest()
227            if isinstance(condition, exp.And):
228                return exp.paren(
229                    exp.or_(
230                        exp.not_(condition.left, copy=False),
231                        exp.not_(condition.right, copy=False),
232                        copy=False,
233                    )
234                )
235            if isinstance(condition, exp.Or):
236                return exp.paren(
237                    exp.and_(
238                        exp.not_(condition.left, copy=False),
239                        exp.not_(condition.right, copy=False),
240                        copy=False,
241                    )
242                )
243            if is_null(condition):
244                return exp.null()
245        if always_true(this):
246            return exp.false()
247        if is_false(this):
248            return exp.true()
249        if isinstance(this, exp.Not):
250            # double negation
251            # NOT NOT x -> x
252            return this.this
253    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
256def flatten(expression):
257    """
258    A AND (B AND C) -> A AND B AND C
259    A OR (B OR C) -> A OR B OR C
260    """
261    if isinstance(expression, exp.Connector):
262        for node in expression.args.values():
263            child = node.unnest()
264            if isinstance(child, expression.__class__):
265                node.replace(child)
266    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
269def simplify_connectors(expression, root=True):
270    def _simplify_connectors(expression, left, right):
271        if isinstance(expression, exp.And):
272            if is_false(left) or is_false(right):
273                return exp.false()
274            if is_zero(left) or is_zero(right):
275                return exp.false()
276            if is_null(left) or is_null(right):
277                return exp.null()
278            if always_true(left) and always_true(right):
279                return exp.true()
280            if always_true(left):
281                return right
282            if always_true(right):
283                return left
284            return _simplify_comparison(expression, left, right)
285        elif isinstance(expression, exp.Or):
286            if always_true(left) or always_true(right):
287                return exp.true()
288            if (
289                (is_null(left) and is_null(right))
290                or (is_null(left) and always_false(right))
291                or (always_false(left) and is_null(right))
292            ):
293                return exp.null()
294            if is_false(left):
295                return right
296            if is_false(right):
297                return left
298            return _simplify_comparison(expression, left, right, or_=True)
299        elif isinstance(expression, exp.Xor):
300            if left == right:
301                return exp.false()
302
303    if isinstance(expression, exp.Connector):
304        return _flat_simplify(expression, _simplify_connectors, root)
305    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
392def remove_complements(expression, root=True):
393    """
394    Removing complements.
395
396    A AND NOT A -> FALSE
397    A OR NOT A -> TRUE
398    """
399    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
400        ops = set(expression.flatten())
401        for op in ops:
402            if isinstance(op, exp.Not) and op.this in ops:
403                return exp.false() if isinstance(expression, exp.And) else exp.true()
404
405    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
408def uniq_sort(expression, root=True):
409    """
410    Uniq and sort a connector.
411
412    C AND A AND B AND B -> A AND B AND C
413    """
414    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
415        flattened = tuple(expression.flatten())
416
417        if isinstance(expression, exp.Xor):
418            result_func = exp.xor
419            # Do not deduplicate XOR as A XOR A != A if A == True
420            deduped = None
421            arr = tuple((gen(e), e) for e in flattened)
422        else:
423            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
424            deduped = {gen(e): e for e in flattened}
425            arr = tuple(deduped.items())
426
427        # check if the operands are already sorted, if not sort them
428        # A AND C AND B -> A AND B AND C
429        for i, (sql, e) in enumerate(arr[1:]):
430            if sql < arr[i][0]:
431                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
432                break
433        else:
434            # we didn't have to sort but maybe we need to dedup
435            if deduped and len(deduped) < len(flattened):
436                expression = result_func(*deduped.values(), copy=False)
437
438    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
441def absorb_and_eliminate(expression, root=True):
442    """
443    absorption:
444        A AND (A OR B) -> A
445        A OR (A AND B) -> A
446        A AND (NOT A OR B) -> A AND B
447        A OR (NOT A AND B) -> A OR B
448    elimination:
449        (A AND B) OR (A AND NOT B) -> A
450        (A OR B) AND (A OR NOT B) -> A
451    """
452    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
453        kind = exp.Or if isinstance(expression, exp.And) else exp.And
454
455        ops = tuple(expression.flatten())
456
457        # Initialize lookup tables:
458        # Set of all operands, used to find complements for absorption.
459        op_set = set()
460        # Sub-operands, used to find subsets for absorption.
461        subops = defaultdict(list)
462        # Pairs of complements, used for elimination.
463        pairs = defaultdict(list)
464
465        # Populate the lookup tables
466        for op in ops:
467            op_set.add(op)
468
469            if not isinstance(op, kind):
470                # In cases like: A OR (A AND B)
471                # Subop will be: ^
472                subops[op].append({op})
473                continue
474
475            # In cases like: (A AND B) OR (A AND B AND C)
476            # Subops will be: ^     ^
477            subset = set(op.flatten())
478            for i in subset:
479                subops[i].append(subset)
480
481            a, b = op.unnest_operands()
482            if isinstance(a, exp.Not):
483                pairs[frozenset((a.this, b))].append((op, b))
484            if isinstance(b, exp.Not):
485                pairs[frozenset((a, b.this))].append((op, a))
486
487        for op in ops:
488            if not isinstance(op, kind):
489                continue
490
491            a, b = op.unnest_operands()
492
493            # Absorb
494            if isinstance(a, exp.Not) and a.this in op_set:
495                a.replace(exp.true() if kind == exp.And else exp.false())
496                continue
497            if isinstance(b, exp.Not) and b.this in op_set:
498                b.replace(exp.true() if kind == exp.And else exp.false())
499                continue
500            superset = set(op.flatten())
501            if any(any(subset < superset for subset in subops[i]) for i in superset):
502                op.replace(exp.false() if kind == exp.And else exp.true())
503                continue
504
505            # Eliminate
506            for other, complement in pairs[frozenset((a, b))]:
507                op.replace(complement)
508                other.replace(complement)
509
510    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
513def propagate_constants(expression, root=True):
514    """
515    Propagate constants for conjunctions in DNF:
516
517    SELECT * FROM t WHERE a = b AND b = 5 becomes
518    SELECT * FROM t WHERE a = 5 AND b = 5
519
520    Reference: https://www.sqlite.org/optoverview.html
521    """
522
523    if (
524        isinstance(expression, exp.And)
525        and (root or not expression.same_parent)
526        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
527    ):
528        constant_mapping = {}
529        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
530            if isinstance(expr, exp.EQ):
531                l, r = expr.left, expr.right
532
533                # TODO: create a helper that can be used to detect nested literal expressions such
534                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
535                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
536                    constant_mapping[l] = (id(l), r)
537
538        if constant_mapping:
539            for column in find_all_in_scope(expression, exp.Column):
540                parent = column.parent
541                column_id, constant = constant_mapping.get(column) or (None, None)
542                if (
543                    column_id is not None
544                    and id(column) != column_id
545                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
546                ):
547                    column.replace(constant.copy())
548
549    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
624def simplify_literals(expression, root=True):
625    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
626        return _flat_simplify(expression, _simplify_binary, root)
627
628    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
629        return expression.this.this
630
631    if type(expression) in INVERSE_DATE_OPS:
632        return _simplify_binary(expression, expression.this, expression.interval()) or expression
633
634    return expression
def simplify_parens(expression):
735def simplify_parens(expression):
736    if not isinstance(expression, exp.Paren):
737        return expression
738
739    this = expression.this
740    parent = expression.parent
741    parent_is_predicate = isinstance(parent, exp.Predicate)
742
743    if (
744        not isinstance(this, exp.Select)
745        and not isinstance(parent, exp.SubqueryPredicate)
746        and (
747            not isinstance(parent, (exp.Condition, exp.Binary))
748            or isinstance(parent, exp.Paren)
749            or (
750                not isinstance(this, exp.Binary)
751                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
752            )
753            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
754            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
755            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
756            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
757        )
758    ):
759        return this
760    return expression
def simplify_coalesce(expression):
771def simplify_coalesce(expression):
772    # COALESCE(x) -> x
773    if (
774        isinstance(expression, exp.Coalesce)
775        and (not expression.expressions or _is_nonnull_constant(expression.this))
776        # COALESCE is also used as a Spark partitioning hint
777        and not isinstance(expression.parent, exp.Hint)
778    ):
779        return expression.this
780
781    if not isinstance(expression, COMPARISONS):
782        return expression
783
784    if isinstance(expression.left, exp.Coalesce):
785        coalesce = expression.left
786        other = expression.right
787    elif isinstance(expression.right, exp.Coalesce):
788        coalesce = expression.right
789        other = expression.left
790    else:
791        return expression
792
793    # This transformation is valid for non-constants,
794    # but it really only does anything if they are both constants.
795    if not _is_constant(other):
796        return expression
797
798    # Find the first constant arg
799    for arg_index, arg in enumerate(coalesce.expressions):
800        if _is_constant(arg):
801            break
802    else:
803        return expression
804
805    coalesce.set("expressions", coalesce.expressions[:arg_index])
806
807    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
808    # since we already remove COALESCE at the top of this function.
809    coalesce = coalesce if coalesce.expressions else coalesce.this
810
811    # This expression is more complex than when we started, but it will get simplified further
812    return exp.paren(
813        exp.or_(
814            exp.and_(
815                coalesce.is_(exp.null()).not_(copy=False),
816                expression.copy(),
817                copy=False,
818            ),
819            exp.and_(
820                coalesce.is_(exp.null()),
821                type(expression)(this=arg.copy(), expression=other.copy()),
822                copy=False,
823            ),
824            copy=False,
825        )
826    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
832def simplify_concat(expression):
833    """Reduces all groups that contain string literals by concatenating them."""
834    if not isinstance(expression, CONCATS) or (
835        # We can't reduce a CONCAT_WS call if we don't statically know the separator
836        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
837    ):
838        return expression
839
840    if isinstance(expression, exp.ConcatWs):
841        sep_expr, *expressions = expression.expressions
842        sep = sep_expr.name
843        concat_type = exp.ConcatWs
844        args = {}
845    else:
846        expressions = expression.expressions
847        sep = ""
848        concat_type = exp.Concat
849        args = {
850            "safe": expression.args.get("safe"),
851            "coalesce": expression.args.get("coalesce"),
852        }
853
854    new_args = []
855    for is_string_group, group in itertools.groupby(
856        expressions or expression.flatten(), lambda e: e.is_string
857    ):
858        if is_string_group:
859            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
860        else:
861            new_args.extend(group)
862
863    if len(new_args) == 1 and new_args[0].is_string:
864        return new_args[0]
865
866    if concat_type is exp.ConcatWs:
867        new_args = [sep_expr] + new_args
868    elif isinstance(expression, exp.DPipe):
869        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
870
871    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
874def simplify_conditionals(expression):
875    """Simplifies expressions like IF, CASE if their condition is statically known."""
876    if isinstance(expression, exp.Case):
877        this = expression.this
878        for case in expression.args["ifs"]:
879            cond = case.this
880            if this:
881                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
882                cond = cond.replace(this.pop().eq(cond))
883
884            if always_true(cond):
885                return case.args["true"]
886
887            if always_false(cond):
888                case.pop()
889                if not expression.args["ifs"]:
890                    return expression.args.get("default") or exp.null()
891    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
892        if always_true(expression.this):
893            return expression.args["true"]
894        if always_false(expression.this):
895            return expression.args.get("false") or exp.null()
896
897    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
900def simplify_startswith(expression: exp.Expression) -> exp.Expression:
901    """
902    Reduces a prefix check to either TRUE or FALSE if both the string and the
903    prefix are statically known.
904
905    Example:
906        >>> from sqlglot import parse_one
907        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
908        'TRUE'
909    """
910    if (
911        isinstance(expression, exp.StartsWith)
912        and expression.this.is_string
913        and expression.expression.is_string
914    ):
915        return exp.convert(expression.name.startswith(expression.expression.name))
916
917    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GTE'>}
def simplify_datetrunc(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1066def sort_comparison(expression: exp.Expression) -> exp.Expression:
1067    if expression.__class__ in COMPLEMENT_COMPARISONS:
1068        l, r = expression.this, expression.expression
1069        l_column = isinstance(l, exp.Column)
1070        r_column = isinstance(r, exp.Column)
1071        l_const = _is_constant(l)
1072        r_const = _is_constant(r)
1073
1074        if (l_column and not r_column) or (r_const and not l_const):
1075            return expression
1076        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1077            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1078                this=r, expression=l
1079            )
1080    return expression
JOINS = {('', 'INNER'), ('RIGHT', 'OUTER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
1094def remove_where_true(expression):
1095    for where in expression.find_all(exp.Where):
1096        if always_true(where.this):
1097            where.pop()
1098    for join in expression.find_all(exp.Join):
1099        if (
1100            always_true(join.args.get("on"))
1101            and not join.args.get("using")
1102            and not join.args.get("method")
1103            and (join.side, join.kind) in JOINS
1104        ):
1105            join.args["on"].pop()
1106            join.set("side", None)
1107            join.set("kind", "CROSS")
def always_true(expression):
1110def always_true(expression):
1111    return (isinstance(expression, exp.Boolean) and expression.this) or (
1112        isinstance(expression, exp.Literal) and not is_zero(expression)
1113    )
def always_false(expression):
1116def always_false(expression):
1117    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
1120def is_zero(expression):
1121    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):
1124def is_complement(a, b):
1125    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1128def is_false(a: exp.Expression) -> bool:
1129    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1132def is_null(a: exp.Expression) -> bool:
1133    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1136def eval_boolean(expression, a, b):
1137    if isinstance(expression, (exp.EQ, exp.Is)):
1138        return boolean_literal(a == b)
1139    if isinstance(expression, exp.NEQ):
1140        return boolean_literal(a != b)
1141    if isinstance(expression, exp.GT):
1142        return boolean_literal(a > b)
1143    if isinstance(expression, exp.GTE):
1144        return boolean_literal(a >= b)
1145    if isinstance(expression, exp.LT):
1146        return boolean_literal(a < b)
1147    if isinstance(expression, exp.LTE):
1148        return boolean_literal(a <= b)
1149    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1152def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1153    if isinstance(value, datetime.datetime):
1154        return value.date()
1155    if isinstance(value, datetime.date):
1156        return value
1157    try:
1158        return datetime.datetime.fromisoformat(value).date()
1159    except ValueError:
1160        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1163def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1164    if isinstance(value, datetime.datetime):
1165        return value
1166    if isinstance(value, datetime.date):
1167        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1168    try:
1169        return datetime.datetime.fromisoformat(value)
1170    except ValueError:
1171        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1174def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1175    if not value:
1176        return None
1177    if to.is_type(exp.DataType.Type.DATE):
1178        return cast_as_date(value)
1179    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1180        return cast_as_datetime(value)
1181    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1184def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1185    if isinstance(cast, exp.Cast):
1186        to = cast.to
1187    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1188        to = exp.DataType.build(exp.DataType.Type.DATE)
1189    else:
1190        return None
1191
1192    if isinstance(cast.this, exp.Literal):
1193        value: t.Any = cast.this.name
1194    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1195        value = extract_date(cast.this)
1196    else:
1197        return None
1198    return cast_value(value, to)
def extract_interval(expression):
1205def extract_interval(expression):
1206    try:
1207        n = int(expression.this.to_py())
1208        unit = expression.text("unit").lower()
1209        return interval(unit, n)
1210    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1211        return None
def extract_type(*expressions):
1214def extract_type(*expressions):
1215    target_type = None
1216    for expression in expressions:
1217        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1218        if target_type:
1219            break
1220
1221    return target_type
def date_literal(date, target_type=None):
1224def date_literal(date, target_type=None):
1225    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1226        target_type = (
1227            exp.DataType.Type.DATETIME
1228            if isinstance(date, datetime.datetime)
1229            else exp.DataType.Type.DATE
1230        )
1231
1232    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1235def interval(unit: str, n: int = 1):
1236    from dateutil.relativedelta import relativedelta
1237
1238    if unit == "year":
1239        return relativedelta(years=1 * n)
1240    if unit == "quarter":
1241        return relativedelta(months=3 * n)
1242    if unit == "month":
1243        return relativedelta(months=1 * n)
1244    if unit == "week":
1245        return relativedelta(weeks=1 * n)
1246    if unit == "day":
1247        return relativedelta(days=1 * n)
1248    if unit == "hour":
1249        return relativedelta(hours=1 * n)
1250    if unit == "minute":
1251        return relativedelta(minutes=1 * n)
1252    if unit == "second":
1253        return relativedelta(seconds=1 * n)
1254
1255    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1258def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1259    if unit == "year":
1260        return d.replace(month=1, day=1)
1261    if unit == "quarter":
1262        if d.month <= 3:
1263            return d.replace(month=1, day=1)
1264        elif d.month <= 6:
1265            return d.replace(month=4, day=1)
1266        elif d.month <= 9:
1267            return d.replace(month=7, day=1)
1268        else:
1269            return d.replace(month=10, day=1)
1270    if unit == "month":
1271        return d.replace(month=d.month, day=1)
1272    if unit == "week":
1273        # Assuming week starts on Monday (0) and ends on Sunday (6)
1274        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1275    if unit == "day":
1276        return d
1277
1278    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1281def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1282    floor = date_floor(d, unit, dialect)
1283
1284    if floor == d:
1285        return d
1286
1287    return floor + interval(unit)
def boolean_literal(condition):
1290def boolean_literal(condition):
1291    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1320def gen(expression: t.Any) -> str:
1321    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1322
1323    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1324    generator is expensive so we have a bare minimum sql generator here.
1325    """
1326    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1329class Gen:
1330    def __init__(self):
1331        self.stack = []
1332        self.sqls = []
1333
1334    def gen(self, expression: exp.Expression) -> str:
1335        self.stack = [expression]
1336        self.sqls.clear()
1337
1338        while self.stack:
1339            node = self.stack.pop()
1340
1341            if isinstance(node, exp.Expression):
1342                exp_handler_name = f"{node.key}_sql"
1343
1344                if hasattr(self, exp_handler_name):
1345                    getattr(self, exp_handler_name)(node)
1346                elif isinstance(node, exp.Func):
1347                    self._function(node)
1348                else:
1349                    key = node.key.upper()
1350                    self.stack.append(f"{key} " if self._args(node) else key)
1351            elif type(node) is list:
1352                for n in reversed(node):
1353                    if n is not None:
1354                        self.stack.extend((n, ","))
1355                if node:
1356                    self.stack.pop()
1357            else:
1358                if node is not None:
1359                    self.sqls.append(str(node))
1360
1361        return "".join(self.sqls)
1362
1363    def add_sql(self, e: exp.Add) -> None:
1364        self._binary(e, " + ")
1365
1366    def alias_sql(self, e: exp.Alias) -> None:
1367        self.stack.extend(
1368            (
1369                e.args.get("alias"),
1370                " AS ",
1371                e.args.get("this"),
1372            )
1373        )
1374
1375    def and_sql(self, e: exp.And) -> None:
1376        self._binary(e, " AND ")
1377
1378    def anonymous_sql(self, e: exp.Anonymous) -> None:
1379        this = e.this
1380        if isinstance(this, str):
1381            name = this.upper()
1382        elif isinstance(this, exp.Identifier):
1383            name = this.this
1384            name = f'"{name}"' if this.quoted else name.upper()
1385        else:
1386            raise ValueError(
1387                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1388            )
1389
1390        self.stack.extend(
1391            (
1392                ")",
1393                e.expressions,
1394                "(",
1395                name,
1396            )
1397        )
1398
1399    def between_sql(self, e: exp.Between) -> None:
1400        self.stack.extend(
1401            (
1402                e.args.get("high"),
1403                " AND ",
1404                e.args.get("low"),
1405                " BETWEEN ",
1406                e.this,
1407            )
1408        )
1409
1410    def boolean_sql(self, e: exp.Boolean) -> None:
1411        self.stack.append("TRUE" if e.this else "FALSE")
1412
1413    def bracket_sql(self, e: exp.Bracket) -> None:
1414        self.stack.extend(
1415            (
1416                "]",
1417                e.expressions,
1418                "[",
1419                e.this,
1420            )
1421        )
1422
1423    def column_sql(self, e: exp.Column) -> None:
1424        for p in reversed(e.parts):
1425            self.stack.extend((p, "."))
1426        self.stack.pop()
1427
1428    def datatype_sql(self, e: exp.DataType) -> None:
1429        self._args(e, 1)
1430        self.stack.append(f"{e.this.name} ")
1431
1432    def div_sql(self, e: exp.Div) -> None:
1433        self._binary(e, " / ")
1434
1435    def dot_sql(self, e: exp.Dot) -> None:
1436        self._binary(e, ".")
1437
1438    def eq_sql(self, e: exp.EQ) -> None:
1439        self._binary(e, " = ")
1440
1441    def from_sql(self, e: exp.From) -> None:
1442        self.stack.extend((e.this, "FROM "))
1443
1444    def gt_sql(self, e: exp.GT) -> None:
1445        self._binary(e, " > ")
1446
1447    def gte_sql(self, e: exp.GTE) -> None:
1448        self._binary(e, " >= ")
1449
1450    def identifier_sql(self, e: exp.Identifier) -> None:
1451        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1452
1453    def ilike_sql(self, e: exp.ILike) -> None:
1454        self._binary(e, " ILIKE ")
1455
1456    def in_sql(self, e: exp.In) -> None:
1457        self.stack.append(")")
1458        self._args(e, 1)
1459        self.stack.extend(
1460            (
1461                "(",
1462                " IN ",
1463                e.this,
1464            )
1465        )
1466
1467    def intdiv_sql(self, e: exp.IntDiv) -> None:
1468        self._binary(e, " DIV ")
1469
1470    def is_sql(self, e: exp.Is) -> None:
1471        self._binary(e, " IS ")
1472
1473    def like_sql(self, e: exp.Like) -> None:
1474        self._binary(e, " Like ")
1475
1476    def literal_sql(self, e: exp.Literal) -> None:
1477        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1478
1479    def lt_sql(self, e: exp.LT) -> None:
1480        self._binary(e, " < ")
1481
1482    def lte_sql(self, e: exp.LTE) -> None:
1483        self._binary(e, " <= ")
1484
1485    def mod_sql(self, e: exp.Mod) -> None:
1486        self._binary(e, " % ")
1487
1488    def mul_sql(self, e: exp.Mul) -> None:
1489        self._binary(e, " * ")
1490
1491    def neg_sql(self, e: exp.Neg) -> None:
1492        self._unary(e, "-")
1493
1494    def neq_sql(self, e: exp.NEQ) -> None:
1495        self._binary(e, " <> ")
1496
1497    def not_sql(self, e: exp.Not) -> None:
1498        self._unary(e, "NOT ")
1499
1500    def null_sql(self, e: exp.Null) -> None:
1501        self.stack.append("NULL")
1502
1503    def or_sql(self, e: exp.Or) -> None:
1504        self._binary(e, " OR ")
1505
1506    def paren_sql(self, e: exp.Paren) -> None:
1507        self.stack.extend(
1508            (
1509                ")",
1510                e.this,
1511                "(",
1512            )
1513        )
1514
1515    def sub_sql(self, e: exp.Sub) -> None:
1516        self._binary(e, " - ")
1517
1518    def subquery_sql(self, e: exp.Subquery) -> None:
1519        self._args(e, 2)
1520        alias = e.args.get("alias")
1521        if alias:
1522            self.stack.append(alias)
1523        self.stack.extend((")", e.this, "("))
1524
1525    def table_sql(self, e: exp.Table) -> None:
1526        self._args(e, 4)
1527        alias = e.args.get("alias")
1528        if alias:
1529            self.stack.append(alias)
1530        for p in reversed(e.parts):
1531            self.stack.extend((p, "."))
1532        self.stack.pop()
1533
1534    def tablealias_sql(self, e: exp.TableAlias) -> None:
1535        columns = e.columns
1536
1537        if columns:
1538            self.stack.extend((")", columns, "("))
1539
1540        self.stack.extend((e.this, " AS "))
1541
1542    def var_sql(self, e: exp.Var) -> None:
1543        self.stack.append(e.this)
1544
1545    def _binary(self, e: exp.Binary, op: str) -> None:
1546        self.stack.extend((e.expression, op, e.this))
1547
1548    def _unary(self, e: exp.Unary, op: str) -> None:
1549        self.stack.extend((e.this, op))
1550
1551    def _function(self, e: exp.Func) -> None:
1552        self.stack.extend(
1553            (
1554                ")",
1555                list(e.args.values()),
1556                "(",
1557                e.sql_name(),
1558            )
1559        )
1560
1561    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1562        kvs = []
1563        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1564
1565        for k in arg_types or arg_types:
1566            v = node.args.get(k)
1567
1568            if v is not None:
1569                kvs.append([f":{k}", v])
1570        if kvs:
1571            self.stack.append(kvs)
1572            return True
1573        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1334    def gen(self, expression: exp.Expression) -> str:
1335        self.stack = [expression]
1336        self.sqls.clear()
1337
1338        while self.stack:
1339            node = self.stack.pop()
1340
1341            if isinstance(node, exp.Expression):
1342                exp_handler_name = f"{node.key}_sql"
1343
1344                if hasattr(self, exp_handler_name):
1345                    getattr(self, exp_handler_name)(node)
1346                elif isinstance(node, exp.Func):
1347                    self._function(node)
1348                else:
1349                    key = node.key.upper()
1350                    self.stack.append(f"{key} " if self._args(node) else key)
1351            elif type(node) is list:
1352                for n in reversed(node):
1353                    if n is not None:
1354                        self.stack.extend((n, ","))
1355                if node:
1356                    self.stack.pop()
1357            else:
1358                if node is not None:
1359                    self.sqls.append(str(node))
1360
1361        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1363    def add_sql(self, e: exp.Add) -> None:
1364        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1366    def alias_sql(self, e: exp.Alias) -> None:
1367        self.stack.extend(
1368            (
1369                e.args.get("alias"),
1370                " AS ",
1371                e.args.get("this"),
1372            )
1373        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1375    def and_sql(self, e: exp.And) -> None:
1376        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1378    def anonymous_sql(self, e: exp.Anonymous) -> None:
1379        this = e.this
1380        if isinstance(this, str):
1381            name = this.upper()
1382        elif isinstance(this, exp.Identifier):
1383            name = this.this
1384            name = f'"{name}"' if this.quoted else name.upper()
1385        else:
1386            raise ValueError(
1387                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1388            )
1389
1390        self.stack.extend(
1391            (
1392                ")",
1393                e.expressions,
1394                "(",
1395                name,
1396            )
1397        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1399    def between_sql(self, e: exp.Between) -> None:
1400        self.stack.extend(
1401            (
1402                e.args.get("high"),
1403                " AND ",
1404                e.args.get("low"),
1405                " BETWEEN ",
1406                e.this,
1407            )
1408        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1410    def boolean_sql(self, e: exp.Boolean) -> None:
1411        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1413    def bracket_sql(self, e: exp.Bracket) -> None:
1414        self.stack.extend(
1415            (
1416                "]",
1417                e.expressions,
1418                "[",
1419                e.this,
1420            )
1421        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1423    def column_sql(self, e: exp.Column) -> None:
1424        for p in reversed(e.parts):
1425            self.stack.extend((p, "."))
1426        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1428    def datatype_sql(self, e: exp.DataType) -> None:
1429        self._args(e, 1)
1430        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1432    def div_sql(self, e: exp.Div) -> None:
1433        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1435    def dot_sql(self, e: exp.Dot) -> None:
1436        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1438    def eq_sql(self, e: exp.EQ) -> None:
1439        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1441    def from_sql(self, e: exp.From) -> None:
1442        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1444    def gt_sql(self, e: exp.GT) -> None:
1445        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1447    def gte_sql(self, e: exp.GTE) -> None:
1448        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1450    def identifier_sql(self, e: exp.Identifier) -> None:
1451        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1453    def ilike_sql(self, e: exp.ILike) -> None:
1454        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1456    def in_sql(self, e: exp.In) -> None:
1457        self.stack.append(")")
1458        self._args(e, 1)
1459        self.stack.extend(
1460            (
1461                "(",
1462                " IN ",
1463                e.this,
1464            )
1465        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1467    def intdiv_sql(self, e: exp.IntDiv) -> None:
1468        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1470    def is_sql(self, e: exp.Is) -> None:
1471        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1473    def like_sql(self, e: exp.Like) -> None:
1474        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1476    def literal_sql(self, e: exp.Literal) -> None:
1477        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1479    def lt_sql(self, e: exp.LT) -> None:
1480        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1482    def lte_sql(self, e: exp.LTE) -> None:
1483        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1485    def mod_sql(self, e: exp.Mod) -> None:
1486        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1488    def mul_sql(self, e: exp.Mul) -> None:
1489        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1491    def neg_sql(self, e: exp.Neg) -> None:
1492        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1494    def neq_sql(self, e: exp.NEQ) -> None:
1495        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1497    def not_sql(self, e: exp.Not) -> None:
1498        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1500    def null_sql(self, e: exp.Null) -> None:
1501        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1503    def or_sql(self, e: exp.Or) -> None:
1504        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1506    def paren_sql(self, e: exp.Paren) -> None:
1507        self.stack.extend(
1508            (
1509                ")",
1510                e.this,
1511                "(",
1512            )
1513        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1515    def sub_sql(self, e: exp.Sub) -> None:
1516        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1518    def subquery_sql(self, e: exp.Subquery) -> None:
1519        self._args(e, 2)
1520        alias = e.args.get("alias")
1521        if alias:
1522            self.stack.append(alias)
1523        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1525    def table_sql(self, e: exp.Table) -> None:
1526        self._args(e, 4)
1527        alias = e.args.get("alias")
1528        if alias:
1529            self.stack.append(alias)
1530        for p in reversed(e.parts):
1531            self.stack.extend((p, "."))
1532        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1534    def tablealias_sql(self, e: exp.TableAlias) -> None:
1535        columns = e.columns
1536
1537        if columns:
1538            self.stack.extend((")", columns, "("))
1539
1540        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1542    def var_sql(self, e: exp.Var) -> None:
1543        self.stack.append(e.this)