sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23# Final means that an expression should not be simplified 24FINAL = "final" 25 26# Value ranges for byte-sized signed/unsigned integers 27TINYINT_MIN = -128 28TINYINT_MAX = 127 29UTINYINT_MIN = 0 30UTINYINT_MAX = 255 31 32 33class UnsupportedUnit(Exception): 34 pass 35 36 37def simplify( 38 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 39): 40 """ 41 Rewrite sqlglot AST to simplify expressions. 42 43 Example: 44 >>> import sqlglot 45 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 46 >>> simplify(expression).sql() 47 'TRUE' 48 49 Args: 50 expression (sqlglot.Expression): expression to simplify 51 constant_propagation: whether the constant propagation rule should be used 52 53 Returns: 54 sqlglot.Expression: simplified expression 55 """ 56 57 dialect = Dialect.get_or_raise(dialect) 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # group by expressions cannot be simplified, for example 64 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 65 # the projection must exactly match the group by key 66 group = expression.args.get("group") 67 68 if group and hasattr(expression, "selects"): 69 groups = set(group.expressions) 70 group.meta[FINAL] = True 71 72 for e in expression.selects: 73 for node in e.walk(): 74 if node in groups: 75 e.meta[FINAL] = True 76 break 77 78 having = expression.args.get("having") 79 if having: 80 for node in having.walk(): 81 if node in groups: 82 having.meta[FINAL] = True 83 break 84 85 # Pre-order transformations 86 node = expression 87 node = rewrite_between(node) 88 node = uniq_sort(node, root) 89 node = absorb_and_eliminate(node, root) 90 node = simplify_concat(node) 91 node = simplify_conditionals(node) 92 93 if constant_propagation: 94 node = propagate_constants(node, root) 95 96 exp.replace_children(node, lambda e: _simplify(e, False)) 97 98 # Post-order transformations 99 node = simplify_not(node) 100 node = flatten(node) 101 node = simplify_connectors(node, root) 102 node = remove_complements(node, root) 103 node = simplify_coalesce(node) 104 node.parent = expression.parent 105 node = simplify_literals(node, root) 106 node = simplify_equality(node) 107 node = simplify_parens(node) 108 node = simplify_datetrunc(node, dialect) 109 node = sort_comparison(node) 110 node = simplify_startswith(node) 111 112 if root: 113 expression.replace(node) 114 return node 115 116 expression = while_changing(expression, _simplify) 117 remove_where_true(expression) 118 return expression 119 120 121def catch(*exceptions): 122 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 123 124 def decorator(func): 125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression 130 131 return wrapped 132 133 return decorator 134 135 136def rewrite_between(expression: exp.Expression) -> exp.Expression: 137 """Rewrite x between y and z to x >= y AND x <= z. 138 139 This is done because comparison simplification is only done on lt/lte/gt/gte. 140 """ 141 if isinstance(expression, exp.Between): 142 negate = isinstance(expression.parent, exp.Not) 143 144 expression = exp.and_( 145 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 146 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 147 copy=False, 148 ) 149 150 if negate: 151 expression = exp.paren(expression, copy=False) 152 153 return expression 154 155 156COMPLEMENT_COMPARISONS = { 157 exp.LT: exp.GTE, 158 exp.GT: exp.LTE, 159 exp.LTE: exp.GT, 160 exp.GTE: exp.LT, 161 exp.EQ: exp.NEQ, 162 exp.NEQ: exp.EQ, 163} 164 165 166def simplify_not(expression): 167 """ 168 Demorgan's Law 169 NOT (x OR y) -> NOT x AND NOT y 170 NOT (x AND y) -> NOT x OR NOT y 171 """ 172 if isinstance(expression, exp.Not): 173 this = expression.this 174 if is_null(this): 175 return exp.null() 176 if this.__class__ in COMPLEMENT_COMPARISONS: 177 return COMPLEMENT_COMPARISONS[this.__class__]( 178 this=this.this, expression=this.expression 179 ) 180 if isinstance(this, exp.Paren): 181 condition = this.unnest() 182 if isinstance(condition, exp.And): 183 return exp.paren( 184 exp.or_( 185 exp.not_(condition.left, copy=False), 186 exp.not_(condition.right, copy=False), 187 copy=False, 188 ) 189 ) 190 if isinstance(condition, exp.Or): 191 return exp.paren( 192 exp.and_( 193 exp.not_(condition.left, copy=False), 194 exp.not_(condition.right, copy=False), 195 copy=False, 196 ) 197 ) 198 if is_null(condition): 199 return exp.null() 200 if always_true(this): 201 return exp.false() 202 if is_false(this): 203 return exp.true() 204 if isinstance(this, exp.Not): 205 # double negation 206 # NOT NOT x -> x 207 return this.this 208 return expression 209 210 211def flatten(expression): 212 """ 213 A AND (B AND C) -> A AND B AND C 214 A OR (B OR C) -> A OR B OR C 215 """ 216 if isinstance(expression, exp.Connector): 217 for node in expression.args.values(): 218 child = node.unnest() 219 if isinstance(child, expression.__class__): 220 node.replace(child) 221 return expression 222 223 224def simplify_connectors(expression, root=True): 225 def _simplify_connectors(expression, left, right): 226 if left == right: 227 return left 228 if isinstance(expression, exp.And): 229 if is_false(left) or is_false(right): 230 return exp.false() 231 if is_null(left) or is_null(right): 232 return exp.null() 233 if always_true(left) and always_true(right): 234 return exp.true() 235 if always_true(left): 236 return right 237 if always_true(right): 238 return left 239 return _simplify_comparison(expression, left, right) 240 elif isinstance(expression, exp.Or): 241 if always_true(left) or always_true(right): 242 return exp.true() 243 if is_false(left) and is_false(right): 244 return exp.false() 245 if ( 246 (is_null(left) and is_null(right)) 247 or (is_null(left) and is_false(right)) 248 or (is_false(left) and is_null(right)) 249 ): 250 return exp.null() 251 if is_false(left): 252 return right 253 if is_false(right): 254 return left 255 return _simplify_comparison(expression, left, right, or_=True) 256 257 if isinstance(expression, exp.Connector): 258 return _flat_simplify(expression, _simplify_connectors, root) 259 return expression 260 261 262LT_LTE = (exp.LT, exp.LTE) 263GT_GTE = (exp.GT, exp.GTE) 264 265COMPARISONS = ( 266 *LT_LTE, 267 *GT_GTE, 268 exp.EQ, 269 exp.NEQ, 270 exp.Is, 271) 272 273INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 274 exp.LT: exp.GT, 275 exp.GT: exp.LT, 276 exp.LTE: exp.GTE, 277 exp.GTE: exp.LTE, 278} 279 280NONDETERMINISTIC = (exp.Rand, exp.Randn) 281 282 283def _simplify_comparison(expression, left, right, or_=False): 284 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 285 ll, lr = left.args.values() 286 rl, rr = right.args.values() 287 288 largs = {ll, lr} 289 rargs = {rl, rr} 290 291 matching = largs & rargs 292 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 293 294 if matching and columns: 295 try: 296 l = first(largs - columns) 297 r = first(rargs - columns) 298 except StopIteration: 299 return expression 300 301 if l.is_number and r.is_number: 302 l = float(l.name) 303 r = float(r.name) 304 elif l.is_string and r.is_string: 305 l = l.name 306 r = r.name 307 else: 308 l = extract_date(l) 309 if not l: 310 return None 311 r = extract_date(r) 312 if not r: 313 return None 314 # python won't compare date and datetime, but many engines will upcast 315 l, r = cast_as_datetime(l), cast_as_datetime(r) 316 317 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 318 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 319 return left if (av > bv if or_ else av <= bv) else right 320 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 321 return left if (av < bv if or_ else av >= bv) else right 322 323 # we can't ever shortcut to true because the column could be null 324 if not or_: 325 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 326 if av <= bv: 327 return exp.false() 328 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 329 if av >= bv: 330 return exp.false() 331 elif isinstance(a, exp.EQ): 332 if isinstance(b, exp.LT): 333 return exp.false() if av >= bv else a 334 if isinstance(b, exp.LTE): 335 return exp.false() if av > bv else a 336 if isinstance(b, exp.GT): 337 return exp.false() if av <= bv else a 338 if isinstance(b, exp.GTE): 339 return exp.false() if av < bv else a 340 if isinstance(b, exp.NEQ): 341 return exp.false() if av == bv else a 342 return None 343 344 345def remove_complements(expression, root=True): 346 """ 347 Removing complements. 348 349 A AND NOT A -> FALSE 350 A OR NOT A -> TRUE 351 """ 352 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 353 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 354 355 for a, b in itertools.permutations(expression.flatten(), 2): 356 if is_complement(a, b): 357 return complement 358 return expression 359 360 361def uniq_sort(expression, root=True): 362 """ 363 Uniq and sort a connector. 364 365 C AND A AND B AND B -> A AND B AND C 366 """ 367 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 368 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 369 flattened = tuple(expression.flatten()) 370 deduped = {gen(e): e for e in flattened} 371 arr = tuple(deduped.items()) 372 373 # check if the operands are already sorted, if not sort them 374 # A AND C AND B -> A AND B AND C 375 for i, (sql, e) in enumerate(arr[1:]): 376 if sql < arr[i][0]: 377 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 378 break 379 else: 380 # we didn't have to sort but maybe we need to dedup 381 if len(deduped) < len(flattened): 382 expression = result_func(*deduped.values(), copy=False) 383 384 return expression 385 386 387def absorb_and_eliminate(expression, root=True): 388 """ 389 absorption: 390 A AND (A OR B) -> A 391 A OR (A AND B) -> A 392 A AND (NOT A OR B) -> A AND B 393 A OR (NOT A AND B) -> A OR B 394 elimination: 395 (A AND B) OR (A AND NOT B) -> A 396 (A OR B) AND (A OR NOT B) -> A 397 """ 398 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 399 kind = exp.Or if isinstance(expression, exp.And) else exp.And 400 401 for a, b in itertools.permutations(expression.flatten(), 2): 402 if isinstance(a, kind): 403 aa, ab = a.unnest_operands() 404 405 # absorb 406 if is_complement(b, aa): 407 aa.replace(exp.true() if kind == exp.And else exp.false()) 408 elif is_complement(b, ab): 409 ab.replace(exp.true() if kind == exp.And else exp.false()) 410 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 411 a.replace(exp.false() if kind == exp.And else exp.true()) 412 elif isinstance(b, kind): 413 # eliminate 414 rhs = b.unnest_operands() 415 ba, bb = rhs 416 417 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 418 a.replace(aa) 419 b.replace(aa) 420 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 421 a.replace(ab) 422 b.replace(ab) 423 424 return expression 425 426 427def propagate_constants(expression, root=True): 428 """ 429 Propagate constants for conjunctions in DNF: 430 431 SELECT * FROM t WHERE a = b AND b = 5 becomes 432 SELECT * FROM t WHERE a = 5 AND b = 5 433 434 Reference: https://www.sqlite.org/optoverview.html 435 """ 436 437 if ( 438 isinstance(expression, exp.And) 439 and (root or not expression.same_parent) 440 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 441 ): 442 constant_mapping = {} 443 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 444 if isinstance(expr, exp.EQ): 445 l, r = expr.left, expr.right 446 447 # TODO: create a helper that can be used to detect nested literal expressions such 448 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 449 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 450 constant_mapping[l] = (id(l), r) 451 452 if constant_mapping: 453 for column in find_all_in_scope(expression, exp.Column): 454 parent = column.parent 455 column_id, constant = constant_mapping.get(column) or (None, None) 456 if ( 457 column_id is not None 458 and id(column) != column_id 459 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 460 ): 461 column.replace(constant.copy()) 462 463 return expression 464 465 466INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 467 exp.DateAdd: exp.Sub, 468 exp.DateSub: exp.Add, 469 exp.DatetimeAdd: exp.Sub, 470 exp.DatetimeSub: exp.Add, 471} 472 473INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 474 **INVERSE_DATE_OPS, 475 exp.Add: exp.Sub, 476 exp.Sub: exp.Add, 477} 478 479 480def _is_number(expression: exp.Expression) -> bool: 481 return expression.is_number 482 483 484def _is_interval(expression: exp.Expression) -> bool: 485 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 486 487 488@catch(ModuleNotFoundError, UnsupportedUnit) 489def simplify_equality(expression: exp.Expression) -> exp.Expression: 490 """ 491 Use the subtraction and addition properties of equality to simplify expressions: 492 493 x + 1 = 3 becomes x = 2 494 495 There are two binary operations in the above expression: + and = 496 Here's how we reference all the operands in the code below: 497 498 l r 499 x + 1 = 3 500 a b 501 """ 502 if isinstance(expression, COMPARISONS): 503 l, r = expression.left, expression.right 504 505 if l.__class__ not in INVERSE_OPS: 506 return expression 507 508 if r.is_number: 509 a_predicate = _is_number 510 b_predicate = _is_number 511 elif _is_date_literal(r): 512 a_predicate = _is_date_literal 513 b_predicate = _is_interval 514 else: 515 return expression 516 517 if l.__class__ in INVERSE_DATE_OPS: 518 l = t.cast(exp.IntervalOp, l) 519 a = l.this 520 b = l.interval() 521 else: 522 l = t.cast(exp.Binary, l) 523 a, b = l.left, l.right 524 525 if not a_predicate(a) and b_predicate(b): 526 pass 527 elif not a_predicate(b) and b_predicate(a): 528 a, b = b, a 529 else: 530 return expression 531 532 return expression.__class__( 533 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 534 ) 535 return expression 536 537 538def simplify_literals(expression, root=True): 539 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 540 return _flat_simplify(expression, _simplify_binary, root) 541 542 if isinstance(expression, exp.Neg): 543 this = expression.this 544 if this.is_number: 545 value = this.name 546 if value[0] == "-": 547 return exp.Literal.number(value[1:]) 548 return exp.Literal.number(f"-{value}") 549 550 if type(expression) in INVERSE_DATE_OPS: 551 return _simplify_binary(expression, expression.this, expression.interval()) or expression 552 553 return expression 554 555 556NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 557 558 559def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 560 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 561 this = _simplify_integer_cast(expr.this) 562 else: 563 this = expr.this 564 565 if isinstance(expr, exp.Cast) and this.is_int: 566 num = int(this.name) 567 568 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 569 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 570 # engine-dependent 571 if ( 572 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 573 ) or ( 574 UTINYINT_MIN <= num <= UTINYINT_MAX 575 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 576 ): 577 return this 578 579 return expr 580 581 582def _simplify_binary(expression, a, b): 583 if isinstance(expression, COMPARISONS): 584 a = _simplify_integer_cast(a) 585 b = _simplify_integer_cast(b) 586 587 if isinstance(expression, exp.Is): 588 if isinstance(b, exp.Not): 589 c = b.this 590 not_ = True 591 else: 592 c = b 593 not_ = False 594 595 if is_null(c): 596 if isinstance(a, exp.Literal): 597 return exp.true() if not_ else exp.false() 598 if is_null(a): 599 return exp.false() if not_ else exp.true() 600 elif isinstance(expression, NULL_OK): 601 return None 602 elif is_null(a) or is_null(b): 603 return exp.null() 604 605 if a.is_number and b.is_number: 606 num_a = int(a.name) if a.is_int else Decimal(a.name) 607 num_b = int(b.name) if b.is_int else Decimal(b.name) 608 609 if isinstance(expression, exp.Add): 610 return exp.Literal.number(num_a + num_b) 611 if isinstance(expression, exp.Mul): 612 return exp.Literal.number(num_a * num_b) 613 614 # We only simplify Sub, Div if a and b have the same parent because they're not associative 615 if isinstance(expression, exp.Sub): 616 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 617 if isinstance(expression, exp.Div): 618 # engines have differing int div behavior so intdiv is not safe 619 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 620 return None 621 return exp.Literal.number(num_a / num_b) 622 623 boolean = eval_boolean(expression, num_a, num_b) 624 625 if boolean: 626 return boolean 627 elif a.is_string and b.is_string: 628 boolean = eval_boolean(expression, a.this, b.this) 629 630 if boolean: 631 return boolean 632 elif _is_date_literal(a) and isinstance(b, exp.Interval): 633 date, b = extract_date(a), extract_interval(b) 634 if date and b: 635 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 636 return date_literal(date + b, extract_type(a)) 637 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 638 return date_literal(date - b, extract_type(a)) 639 elif isinstance(a, exp.Interval) and _is_date_literal(b): 640 a, date = extract_interval(a), extract_date(b) 641 # you cannot subtract a date from an interval 642 if a and b and isinstance(expression, exp.Add): 643 return date_literal(a + date, extract_type(b)) 644 elif _is_date_literal(a) and _is_date_literal(b): 645 if isinstance(expression, exp.Predicate): 646 a, b = extract_date(a), extract_date(b) 647 boolean = eval_boolean(expression, a, b) 648 if boolean: 649 return boolean 650 651 return None 652 653 654def simplify_parens(expression): 655 if not isinstance(expression, exp.Paren): 656 return expression 657 658 this = expression.this 659 parent = expression.parent 660 parent_is_predicate = isinstance(parent, exp.Predicate) 661 662 if ( 663 not isinstance(this, exp.Select) 664 and not isinstance(parent, exp.SubqueryPredicate) 665 and ( 666 not isinstance(parent, (exp.Condition, exp.Binary)) 667 or isinstance(parent, exp.Paren) 668 or ( 669 not isinstance(this, exp.Binary) 670 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 671 ) 672 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 673 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 674 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 675 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 676 ) 677 ): 678 return this 679 return expression 680 681 682def _is_nonnull_constant(expression: exp.Expression) -> bool: 683 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 684 685 686def _is_constant(expression: exp.Expression) -> bool: 687 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 688 689 690def simplify_coalesce(expression): 691 # COALESCE(x) -> x 692 if ( 693 isinstance(expression, exp.Coalesce) 694 and (not expression.expressions or _is_nonnull_constant(expression.this)) 695 # COALESCE is also used as a Spark partitioning hint 696 and not isinstance(expression.parent, exp.Hint) 697 ): 698 return expression.this 699 700 if not isinstance(expression, COMPARISONS): 701 return expression 702 703 if isinstance(expression.left, exp.Coalesce): 704 coalesce = expression.left 705 other = expression.right 706 elif isinstance(expression.right, exp.Coalesce): 707 coalesce = expression.right 708 other = expression.left 709 else: 710 return expression 711 712 # This transformation is valid for non-constants, 713 # but it really only does anything if they are both constants. 714 if not _is_constant(other): 715 return expression 716 717 # Find the first constant arg 718 for arg_index, arg in enumerate(coalesce.expressions): 719 if _is_constant(arg): 720 break 721 else: 722 return expression 723 724 coalesce.set("expressions", coalesce.expressions[:arg_index]) 725 726 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 727 # since we already remove COALESCE at the top of this function. 728 coalesce = coalesce if coalesce.expressions else coalesce.this 729 730 # This expression is more complex than when we started, but it will get simplified further 731 return exp.paren( 732 exp.or_( 733 exp.and_( 734 coalesce.is_(exp.null()).not_(copy=False), 735 expression.copy(), 736 copy=False, 737 ), 738 exp.and_( 739 coalesce.is_(exp.null()), 740 type(expression)(this=arg.copy(), expression=other.copy()), 741 copy=False, 742 ), 743 copy=False, 744 ) 745 ) 746 747 748CONCATS = (exp.Concat, exp.DPipe) 749 750 751def simplify_concat(expression): 752 """Reduces all groups that contain string literals by concatenating them.""" 753 if not isinstance(expression, CONCATS) or ( 754 # We can't reduce a CONCAT_WS call if we don't statically know the separator 755 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 756 ): 757 return expression 758 759 if isinstance(expression, exp.ConcatWs): 760 sep_expr, *expressions = expression.expressions 761 sep = sep_expr.name 762 concat_type = exp.ConcatWs 763 args = {} 764 else: 765 expressions = expression.expressions 766 sep = "" 767 concat_type = exp.Concat 768 args = { 769 "safe": expression.args.get("safe"), 770 "coalesce": expression.args.get("coalesce"), 771 } 772 773 new_args = [] 774 for is_string_group, group in itertools.groupby( 775 expressions or expression.flatten(), lambda e: e.is_string 776 ): 777 if is_string_group: 778 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 779 else: 780 new_args.extend(group) 781 782 if len(new_args) == 1 and new_args[0].is_string: 783 return new_args[0] 784 785 if concat_type is exp.ConcatWs: 786 new_args = [sep_expr] + new_args 787 elif isinstance(expression, exp.DPipe): 788 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 789 790 return concat_type(expressions=new_args, **args) 791 792 793def simplify_conditionals(expression): 794 """Simplifies expressions like IF, CASE if their condition is statically known.""" 795 if isinstance(expression, exp.Case): 796 this = expression.this 797 for case in expression.args["ifs"]: 798 cond = case.this 799 if this: 800 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 801 cond = cond.replace(this.pop().eq(cond)) 802 803 if always_true(cond): 804 return case.args["true"] 805 806 if always_false(cond): 807 case.pop() 808 if not expression.args["ifs"]: 809 return expression.args.get("default") or exp.null() 810 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 811 if always_true(expression.this): 812 return expression.args["true"] 813 if always_false(expression.this): 814 return expression.args.get("false") or exp.null() 815 816 return expression 817 818 819def simplify_startswith(expression: exp.Expression) -> exp.Expression: 820 """ 821 Reduces a prefix check to either TRUE or FALSE if both the string and the 822 prefix are statically known. 823 824 Example: 825 >>> from sqlglot import parse_one 826 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 827 'TRUE' 828 """ 829 if ( 830 isinstance(expression, exp.StartsWith) 831 and expression.this.is_string 832 and expression.expression.is_string 833 ): 834 return exp.convert(expression.name.startswith(expression.expression.name)) 835 836 return expression 837 838 839DateRange = t.Tuple[datetime.date, datetime.date] 840 841 842def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 843 """ 844 Get the date range for a DATE_TRUNC equality comparison: 845 846 Example: 847 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 848 Returns: 849 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 850 """ 851 floor = date_floor(date, unit, dialect) 852 853 if date != floor: 854 # This will always be False, except for NULL values. 855 return None 856 857 return floor, floor + interval(unit) 858 859 860def _datetrunc_eq_expression( 861 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 862) -> exp.Expression: 863 """Get the logical expression for a date range""" 864 return exp.and_( 865 left >= date_literal(drange[0], target_type), 866 left < date_literal(drange[1], target_type), 867 copy=False, 868 ) 869 870 871def _datetrunc_eq( 872 left: exp.Expression, 873 date: datetime.date, 874 unit: str, 875 dialect: Dialect, 876 target_type: t.Optional[exp.DataType], 877) -> t.Optional[exp.Expression]: 878 drange = _datetrunc_range(date, unit, dialect) 879 if not drange: 880 return None 881 882 return _datetrunc_eq_expression(left, drange, target_type) 883 884 885def _datetrunc_neq( 886 left: exp.Expression, 887 date: datetime.date, 888 unit: str, 889 dialect: Dialect, 890 target_type: t.Optional[exp.DataType], 891) -> t.Optional[exp.Expression]: 892 drange = _datetrunc_range(date, unit, dialect) 893 if not drange: 894 return None 895 896 return exp.and_( 897 left < date_literal(drange[0], target_type), 898 left >= date_literal(drange[1], target_type), 899 copy=False, 900 ) 901 902 903DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 904 exp.LT: lambda l, dt, u, d, t: l 905 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 906 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 907 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 908 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 909 exp.EQ: _datetrunc_eq, 910 exp.NEQ: _datetrunc_neq, 911} 912DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 913DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 914 915 916def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 917 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 918 919 920@catch(ModuleNotFoundError, UnsupportedUnit) 921def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 922 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 923 comparison = expression.__class__ 924 925 if isinstance(expression, DATETRUNCS): 926 this = expression.this 927 trunc_type = extract_type(this) 928 date = extract_date(this) 929 if date and expression.unit: 930 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 931 elif comparison not in DATETRUNC_COMPARISONS: 932 return expression 933 934 if isinstance(expression, exp.Binary): 935 l, r = expression.left, expression.right 936 937 if not _is_datetrunc_predicate(l, r): 938 return expression 939 940 l = t.cast(exp.DateTrunc, l) 941 trunc_arg = l.this 942 unit = l.unit.name.lower() 943 date = extract_date(r) 944 945 if not date: 946 return expression 947 948 return ( 949 DATETRUNC_BINARY_COMPARISONS[comparison]( 950 trunc_arg, date, unit, dialect, extract_type(trunc_arg, r) 951 ) 952 or expression 953 ) 954 955 if isinstance(expression, exp.In): 956 l = expression.this 957 rs = expression.expressions 958 959 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 960 l = t.cast(exp.DateTrunc, l) 961 unit = l.unit.name.lower() 962 963 ranges = [] 964 for r in rs: 965 date = extract_date(r) 966 if not date: 967 return expression 968 drange = _datetrunc_range(date, unit, dialect) 969 if drange: 970 ranges.append(drange) 971 972 if not ranges: 973 return expression 974 975 ranges = merge_ranges(ranges) 976 target_type = extract_type(l, *rs) 977 978 return exp.or_( 979 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 980 ) 981 982 return expression 983 984 985def sort_comparison(expression: exp.Expression) -> exp.Expression: 986 if expression.__class__ in COMPLEMENT_COMPARISONS: 987 l, r = expression.this, expression.expression 988 l_column = isinstance(l, exp.Column) 989 r_column = isinstance(r, exp.Column) 990 l_const = _is_constant(l) 991 r_const = _is_constant(r) 992 993 if (l_column and not r_column) or (r_const and not l_const): 994 return expression 995 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 996 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 997 this=r, expression=l 998 ) 999 return expression 1000 1001 1002# CROSS joins result in an empty table if the right table is empty. 1003# So we can only simplify certain types of joins to CROSS. 1004# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1005JOINS = { 1006 ("", ""), 1007 ("", "INNER"), 1008 ("RIGHT", ""), 1009 ("RIGHT", "OUTER"), 1010} 1011 1012 1013def remove_where_true(expression): 1014 for where in expression.find_all(exp.Where): 1015 if always_true(where.this): 1016 where.pop() 1017 for join in expression.find_all(exp.Join): 1018 if ( 1019 always_true(join.args.get("on")) 1020 and not join.args.get("using") 1021 and not join.args.get("method") 1022 and (join.side, join.kind) in JOINS 1023 ): 1024 join.args["on"].pop() 1025 join.set("side", None) 1026 join.set("kind", "CROSS") 1027 1028 1029def always_true(expression): 1030 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 1031 expression, exp.Literal 1032 ) 1033 1034 1035def always_false(expression): 1036 return is_false(expression) or is_null(expression) 1037 1038 1039def is_complement(a, b): 1040 return isinstance(b, exp.Not) and b.this == a 1041 1042 1043def is_false(a: exp.Expression) -> bool: 1044 return type(a) is exp.Boolean and not a.this 1045 1046 1047def is_null(a: exp.Expression) -> bool: 1048 return type(a) is exp.Null 1049 1050 1051def eval_boolean(expression, a, b): 1052 if isinstance(expression, (exp.EQ, exp.Is)): 1053 return boolean_literal(a == b) 1054 if isinstance(expression, exp.NEQ): 1055 return boolean_literal(a != b) 1056 if isinstance(expression, exp.GT): 1057 return boolean_literal(a > b) 1058 if isinstance(expression, exp.GTE): 1059 return boolean_literal(a >= b) 1060 if isinstance(expression, exp.LT): 1061 return boolean_literal(a < b) 1062 if isinstance(expression, exp.LTE): 1063 return boolean_literal(a <= b) 1064 return None 1065 1066 1067def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1068 if isinstance(value, datetime.datetime): 1069 return value.date() 1070 if isinstance(value, datetime.date): 1071 return value 1072 try: 1073 return datetime.datetime.fromisoformat(value).date() 1074 except ValueError: 1075 return None 1076 1077 1078def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1079 if isinstance(value, datetime.datetime): 1080 return value 1081 if isinstance(value, datetime.date): 1082 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1083 try: 1084 return datetime.datetime.fromisoformat(value) 1085 except ValueError: 1086 return None 1087 1088 1089def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1090 if not value: 1091 return None 1092 if to.is_type(exp.DataType.Type.DATE): 1093 return cast_as_date(value) 1094 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1095 return cast_as_datetime(value) 1096 return None 1097 1098 1099def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1100 if isinstance(cast, exp.Cast): 1101 to = cast.to 1102 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1103 to = exp.DataType.build(exp.DataType.Type.DATE) 1104 else: 1105 return None 1106 1107 if isinstance(cast.this, exp.Literal): 1108 value: t.Any = cast.this.name 1109 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1110 value = extract_date(cast.this) 1111 else: 1112 return None 1113 return cast_value(value, to) 1114 1115 1116def _is_date_literal(expression: exp.Expression) -> bool: 1117 return extract_date(expression) is not None 1118 1119 1120def extract_interval(expression): 1121 try: 1122 n = int(expression.name) 1123 unit = expression.text("unit").lower() 1124 return interval(unit, n) 1125 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1126 return None 1127 1128 1129def extract_type(*expressions): 1130 target_type = None 1131 for expression in expressions: 1132 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1133 if target_type: 1134 break 1135 1136 return target_type 1137 1138 1139def date_literal(date, target_type=None): 1140 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1141 target_type = ( 1142 exp.DataType.Type.DATETIME 1143 if isinstance(date, datetime.datetime) 1144 else exp.DataType.Type.DATE 1145 ) 1146 1147 return exp.cast(exp.Literal.string(date), target_type) 1148 1149 1150def interval(unit: str, n: int = 1): 1151 from dateutil.relativedelta import relativedelta 1152 1153 if unit == "year": 1154 return relativedelta(years=1 * n) 1155 if unit == "quarter": 1156 return relativedelta(months=3 * n) 1157 if unit == "month": 1158 return relativedelta(months=1 * n) 1159 if unit == "week": 1160 return relativedelta(weeks=1 * n) 1161 if unit == "day": 1162 return relativedelta(days=1 * n) 1163 if unit == "hour": 1164 return relativedelta(hours=1 * n) 1165 if unit == "minute": 1166 return relativedelta(minutes=1 * n) 1167 if unit == "second": 1168 return relativedelta(seconds=1 * n) 1169 1170 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1171 1172 1173def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1174 if unit == "year": 1175 return d.replace(month=1, day=1) 1176 if unit == "quarter": 1177 if d.month <= 3: 1178 return d.replace(month=1, day=1) 1179 elif d.month <= 6: 1180 return d.replace(month=4, day=1) 1181 elif d.month <= 9: 1182 return d.replace(month=7, day=1) 1183 else: 1184 return d.replace(month=10, day=1) 1185 if unit == "month": 1186 return d.replace(month=d.month, day=1) 1187 if unit == "week": 1188 # Assuming week starts on Monday (0) and ends on Sunday (6) 1189 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1190 if unit == "day": 1191 return d 1192 1193 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1194 1195 1196def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1197 floor = date_floor(d, unit, dialect) 1198 1199 if floor == d: 1200 return d 1201 1202 return floor + interval(unit) 1203 1204 1205def boolean_literal(condition): 1206 return exp.true() if condition else exp.false() 1207 1208 1209def _flat_simplify(expression, simplifier, root=True): 1210 if root or not expression.same_parent: 1211 operands = [] 1212 queue = deque(expression.flatten(unnest=False)) 1213 size = len(queue) 1214 1215 while queue: 1216 a = queue.popleft() 1217 1218 for b in queue: 1219 result = simplifier(expression, a, b) 1220 1221 if result and result is not expression: 1222 queue.remove(b) 1223 queue.appendleft(result) 1224 break 1225 else: 1226 operands.append(a) 1227 1228 if len(operands) < size: 1229 return functools.reduce( 1230 lambda a, b: expression.__class__(this=a, expression=b), operands 1231 ) 1232 return expression 1233 1234 1235def gen(expression: t.Any) -> str: 1236 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1237 1238 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1239 generator is expensive so we have a bare minimum sql generator here. 1240 """ 1241 return Gen().gen(expression) 1242 1243 1244class Gen: 1245 def __init__(self): 1246 self.stack = [] 1247 self.sqls = [] 1248 1249 def gen(self, expression: exp.Expression) -> str: 1250 self.stack = [expression] 1251 self.sqls.clear() 1252 1253 while self.stack: 1254 node = self.stack.pop() 1255 1256 if isinstance(node, exp.Expression): 1257 exp_handler_name = f"{node.key}_sql" 1258 1259 if hasattr(self, exp_handler_name): 1260 getattr(self, exp_handler_name)(node) 1261 elif isinstance(node, exp.Func): 1262 self._function(node) 1263 else: 1264 key = node.key.upper() 1265 self.stack.append(f"{key} " if self._args(node) else key) 1266 elif type(node) is list: 1267 for n in reversed(node): 1268 if n is not None: 1269 self.stack.extend((n, ",")) 1270 if node: 1271 self.stack.pop() 1272 else: 1273 if node is not None: 1274 self.sqls.append(str(node)) 1275 1276 return "".join(self.sqls) 1277 1278 def add_sql(self, e: exp.Add) -> None: 1279 self._binary(e, " + ") 1280 1281 def alias_sql(self, e: exp.Alias) -> None: 1282 self.stack.extend( 1283 ( 1284 e.args.get("alias"), 1285 " AS ", 1286 e.args.get("this"), 1287 ) 1288 ) 1289 1290 def and_sql(self, e: exp.And) -> None: 1291 self._binary(e, " AND ") 1292 1293 def anonymous_sql(self, e: exp.Anonymous) -> None: 1294 this = e.this 1295 if isinstance(this, str): 1296 name = this.upper() 1297 elif isinstance(this, exp.Identifier): 1298 name = this.this 1299 name = f'"{name}"' if this.quoted else name.upper() 1300 else: 1301 raise ValueError( 1302 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1303 ) 1304 1305 self.stack.extend( 1306 ( 1307 ")", 1308 e.expressions, 1309 "(", 1310 name, 1311 ) 1312 ) 1313 1314 def between_sql(self, e: exp.Between) -> None: 1315 self.stack.extend( 1316 ( 1317 e.args.get("high"), 1318 " AND ", 1319 e.args.get("low"), 1320 " BETWEEN ", 1321 e.this, 1322 ) 1323 ) 1324 1325 def boolean_sql(self, e: exp.Boolean) -> None: 1326 self.stack.append("TRUE" if e.this else "FALSE") 1327 1328 def bracket_sql(self, e: exp.Bracket) -> None: 1329 self.stack.extend( 1330 ( 1331 "]", 1332 e.expressions, 1333 "[", 1334 e.this, 1335 ) 1336 ) 1337 1338 def column_sql(self, e: exp.Column) -> None: 1339 for p in reversed(e.parts): 1340 self.stack.extend((p, ".")) 1341 self.stack.pop() 1342 1343 def datatype_sql(self, e: exp.DataType) -> None: 1344 self._args(e, 1) 1345 self.stack.append(f"{e.this.name} ") 1346 1347 def div_sql(self, e: exp.Div) -> None: 1348 self._binary(e, " / ") 1349 1350 def dot_sql(self, e: exp.Dot) -> None: 1351 self._binary(e, ".") 1352 1353 def eq_sql(self, e: exp.EQ) -> None: 1354 self._binary(e, " = ") 1355 1356 def from_sql(self, e: exp.From) -> None: 1357 self.stack.extend((e.this, "FROM ")) 1358 1359 def gt_sql(self, e: exp.GT) -> None: 1360 self._binary(e, " > ") 1361 1362 def gte_sql(self, e: exp.GTE) -> None: 1363 self._binary(e, " >= ") 1364 1365 def identifier_sql(self, e: exp.Identifier) -> None: 1366 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1367 1368 def ilike_sql(self, e: exp.ILike) -> None: 1369 self._binary(e, " ILIKE ") 1370 1371 def in_sql(self, e: exp.In) -> None: 1372 self.stack.append(")") 1373 self._args(e, 1) 1374 self.stack.extend( 1375 ( 1376 "(", 1377 " IN ", 1378 e.this, 1379 ) 1380 ) 1381 1382 def intdiv_sql(self, e: exp.IntDiv) -> None: 1383 self._binary(e, " DIV ") 1384 1385 def is_sql(self, e: exp.Is) -> None: 1386 self._binary(e, " IS ") 1387 1388 def like_sql(self, e: exp.Like) -> None: 1389 self._binary(e, " Like ") 1390 1391 def literal_sql(self, e: exp.Literal) -> None: 1392 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1393 1394 def lt_sql(self, e: exp.LT) -> None: 1395 self._binary(e, " < ") 1396 1397 def lte_sql(self, e: exp.LTE) -> None: 1398 self._binary(e, " <= ") 1399 1400 def mod_sql(self, e: exp.Mod) -> None: 1401 self._binary(e, " % ") 1402 1403 def mul_sql(self, e: exp.Mul) -> None: 1404 self._binary(e, " * ") 1405 1406 def neg_sql(self, e: exp.Neg) -> None: 1407 self._unary(e, "-") 1408 1409 def neq_sql(self, e: exp.NEQ) -> None: 1410 self._binary(e, " <> ") 1411 1412 def not_sql(self, e: exp.Not) -> None: 1413 self._unary(e, "NOT ") 1414 1415 def null_sql(self, e: exp.Null) -> None: 1416 self.stack.append("NULL") 1417 1418 def or_sql(self, e: exp.Or) -> None: 1419 self._binary(e, " OR ") 1420 1421 def paren_sql(self, e: exp.Paren) -> None: 1422 self.stack.extend( 1423 ( 1424 ")", 1425 e.this, 1426 "(", 1427 ) 1428 ) 1429 1430 def sub_sql(self, e: exp.Sub) -> None: 1431 self._binary(e, " - ") 1432 1433 def subquery_sql(self, e: exp.Subquery) -> None: 1434 self._args(e, 2) 1435 alias = e.args.get("alias") 1436 if alias: 1437 self.stack.append(alias) 1438 self.stack.extend((")", e.this, "(")) 1439 1440 def table_sql(self, e: exp.Table) -> None: 1441 self._args(e, 4) 1442 alias = e.args.get("alias") 1443 if alias: 1444 self.stack.append(alias) 1445 for p in reversed(e.parts): 1446 self.stack.extend((p, ".")) 1447 self.stack.pop() 1448 1449 def tablealias_sql(self, e: exp.TableAlias) -> None: 1450 columns = e.columns 1451 1452 if columns: 1453 self.stack.extend((")", columns, "(")) 1454 1455 self.stack.extend((e.this, " AS ")) 1456 1457 def var_sql(self, e: exp.Var) -> None: 1458 self.stack.append(e.this) 1459 1460 def _binary(self, e: exp.Binary, op: str) -> None: 1461 self.stack.extend((e.expression, op, e.this)) 1462 1463 def _unary(self, e: exp.Unary, op: str) -> None: 1464 self.stack.extend((e.this, op)) 1465 1466 def _function(self, e: exp.Func) -> None: 1467 self.stack.extend( 1468 ( 1469 ")", 1470 list(e.args.values()), 1471 "(", 1472 e.sql_name(), 1473 ) 1474 ) 1475 1476 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1477 kvs = [] 1478 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1479 1480 for k in arg_types or arg_types: 1481 v = node.args.get(k) 1482 1483 if v is not None: 1484 kvs.append([f":{k}", v]) 1485 if kvs: 1486 self.stack.append(kvs) 1487 return True 1488 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
38def simplify( 39 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 40): 41 """ 42 Rewrite sqlglot AST to simplify expressions. 43 44 Example: 45 >>> import sqlglot 46 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 47 >>> simplify(expression).sql() 48 'TRUE' 49 50 Args: 51 expression (sqlglot.Expression): expression to simplify 52 constant_propagation: whether the constant propagation rule should be used 53 54 Returns: 55 sqlglot.Expression: simplified expression 56 """ 57 58 dialect = Dialect.get_or_raise(dialect) 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # group by expressions cannot be simplified, for example 65 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 66 # the projection must exactly match the group by key 67 group = expression.args.get("group") 68 69 if group and hasattr(expression, "selects"): 70 groups = set(group.expressions) 71 group.meta[FINAL] = True 72 73 for e in expression.selects: 74 for node in e.walk(): 75 if node in groups: 76 e.meta[FINAL] = True 77 break 78 79 having = expression.args.get("having") 80 if having: 81 for node in having.walk(): 82 if node in groups: 83 having.meta[FINAL] = True 84 break 85 86 # Pre-order transformations 87 node = expression 88 node = rewrite_between(node) 89 node = uniq_sort(node, root) 90 node = absorb_and_eliminate(node, root) 91 node = simplify_concat(node) 92 node = simplify_conditionals(node) 93 94 if constant_propagation: 95 node = propagate_constants(node, root) 96 97 exp.replace_children(node, lambda e: _simplify(e, False)) 98 99 # Post-order transformations 100 node = simplify_not(node) 101 node = flatten(node) 102 node = simplify_connectors(node, root) 103 node = remove_complements(node, root) 104 node = simplify_coalesce(node) 105 node.parent = expression.parent 106 node = simplify_literals(node, root) 107 node = simplify_equality(node) 108 node = simplify_parens(node) 109 node = simplify_datetrunc(node, dialect) 110 node = sort_comparison(node) 111 node = simplify_startswith(node) 112 113 if root: 114 expression.replace(node) 115 return node 116 117 expression = while_changing(expression, _simplify) 118 remove_where_true(expression) 119 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
122def catch(*exceptions): 123 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 124 125 def decorator(func): 126 def wrapped(expression, *args, **kwargs): 127 try: 128 return func(expression, *args, **kwargs) 129 except exceptions: 130 return expression 131 132 return wrapped 133 134 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
137def rewrite_between(expression: exp.Expression) -> exp.Expression: 138 """Rewrite x between y and z to x >= y AND x <= z. 139 140 This is done because comparison simplification is only done on lt/lte/gt/gte. 141 """ 142 if isinstance(expression, exp.Between): 143 negate = isinstance(expression.parent, exp.Not) 144 145 expression = exp.and_( 146 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 147 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 148 copy=False, 149 ) 150 151 if negate: 152 expression = exp.paren(expression, copy=False) 153 154 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
167def simplify_not(expression): 168 """ 169 Demorgan's Law 170 NOT (x OR y) -> NOT x AND NOT y 171 NOT (x AND y) -> NOT x OR NOT y 172 """ 173 if isinstance(expression, exp.Not): 174 this = expression.this 175 if is_null(this): 176 return exp.null() 177 if this.__class__ in COMPLEMENT_COMPARISONS: 178 return COMPLEMENT_COMPARISONS[this.__class__]( 179 this=this.this, expression=this.expression 180 ) 181 if isinstance(this, exp.Paren): 182 condition = this.unnest() 183 if isinstance(condition, exp.And): 184 return exp.paren( 185 exp.or_( 186 exp.not_(condition.left, copy=False), 187 exp.not_(condition.right, copy=False), 188 copy=False, 189 ) 190 ) 191 if isinstance(condition, exp.Or): 192 return exp.paren( 193 exp.and_( 194 exp.not_(condition.left, copy=False), 195 exp.not_(condition.right, copy=False), 196 copy=False, 197 ) 198 ) 199 if is_null(condition): 200 return exp.null() 201 if always_true(this): 202 return exp.false() 203 if is_false(this): 204 return exp.true() 205 if isinstance(this, exp.Not): 206 # double negation 207 # NOT NOT x -> x 208 return this.this 209 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
212def flatten(expression): 213 """ 214 A AND (B AND C) -> A AND B AND C 215 A OR (B OR C) -> A OR B OR C 216 """ 217 if isinstance(expression, exp.Connector): 218 for node in expression.args.values(): 219 child = node.unnest() 220 if isinstance(child, expression.__class__): 221 node.replace(child) 222 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
225def simplify_connectors(expression, root=True): 226 def _simplify_connectors(expression, left, right): 227 if left == right: 228 return left 229 if isinstance(expression, exp.And): 230 if is_false(left) or is_false(right): 231 return exp.false() 232 if is_null(left) or is_null(right): 233 return exp.null() 234 if always_true(left) and always_true(right): 235 return exp.true() 236 if always_true(left): 237 return right 238 if always_true(right): 239 return left 240 return _simplify_comparison(expression, left, right) 241 elif isinstance(expression, exp.Or): 242 if always_true(left) or always_true(right): 243 return exp.true() 244 if is_false(left) and is_false(right): 245 return exp.false() 246 if ( 247 (is_null(left) and is_null(right)) 248 or (is_null(left) and is_false(right)) 249 or (is_false(left) and is_null(right)) 250 ): 251 return exp.null() 252 if is_false(left): 253 return right 254 if is_false(right): 255 return left 256 return _simplify_comparison(expression, left, right, or_=True) 257 258 if isinstance(expression, exp.Connector): 259 return _flat_simplify(expression, _simplify_connectors, root) 260 return expression
346def remove_complements(expression, root=True): 347 """ 348 Removing complements. 349 350 A AND NOT A -> FALSE 351 A OR NOT A -> TRUE 352 """ 353 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 354 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 355 356 for a, b in itertools.permutations(expression.flatten(), 2): 357 if is_complement(a, b): 358 return complement 359 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
362def uniq_sort(expression, root=True): 363 """ 364 Uniq and sort a connector. 365 366 C AND A AND B AND B -> A AND B AND C 367 """ 368 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 369 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 370 flattened = tuple(expression.flatten()) 371 deduped = {gen(e): e for e in flattened} 372 arr = tuple(deduped.items()) 373 374 # check if the operands are already sorted, if not sort them 375 # A AND C AND B -> A AND B AND C 376 for i, (sql, e) in enumerate(arr[1:]): 377 if sql < arr[i][0]: 378 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 379 break 380 else: 381 # we didn't have to sort but maybe we need to dedup 382 if len(deduped) < len(flattened): 383 expression = result_func(*deduped.values(), copy=False) 384 385 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
388def absorb_and_eliminate(expression, root=True): 389 """ 390 absorption: 391 A AND (A OR B) -> A 392 A OR (A AND B) -> A 393 A AND (NOT A OR B) -> A AND B 394 A OR (NOT A AND B) -> A OR B 395 elimination: 396 (A AND B) OR (A AND NOT B) -> A 397 (A OR B) AND (A OR NOT B) -> A 398 """ 399 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 400 kind = exp.Or if isinstance(expression, exp.And) else exp.And 401 402 for a, b in itertools.permutations(expression.flatten(), 2): 403 if isinstance(a, kind): 404 aa, ab = a.unnest_operands() 405 406 # absorb 407 if is_complement(b, aa): 408 aa.replace(exp.true() if kind == exp.And else exp.false()) 409 elif is_complement(b, ab): 410 ab.replace(exp.true() if kind == exp.And else exp.false()) 411 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 412 a.replace(exp.false() if kind == exp.And else exp.true()) 413 elif isinstance(b, kind): 414 # eliminate 415 rhs = b.unnest_operands() 416 ba, bb = rhs 417 418 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 419 a.replace(aa) 420 b.replace(aa) 421 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 422 a.replace(ab) 423 b.replace(ab) 424 425 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
428def propagate_constants(expression, root=True): 429 """ 430 Propagate constants for conjunctions in DNF: 431 432 SELECT * FROM t WHERE a = b AND b = 5 becomes 433 SELECT * FROM t WHERE a = 5 AND b = 5 434 435 Reference: https://www.sqlite.org/optoverview.html 436 """ 437 438 if ( 439 isinstance(expression, exp.And) 440 and (root or not expression.same_parent) 441 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 442 ): 443 constant_mapping = {} 444 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 445 if isinstance(expr, exp.EQ): 446 l, r = expr.left, expr.right 447 448 # TODO: create a helper that can be used to detect nested literal expressions such 449 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 450 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 451 constant_mapping[l] = (id(l), r) 452 453 if constant_mapping: 454 for column in find_all_in_scope(expression, exp.Column): 455 parent = column.parent 456 column_id, constant = constant_mapping.get(column) or (None, None) 457 if ( 458 column_id is not None 459 and id(column) != column_id 460 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 461 ): 462 column.replace(constant.copy()) 463 464 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
126 def wrapped(expression, *args, **kwargs): 127 try: 128 return func(expression, *args, **kwargs) 129 except exceptions: 130 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
539def simplify_literals(expression, root=True): 540 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 541 return _flat_simplify(expression, _simplify_binary, root) 542 543 if isinstance(expression, exp.Neg): 544 this = expression.this 545 if this.is_number: 546 value = this.name 547 if value[0] == "-": 548 return exp.Literal.number(value[1:]) 549 return exp.Literal.number(f"-{value}") 550 551 if type(expression) in INVERSE_DATE_OPS: 552 return _simplify_binary(expression, expression.this, expression.interval()) or expression 553 554 return expression
655def simplify_parens(expression): 656 if not isinstance(expression, exp.Paren): 657 return expression 658 659 this = expression.this 660 parent = expression.parent 661 parent_is_predicate = isinstance(parent, exp.Predicate) 662 663 if ( 664 not isinstance(this, exp.Select) 665 and not isinstance(parent, exp.SubqueryPredicate) 666 and ( 667 not isinstance(parent, (exp.Condition, exp.Binary)) 668 or isinstance(parent, exp.Paren) 669 or ( 670 not isinstance(this, exp.Binary) 671 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 672 ) 673 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 674 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 675 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 676 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 677 ) 678 ): 679 return this 680 return expression
691def simplify_coalesce(expression): 692 # COALESCE(x) -> x 693 if ( 694 isinstance(expression, exp.Coalesce) 695 and (not expression.expressions or _is_nonnull_constant(expression.this)) 696 # COALESCE is also used as a Spark partitioning hint 697 and not isinstance(expression.parent, exp.Hint) 698 ): 699 return expression.this 700 701 if not isinstance(expression, COMPARISONS): 702 return expression 703 704 if isinstance(expression.left, exp.Coalesce): 705 coalesce = expression.left 706 other = expression.right 707 elif isinstance(expression.right, exp.Coalesce): 708 coalesce = expression.right 709 other = expression.left 710 else: 711 return expression 712 713 # This transformation is valid for non-constants, 714 # but it really only does anything if they are both constants. 715 if not _is_constant(other): 716 return expression 717 718 # Find the first constant arg 719 for arg_index, arg in enumerate(coalesce.expressions): 720 if _is_constant(arg): 721 break 722 else: 723 return expression 724 725 coalesce.set("expressions", coalesce.expressions[:arg_index]) 726 727 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 728 # since we already remove COALESCE at the top of this function. 729 coalesce = coalesce if coalesce.expressions else coalesce.this 730 731 # This expression is more complex than when we started, but it will get simplified further 732 return exp.paren( 733 exp.or_( 734 exp.and_( 735 coalesce.is_(exp.null()).not_(copy=False), 736 expression.copy(), 737 copy=False, 738 ), 739 exp.and_( 740 coalesce.is_(exp.null()), 741 type(expression)(this=arg.copy(), expression=other.copy()), 742 copy=False, 743 ), 744 copy=False, 745 ) 746 )
752def simplify_concat(expression): 753 """Reduces all groups that contain string literals by concatenating them.""" 754 if not isinstance(expression, CONCATS) or ( 755 # We can't reduce a CONCAT_WS call if we don't statically know the separator 756 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 757 ): 758 return expression 759 760 if isinstance(expression, exp.ConcatWs): 761 sep_expr, *expressions = expression.expressions 762 sep = sep_expr.name 763 concat_type = exp.ConcatWs 764 args = {} 765 else: 766 expressions = expression.expressions 767 sep = "" 768 concat_type = exp.Concat 769 args = { 770 "safe": expression.args.get("safe"), 771 "coalesce": expression.args.get("coalesce"), 772 } 773 774 new_args = [] 775 for is_string_group, group in itertools.groupby( 776 expressions or expression.flatten(), lambda e: e.is_string 777 ): 778 if is_string_group: 779 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 780 else: 781 new_args.extend(group) 782 783 if len(new_args) == 1 and new_args[0].is_string: 784 return new_args[0] 785 786 if concat_type is exp.ConcatWs: 787 new_args = [sep_expr] + new_args 788 elif isinstance(expression, exp.DPipe): 789 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 790 791 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
794def simplify_conditionals(expression): 795 """Simplifies expressions like IF, CASE if their condition is statically known.""" 796 if isinstance(expression, exp.Case): 797 this = expression.this 798 for case in expression.args["ifs"]: 799 cond = case.this 800 if this: 801 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 802 cond = cond.replace(this.pop().eq(cond)) 803 804 if always_true(cond): 805 return case.args["true"] 806 807 if always_false(cond): 808 case.pop() 809 if not expression.args["ifs"]: 810 return expression.args.get("default") or exp.null() 811 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 812 if always_true(expression.this): 813 return expression.args["true"] 814 if always_false(expression.this): 815 return expression.args.get("false") or exp.null() 816 817 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
820def simplify_startswith(expression: exp.Expression) -> exp.Expression: 821 """ 822 Reduces a prefix check to either TRUE or FALSE if both the string and the 823 prefix are statically known. 824 825 Example: 826 >>> from sqlglot import parse_one 827 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 828 'TRUE' 829 """ 830 if ( 831 isinstance(expression, exp.StartsWith) 832 and expression.this.is_string 833 and expression.expression.is_string 834 ): 835 return exp.convert(expression.name.startswith(expression.expression.name)) 836 837 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
126 def wrapped(expression, *args, **kwargs): 127 try: 128 return func(expression, *args, **kwargs) 129 except exceptions: 130 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
986def sort_comparison(expression: exp.Expression) -> exp.Expression: 987 if expression.__class__ in COMPLEMENT_COMPARISONS: 988 l, r = expression.this, expression.expression 989 l_column = isinstance(l, exp.Column) 990 r_column = isinstance(r, exp.Column) 991 l_const = _is_constant(l) 992 r_const = _is_constant(r) 993 994 if (l_column and not r_column) or (r_const and not l_const): 995 return expression 996 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 997 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 998 this=r, expression=l 999 ) 1000 return expression
1014def remove_where_true(expression): 1015 for where in expression.find_all(exp.Where): 1016 if always_true(where.this): 1017 where.pop() 1018 for join in expression.find_all(exp.Join): 1019 if ( 1020 always_true(join.args.get("on")) 1021 and not join.args.get("using") 1022 and not join.args.get("method") 1023 and (join.side, join.kind) in JOINS 1024 ): 1025 join.args["on"].pop() 1026 join.set("side", None) 1027 join.set("kind", "CROSS")
1052def eval_boolean(expression, a, b): 1053 if isinstance(expression, (exp.EQ, exp.Is)): 1054 return boolean_literal(a == b) 1055 if isinstance(expression, exp.NEQ): 1056 return boolean_literal(a != b) 1057 if isinstance(expression, exp.GT): 1058 return boolean_literal(a > b) 1059 if isinstance(expression, exp.GTE): 1060 return boolean_literal(a >= b) 1061 if isinstance(expression, exp.LT): 1062 return boolean_literal(a < b) 1063 if isinstance(expression, exp.LTE): 1064 return boolean_literal(a <= b) 1065 return None
1068def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1069 if isinstance(value, datetime.datetime): 1070 return value.date() 1071 if isinstance(value, datetime.date): 1072 return value 1073 try: 1074 return datetime.datetime.fromisoformat(value).date() 1075 except ValueError: 1076 return None
1079def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1080 if isinstance(value, datetime.datetime): 1081 return value 1082 if isinstance(value, datetime.date): 1083 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1084 try: 1085 return datetime.datetime.fromisoformat(value) 1086 except ValueError: 1087 return None
1090def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1091 if not value: 1092 return None 1093 if to.is_type(exp.DataType.Type.DATE): 1094 return cast_as_date(value) 1095 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1096 return cast_as_datetime(value) 1097 return None
1100def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1101 if isinstance(cast, exp.Cast): 1102 to = cast.to 1103 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1104 to = exp.DataType.build(exp.DataType.Type.DATE) 1105 else: 1106 return None 1107 1108 if isinstance(cast.this, exp.Literal): 1109 value: t.Any = cast.this.name 1110 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1111 value = extract_date(cast.this) 1112 else: 1113 return None 1114 return cast_value(value, to)
1140def date_literal(date, target_type=None): 1141 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1142 target_type = ( 1143 exp.DataType.Type.DATETIME 1144 if isinstance(date, datetime.datetime) 1145 else exp.DataType.Type.DATE 1146 ) 1147 1148 return exp.cast(exp.Literal.string(date), target_type)
1151def interval(unit: str, n: int = 1): 1152 from dateutil.relativedelta import relativedelta 1153 1154 if unit == "year": 1155 return relativedelta(years=1 * n) 1156 if unit == "quarter": 1157 return relativedelta(months=3 * n) 1158 if unit == "month": 1159 return relativedelta(months=1 * n) 1160 if unit == "week": 1161 return relativedelta(weeks=1 * n) 1162 if unit == "day": 1163 return relativedelta(days=1 * n) 1164 if unit == "hour": 1165 return relativedelta(hours=1 * n) 1166 if unit == "minute": 1167 return relativedelta(minutes=1 * n) 1168 if unit == "second": 1169 return relativedelta(seconds=1 * n) 1170 1171 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1174def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1175 if unit == "year": 1176 return d.replace(month=1, day=1) 1177 if unit == "quarter": 1178 if d.month <= 3: 1179 return d.replace(month=1, day=1) 1180 elif d.month <= 6: 1181 return d.replace(month=4, day=1) 1182 elif d.month <= 9: 1183 return d.replace(month=7, day=1) 1184 else: 1185 return d.replace(month=10, day=1) 1186 if unit == "month": 1187 return d.replace(month=d.month, day=1) 1188 if unit == "week": 1189 # Assuming week starts on Monday (0) and ends on Sunday (6) 1190 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1191 if unit == "day": 1192 return d 1193 1194 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1236def gen(expression: t.Any) -> str: 1237 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1238 1239 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1240 generator is expensive so we have a bare minimum sql generator here. 1241 """ 1242 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1245class Gen: 1246 def __init__(self): 1247 self.stack = [] 1248 self.sqls = [] 1249 1250 def gen(self, expression: exp.Expression) -> str: 1251 self.stack = [expression] 1252 self.sqls.clear() 1253 1254 while self.stack: 1255 node = self.stack.pop() 1256 1257 if isinstance(node, exp.Expression): 1258 exp_handler_name = f"{node.key}_sql" 1259 1260 if hasattr(self, exp_handler_name): 1261 getattr(self, exp_handler_name)(node) 1262 elif isinstance(node, exp.Func): 1263 self._function(node) 1264 else: 1265 key = node.key.upper() 1266 self.stack.append(f"{key} " if self._args(node) else key) 1267 elif type(node) is list: 1268 for n in reversed(node): 1269 if n is not None: 1270 self.stack.extend((n, ",")) 1271 if node: 1272 self.stack.pop() 1273 else: 1274 if node is not None: 1275 self.sqls.append(str(node)) 1276 1277 return "".join(self.sqls) 1278 1279 def add_sql(self, e: exp.Add) -> None: 1280 self._binary(e, " + ") 1281 1282 def alias_sql(self, e: exp.Alias) -> None: 1283 self.stack.extend( 1284 ( 1285 e.args.get("alias"), 1286 " AS ", 1287 e.args.get("this"), 1288 ) 1289 ) 1290 1291 def and_sql(self, e: exp.And) -> None: 1292 self._binary(e, " AND ") 1293 1294 def anonymous_sql(self, e: exp.Anonymous) -> None: 1295 this = e.this 1296 if isinstance(this, str): 1297 name = this.upper() 1298 elif isinstance(this, exp.Identifier): 1299 name = this.this 1300 name = f'"{name}"' if this.quoted else name.upper() 1301 else: 1302 raise ValueError( 1303 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1304 ) 1305 1306 self.stack.extend( 1307 ( 1308 ")", 1309 e.expressions, 1310 "(", 1311 name, 1312 ) 1313 ) 1314 1315 def between_sql(self, e: exp.Between) -> None: 1316 self.stack.extend( 1317 ( 1318 e.args.get("high"), 1319 " AND ", 1320 e.args.get("low"), 1321 " BETWEEN ", 1322 e.this, 1323 ) 1324 ) 1325 1326 def boolean_sql(self, e: exp.Boolean) -> None: 1327 self.stack.append("TRUE" if e.this else "FALSE") 1328 1329 def bracket_sql(self, e: exp.Bracket) -> None: 1330 self.stack.extend( 1331 ( 1332 "]", 1333 e.expressions, 1334 "[", 1335 e.this, 1336 ) 1337 ) 1338 1339 def column_sql(self, e: exp.Column) -> None: 1340 for p in reversed(e.parts): 1341 self.stack.extend((p, ".")) 1342 self.stack.pop() 1343 1344 def datatype_sql(self, e: exp.DataType) -> None: 1345 self._args(e, 1) 1346 self.stack.append(f"{e.this.name} ") 1347 1348 def div_sql(self, e: exp.Div) -> None: 1349 self._binary(e, " / ") 1350 1351 def dot_sql(self, e: exp.Dot) -> None: 1352 self._binary(e, ".") 1353 1354 def eq_sql(self, e: exp.EQ) -> None: 1355 self._binary(e, " = ") 1356 1357 def from_sql(self, e: exp.From) -> None: 1358 self.stack.extend((e.this, "FROM ")) 1359 1360 def gt_sql(self, e: exp.GT) -> None: 1361 self._binary(e, " > ") 1362 1363 def gte_sql(self, e: exp.GTE) -> None: 1364 self._binary(e, " >= ") 1365 1366 def identifier_sql(self, e: exp.Identifier) -> None: 1367 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1368 1369 def ilike_sql(self, e: exp.ILike) -> None: 1370 self._binary(e, " ILIKE ") 1371 1372 def in_sql(self, e: exp.In) -> None: 1373 self.stack.append(")") 1374 self._args(e, 1) 1375 self.stack.extend( 1376 ( 1377 "(", 1378 " IN ", 1379 e.this, 1380 ) 1381 ) 1382 1383 def intdiv_sql(self, e: exp.IntDiv) -> None: 1384 self._binary(e, " DIV ") 1385 1386 def is_sql(self, e: exp.Is) -> None: 1387 self._binary(e, " IS ") 1388 1389 def like_sql(self, e: exp.Like) -> None: 1390 self._binary(e, " Like ") 1391 1392 def literal_sql(self, e: exp.Literal) -> None: 1393 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1394 1395 def lt_sql(self, e: exp.LT) -> None: 1396 self._binary(e, " < ") 1397 1398 def lte_sql(self, e: exp.LTE) -> None: 1399 self._binary(e, " <= ") 1400 1401 def mod_sql(self, e: exp.Mod) -> None: 1402 self._binary(e, " % ") 1403 1404 def mul_sql(self, e: exp.Mul) -> None: 1405 self._binary(e, " * ") 1406 1407 def neg_sql(self, e: exp.Neg) -> None: 1408 self._unary(e, "-") 1409 1410 def neq_sql(self, e: exp.NEQ) -> None: 1411 self._binary(e, " <> ") 1412 1413 def not_sql(self, e: exp.Not) -> None: 1414 self._unary(e, "NOT ") 1415 1416 def null_sql(self, e: exp.Null) -> None: 1417 self.stack.append("NULL") 1418 1419 def or_sql(self, e: exp.Or) -> None: 1420 self._binary(e, " OR ") 1421 1422 def paren_sql(self, e: exp.Paren) -> None: 1423 self.stack.extend( 1424 ( 1425 ")", 1426 e.this, 1427 "(", 1428 ) 1429 ) 1430 1431 def sub_sql(self, e: exp.Sub) -> None: 1432 self._binary(e, " - ") 1433 1434 def subquery_sql(self, e: exp.Subquery) -> None: 1435 self._args(e, 2) 1436 alias = e.args.get("alias") 1437 if alias: 1438 self.stack.append(alias) 1439 self.stack.extend((")", e.this, "(")) 1440 1441 def table_sql(self, e: exp.Table) -> None: 1442 self._args(e, 4) 1443 alias = e.args.get("alias") 1444 if alias: 1445 self.stack.append(alias) 1446 for p in reversed(e.parts): 1447 self.stack.extend((p, ".")) 1448 self.stack.pop() 1449 1450 def tablealias_sql(self, e: exp.TableAlias) -> None: 1451 columns = e.columns 1452 1453 if columns: 1454 self.stack.extend((")", columns, "(")) 1455 1456 self.stack.extend((e.this, " AS ")) 1457 1458 def var_sql(self, e: exp.Var) -> None: 1459 self.stack.append(e.this) 1460 1461 def _binary(self, e: exp.Binary, op: str) -> None: 1462 self.stack.extend((e.expression, op, e.this)) 1463 1464 def _unary(self, e: exp.Unary, op: str) -> None: 1465 self.stack.extend((e.this, op)) 1466 1467 def _function(self, e: exp.Func) -> None: 1468 self.stack.extend( 1469 ( 1470 ")", 1471 list(e.args.values()), 1472 "(", 1473 e.sql_name(), 1474 ) 1475 ) 1476 1477 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1478 kvs = [] 1479 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1480 1481 for k in arg_types or arg_types: 1482 v = node.args.get(k) 1483 1484 if v is not None: 1485 kvs.append([f":{k}", v]) 1486 if kvs: 1487 self.stack.append(kvs) 1488 return True 1489 return False
1250 def gen(self, expression: exp.Expression) -> str: 1251 self.stack = [expression] 1252 self.sqls.clear() 1253 1254 while self.stack: 1255 node = self.stack.pop() 1256 1257 if isinstance(node, exp.Expression): 1258 exp_handler_name = f"{node.key}_sql" 1259 1260 if hasattr(self, exp_handler_name): 1261 getattr(self, exp_handler_name)(node) 1262 elif isinstance(node, exp.Func): 1263 self._function(node) 1264 else: 1265 key = node.key.upper() 1266 self.stack.append(f"{key} " if self._args(node) else key) 1267 elif type(node) is list: 1268 for n in reversed(node): 1269 if n is not None: 1270 self.stack.extend((n, ",")) 1271 if node: 1272 self.stack.pop() 1273 else: 1274 if node is not None: 1275 self.sqls.append(str(node)) 1276 1277 return "".join(self.sqls)
1294 def anonymous_sql(self, e: exp.Anonymous) -> None: 1295 this = e.this 1296 if isinstance(this, str): 1297 name = this.upper() 1298 elif isinstance(this, exp.Identifier): 1299 name = this.this 1300 name = f'"{name}"' if this.quoted else name.upper() 1301 else: 1302 raise ValueError( 1303 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1304 ) 1305 1306 self.stack.extend( 1307 ( 1308 ")", 1309 e.expressions, 1310 "(", 1311 name, 1312 ) 1313 )