Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from decimal import Decimal
  10from functools import reduce
  11
  12import sqlglot
  13from sqlglot import Dialect, exp
  14from sqlglot.helper import first, merge_ranges, while_changing
  15from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  16
  17if t.TYPE_CHECKING:
  18    from sqlglot.dialects.dialect import DialectType
  19
  20    DateTruncBinaryTransform = t.Callable[
  21        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  22    ]
  23
  24logger = logging.getLogger("sqlglot")
  25
  26# Final means that an expression should not be simplified
  27FINAL = "final"
  28
  29# Value ranges for byte-sized signed/unsigned integers
  30TINYINT_MIN = -128
  31TINYINT_MAX = 127
  32UTINYINT_MIN = 0
  33UTINYINT_MAX = 255
  34
  35
  36class UnsupportedUnit(Exception):
  37    pass
  38
  39
  40def simplify(
  41    expression: exp.Expression,
  42    constant_propagation: bool = False,
  43    dialect: DialectType = None,
  44    max_depth: t.Optional[int] = None,
  45):
  46    """
  47    Rewrite sqlglot AST to simplify expressions.
  48
  49    Example:
  50        >>> import sqlglot
  51        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  52        >>> simplify(expression).sql()
  53        'TRUE'
  54
  55    Args:
  56        expression: expression to simplify
  57        constant_propagation: whether the constant propagation rule should be used
  58        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
  59    Returns:
  60        sqlglot.Expression: simplified expression
  61    """
  62
  63    dialect = Dialect.get_or_raise(dialect)
  64
  65    def _simplify(expression, root=True):
  66        if (
  67            max_depth
  68            and isinstance(expression, exp.Connector)
  69            and not isinstance(expression.parent, exp.Connector)
  70        ):
  71            depth = connector_depth(expression)
  72            if depth > max_depth:
  73                logger.info(
  74                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
  75                )
  76                return expression
  77
  78        if expression.meta.get(FINAL):
  79            return expression
  80
  81        # group by expressions cannot be simplified, for example
  82        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  83        # the projection must exactly match the group by key
  84        group = expression.args.get("group")
  85
  86        if group and hasattr(expression, "selects"):
  87            groups = set(group.expressions)
  88            group.meta[FINAL] = True
  89
  90            for e in expression.selects:
  91                for node in e.walk():
  92                    if node in groups:
  93                        e.meta[FINAL] = True
  94                        break
  95
  96            having = expression.args.get("having")
  97            if having:
  98                for node in having.walk():
  99                    if node in groups:
 100                        having.meta[FINAL] = True
 101                        break
 102
 103        # Pre-order transformations
 104        node = expression
 105        node = rewrite_between(node)
 106        node = uniq_sort(node, root)
 107        node = absorb_and_eliminate(node, root)
 108        node = simplify_concat(node)
 109        node = simplify_conditionals(node)
 110
 111        if constant_propagation:
 112            node = propagate_constants(node, root)
 113
 114        exp.replace_children(node, lambda e: _simplify(e, False))
 115
 116        # Post-order transformations
 117        node = simplify_not(node)
 118        node = flatten(node)
 119        node = simplify_connectors(node, root)
 120        node = remove_complements(node, root)
 121        node = simplify_coalesce(node)
 122        node.parent = expression.parent
 123        node = simplify_literals(node, root)
 124        node = simplify_equality(node)
 125        node = simplify_parens(node)
 126        node = simplify_datetrunc(node, dialect)
 127        node = sort_comparison(node)
 128        node = simplify_startswith(node)
 129
 130        if root:
 131            expression.replace(node)
 132        return node
 133
 134    expression = while_changing(expression, _simplify)
 135    remove_where_true(expression)
 136    return expression
 137
 138
 139def connector_depth(expression: exp.Expression) -> int:
 140    """
 141    Determine the maximum depth of a tree of Connectors.
 142
 143    For example:
 144        >>> from sqlglot import parse_one
 145        >>> connector_depth(parse_one("a AND b AND c AND d"))
 146        3
 147    """
 148    stack = deque([(expression, 0)])
 149    max_depth = 0
 150
 151    while stack:
 152        expression, depth = stack.pop()
 153
 154        if not isinstance(expression, exp.Connector):
 155            continue
 156
 157        depth += 1
 158        max_depth = max(depth, max_depth)
 159
 160        stack.append((expression.left, depth))
 161        stack.append((expression.right, depth))
 162
 163    return max_depth
 164
 165
 166def catch(*exceptions):
 167    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 168
 169    def decorator(func):
 170        def wrapped(expression, *args, **kwargs):
 171            try:
 172                return func(expression, *args, **kwargs)
 173            except exceptions:
 174                return expression
 175
 176        return wrapped
 177
 178    return decorator
 179
 180
 181def rewrite_between(expression: exp.Expression) -> exp.Expression:
 182    """Rewrite x between y and z to x >= y AND x <= z.
 183
 184    This is done because comparison simplification is only done on lt/lte/gt/gte.
 185    """
 186    if isinstance(expression, exp.Between):
 187        negate = isinstance(expression.parent, exp.Not)
 188
 189        expression = exp.and_(
 190            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 191            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 192            copy=False,
 193        )
 194
 195        if negate:
 196            expression = exp.paren(expression, copy=False)
 197
 198    return expression
 199
 200
 201COMPLEMENT_COMPARISONS = {
 202    exp.LT: exp.GTE,
 203    exp.GT: exp.LTE,
 204    exp.LTE: exp.GT,
 205    exp.GTE: exp.LT,
 206    exp.EQ: exp.NEQ,
 207    exp.NEQ: exp.EQ,
 208}
 209
 210
 211def simplify_not(expression):
 212    """
 213    Demorgan's Law
 214    NOT (x OR y) -> NOT x AND NOT y
 215    NOT (x AND y) -> NOT x OR NOT y
 216    """
 217    if isinstance(expression, exp.Not):
 218        this = expression.this
 219        if is_null(this):
 220            return exp.null()
 221        if this.__class__ in COMPLEMENT_COMPARISONS:
 222            return COMPLEMENT_COMPARISONS[this.__class__](
 223                this=this.this, expression=this.expression
 224            )
 225        if isinstance(this, exp.Paren):
 226            condition = this.unnest()
 227            if isinstance(condition, exp.And):
 228                return exp.paren(
 229                    exp.or_(
 230                        exp.not_(condition.left, copy=False),
 231                        exp.not_(condition.right, copy=False),
 232                        copy=False,
 233                    )
 234                )
 235            if isinstance(condition, exp.Or):
 236                return exp.paren(
 237                    exp.and_(
 238                        exp.not_(condition.left, copy=False),
 239                        exp.not_(condition.right, copy=False),
 240                        copy=False,
 241                    )
 242                )
 243            if is_null(condition):
 244                return exp.null()
 245        if always_true(this):
 246            return exp.false()
 247        if is_false(this):
 248            return exp.true()
 249        if isinstance(this, exp.Not):
 250            # double negation
 251            # NOT NOT x -> x
 252            return this.this
 253    return expression
 254
 255
 256def flatten(expression):
 257    """
 258    A AND (B AND C) -> A AND B AND C
 259    A OR (B OR C) -> A OR B OR C
 260    """
 261    if isinstance(expression, exp.Connector):
 262        for node in expression.args.values():
 263            child = node.unnest()
 264            if isinstance(child, expression.__class__):
 265                node.replace(child)
 266    return expression
 267
 268
 269def simplify_connectors(expression, root=True):
 270    def _simplify_connectors(expression, left, right):
 271        if left == right:
 272            if isinstance(expression, exp.Xor):
 273                return exp.false()
 274            return left
 275        if isinstance(expression, exp.And):
 276            if is_false(left) or is_false(right):
 277                return exp.false()
 278            if is_null(left) or is_null(right):
 279                return exp.null()
 280            if always_true(left) and always_true(right):
 281                return exp.true()
 282            if always_true(left):
 283                return right
 284            if always_true(right):
 285                return left
 286            return _simplify_comparison(expression, left, right)
 287        elif isinstance(expression, exp.Or):
 288            if always_true(left) or always_true(right):
 289                return exp.true()
 290            if is_false(left) and is_false(right):
 291                return exp.false()
 292            if (
 293                (is_null(left) and is_null(right))
 294                or (is_null(left) and is_false(right))
 295                or (is_false(left) and is_null(right))
 296            ):
 297                return exp.null()
 298            if is_false(left):
 299                return right
 300            if is_false(right):
 301                return left
 302            return _simplify_comparison(expression, left, right, or_=True)
 303
 304    if isinstance(expression, exp.Connector):
 305        return _flat_simplify(expression, _simplify_connectors, root)
 306    return expression
 307
 308
 309LT_LTE = (exp.LT, exp.LTE)
 310GT_GTE = (exp.GT, exp.GTE)
 311
 312COMPARISONS = (
 313    *LT_LTE,
 314    *GT_GTE,
 315    exp.EQ,
 316    exp.NEQ,
 317    exp.Is,
 318)
 319
 320INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 321    exp.LT: exp.GT,
 322    exp.GT: exp.LT,
 323    exp.LTE: exp.GTE,
 324    exp.GTE: exp.LTE,
 325}
 326
 327NONDETERMINISTIC = (exp.Rand, exp.Randn)
 328AND_OR = (exp.And, exp.Or)
 329
 330
 331def _simplify_comparison(expression, left, right, or_=False):
 332    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 333        ll, lr = left.args.values()
 334        rl, rr = right.args.values()
 335
 336        largs = {ll, lr}
 337        rargs = {rl, rr}
 338
 339        matching = largs & rargs
 340        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 341
 342        if matching and columns:
 343            try:
 344                l = first(largs - columns)
 345                r = first(rargs - columns)
 346            except StopIteration:
 347                return expression
 348
 349            if l.is_number and r.is_number:
 350                l = float(l.name)
 351                r = float(r.name)
 352            elif l.is_string and r.is_string:
 353                l = l.name
 354                r = r.name
 355            else:
 356                l = extract_date(l)
 357                if not l:
 358                    return None
 359                r = extract_date(r)
 360                if not r:
 361                    return None
 362                # python won't compare date and datetime, but many engines will upcast
 363                l, r = cast_as_datetime(l), cast_as_datetime(r)
 364
 365            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 366                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 367                    return left if (av > bv if or_ else av <= bv) else right
 368                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 369                    return left if (av < bv if or_ else av >= bv) else right
 370
 371                # we can't ever shortcut to true because the column could be null
 372                if not or_:
 373                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 374                        if av <= bv:
 375                            return exp.false()
 376                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 377                        if av >= bv:
 378                            return exp.false()
 379                    elif isinstance(a, exp.EQ):
 380                        if isinstance(b, exp.LT):
 381                            return exp.false() if av >= bv else a
 382                        if isinstance(b, exp.LTE):
 383                            return exp.false() if av > bv else a
 384                        if isinstance(b, exp.GT):
 385                            return exp.false() if av <= bv else a
 386                        if isinstance(b, exp.GTE):
 387                            return exp.false() if av < bv else a
 388                        if isinstance(b, exp.NEQ):
 389                            return exp.false() if av == bv else a
 390    return None
 391
 392
 393def remove_complements(expression, root=True):
 394    """
 395    Removing complements.
 396
 397    A AND NOT A -> FALSE
 398    A OR NOT A -> TRUE
 399    """
 400    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 401        ops = set(expression.flatten())
 402        for op in ops:
 403            if isinstance(op, exp.Not) and op.this in ops:
 404                return exp.false() if isinstance(expression, exp.And) else exp.true()
 405
 406    return expression
 407
 408
 409def uniq_sort(expression, root=True):
 410    """
 411    Uniq and sort a connector.
 412
 413    C AND A AND B AND B -> A AND B AND C
 414    """
 415    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 416        flattened = tuple(expression.flatten())
 417
 418        if isinstance(expression, exp.Xor):
 419            result_func = exp.xor
 420            # Do not deduplicate XOR as A XOR A != A if A == True
 421            deduped = None
 422            arr = tuple((gen(e), e) for e in flattened)
 423        else:
 424            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 425            deduped = {gen(e): e for e in flattened}
 426            arr = tuple(deduped.items())
 427
 428        # check if the operands are already sorted, if not sort them
 429        # A AND C AND B -> A AND B AND C
 430        for i, (sql, e) in enumerate(arr[1:]):
 431            if sql < arr[i][0]:
 432                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 433                break
 434        else:
 435            # we didn't have to sort but maybe we need to dedup
 436            if deduped and len(deduped) < len(flattened):
 437                expression = result_func(*deduped.values(), copy=False)
 438
 439    return expression
 440
 441
 442def absorb_and_eliminate(expression, root=True):
 443    """
 444    absorption:
 445        A AND (A OR B) -> A
 446        A OR (A AND B) -> A
 447        A AND (NOT A OR B) -> A AND B
 448        A OR (NOT A AND B) -> A OR B
 449    elimination:
 450        (A AND B) OR (A AND NOT B) -> A
 451        (A OR B) AND (A OR NOT B) -> A
 452    """
 453    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 454        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 455
 456        ops = tuple(expression.flatten())
 457
 458        # Initialize lookup tables:
 459        # Set of all operands, used to find complements for absorption.
 460        op_set = set()
 461        # Sub-operands, used to find subsets for absorption.
 462        subops = defaultdict(list)
 463        # Pairs of complements, used for elimination.
 464        pairs = defaultdict(list)
 465
 466        # Populate the lookup tables
 467        for op in ops:
 468            op_set.add(op)
 469
 470            if not isinstance(op, kind):
 471                # In cases like: A OR (A AND B)
 472                # Subop will be: ^
 473                subops[op].append({op})
 474                continue
 475
 476            # In cases like: (A AND B) OR (A AND B AND C)
 477            # Subops will be: ^     ^
 478            subset = set(op.flatten())
 479            for i in subset:
 480                subops[i].append(subset)
 481
 482            a, b = op.unnest_operands()
 483            if isinstance(a, exp.Not):
 484                pairs[frozenset((a.this, b))].append((op, b))
 485            if isinstance(b, exp.Not):
 486                pairs[frozenset((a, b.this))].append((op, a))
 487
 488        for op in ops:
 489            if not isinstance(op, kind):
 490                continue
 491
 492            a, b = op.unnest_operands()
 493
 494            # Absorb
 495            if isinstance(a, exp.Not) and a.this in op_set:
 496                a.replace(exp.true() if kind == exp.And else exp.false())
 497                continue
 498            if isinstance(b, exp.Not) and b.this in op_set:
 499                b.replace(exp.true() if kind == exp.And else exp.false())
 500                continue
 501            superset = set(op.flatten())
 502            if any(any(subset < superset for subset in subops[i]) for i in superset):
 503                op.replace(exp.false() if kind == exp.And else exp.true())
 504                continue
 505
 506            # Eliminate
 507            for other, complement in pairs[frozenset((a, b))]:
 508                op.replace(complement)
 509                other.replace(complement)
 510
 511    return expression
 512
 513
 514def propagate_constants(expression, root=True):
 515    """
 516    Propagate constants for conjunctions in DNF:
 517
 518    SELECT * FROM t WHERE a = b AND b = 5 becomes
 519    SELECT * FROM t WHERE a = 5 AND b = 5
 520
 521    Reference: https://www.sqlite.org/optoverview.html
 522    """
 523
 524    if (
 525        isinstance(expression, exp.And)
 526        and (root or not expression.same_parent)
 527        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 528    ):
 529        constant_mapping = {}
 530        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 531            if isinstance(expr, exp.EQ):
 532                l, r = expr.left, expr.right
 533
 534                # TODO: create a helper that can be used to detect nested literal expressions such
 535                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 536                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 537                    constant_mapping[l] = (id(l), r)
 538
 539        if constant_mapping:
 540            for column in find_all_in_scope(expression, exp.Column):
 541                parent = column.parent
 542                column_id, constant = constant_mapping.get(column) or (None, None)
 543                if (
 544                    column_id is not None
 545                    and id(column) != column_id
 546                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 547                ):
 548                    column.replace(constant.copy())
 549
 550    return expression
 551
 552
 553INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 554    exp.DateAdd: exp.Sub,
 555    exp.DateSub: exp.Add,
 556    exp.DatetimeAdd: exp.Sub,
 557    exp.DatetimeSub: exp.Add,
 558}
 559
 560INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 561    **INVERSE_DATE_OPS,
 562    exp.Add: exp.Sub,
 563    exp.Sub: exp.Add,
 564}
 565
 566
 567def _is_number(expression: exp.Expression) -> bool:
 568    return expression.is_number
 569
 570
 571def _is_interval(expression: exp.Expression) -> bool:
 572    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 573
 574
 575@catch(ModuleNotFoundError, UnsupportedUnit)
 576def simplify_equality(expression: exp.Expression) -> exp.Expression:
 577    """
 578    Use the subtraction and addition properties of equality to simplify expressions:
 579
 580        x + 1 = 3 becomes x = 2
 581
 582    There are two binary operations in the above expression: + and =
 583    Here's how we reference all the operands in the code below:
 584
 585          l     r
 586        x + 1 = 3
 587        a   b
 588    """
 589    if isinstance(expression, COMPARISONS):
 590        l, r = expression.left, expression.right
 591
 592        if l.__class__ not in INVERSE_OPS:
 593            return expression
 594
 595        if r.is_number:
 596            a_predicate = _is_number
 597            b_predicate = _is_number
 598        elif _is_date_literal(r):
 599            a_predicate = _is_date_literal
 600            b_predicate = _is_interval
 601        else:
 602            return expression
 603
 604        if l.__class__ in INVERSE_DATE_OPS:
 605            l = t.cast(exp.IntervalOp, l)
 606            a = l.this
 607            b = l.interval()
 608        else:
 609            l = t.cast(exp.Binary, l)
 610            a, b = l.left, l.right
 611
 612        if not a_predicate(a) and b_predicate(b):
 613            pass
 614        elif not a_predicate(b) and b_predicate(a):
 615            a, b = b, a
 616        else:
 617            return expression
 618
 619        return expression.__class__(
 620            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 621        )
 622    return expression
 623
 624
 625def simplify_literals(expression, root=True):
 626    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 627        return _flat_simplify(expression, _simplify_binary, root)
 628
 629    if isinstance(expression, exp.Neg):
 630        this = expression.this
 631        if this.is_number:
 632            value = this.name
 633            if value[0] == "-":
 634                return exp.Literal.number(value[1:])
 635            return exp.Literal.number(f"-{value}")
 636
 637    if type(expression) in INVERSE_DATE_OPS:
 638        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 639
 640    return expression
 641
 642
 643NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 644
 645
 646def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 647    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 648        this = _simplify_integer_cast(expr.this)
 649    else:
 650        this = expr.this
 651
 652    if isinstance(expr, exp.Cast) and this.is_int:
 653        num = int(this.name)
 654
 655        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 656        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 657        # engine-dependent
 658        if (
 659            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 660        ) or (
 661            UTINYINT_MIN <= num <= UTINYINT_MAX
 662            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 663        ):
 664            return this
 665
 666    return expr
 667
 668
 669def _simplify_binary(expression, a, b):
 670    if isinstance(expression, COMPARISONS):
 671        a = _simplify_integer_cast(a)
 672        b = _simplify_integer_cast(b)
 673
 674    if isinstance(expression, exp.Is):
 675        if isinstance(b, exp.Not):
 676            c = b.this
 677            not_ = True
 678        else:
 679            c = b
 680            not_ = False
 681
 682        if is_null(c):
 683            if isinstance(a, exp.Literal):
 684                return exp.true() if not_ else exp.false()
 685            if is_null(a):
 686                return exp.false() if not_ else exp.true()
 687    elif isinstance(expression, NULL_OK):
 688        return None
 689    elif is_null(a) or is_null(b):
 690        return exp.null()
 691
 692    if a.is_number and b.is_number:
 693        num_a = int(a.name) if a.is_int else Decimal(a.name)
 694        num_b = int(b.name) if b.is_int else Decimal(b.name)
 695
 696        if isinstance(expression, exp.Add):
 697            return exp.Literal.number(num_a + num_b)
 698        if isinstance(expression, exp.Mul):
 699            return exp.Literal.number(num_a * num_b)
 700
 701        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 702        if isinstance(expression, exp.Sub):
 703            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 704        if isinstance(expression, exp.Div):
 705            # engines have differing int div behavior so intdiv is not safe
 706            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 707                return None
 708            return exp.Literal.number(num_a / num_b)
 709
 710        boolean = eval_boolean(expression, num_a, num_b)
 711
 712        if boolean:
 713            return boolean
 714    elif a.is_string and b.is_string:
 715        boolean = eval_boolean(expression, a.this, b.this)
 716
 717        if boolean:
 718            return boolean
 719    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 720        date, b = extract_date(a), extract_interval(b)
 721        if date and b:
 722            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 723                return date_literal(date + b, extract_type(a))
 724            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 725                return date_literal(date - b, extract_type(a))
 726    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 727        a, date = extract_interval(a), extract_date(b)
 728        # you cannot subtract a date from an interval
 729        if a and b and isinstance(expression, exp.Add):
 730            return date_literal(a + date, extract_type(b))
 731    elif _is_date_literal(a) and _is_date_literal(b):
 732        if isinstance(expression, exp.Predicate):
 733            a, b = extract_date(a), extract_date(b)
 734            boolean = eval_boolean(expression, a, b)
 735            if boolean:
 736                return boolean
 737
 738    return None
 739
 740
 741def simplify_parens(expression):
 742    if not isinstance(expression, exp.Paren):
 743        return expression
 744
 745    this = expression.this
 746    parent = expression.parent
 747    parent_is_predicate = isinstance(parent, exp.Predicate)
 748
 749    if (
 750        not isinstance(this, exp.Select)
 751        and not isinstance(parent, exp.SubqueryPredicate)
 752        and (
 753            not isinstance(parent, (exp.Condition, exp.Binary))
 754            or isinstance(parent, exp.Paren)
 755            or (
 756                not isinstance(this, exp.Binary)
 757                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 758            )
 759            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 760            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 761            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 762            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 763        )
 764    ):
 765        return this
 766    return expression
 767
 768
 769def _is_nonnull_constant(expression: exp.Expression) -> bool:
 770    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 771
 772
 773def _is_constant(expression: exp.Expression) -> bool:
 774    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 775
 776
 777def simplify_coalesce(expression):
 778    # COALESCE(x) -> x
 779    if (
 780        isinstance(expression, exp.Coalesce)
 781        and (not expression.expressions or _is_nonnull_constant(expression.this))
 782        # COALESCE is also used as a Spark partitioning hint
 783        and not isinstance(expression.parent, exp.Hint)
 784    ):
 785        return expression.this
 786
 787    if not isinstance(expression, COMPARISONS):
 788        return expression
 789
 790    if isinstance(expression.left, exp.Coalesce):
 791        coalesce = expression.left
 792        other = expression.right
 793    elif isinstance(expression.right, exp.Coalesce):
 794        coalesce = expression.right
 795        other = expression.left
 796    else:
 797        return expression
 798
 799    # This transformation is valid for non-constants,
 800    # but it really only does anything if they are both constants.
 801    if not _is_constant(other):
 802        return expression
 803
 804    # Find the first constant arg
 805    for arg_index, arg in enumerate(coalesce.expressions):
 806        if _is_constant(arg):
 807            break
 808    else:
 809        return expression
 810
 811    coalesce.set("expressions", coalesce.expressions[:arg_index])
 812
 813    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 814    # since we already remove COALESCE at the top of this function.
 815    coalesce = coalesce if coalesce.expressions else coalesce.this
 816
 817    # This expression is more complex than when we started, but it will get simplified further
 818    return exp.paren(
 819        exp.or_(
 820            exp.and_(
 821                coalesce.is_(exp.null()).not_(copy=False),
 822                expression.copy(),
 823                copy=False,
 824            ),
 825            exp.and_(
 826                coalesce.is_(exp.null()),
 827                type(expression)(this=arg.copy(), expression=other.copy()),
 828                copy=False,
 829            ),
 830            copy=False,
 831        )
 832    )
 833
 834
 835CONCATS = (exp.Concat, exp.DPipe)
 836
 837
 838def simplify_concat(expression):
 839    """Reduces all groups that contain string literals by concatenating them."""
 840    if not isinstance(expression, CONCATS) or (
 841        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 842        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 843    ):
 844        return expression
 845
 846    if isinstance(expression, exp.ConcatWs):
 847        sep_expr, *expressions = expression.expressions
 848        sep = sep_expr.name
 849        concat_type = exp.ConcatWs
 850        args = {}
 851    else:
 852        expressions = expression.expressions
 853        sep = ""
 854        concat_type = exp.Concat
 855        args = {
 856            "safe": expression.args.get("safe"),
 857            "coalesce": expression.args.get("coalesce"),
 858        }
 859
 860    new_args = []
 861    for is_string_group, group in itertools.groupby(
 862        expressions or expression.flatten(), lambda e: e.is_string
 863    ):
 864        if is_string_group:
 865            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 866        else:
 867            new_args.extend(group)
 868
 869    if len(new_args) == 1 and new_args[0].is_string:
 870        return new_args[0]
 871
 872    if concat_type is exp.ConcatWs:
 873        new_args = [sep_expr] + new_args
 874    elif isinstance(expression, exp.DPipe):
 875        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 876
 877    return concat_type(expressions=new_args, **args)
 878
 879
 880def simplify_conditionals(expression):
 881    """Simplifies expressions like IF, CASE if their condition is statically known."""
 882    if isinstance(expression, exp.Case):
 883        this = expression.this
 884        for case in expression.args["ifs"]:
 885            cond = case.this
 886            if this:
 887                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 888                cond = cond.replace(this.pop().eq(cond))
 889
 890            if always_true(cond):
 891                return case.args["true"]
 892
 893            if always_false(cond):
 894                case.pop()
 895                if not expression.args["ifs"]:
 896                    return expression.args.get("default") or exp.null()
 897    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 898        if always_true(expression.this):
 899            return expression.args["true"]
 900        if always_false(expression.this):
 901            return expression.args.get("false") or exp.null()
 902
 903    return expression
 904
 905
 906def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 907    """
 908    Reduces a prefix check to either TRUE or FALSE if both the string and the
 909    prefix are statically known.
 910
 911    Example:
 912        >>> from sqlglot import parse_one
 913        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 914        'TRUE'
 915    """
 916    if (
 917        isinstance(expression, exp.StartsWith)
 918        and expression.this.is_string
 919        and expression.expression.is_string
 920    ):
 921        return exp.convert(expression.name.startswith(expression.expression.name))
 922
 923    return expression
 924
 925
 926DateRange = t.Tuple[datetime.date, datetime.date]
 927
 928
 929def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 930    """
 931    Get the date range for a DATE_TRUNC equality comparison:
 932
 933    Example:
 934        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 935    Returns:
 936        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 937    """
 938    floor = date_floor(date, unit, dialect)
 939
 940    if date != floor:
 941        # This will always be False, except for NULL values.
 942        return None
 943
 944    return floor, floor + interval(unit)
 945
 946
 947def _datetrunc_eq_expression(
 948    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 949) -> exp.Expression:
 950    """Get the logical expression for a date range"""
 951    return exp.and_(
 952        left >= date_literal(drange[0], target_type),
 953        left < date_literal(drange[1], target_type),
 954        copy=False,
 955    )
 956
 957
 958def _datetrunc_eq(
 959    left: exp.Expression,
 960    date: datetime.date,
 961    unit: str,
 962    dialect: Dialect,
 963    target_type: t.Optional[exp.DataType],
 964) -> t.Optional[exp.Expression]:
 965    drange = _datetrunc_range(date, unit, dialect)
 966    if not drange:
 967        return None
 968
 969    return _datetrunc_eq_expression(left, drange, target_type)
 970
 971
 972def _datetrunc_neq(
 973    left: exp.Expression,
 974    date: datetime.date,
 975    unit: str,
 976    dialect: Dialect,
 977    target_type: t.Optional[exp.DataType],
 978) -> t.Optional[exp.Expression]:
 979    drange = _datetrunc_range(date, unit, dialect)
 980    if not drange:
 981        return None
 982
 983    return exp.and_(
 984        left < date_literal(drange[0], target_type),
 985        left >= date_literal(drange[1], target_type),
 986        copy=False,
 987    )
 988
 989
 990DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 991    exp.LT: lambda l, dt, u, d, t: l
 992    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 993    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 994    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 995    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 996    exp.EQ: _datetrunc_eq,
 997    exp.NEQ: _datetrunc_neq,
 998}
 999DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
1000DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
1001
1002
1003def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1004    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1005
1006
1007@catch(ModuleNotFoundError, UnsupportedUnit)
1008def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1009    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1010    comparison = expression.__class__
1011
1012    if isinstance(expression, DATETRUNCS):
1013        this = expression.this
1014        trunc_type = extract_type(this)
1015        date = extract_date(this)
1016        if date and expression.unit:
1017            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1018    elif comparison not in DATETRUNC_COMPARISONS:
1019        return expression
1020
1021    if isinstance(expression, exp.Binary):
1022        l, r = expression.left, expression.right
1023
1024        if not _is_datetrunc_predicate(l, r):
1025            return expression
1026
1027        l = t.cast(exp.DateTrunc, l)
1028        trunc_arg = l.this
1029        unit = l.unit.name.lower()
1030        date = extract_date(r)
1031
1032        if not date:
1033            return expression
1034
1035        return (
1036            DATETRUNC_BINARY_COMPARISONS[comparison](
1037                trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
1038            )
1039            or expression
1040        )
1041
1042    if isinstance(expression, exp.In):
1043        l = expression.this
1044        rs = expression.expressions
1045
1046        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1047            l = t.cast(exp.DateTrunc, l)
1048            unit = l.unit.name.lower()
1049
1050            ranges = []
1051            for r in rs:
1052                date = extract_date(r)
1053                if not date:
1054                    return expression
1055                drange = _datetrunc_range(date, unit, dialect)
1056                if drange:
1057                    ranges.append(drange)
1058
1059            if not ranges:
1060                return expression
1061
1062            ranges = merge_ranges(ranges)
1063            target_type = extract_type(l, *rs)
1064
1065            return exp.or_(
1066                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1067            )
1068
1069    return expression
1070
1071
1072def sort_comparison(expression: exp.Expression) -> exp.Expression:
1073    if expression.__class__ in COMPLEMENT_COMPARISONS:
1074        l, r = expression.this, expression.expression
1075        l_column = isinstance(l, exp.Column)
1076        r_column = isinstance(r, exp.Column)
1077        l_const = _is_constant(l)
1078        r_const = _is_constant(r)
1079
1080        if (l_column and not r_column) or (r_const and not l_const):
1081            return expression
1082        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1083            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1084                this=r, expression=l
1085            )
1086    return expression
1087
1088
1089# CROSS joins result in an empty table if the right table is empty.
1090# So we can only simplify certain types of joins to CROSS.
1091# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1092JOINS = {
1093    ("", ""),
1094    ("", "INNER"),
1095    ("RIGHT", ""),
1096    ("RIGHT", "OUTER"),
1097}
1098
1099
1100def remove_where_true(expression):
1101    for where in expression.find_all(exp.Where):
1102        if always_true(where.this):
1103            where.pop()
1104    for join in expression.find_all(exp.Join):
1105        if (
1106            always_true(join.args.get("on"))
1107            and not join.args.get("using")
1108            and not join.args.get("method")
1109            and (join.side, join.kind) in JOINS
1110        ):
1111            join.args["on"].pop()
1112            join.set("side", None)
1113            join.set("kind", "CROSS")
1114
1115
1116def always_true(expression):
1117    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1118        expression, exp.Literal
1119    )
1120
1121
1122def always_false(expression):
1123    return is_false(expression) or is_null(expression)
1124
1125
1126def is_complement(a, b):
1127    return isinstance(b, exp.Not) and b.this == a
1128
1129
1130def is_false(a: exp.Expression) -> bool:
1131    return type(a) is exp.Boolean and not a.this
1132
1133
1134def is_null(a: exp.Expression) -> bool:
1135    return type(a) is exp.Null
1136
1137
1138def eval_boolean(expression, a, b):
1139    if isinstance(expression, (exp.EQ, exp.Is)):
1140        return boolean_literal(a == b)
1141    if isinstance(expression, exp.NEQ):
1142        return boolean_literal(a != b)
1143    if isinstance(expression, exp.GT):
1144        return boolean_literal(a > b)
1145    if isinstance(expression, exp.GTE):
1146        return boolean_literal(a >= b)
1147    if isinstance(expression, exp.LT):
1148        return boolean_literal(a < b)
1149    if isinstance(expression, exp.LTE):
1150        return boolean_literal(a <= b)
1151    return None
1152
1153
1154def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1155    if isinstance(value, datetime.datetime):
1156        return value.date()
1157    if isinstance(value, datetime.date):
1158        return value
1159    try:
1160        return datetime.datetime.fromisoformat(value).date()
1161    except ValueError:
1162        return None
1163
1164
1165def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1166    if isinstance(value, datetime.datetime):
1167        return value
1168    if isinstance(value, datetime.date):
1169        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1170    try:
1171        return datetime.datetime.fromisoformat(value)
1172    except ValueError:
1173        return None
1174
1175
1176def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1177    if not value:
1178        return None
1179    if to.is_type(exp.DataType.Type.DATE):
1180        return cast_as_date(value)
1181    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1182        return cast_as_datetime(value)
1183    return None
1184
1185
1186def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1187    if isinstance(cast, exp.Cast):
1188        to = cast.to
1189    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1190        to = exp.DataType.build(exp.DataType.Type.DATE)
1191    else:
1192        return None
1193
1194    if isinstance(cast.this, exp.Literal):
1195        value: t.Any = cast.this.name
1196    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1197        value = extract_date(cast.this)
1198    else:
1199        return None
1200    return cast_value(value, to)
1201
1202
1203def _is_date_literal(expression: exp.Expression) -> bool:
1204    return extract_date(expression) is not None
1205
1206
1207def extract_interval(expression):
1208    try:
1209        n = int(expression.name)
1210        unit = expression.text("unit").lower()
1211        return interval(unit, n)
1212    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1213        return None
1214
1215
1216def extract_type(*expressions):
1217    target_type = None
1218    for expression in expressions:
1219        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1220        if target_type:
1221            break
1222
1223    return target_type
1224
1225
1226def date_literal(date, target_type=None):
1227    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1228        target_type = (
1229            exp.DataType.Type.DATETIME
1230            if isinstance(date, datetime.datetime)
1231            else exp.DataType.Type.DATE
1232        )
1233
1234    return exp.cast(exp.Literal.string(date), target_type)
1235
1236
1237def interval(unit: str, n: int = 1):
1238    from dateutil.relativedelta import relativedelta
1239
1240    if unit == "year":
1241        return relativedelta(years=1 * n)
1242    if unit == "quarter":
1243        return relativedelta(months=3 * n)
1244    if unit == "month":
1245        return relativedelta(months=1 * n)
1246    if unit == "week":
1247        return relativedelta(weeks=1 * n)
1248    if unit == "day":
1249        return relativedelta(days=1 * n)
1250    if unit == "hour":
1251        return relativedelta(hours=1 * n)
1252    if unit == "minute":
1253        return relativedelta(minutes=1 * n)
1254    if unit == "second":
1255        return relativedelta(seconds=1 * n)
1256
1257    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1258
1259
1260def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1261    if unit == "year":
1262        return d.replace(month=1, day=1)
1263    if unit == "quarter":
1264        if d.month <= 3:
1265            return d.replace(month=1, day=1)
1266        elif d.month <= 6:
1267            return d.replace(month=4, day=1)
1268        elif d.month <= 9:
1269            return d.replace(month=7, day=1)
1270        else:
1271            return d.replace(month=10, day=1)
1272    if unit == "month":
1273        return d.replace(month=d.month, day=1)
1274    if unit == "week":
1275        # Assuming week starts on Monday (0) and ends on Sunday (6)
1276        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1277    if unit == "day":
1278        return d
1279
1280    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1281
1282
1283def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1284    floor = date_floor(d, unit, dialect)
1285
1286    if floor == d:
1287        return d
1288
1289    return floor + interval(unit)
1290
1291
1292def boolean_literal(condition):
1293    return exp.true() if condition else exp.false()
1294
1295
1296def _flat_simplify(expression, simplifier, root=True):
1297    if root or not expression.same_parent:
1298        operands = []
1299        queue = deque(expression.flatten(unnest=False))
1300        size = len(queue)
1301
1302        while queue:
1303            a = queue.popleft()
1304
1305            for b in queue:
1306                result = simplifier(expression, a, b)
1307
1308                if result and result is not expression:
1309                    queue.remove(b)
1310                    queue.appendleft(result)
1311                    break
1312            else:
1313                operands.append(a)
1314
1315        if len(operands) < size:
1316            return functools.reduce(
1317                lambda a, b: expression.__class__(this=a, expression=b), operands
1318            )
1319    return expression
1320
1321
1322def gen(expression: t.Any) -> str:
1323    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1324
1325    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1326    generator is expensive so we have a bare minimum sql generator here.
1327    """
1328    return Gen().gen(expression)
1329
1330
1331class Gen:
1332    def __init__(self):
1333        self.stack = []
1334        self.sqls = []
1335
1336    def gen(self, expression: exp.Expression) -> str:
1337        self.stack = [expression]
1338        self.sqls.clear()
1339
1340        while self.stack:
1341            node = self.stack.pop()
1342
1343            if isinstance(node, exp.Expression):
1344                exp_handler_name = f"{node.key}_sql"
1345
1346                if hasattr(self, exp_handler_name):
1347                    getattr(self, exp_handler_name)(node)
1348                elif isinstance(node, exp.Func):
1349                    self._function(node)
1350                else:
1351                    key = node.key.upper()
1352                    self.stack.append(f"{key} " if self._args(node) else key)
1353            elif type(node) is list:
1354                for n in reversed(node):
1355                    if n is not None:
1356                        self.stack.extend((n, ","))
1357                if node:
1358                    self.stack.pop()
1359            else:
1360                if node is not None:
1361                    self.sqls.append(str(node))
1362
1363        return "".join(self.sqls)
1364
1365    def add_sql(self, e: exp.Add) -> None:
1366        self._binary(e, " + ")
1367
1368    def alias_sql(self, e: exp.Alias) -> None:
1369        self.stack.extend(
1370            (
1371                e.args.get("alias"),
1372                " AS ",
1373                e.args.get("this"),
1374            )
1375        )
1376
1377    def and_sql(self, e: exp.And) -> None:
1378        self._binary(e, " AND ")
1379
1380    def anonymous_sql(self, e: exp.Anonymous) -> None:
1381        this = e.this
1382        if isinstance(this, str):
1383            name = this.upper()
1384        elif isinstance(this, exp.Identifier):
1385            name = this.this
1386            name = f'"{name}"' if this.quoted else name.upper()
1387        else:
1388            raise ValueError(
1389                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1390            )
1391
1392        self.stack.extend(
1393            (
1394                ")",
1395                e.expressions,
1396                "(",
1397                name,
1398            )
1399        )
1400
1401    def between_sql(self, e: exp.Between) -> None:
1402        self.stack.extend(
1403            (
1404                e.args.get("high"),
1405                " AND ",
1406                e.args.get("low"),
1407                " BETWEEN ",
1408                e.this,
1409            )
1410        )
1411
1412    def boolean_sql(self, e: exp.Boolean) -> None:
1413        self.stack.append("TRUE" if e.this else "FALSE")
1414
1415    def bracket_sql(self, e: exp.Bracket) -> None:
1416        self.stack.extend(
1417            (
1418                "]",
1419                e.expressions,
1420                "[",
1421                e.this,
1422            )
1423        )
1424
1425    def column_sql(self, e: exp.Column) -> None:
1426        for p in reversed(e.parts):
1427            self.stack.extend((p, "."))
1428        self.stack.pop()
1429
1430    def datatype_sql(self, e: exp.DataType) -> None:
1431        self._args(e, 1)
1432        self.stack.append(f"{e.this.name} ")
1433
1434    def div_sql(self, e: exp.Div) -> None:
1435        self._binary(e, " / ")
1436
1437    def dot_sql(self, e: exp.Dot) -> None:
1438        self._binary(e, ".")
1439
1440    def eq_sql(self, e: exp.EQ) -> None:
1441        self._binary(e, " = ")
1442
1443    def from_sql(self, e: exp.From) -> None:
1444        self.stack.extend((e.this, "FROM "))
1445
1446    def gt_sql(self, e: exp.GT) -> None:
1447        self._binary(e, " > ")
1448
1449    def gte_sql(self, e: exp.GTE) -> None:
1450        self._binary(e, " >= ")
1451
1452    def identifier_sql(self, e: exp.Identifier) -> None:
1453        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1454
1455    def ilike_sql(self, e: exp.ILike) -> None:
1456        self._binary(e, " ILIKE ")
1457
1458    def in_sql(self, e: exp.In) -> None:
1459        self.stack.append(")")
1460        self._args(e, 1)
1461        self.stack.extend(
1462            (
1463                "(",
1464                " IN ",
1465                e.this,
1466            )
1467        )
1468
1469    def intdiv_sql(self, e: exp.IntDiv) -> None:
1470        self._binary(e, " DIV ")
1471
1472    def is_sql(self, e: exp.Is) -> None:
1473        self._binary(e, " IS ")
1474
1475    def like_sql(self, e: exp.Like) -> None:
1476        self._binary(e, " Like ")
1477
1478    def literal_sql(self, e: exp.Literal) -> None:
1479        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1480
1481    def lt_sql(self, e: exp.LT) -> None:
1482        self._binary(e, " < ")
1483
1484    def lte_sql(self, e: exp.LTE) -> None:
1485        self._binary(e, " <= ")
1486
1487    def mod_sql(self, e: exp.Mod) -> None:
1488        self._binary(e, " % ")
1489
1490    def mul_sql(self, e: exp.Mul) -> None:
1491        self._binary(e, " * ")
1492
1493    def neg_sql(self, e: exp.Neg) -> None:
1494        self._unary(e, "-")
1495
1496    def neq_sql(self, e: exp.NEQ) -> None:
1497        self._binary(e, " <> ")
1498
1499    def not_sql(self, e: exp.Not) -> None:
1500        self._unary(e, "NOT ")
1501
1502    def null_sql(self, e: exp.Null) -> None:
1503        self.stack.append("NULL")
1504
1505    def or_sql(self, e: exp.Or) -> None:
1506        self._binary(e, " OR ")
1507
1508    def paren_sql(self, e: exp.Paren) -> None:
1509        self.stack.extend(
1510            (
1511                ")",
1512                e.this,
1513                "(",
1514            )
1515        )
1516
1517    def sub_sql(self, e: exp.Sub) -> None:
1518        self._binary(e, " - ")
1519
1520    def subquery_sql(self, e: exp.Subquery) -> None:
1521        self._args(e, 2)
1522        alias = e.args.get("alias")
1523        if alias:
1524            self.stack.append(alias)
1525        self.stack.extend((")", e.this, "("))
1526
1527    def table_sql(self, e: exp.Table) -> None:
1528        self._args(e, 4)
1529        alias = e.args.get("alias")
1530        if alias:
1531            self.stack.append(alias)
1532        for p in reversed(e.parts):
1533            self.stack.extend((p, "."))
1534        self.stack.pop()
1535
1536    def tablealias_sql(self, e: exp.TableAlias) -> None:
1537        columns = e.columns
1538
1539        if columns:
1540            self.stack.extend((")", columns, "("))
1541
1542        self.stack.extend((e.this, " AS "))
1543
1544    def var_sql(self, e: exp.Var) -> None:
1545        self.stack.append(e.this)
1546
1547    def _binary(self, e: exp.Binary, op: str) -> None:
1548        self.stack.extend((e.expression, op, e.this))
1549
1550    def _unary(self, e: exp.Unary, op: str) -> None:
1551        self.stack.extend((e.this, op))
1552
1553    def _function(self, e: exp.Func) -> None:
1554        self.stack.extend(
1555            (
1556                ")",
1557                list(e.args.values()),
1558                "(",
1559                e.sql_name(),
1560            )
1561        )
1562
1563    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1564        kvs = []
1565        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1566
1567        for k in arg_types or arg_types:
1568            v = node.args.get(k)
1569
1570            if v is not None:
1571                kvs.append([f":{k}", v])
1572        if kvs:
1573            self.stack.append(kvs)
1574            return True
1575        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
37class UnsupportedUnit(Exception):
38    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, max_depth: Optional[int] = None):
 41def simplify(
 42    expression: exp.Expression,
 43    constant_propagation: bool = False,
 44    dialect: DialectType = None,
 45    max_depth: t.Optional[int] = None,
 46):
 47    """
 48    Rewrite sqlglot AST to simplify expressions.
 49
 50    Example:
 51        >>> import sqlglot
 52        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 53        >>> simplify(expression).sql()
 54        'TRUE'
 55
 56    Args:
 57        expression: expression to simplify
 58        constant_propagation: whether the constant propagation rule should be used
 59        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
 60    Returns:
 61        sqlglot.Expression: simplified expression
 62    """
 63
 64    dialect = Dialect.get_or_raise(dialect)
 65
 66    def _simplify(expression, root=True):
 67        if (
 68            max_depth
 69            and isinstance(expression, exp.Connector)
 70            and not isinstance(expression.parent, exp.Connector)
 71        ):
 72            depth = connector_depth(expression)
 73            if depth > max_depth:
 74                logger.info(
 75                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
 76                )
 77                return expression
 78
 79        if expression.meta.get(FINAL):
 80            return expression
 81
 82        # group by expressions cannot be simplified, for example
 83        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 84        # the projection must exactly match the group by key
 85        group = expression.args.get("group")
 86
 87        if group and hasattr(expression, "selects"):
 88            groups = set(group.expressions)
 89            group.meta[FINAL] = True
 90
 91            for e in expression.selects:
 92                for node in e.walk():
 93                    if node in groups:
 94                        e.meta[FINAL] = True
 95                        break
 96
 97            having = expression.args.get("having")
 98            if having:
 99                for node in having.walk():
100                    if node in groups:
101                        having.meta[FINAL] = True
102                        break
103
104        # Pre-order transformations
105        node = expression
106        node = rewrite_between(node)
107        node = uniq_sort(node, root)
108        node = absorb_and_eliminate(node, root)
109        node = simplify_concat(node)
110        node = simplify_conditionals(node)
111
112        if constant_propagation:
113            node = propagate_constants(node, root)
114
115        exp.replace_children(node, lambda e: _simplify(e, False))
116
117        # Post-order transformations
118        node = simplify_not(node)
119        node = flatten(node)
120        node = simplify_connectors(node, root)
121        node = remove_complements(node, root)
122        node = simplify_coalesce(node)
123        node.parent = expression.parent
124        node = simplify_literals(node, root)
125        node = simplify_equality(node)
126        node = simplify_parens(node)
127        node = simplify_datetrunc(node, dialect)
128        node = sort_comparison(node)
129        node = simplify_startswith(node)
130
131        if root:
132            expression.replace(node)
133        return node
134
135    expression = while_changing(expression, _simplify)
136    remove_where_true(expression)
137    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • max_depth: Chains of Connectors (AND, OR, etc) exceeding max_depth will be skipped
Returns:

sqlglot.Expression: simplified expression

def connector_depth(expression: sqlglot.expressions.Expression) -> int:
140def connector_depth(expression: exp.Expression) -> int:
141    """
142    Determine the maximum depth of a tree of Connectors.
143
144    For example:
145        >>> from sqlglot import parse_one
146        >>> connector_depth(parse_one("a AND b AND c AND d"))
147        3
148    """
149    stack = deque([(expression, 0)])
150    max_depth = 0
151
152    while stack:
153        expression, depth = stack.pop()
154
155        if not isinstance(expression, exp.Connector):
156            continue
157
158        depth += 1
159        max_depth = max(depth, max_depth)
160
161        stack.append((expression.left, depth))
162        stack.append((expression.right, depth))
163
164    return max_depth

Determine the maximum depth of a tree of Connectors.

For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
def catch(*exceptions):
167def catch(*exceptions):
168    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
169
170    def decorator(func):
171        def wrapped(expression, *args, **kwargs):
172            try:
173                return func(expression, *args, **kwargs)
174            except exceptions:
175                return expression
176
177        return wrapped
178
179    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
182def rewrite_between(expression: exp.Expression) -> exp.Expression:
183    """Rewrite x between y and z to x >= y AND x <= z.
184
185    This is done because comparison simplification is only done on lt/lte/gt/gte.
186    """
187    if isinstance(expression, exp.Between):
188        negate = isinstance(expression.parent, exp.Not)
189
190        expression = exp.and_(
191            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
192            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
193            copy=False,
194        )
195
196        if negate:
197            expression = exp.paren(expression, copy=False)
198
199    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
212def simplify_not(expression):
213    """
214    Demorgan's Law
215    NOT (x OR y) -> NOT x AND NOT y
216    NOT (x AND y) -> NOT x OR NOT y
217    """
218    if isinstance(expression, exp.Not):
219        this = expression.this
220        if is_null(this):
221            return exp.null()
222        if this.__class__ in COMPLEMENT_COMPARISONS:
223            return COMPLEMENT_COMPARISONS[this.__class__](
224                this=this.this, expression=this.expression
225            )
226        if isinstance(this, exp.Paren):
227            condition = this.unnest()
228            if isinstance(condition, exp.And):
229                return exp.paren(
230                    exp.or_(
231                        exp.not_(condition.left, copy=False),
232                        exp.not_(condition.right, copy=False),
233                        copy=False,
234                    )
235                )
236            if isinstance(condition, exp.Or):
237                return exp.paren(
238                    exp.and_(
239                        exp.not_(condition.left, copy=False),
240                        exp.not_(condition.right, copy=False),
241                        copy=False,
242                    )
243                )
244            if is_null(condition):
245                return exp.null()
246        if always_true(this):
247            return exp.false()
248        if is_false(this):
249            return exp.true()
250        if isinstance(this, exp.Not):
251            # double negation
252            # NOT NOT x -> x
253            return this.this
254    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
257def flatten(expression):
258    """
259    A AND (B AND C) -> A AND B AND C
260    A OR (B OR C) -> A OR B OR C
261    """
262    if isinstance(expression, exp.Connector):
263        for node in expression.args.values():
264            child = node.unnest()
265            if isinstance(child, expression.__class__):
266                node.replace(child)
267    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
270def simplify_connectors(expression, root=True):
271    def _simplify_connectors(expression, left, right):
272        if left == right:
273            if isinstance(expression, exp.Xor):
274                return exp.false()
275            return left
276        if isinstance(expression, exp.And):
277            if is_false(left) or is_false(right):
278                return exp.false()
279            if is_null(left) or is_null(right):
280                return exp.null()
281            if always_true(left) and always_true(right):
282                return exp.true()
283            if always_true(left):
284                return right
285            if always_true(right):
286                return left
287            return _simplify_comparison(expression, left, right)
288        elif isinstance(expression, exp.Or):
289            if always_true(left) or always_true(right):
290                return exp.true()
291            if is_false(left) and is_false(right):
292                return exp.false()
293            if (
294                (is_null(left) and is_null(right))
295                or (is_null(left) and is_false(right))
296                or (is_false(left) and is_null(right))
297            ):
298                return exp.null()
299            if is_false(left):
300                return right
301            if is_false(right):
302                return left
303            return _simplify_comparison(expression, left, right, or_=True)
304
305    if isinstance(expression, exp.Connector):
306        return _flat_simplify(expression, _simplify_connectors, root)
307    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
394def remove_complements(expression, root=True):
395    """
396    Removing complements.
397
398    A AND NOT A -> FALSE
399    A OR NOT A -> TRUE
400    """
401    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
402        ops = set(expression.flatten())
403        for op in ops:
404            if isinstance(op, exp.Not) and op.this in ops:
405                return exp.false() if isinstance(expression, exp.And) else exp.true()
406
407    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
410def uniq_sort(expression, root=True):
411    """
412    Uniq and sort a connector.
413
414    C AND A AND B AND B -> A AND B AND C
415    """
416    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
417        flattened = tuple(expression.flatten())
418
419        if isinstance(expression, exp.Xor):
420            result_func = exp.xor
421            # Do not deduplicate XOR as A XOR A != A if A == True
422            deduped = None
423            arr = tuple((gen(e), e) for e in flattened)
424        else:
425            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
426            deduped = {gen(e): e for e in flattened}
427            arr = tuple(deduped.items())
428
429        # check if the operands are already sorted, if not sort them
430        # A AND C AND B -> A AND B AND C
431        for i, (sql, e) in enumerate(arr[1:]):
432            if sql < arr[i][0]:
433                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
434                break
435        else:
436            # we didn't have to sort but maybe we need to dedup
437            if deduped and len(deduped) < len(flattened):
438                expression = result_func(*deduped.values(), copy=False)
439
440    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
443def absorb_and_eliminate(expression, root=True):
444    """
445    absorption:
446        A AND (A OR B) -> A
447        A OR (A AND B) -> A
448        A AND (NOT A OR B) -> A AND B
449        A OR (NOT A AND B) -> A OR B
450    elimination:
451        (A AND B) OR (A AND NOT B) -> A
452        (A OR B) AND (A OR NOT B) -> A
453    """
454    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
455        kind = exp.Or if isinstance(expression, exp.And) else exp.And
456
457        ops = tuple(expression.flatten())
458
459        # Initialize lookup tables:
460        # Set of all operands, used to find complements for absorption.
461        op_set = set()
462        # Sub-operands, used to find subsets for absorption.
463        subops = defaultdict(list)
464        # Pairs of complements, used for elimination.
465        pairs = defaultdict(list)
466
467        # Populate the lookup tables
468        for op in ops:
469            op_set.add(op)
470
471            if not isinstance(op, kind):
472                # In cases like: A OR (A AND B)
473                # Subop will be: ^
474                subops[op].append({op})
475                continue
476
477            # In cases like: (A AND B) OR (A AND B AND C)
478            # Subops will be: ^     ^
479            subset = set(op.flatten())
480            for i in subset:
481                subops[i].append(subset)
482
483            a, b = op.unnest_operands()
484            if isinstance(a, exp.Not):
485                pairs[frozenset((a.this, b))].append((op, b))
486            if isinstance(b, exp.Not):
487                pairs[frozenset((a, b.this))].append((op, a))
488
489        for op in ops:
490            if not isinstance(op, kind):
491                continue
492
493            a, b = op.unnest_operands()
494
495            # Absorb
496            if isinstance(a, exp.Not) and a.this in op_set:
497                a.replace(exp.true() if kind == exp.And else exp.false())
498                continue
499            if isinstance(b, exp.Not) and b.this in op_set:
500                b.replace(exp.true() if kind == exp.And else exp.false())
501                continue
502            superset = set(op.flatten())
503            if any(any(subset < superset for subset in subops[i]) for i in superset):
504                op.replace(exp.false() if kind == exp.And else exp.true())
505                continue
506
507            # Eliminate
508            for other, complement in pairs[frozenset((a, b))]:
509                op.replace(complement)
510                other.replace(complement)
511
512    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
515def propagate_constants(expression, root=True):
516    """
517    Propagate constants for conjunctions in DNF:
518
519    SELECT * FROM t WHERE a = b AND b = 5 becomes
520    SELECT * FROM t WHERE a = 5 AND b = 5
521
522    Reference: https://www.sqlite.org/optoverview.html
523    """
524
525    if (
526        isinstance(expression, exp.And)
527        and (root or not expression.same_parent)
528        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
529    ):
530        constant_mapping = {}
531        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
532            if isinstance(expr, exp.EQ):
533                l, r = expr.left, expr.right
534
535                # TODO: create a helper that can be used to detect nested literal expressions such
536                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
537                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
538                    constant_mapping[l] = (id(l), r)
539
540        if constant_mapping:
541            for column in find_all_in_scope(expression, exp.Column):
542                parent = column.parent
543                column_id, constant = constant_mapping.get(column) or (None, None)
544                if (
545                    column_id is not None
546                    and id(column) != column_id
547                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
548                ):
549                    column.replace(constant.copy())
550
551    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
171        def wrapped(expression, *args, **kwargs):
172            try:
173                return func(expression, *args, **kwargs)
174            except exceptions:
175                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
626def simplify_literals(expression, root=True):
627    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
628        return _flat_simplify(expression, _simplify_binary, root)
629
630    if isinstance(expression, exp.Neg):
631        this = expression.this
632        if this.is_number:
633            value = this.name
634            if value[0] == "-":
635                return exp.Literal.number(value[1:])
636            return exp.Literal.number(f"-{value}")
637
638    if type(expression) in INVERSE_DATE_OPS:
639        return _simplify_binary(expression, expression.this, expression.interval()) or expression
640
641    return expression
def simplify_parens(expression):
742def simplify_parens(expression):
743    if not isinstance(expression, exp.Paren):
744        return expression
745
746    this = expression.this
747    parent = expression.parent
748    parent_is_predicate = isinstance(parent, exp.Predicate)
749
750    if (
751        not isinstance(this, exp.Select)
752        and not isinstance(parent, exp.SubqueryPredicate)
753        and (
754            not isinstance(parent, (exp.Condition, exp.Binary))
755            or isinstance(parent, exp.Paren)
756            or (
757                not isinstance(this, exp.Binary)
758                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
759            )
760            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
761            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
762            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
763            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
764        )
765    ):
766        return this
767    return expression
def simplify_coalesce(expression):
778def simplify_coalesce(expression):
779    # COALESCE(x) -> x
780    if (
781        isinstance(expression, exp.Coalesce)
782        and (not expression.expressions or _is_nonnull_constant(expression.this))
783        # COALESCE is also used as a Spark partitioning hint
784        and not isinstance(expression.parent, exp.Hint)
785    ):
786        return expression.this
787
788    if not isinstance(expression, COMPARISONS):
789        return expression
790
791    if isinstance(expression.left, exp.Coalesce):
792        coalesce = expression.left
793        other = expression.right
794    elif isinstance(expression.right, exp.Coalesce):
795        coalesce = expression.right
796        other = expression.left
797    else:
798        return expression
799
800    # This transformation is valid for non-constants,
801    # but it really only does anything if they are both constants.
802    if not _is_constant(other):
803        return expression
804
805    # Find the first constant arg
806    for arg_index, arg in enumerate(coalesce.expressions):
807        if _is_constant(arg):
808            break
809    else:
810        return expression
811
812    coalesce.set("expressions", coalesce.expressions[:arg_index])
813
814    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
815    # since we already remove COALESCE at the top of this function.
816    coalesce = coalesce if coalesce.expressions else coalesce.this
817
818    # This expression is more complex than when we started, but it will get simplified further
819    return exp.paren(
820        exp.or_(
821            exp.and_(
822                coalesce.is_(exp.null()).not_(copy=False),
823                expression.copy(),
824                copy=False,
825            ),
826            exp.and_(
827                coalesce.is_(exp.null()),
828                type(expression)(this=arg.copy(), expression=other.copy()),
829                copy=False,
830            ),
831            copy=False,
832        )
833    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
839def simplify_concat(expression):
840    """Reduces all groups that contain string literals by concatenating them."""
841    if not isinstance(expression, CONCATS) or (
842        # We can't reduce a CONCAT_WS call if we don't statically know the separator
843        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
844    ):
845        return expression
846
847    if isinstance(expression, exp.ConcatWs):
848        sep_expr, *expressions = expression.expressions
849        sep = sep_expr.name
850        concat_type = exp.ConcatWs
851        args = {}
852    else:
853        expressions = expression.expressions
854        sep = ""
855        concat_type = exp.Concat
856        args = {
857            "safe": expression.args.get("safe"),
858            "coalesce": expression.args.get("coalesce"),
859        }
860
861    new_args = []
862    for is_string_group, group in itertools.groupby(
863        expressions or expression.flatten(), lambda e: e.is_string
864    ):
865        if is_string_group:
866            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
867        else:
868            new_args.extend(group)
869
870    if len(new_args) == 1 and new_args[0].is_string:
871        return new_args[0]
872
873    if concat_type is exp.ConcatWs:
874        new_args = [sep_expr] + new_args
875    elif isinstance(expression, exp.DPipe):
876        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
877
878    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
881def simplify_conditionals(expression):
882    """Simplifies expressions like IF, CASE if their condition is statically known."""
883    if isinstance(expression, exp.Case):
884        this = expression.this
885        for case in expression.args["ifs"]:
886            cond = case.this
887            if this:
888                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
889                cond = cond.replace(this.pop().eq(cond))
890
891            if always_true(cond):
892                return case.args["true"]
893
894            if always_false(cond):
895                case.pop()
896                if not expression.args["ifs"]:
897                    return expression.args.get("default") or exp.null()
898    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
899        if always_true(expression.this):
900            return expression.args["true"]
901        if always_false(expression.this):
902            return expression.args.get("false") or exp.null()
903
904    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
907def simplify_startswith(expression: exp.Expression) -> exp.Expression:
908    """
909    Reduces a prefix check to either TRUE or FALSE if both the string and the
910    prefix are statically known.
911
912    Example:
913        >>> from sqlglot import parse_one
914        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
915        'TRUE'
916    """
917    if (
918        isinstance(expression, exp.StartsWith)
919        and expression.this.is_string
920        and expression.expression.is_string
921    ):
922        return exp.convert(expression.name.startswith(expression.expression.name))
923
924    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc(expression, *args, **kwargs):
171        def wrapped(expression, *args, **kwargs):
172            try:
173                return func(expression, *args, **kwargs)
174            except exceptions:
175                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1073def sort_comparison(expression: exp.Expression) -> exp.Expression:
1074    if expression.__class__ in COMPLEMENT_COMPARISONS:
1075        l, r = expression.this, expression.expression
1076        l_column = isinstance(l, exp.Column)
1077        r_column = isinstance(r, exp.Column)
1078        l_const = _is_constant(l)
1079        r_const = _is_constant(r)
1080
1081        if (l_column and not r_column) or (r_const and not l_const):
1082            return expression
1083        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1084            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1085                this=r, expression=l
1086            )
1087    return expression
JOINS = {('RIGHT', ''), ('RIGHT', 'OUTER'), ('', 'INNER'), ('', '')}
def remove_where_true(expression):
1101def remove_where_true(expression):
1102    for where in expression.find_all(exp.Where):
1103        if always_true(where.this):
1104            where.pop()
1105    for join in expression.find_all(exp.Join):
1106        if (
1107            always_true(join.args.get("on"))
1108            and not join.args.get("using")
1109            and not join.args.get("method")
1110            and (join.side, join.kind) in JOINS
1111        ):
1112            join.args["on"].pop()
1113            join.set("side", None)
1114            join.set("kind", "CROSS")
def always_true(expression):
1117def always_true(expression):
1118    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1119        expression, exp.Literal
1120    )
def always_false(expression):
1123def always_false(expression):
1124    return is_false(expression) or is_null(expression)
def is_complement(a, b):
1127def is_complement(a, b):
1128    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1131def is_false(a: exp.Expression) -> bool:
1132    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1135def is_null(a: exp.Expression) -> bool:
1136    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1139def eval_boolean(expression, a, b):
1140    if isinstance(expression, (exp.EQ, exp.Is)):
1141        return boolean_literal(a == b)
1142    if isinstance(expression, exp.NEQ):
1143        return boolean_literal(a != b)
1144    if isinstance(expression, exp.GT):
1145        return boolean_literal(a > b)
1146    if isinstance(expression, exp.GTE):
1147        return boolean_literal(a >= b)
1148    if isinstance(expression, exp.LT):
1149        return boolean_literal(a < b)
1150    if isinstance(expression, exp.LTE):
1151        return boolean_literal(a <= b)
1152    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1155def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1156    if isinstance(value, datetime.datetime):
1157        return value.date()
1158    if isinstance(value, datetime.date):
1159        return value
1160    try:
1161        return datetime.datetime.fromisoformat(value).date()
1162    except ValueError:
1163        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1166def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1167    if isinstance(value, datetime.datetime):
1168        return value
1169    if isinstance(value, datetime.date):
1170        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1171    try:
1172        return datetime.datetime.fromisoformat(value)
1173    except ValueError:
1174        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1177def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1178    if not value:
1179        return None
1180    if to.is_type(exp.DataType.Type.DATE):
1181        return cast_as_date(value)
1182    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1183        return cast_as_datetime(value)
1184    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1187def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1188    if isinstance(cast, exp.Cast):
1189        to = cast.to
1190    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1191        to = exp.DataType.build(exp.DataType.Type.DATE)
1192    else:
1193        return None
1194
1195    if isinstance(cast.this, exp.Literal):
1196        value: t.Any = cast.this.name
1197    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1198        value = extract_date(cast.this)
1199    else:
1200        return None
1201    return cast_value(value, to)
def extract_interval(expression):
1208def extract_interval(expression):
1209    try:
1210        n = int(expression.name)
1211        unit = expression.text("unit").lower()
1212        return interval(unit, n)
1213    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1214        return None
def extract_type(*expressions):
1217def extract_type(*expressions):
1218    target_type = None
1219    for expression in expressions:
1220        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1221        if target_type:
1222            break
1223
1224    return target_type
def date_literal(date, target_type=None):
1227def date_literal(date, target_type=None):
1228    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1229        target_type = (
1230            exp.DataType.Type.DATETIME
1231            if isinstance(date, datetime.datetime)
1232            else exp.DataType.Type.DATE
1233        )
1234
1235    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1238def interval(unit: str, n: int = 1):
1239    from dateutil.relativedelta import relativedelta
1240
1241    if unit == "year":
1242        return relativedelta(years=1 * n)
1243    if unit == "quarter":
1244        return relativedelta(months=3 * n)
1245    if unit == "month":
1246        return relativedelta(months=1 * n)
1247    if unit == "week":
1248        return relativedelta(weeks=1 * n)
1249    if unit == "day":
1250        return relativedelta(days=1 * n)
1251    if unit == "hour":
1252        return relativedelta(hours=1 * n)
1253    if unit == "minute":
1254        return relativedelta(minutes=1 * n)
1255    if unit == "second":
1256        return relativedelta(seconds=1 * n)
1257
1258    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1261def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1262    if unit == "year":
1263        return d.replace(month=1, day=1)
1264    if unit == "quarter":
1265        if d.month <= 3:
1266            return d.replace(month=1, day=1)
1267        elif d.month <= 6:
1268            return d.replace(month=4, day=1)
1269        elif d.month <= 9:
1270            return d.replace(month=7, day=1)
1271        else:
1272            return d.replace(month=10, day=1)
1273    if unit == "month":
1274        return d.replace(month=d.month, day=1)
1275    if unit == "week":
1276        # Assuming week starts on Monday (0) and ends on Sunday (6)
1277        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1278    if unit == "day":
1279        return d
1280
1281    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1284def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1285    floor = date_floor(d, unit, dialect)
1286
1287    if floor == d:
1288        return d
1289
1290    return floor + interval(unit)
def boolean_literal(condition):
1293def boolean_literal(condition):
1294    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1323def gen(expression: t.Any) -> str:
1324    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1325
1326    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1327    generator is expensive so we have a bare minimum sql generator here.
1328    """
1329    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1332class Gen:
1333    def __init__(self):
1334        self.stack = []
1335        self.sqls = []
1336
1337    def gen(self, expression: exp.Expression) -> str:
1338        self.stack = [expression]
1339        self.sqls.clear()
1340
1341        while self.stack:
1342            node = self.stack.pop()
1343
1344            if isinstance(node, exp.Expression):
1345                exp_handler_name = f"{node.key}_sql"
1346
1347                if hasattr(self, exp_handler_name):
1348                    getattr(self, exp_handler_name)(node)
1349                elif isinstance(node, exp.Func):
1350                    self._function(node)
1351                else:
1352                    key = node.key.upper()
1353                    self.stack.append(f"{key} " if self._args(node) else key)
1354            elif type(node) is list:
1355                for n in reversed(node):
1356                    if n is not None:
1357                        self.stack.extend((n, ","))
1358                if node:
1359                    self.stack.pop()
1360            else:
1361                if node is not None:
1362                    self.sqls.append(str(node))
1363
1364        return "".join(self.sqls)
1365
1366    def add_sql(self, e: exp.Add) -> None:
1367        self._binary(e, " + ")
1368
1369    def alias_sql(self, e: exp.Alias) -> None:
1370        self.stack.extend(
1371            (
1372                e.args.get("alias"),
1373                " AS ",
1374                e.args.get("this"),
1375            )
1376        )
1377
1378    def and_sql(self, e: exp.And) -> None:
1379        self._binary(e, " AND ")
1380
1381    def anonymous_sql(self, e: exp.Anonymous) -> None:
1382        this = e.this
1383        if isinstance(this, str):
1384            name = this.upper()
1385        elif isinstance(this, exp.Identifier):
1386            name = this.this
1387            name = f'"{name}"' if this.quoted else name.upper()
1388        else:
1389            raise ValueError(
1390                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1391            )
1392
1393        self.stack.extend(
1394            (
1395                ")",
1396                e.expressions,
1397                "(",
1398                name,
1399            )
1400        )
1401
1402    def between_sql(self, e: exp.Between) -> None:
1403        self.stack.extend(
1404            (
1405                e.args.get("high"),
1406                " AND ",
1407                e.args.get("low"),
1408                " BETWEEN ",
1409                e.this,
1410            )
1411        )
1412
1413    def boolean_sql(self, e: exp.Boolean) -> None:
1414        self.stack.append("TRUE" if e.this else "FALSE")
1415
1416    def bracket_sql(self, e: exp.Bracket) -> None:
1417        self.stack.extend(
1418            (
1419                "]",
1420                e.expressions,
1421                "[",
1422                e.this,
1423            )
1424        )
1425
1426    def column_sql(self, e: exp.Column) -> None:
1427        for p in reversed(e.parts):
1428            self.stack.extend((p, "."))
1429        self.stack.pop()
1430
1431    def datatype_sql(self, e: exp.DataType) -> None:
1432        self._args(e, 1)
1433        self.stack.append(f"{e.this.name} ")
1434
1435    def div_sql(self, e: exp.Div) -> None:
1436        self._binary(e, " / ")
1437
1438    def dot_sql(self, e: exp.Dot) -> None:
1439        self._binary(e, ".")
1440
1441    def eq_sql(self, e: exp.EQ) -> None:
1442        self._binary(e, " = ")
1443
1444    def from_sql(self, e: exp.From) -> None:
1445        self.stack.extend((e.this, "FROM "))
1446
1447    def gt_sql(self, e: exp.GT) -> None:
1448        self._binary(e, " > ")
1449
1450    def gte_sql(self, e: exp.GTE) -> None:
1451        self._binary(e, " >= ")
1452
1453    def identifier_sql(self, e: exp.Identifier) -> None:
1454        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1455
1456    def ilike_sql(self, e: exp.ILike) -> None:
1457        self._binary(e, " ILIKE ")
1458
1459    def in_sql(self, e: exp.In) -> None:
1460        self.stack.append(")")
1461        self._args(e, 1)
1462        self.stack.extend(
1463            (
1464                "(",
1465                " IN ",
1466                e.this,
1467            )
1468        )
1469
1470    def intdiv_sql(self, e: exp.IntDiv) -> None:
1471        self._binary(e, " DIV ")
1472
1473    def is_sql(self, e: exp.Is) -> None:
1474        self._binary(e, " IS ")
1475
1476    def like_sql(self, e: exp.Like) -> None:
1477        self._binary(e, " Like ")
1478
1479    def literal_sql(self, e: exp.Literal) -> None:
1480        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1481
1482    def lt_sql(self, e: exp.LT) -> None:
1483        self._binary(e, " < ")
1484
1485    def lte_sql(self, e: exp.LTE) -> None:
1486        self._binary(e, " <= ")
1487
1488    def mod_sql(self, e: exp.Mod) -> None:
1489        self._binary(e, " % ")
1490
1491    def mul_sql(self, e: exp.Mul) -> None:
1492        self._binary(e, " * ")
1493
1494    def neg_sql(self, e: exp.Neg) -> None:
1495        self._unary(e, "-")
1496
1497    def neq_sql(self, e: exp.NEQ) -> None:
1498        self._binary(e, " <> ")
1499
1500    def not_sql(self, e: exp.Not) -> None:
1501        self._unary(e, "NOT ")
1502
1503    def null_sql(self, e: exp.Null) -> None:
1504        self.stack.append("NULL")
1505
1506    def or_sql(self, e: exp.Or) -> None:
1507        self._binary(e, " OR ")
1508
1509    def paren_sql(self, e: exp.Paren) -> None:
1510        self.stack.extend(
1511            (
1512                ")",
1513                e.this,
1514                "(",
1515            )
1516        )
1517
1518    def sub_sql(self, e: exp.Sub) -> None:
1519        self._binary(e, " - ")
1520
1521    def subquery_sql(self, e: exp.Subquery) -> None:
1522        self._args(e, 2)
1523        alias = e.args.get("alias")
1524        if alias:
1525            self.stack.append(alias)
1526        self.stack.extend((")", e.this, "("))
1527
1528    def table_sql(self, e: exp.Table) -> None:
1529        self._args(e, 4)
1530        alias = e.args.get("alias")
1531        if alias:
1532            self.stack.append(alias)
1533        for p in reversed(e.parts):
1534            self.stack.extend((p, "."))
1535        self.stack.pop()
1536
1537    def tablealias_sql(self, e: exp.TableAlias) -> None:
1538        columns = e.columns
1539
1540        if columns:
1541            self.stack.extend((")", columns, "("))
1542
1543        self.stack.extend((e.this, " AS "))
1544
1545    def var_sql(self, e: exp.Var) -> None:
1546        self.stack.append(e.this)
1547
1548    def _binary(self, e: exp.Binary, op: str) -> None:
1549        self.stack.extend((e.expression, op, e.this))
1550
1551    def _unary(self, e: exp.Unary, op: str) -> None:
1552        self.stack.extend((e.this, op))
1553
1554    def _function(self, e: exp.Func) -> None:
1555        self.stack.extend(
1556            (
1557                ")",
1558                list(e.args.values()),
1559                "(",
1560                e.sql_name(),
1561            )
1562        )
1563
1564    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1565        kvs = []
1566        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1567
1568        for k in arg_types or arg_types:
1569            v = node.args.get(k)
1570
1571            if v is not None:
1572                kvs.append([f":{k}", v])
1573        if kvs:
1574            self.stack.append(kvs)
1575            return True
1576        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1337    def gen(self, expression: exp.Expression) -> str:
1338        self.stack = [expression]
1339        self.sqls.clear()
1340
1341        while self.stack:
1342            node = self.stack.pop()
1343
1344            if isinstance(node, exp.Expression):
1345                exp_handler_name = f"{node.key}_sql"
1346
1347                if hasattr(self, exp_handler_name):
1348                    getattr(self, exp_handler_name)(node)
1349                elif isinstance(node, exp.Func):
1350                    self._function(node)
1351                else:
1352                    key = node.key.upper()
1353                    self.stack.append(f"{key} " if self._args(node) else key)
1354            elif type(node) is list:
1355                for n in reversed(node):
1356                    if n is not None:
1357                        self.stack.extend((n, ","))
1358                if node:
1359                    self.stack.pop()
1360            else:
1361                if node is not None:
1362                    self.sqls.append(str(node))
1363
1364        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1366    def add_sql(self, e: exp.Add) -> None:
1367        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1369    def alias_sql(self, e: exp.Alias) -> None:
1370        self.stack.extend(
1371            (
1372                e.args.get("alias"),
1373                " AS ",
1374                e.args.get("this"),
1375            )
1376        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1378    def and_sql(self, e: exp.And) -> None:
1379        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1381    def anonymous_sql(self, e: exp.Anonymous) -> None:
1382        this = e.this
1383        if isinstance(this, str):
1384            name = this.upper()
1385        elif isinstance(this, exp.Identifier):
1386            name = this.this
1387            name = f'"{name}"' if this.quoted else name.upper()
1388        else:
1389            raise ValueError(
1390                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1391            )
1392
1393        self.stack.extend(
1394            (
1395                ")",
1396                e.expressions,
1397                "(",
1398                name,
1399            )
1400        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1402    def between_sql(self, e: exp.Between) -> None:
1403        self.stack.extend(
1404            (
1405                e.args.get("high"),
1406                " AND ",
1407                e.args.get("low"),
1408                " BETWEEN ",
1409                e.this,
1410            )
1411        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1413    def boolean_sql(self, e: exp.Boolean) -> None:
1414        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1416    def bracket_sql(self, e: exp.Bracket) -> None:
1417        self.stack.extend(
1418            (
1419                "]",
1420                e.expressions,
1421                "[",
1422                e.this,
1423            )
1424        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1426    def column_sql(self, e: exp.Column) -> None:
1427        for p in reversed(e.parts):
1428            self.stack.extend((p, "."))
1429        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1431    def datatype_sql(self, e: exp.DataType) -> None:
1432        self._args(e, 1)
1433        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1435    def div_sql(self, e: exp.Div) -> None:
1436        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1438    def dot_sql(self, e: exp.Dot) -> None:
1439        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1441    def eq_sql(self, e: exp.EQ) -> None:
1442        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1444    def from_sql(self, e: exp.From) -> None:
1445        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1447    def gt_sql(self, e: exp.GT) -> None:
1448        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1450    def gte_sql(self, e: exp.GTE) -> None:
1451        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1453    def identifier_sql(self, e: exp.Identifier) -> None:
1454        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1456    def ilike_sql(self, e: exp.ILike) -> None:
1457        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1459    def in_sql(self, e: exp.In) -> None:
1460        self.stack.append(")")
1461        self._args(e, 1)
1462        self.stack.extend(
1463            (
1464                "(",
1465                " IN ",
1466                e.this,
1467            )
1468        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1470    def intdiv_sql(self, e: exp.IntDiv) -> None:
1471        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1473    def is_sql(self, e: exp.Is) -> None:
1474        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1476    def like_sql(self, e: exp.Like) -> None:
1477        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1479    def literal_sql(self, e: exp.Literal) -> None:
1480        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1482    def lt_sql(self, e: exp.LT) -> None:
1483        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1485    def lte_sql(self, e: exp.LTE) -> None:
1486        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1488    def mod_sql(self, e: exp.Mod) -> None:
1489        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1491    def mul_sql(self, e: exp.Mul) -> None:
1492        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1494    def neg_sql(self, e: exp.Neg) -> None:
1495        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1497    def neq_sql(self, e: exp.NEQ) -> None:
1498        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1500    def not_sql(self, e: exp.Not) -> None:
1501        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1503    def null_sql(self, e: exp.Null) -> None:
1504        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1506    def or_sql(self, e: exp.Or) -> None:
1507        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1509    def paren_sql(self, e: exp.Paren) -> None:
1510        self.stack.extend(
1511            (
1512                ")",
1513                e.this,
1514                "(",
1515            )
1516        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1518    def sub_sql(self, e: exp.Sub) -> None:
1519        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1521    def subquery_sql(self, e: exp.Subquery) -> None:
1522        self._args(e, 2)
1523        alias = e.args.get("alias")
1524        if alias:
1525            self.stack.append(alias)
1526        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1528    def table_sql(self, e: exp.Table) -> None:
1529        self._args(e, 4)
1530        alias = e.args.get("alias")
1531        if alias:
1532            self.stack.append(alias)
1533        for p in reversed(e.parts):
1534            self.stack.extend((p, "."))
1535        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1537    def tablealias_sql(self, e: exp.TableAlias) -> None:
1538        columns = e.columns
1539
1540        if columns:
1541            self.stack.extend((")", columns, "("))
1542
1543        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1545    def var_sql(self, e: exp.Var) -> None:
1546        self.stack.append(e.this)