sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25# Value ranges for byte-sized signed/unsigned integers 26TINYINT_MIN = -128 27TINYINT_MAX = 127 28UTINYINT_MIN = 0 29UTINYINT_MAX = 255 30 31 32class UnsupportedUnit(Exception): 33 pass 34 35 36def simplify( 37 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 38): 39 """ 40 Rewrite sqlglot AST to simplify expressions. 41 42 Example: 43 >>> import sqlglot 44 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 45 >>> simplify(expression).sql() 46 'TRUE' 47 48 Args: 49 expression (sqlglot.Expression): expression to simplify 50 constant_propagation: whether the constant propagation rule should be used 51 52 Returns: 53 sqlglot.Expression: simplified expression 54 """ 55 56 dialect = Dialect.get_or_raise(dialect) 57 58 def _simplify(expression, root=True): 59 if expression.meta.get(FINAL): 60 return expression 61 62 # group by expressions cannot be simplified, for example 63 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 64 # the projection must exactly match the group by key 65 group = expression.args.get("group") 66 67 if group and hasattr(expression, "selects"): 68 groups = set(group.expressions) 69 group.meta[FINAL] = True 70 71 for e in expression.selects: 72 for node in e.walk(): 73 if node in groups: 74 e.meta[FINAL] = True 75 break 76 77 having = expression.args.get("having") 78 if having: 79 for node in having.walk(): 80 if node in groups: 81 having.meta[FINAL] = True 82 break 83 84 # Pre-order transformations 85 node = expression 86 node = rewrite_between(node) 87 node = uniq_sort(node, root) 88 node = absorb_and_eliminate(node, root) 89 node = simplify_concat(node) 90 node = simplify_conditionals(node) 91 92 if constant_propagation: 93 node = propagate_constants(node, root) 94 95 exp.replace_children(node, lambda e: _simplify(e, False)) 96 97 # Post-order transformations 98 node = simplify_not(node) 99 node = flatten(node) 100 node = simplify_connectors(node, root) 101 node = remove_complements(node, root) 102 node = simplify_coalesce(node) 103 node.parent = expression.parent 104 node = simplify_literals(node, root) 105 node = simplify_equality(node) 106 node = simplify_parens(node) 107 node = simplify_datetrunc(node, dialect) 108 node = sort_comparison(node) 109 node = simplify_startswith(node) 110 111 if root: 112 expression.replace(node) 113 return node 114 115 expression = while_changing(expression, _simplify) 116 remove_where_true(expression) 117 return expression 118 119 120def catch(*exceptions): 121 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 122 123 def decorator(func): 124 def wrapped(expression, *args, **kwargs): 125 try: 126 return func(expression, *args, **kwargs) 127 except exceptions: 128 return expression 129 130 return wrapped 131 132 return decorator 133 134 135def rewrite_between(expression: exp.Expression) -> exp.Expression: 136 """Rewrite x between y and z to x >= y AND x <= z. 137 138 This is done because comparison simplification is only done on lt/lte/gt/gte. 139 """ 140 if isinstance(expression, exp.Between): 141 negate = isinstance(expression.parent, exp.Not) 142 143 expression = exp.and_( 144 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 145 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 146 copy=False, 147 ) 148 149 if negate: 150 expression = exp.paren(expression, copy=False) 151 152 return expression 153 154 155COMPLEMENT_COMPARISONS = { 156 exp.LT: exp.GTE, 157 exp.GT: exp.LTE, 158 exp.LTE: exp.GT, 159 exp.GTE: exp.LT, 160 exp.EQ: exp.NEQ, 161 exp.NEQ: exp.EQ, 162} 163 164 165def simplify_not(expression): 166 """ 167 Demorgan's Law 168 NOT (x OR y) -> NOT x AND NOT y 169 NOT (x AND y) -> NOT x OR NOT y 170 """ 171 if isinstance(expression, exp.Not): 172 this = expression.this 173 if is_null(this): 174 return exp.null() 175 if this.__class__ in COMPLEMENT_COMPARISONS: 176 return COMPLEMENT_COMPARISONS[this.__class__]( 177 this=this.this, expression=this.expression 178 ) 179 if isinstance(this, exp.Paren): 180 condition = this.unnest() 181 if isinstance(condition, exp.And): 182 return exp.paren( 183 exp.or_( 184 exp.not_(condition.left, copy=False), 185 exp.not_(condition.right, copy=False), 186 copy=False, 187 ) 188 ) 189 if isinstance(condition, exp.Or): 190 return exp.paren( 191 exp.and_( 192 exp.not_(condition.left, copy=False), 193 exp.not_(condition.right, copy=False), 194 copy=False, 195 ) 196 ) 197 if is_null(condition): 198 return exp.null() 199 if always_true(this): 200 return exp.false() 201 if is_false(this): 202 return exp.true() 203 if isinstance(this, exp.Not): 204 # double negation 205 # NOT NOT x -> x 206 return this.this 207 return expression 208 209 210def flatten(expression): 211 """ 212 A AND (B AND C) -> A AND B AND C 213 A OR (B OR C) -> A OR B OR C 214 """ 215 if isinstance(expression, exp.Connector): 216 for node in expression.args.values(): 217 child = node.unnest() 218 if isinstance(child, expression.__class__): 219 node.replace(child) 220 return expression 221 222 223def simplify_connectors(expression, root=True): 224 def _simplify_connectors(expression, left, right): 225 if left == right: 226 return left 227 if isinstance(expression, exp.And): 228 if is_false(left) or is_false(right): 229 return exp.false() 230 if is_null(left) or is_null(right): 231 return exp.null() 232 if always_true(left) and always_true(right): 233 return exp.true() 234 if always_true(left): 235 return right 236 if always_true(right): 237 return left 238 return _simplify_comparison(expression, left, right) 239 elif isinstance(expression, exp.Or): 240 if always_true(left) or always_true(right): 241 return exp.true() 242 if is_false(left) and is_false(right): 243 return exp.false() 244 if ( 245 (is_null(left) and is_null(right)) 246 or (is_null(left) and is_false(right)) 247 or (is_false(left) and is_null(right)) 248 ): 249 return exp.null() 250 if is_false(left): 251 return right 252 if is_false(right): 253 return left 254 return _simplify_comparison(expression, left, right, or_=True) 255 256 if isinstance(expression, exp.Connector): 257 return _flat_simplify(expression, _simplify_connectors, root) 258 return expression 259 260 261LT_LTE = (exp.LT, exp.LTE) 262GT_GTE = (exp.GT, exp.GTE) 263 264COMPARISONS = ( 265 *LT_LTE, 266 *GT_GTE, 267 exp.EQ, 268 exp.NEQ, 269 exp.Is, 270) 271 272INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 273 exp.LT: exp.GT, 274 exp.GT: exp.LT, 275 exp.LTE: exp.GTE, 276 exp.GTE: exp.LTE, 277} 278 279NONDETERMINISTIC = (exp.Rand, exp.Randn) 280 281 282def _simplify_comparison(expression, left, right, or_=False): 283 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 284 ll, lr = left.args.values() 285 rl, rr = right.args.values() 286 287 largs = {ll, lr} 288 rargs = {rl, rr} 289 290 matching = largs & rargs 291 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 292 293 if matching and columns: 294 try: 295 l = first(largs - columns) 296 r = first(rargs - columns) 297 except StopIteration: 298 return expression 299 300 if l.is_number and r.is_number: 301 l = float(l.name) 302 r = float(r.name) 303 elif l.is_string and r.is_string: 304 l = l.name 305 r = r.name 306 else: 307 l = extract_date(l) 308 if not l: 309 return None 310 r = extract_date(r) 311 if not r: 312 return None 313 # python won't compare date and datetime, but many engines will upcast 314 l, r = cast_as_datetime(l), cast_as_datetime(r) 315 316 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 317 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 318 return left if (av > bv if or_ else av <= bv) else right 319 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 320 return left if (av < bv if or_ else av >= bv) else right 321 322 # we can't ever shortcut to true because the column could be null 323 if not or_: 324 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 325 if av <= bv: 326 return exp.false() 327 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 328 if av >= bv: 329 return exp.false() 330 elif isinstance(a, exp.EQ): 331 if isinstance(b, exp.LT): 332 return exp.false() if av >= bv else a 333 if isinstance(b, exp.LTE): 334 return exp.false() if av > bv else a 335 if isinstance(b, exp.GT): 336 return exp.false() if av <= bv else a 337 if isinstance(b, exp.GTE): 338 return exp.false() if av < bv else a 339 if isinstance(b, exp.NEQ): 340 return exp.false() if av == bv else a 341 return None 342 343 344def remove_complements(expression, root=True): 345 """ 346 Removing complements. 347 348 A AND NOT A -> FALSE 349 A OR NOT A -> TRUE 350 """ 351 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 352 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 353 354 for a, b in itertools.permutations(expression.flatten(), 2): 355 if is_complement(a, b): 356 return complement 357 return expression 358 359 360def uniq_sort(expression, root=True): 361 """ 362 Uniq and sort a connector. 363 364 C AND A AND B AND B -> A AND B AND C 365 """ 366 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 367 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 368 flattened = tuple(expression.flatten()) 369 deduped = {gen(e): e for e in flattened} 370 arr = tuple(deduped.items()) 371 372 # check if the operands are already sorted, if not sort them 373 # A AND C AND B -> A AND B AND C 374 for i, (sql, e) in enumerate(arr[1:]): 375 if sql < arr[i][0]: 376 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 377 break 378 else: 379 # we didn't have to sort but maybe we need to dedup 380 if len(deduped) < len(flattened): 381 expression = result_func(*deduped.values(), copy=False) 382 383 return expression 384 385 386def absorb_and_eliminate(expression, root=True): 387 """ 388 absorption: 389 A AND (A OR B) -> A 390 A OR (A AND B) -> A 391 A AND (NOT A OR B) -> A AND B 392 A OR (NOT A AND B) -> A OR B 393 elimination: 394 (A AND B) OR (A AND NOT B) -> A 395 (A OR B) AND (A OR NOT B) -> A 396 """ 397 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 398 kind = exp.Or if isinstance(expression, exp.And) else exp.And 399 400 for a, b in itertools.permutations(expression.flatten(), 2): 401 if isinstance(a, kind): 402 aa, ab = a.unnest_operands() 403 404 # absorb 405 if is_complement(b, aa): 406 aa.replace(exp.true() if kind == exp.And else exp.false()) 407 elif is_complement(b, ab): 408 ab.replace(exp.true() if kind == exp.And else exp.false()) 409 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 410 a.replace(exp.false() if kind == exp.And else exp.true()) 411 elif isinstance(b, kind): 412 # eliminate 413 rhs = b.unnest_operands() 414 ba, bb = rhs 415 416 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 417 a.replace(aa) 418 b.replace(aa) 419 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 420 a.replace(ab) 421 b.replace(ab) 422 423 return expression 424 425 426def propagate_constants(expression, root=True): 427 """ 428 Propagate constants for conjunctions in DNF: 429 430 SELECT * FROM t WHERE a = b AND b = 5 becomes 431 SELECT * FROM t WHERE a = 5 AND b = 5 432 433 Reference: https://www.sqlite.org/optoverview.html 434 """ 435 436 if ( 437 isinstance(expression, exp.And) 438 and (root or not expression.same_parent) 439 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 440 ): 441 constant_mapping = {} 442 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 443 if isinstance(expr, exp.EQ): 444 l, r = expr.left, expr.right 445 446 # TODO: create a helper that can be used to detect nested literal expressions such 447 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 448 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 449 constant_mapping[l] = (id(l), r) 450 451 if constant_mapping: 452 for column in find_all_in_scope(expression, exp.Column): 453 parent = column.parent 454 column_id, constant = constant_mapping.get(column) or (None, None) 455 if ( 456 column_id is not None 457 and id(column) != column_id 458 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 459 ): 460 column.replace(constant.copy()) 461 462 return expression 463 464 465INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 466 exp.DateAdd: exp.Sub, 467 exp.DateSub: exp.Add, 468 exp.DatetimeAdd: exp.Sub, 469 exp.DatetimeSub: exp.Add, 470} 471 472INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 473 **INVERSE_DATE_OPS, 474 exp.Add: exp.Sub, 475 exp.Sub: exp.Add, 476} 477 478 479def _is_number(expression: exp.Expression) -> bool: 480 return expression.is_number 481 482 483def _is_interval(expression: exp.Expression) -> bool: 484 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 485 486 487@catch(ModuleNotFoundError, UnsupportedUnit) 488def simplify_equality(expression: exp.Expression) -> exp.Expression: 489 """ 490 Use the subtraction and addition properties of equality to simplify expressions: 491 492 x + 1 = 3 becomes x = 2 493 494 There are two binary operations in the above expression: + and = 495 Here's how we reference all the operands in the code below: 496 497 l r 498 x + 1 = 3 499 a b 500 """ 501 if isinstance(expression, COMPARISONS): 502 l, r = expression.left, expression.right 503 504 if l.__class__ not in INVERSE_OPS: 505 return expression 506 507 if r.is_number: 508 a_predicate = _is_number 509 b_predicate = _is_number 510 elif _is_date_literal(r): 511 a_predicate = _is_date_literal 512 b_predicate = _is_interval 513 else: 514 return expression 515 516 if l.__class__ in INVERSE_DATE_OPS: 517 l = t.cast(exp.IntervalOp, l) 518 a = l.this 519 b = l.interval() 520 else: 521 l = t.cast(exp.Binary, l) 522 a, b = l.left, l.right 523 524 if not a_predicate(a) and b_predicate(b): 525 pass 526 elif not a_predicate(b) and b_predicate(a): 527 a, b = b, a 528 else: 529 return expression 530 531 return expression.__class__( 532 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 533 ) 534 return expression 535 536 537def simplify_literals(expression, root=True): 538 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 539 return _flat_simplify(expression, _simplify_binary, root) 540 541 if isinstance(expression, exp.Neg): 542 this = expression.this 543 if this.is_number: 544 value = this.name 545 if value[0] == "-": 546 return exp.Literal.number(value[1:]) 547 return exp.Literal.number(f"-{value}") 548 549 if type(expression) in INVERSE_DATE_OPS: 550 return _simplify_binary(expression, expression.this, expression.interval()) or expression 551 552 return expression 553 554 555NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 556 557 558def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 559 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 560 this = _simplify_integer_cast(expr.this) 561 else: 562 this = expr.this 563 564 if isinstance(expr, exp.Cast) and this.is_int: 565 num = int(this.name) 566 567 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 568 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 569 # engine-dependent 570 if ( 571 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 572 ) or ( 573 UTINYINT_MIN <= num <= UTINYINT_MAX 574 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 575 ): 576 return this 577 578 return expr 579 580 581def _simplify_binary(expression, a, b): 582 if isinstance(expression, COMPARISONS): 583 a = _simplify_integer_cast(a) 584 b = _simplify_integer_cast(b) 585 586 if isinstance(expression, exp.Is): 587 if isinstance(b, exp.Not): 588 c = b.this 589 not_ = True 590 else: 591 c = b 592 not_ = False 593 594 if is_null(c): 595 if isinstance(a, exp.Literal): 596 return exp.true() if not_ else exp.false() 597 if is_null(a): 598 return exp.false() if not_ else exp.true() 599 elif isinstance(expression, NULL_OK): 600 return None 601 elif is_null(a) or is_null(b): 602 return exp.null() 603 604 if a.is_number and b.is_number: 605 num_a = int(a.name) if a.is_int else Decimal(a.name) 606 num_b = int(b.name) if b.is_int else Decimal(b.name) 607 608 if isinstance(expression, exp.Add): 609 return exp.Literal.number(num_a + num_b) 610 if isinstance(expression, exp.Mul): 611 return exp.Literal.number(num_a * num_b) 612 613 # We only simplify Sub, Div if a and b have the same parent because they're not associative 614 if isinstance(expression, exp.Sub): 615 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 616 if isinstance(expression, exp.Div): 617 # engines have differing int div behavior so intdiv is not safe 618 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 619 return None 620 return exp.Literal.number(num_a / num_b) 621 622 boolean = eval_boolean(expression, num_a, num_b) 623 624 if boolean: 625 return boolean 626 elif a.is_string and b.is_string: 627 boolean = eval_boolean(expression, a.this, b.this) 628 629 if boolean: 630 return boolean 631 elif _is_date_literal(a) and isinstance(b, exp.Interval): 632 a, b = extract_date(a), extract_interval(b) 633 if a and b: 634 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 635 return date_literal(a + b) 636 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 637 return date_literal(a - b) 638 elif isinstance(a, exp.Interval) and _is_date_literal(b): 639 a, b = extract_interval(a), extract_date(b) 640 # you cannot subtract a date from an interval 641 if a and b and isinstance(expression, exp.Add): 642 return date_literal(a + b) 643 elif _is_date_literal(a) and _is_date_literal(b): 644 if isinstance(expression, exp.Predicate): 645 a, b = extract_date(a), extract_date(b) 646 boolean = eval_boolean(expression, a, b) 647 if boolean: 648 return boolean 649 650 return None 651 652 653def simplify_parens(expression): 654 if not isinstance(expression, exp.Paren): 655 return expression 656 657 this = expression.this 658 parent = expression.parent 659 parent_is_predicate = isinstance(parent, exp.Predicate) 660 661 if not isinstance(this, exp.Select) and ( 662 not isinstance(parent, (exp.Condition, exp.Binary)) 663 or isinstance(parent, exp.Paren) 664 or ( 665 not isinstance(this, exp.Binary) 666 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 667 ) 668 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 669 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 670 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 671 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 672 ): 673 return this 674 return expression 675 676 677def _is_nonnull_constant(expression: exp.Expression) -> bool: 678 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 679 680 681def _is_constant(expression: exp.Expression) -> bool: 682 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 683 684 685def simplify_coalesce(expression): 686 # COALESCE(x) -> x 687 if ( 688 isinstance(expression, exp.Coalesce) 689 and (not expression.expressions or _is_nonnull_constant(expression.this)) 690 # COALESCE is also used as a Spark partitioning hint 691 and not isinstance(expression.parent, exp.Hint) 692 ): 693 return expression.this 694 695 if not isinstance(expression, COMPARISONS): 696 return expression 697 698 if isinstance(expression.left, exp.Coalesce): 699 coalesce = expression.left 700 other = expression.right 701 elif isinstance(expression.right, exp.Coalesce): 702 coalesce = expression.right 703 other = expression.left 704 else: 705 return expression 706 707 # This transformation is valid for non-constants, 708 # but it really only does anything if they are both constants. 709 if not _is_constant(other): 710 return expression 711 712 # Find the first constant arg 713 for arg_index, arg in enumerate(coalesce.expressions): 714 if _is_constant(arg): 715 break 716 else: 717 return expression 718 719 coalesce.set("expressions", coalesce.expressions[:arg_index]) 720 721 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 722 # since we already remove COALESCE at the top of this function. 723 coalesce = coalesce if coalesce.expressions else coalesce.this 724 725 # This expression is more complex than when we started, but it will get simplified further 726 return exp.paren( 727 exp.or_( 728 exp.and_( 729 coalesce.is_(exp.null()).not_(copy=False), 730 expression.copy(), 731 copy=False, 732 ), 733 exp.and_( 734 coalesce.is_(exp.null()), 735 type(expression)(this=arg.copy(), expression=other.copy()), 736 copy=False, 737 ), 738 copy=False, 739 ) 740 ) 741 742 743CONCATS = (exp.Concat, exp.DPipe) 744 745 746def simplify_concat(expression): 747 """Reduces all groups that contain string literals by concatenating them.""" 748 if not isinstance(expression, CONCATS) or ( 749 # We can't reduce a CONCAT_WS call if we don't statically know the separator 750 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 751 ): 752 return expression 753 754 if isinstance(expression, exp.ConcatWs): 755 sep_expr, *expressions = expression.expressions 756 sep = sep_expr.name 757 concat_type = exp.ConcatWs 758 args = {} 759 else: 760 expressions = expression.expressions 761 sep = "" 762 concat_type = exp.Concat 763 args = { 764 "safe": expression.args.get("safe"), 765 "coalesce": expression.args.get("coalesce"), 766 } 767 768 new_args = [] 769 for is_string_group, group in itertools.groupby( 770 expressions or expression.flatten(), lambda e: e.is_string 771 ): 772 if is_string_group: 773 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 774 else: 775 new_args.extend(group) 776 777 if len(new_args) == 1 and new_args[0].is_string: 778 return new_args[0] 779 780 if concat_type is exp.ConcatWs: 781 new_args = [sep_expr] + new_args 782 783 return concat_type(expressions=new_args, **args) 784 785 786def simplify_conditionals(expression): 787 """Simplifies expressions like IF, CASE if their condition is statically known.""" 788 if isinstance(expression, exp.Case): 789 this = expression.this 790 for case in expression.args["ifs"]: 791 cond = case.this 792 if this: 793 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 794 cond = cond.replace(this.pop().eq(cond)) 795 796 if always_true(cond): 797 return case.args["true"] 798 799 if always_false(cond): 800 case.pop() 801 if not expression.args["ifs"]: 802 return expression.args.get("default") or exp.null() 803 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 804 if always_true(expression.this): 805 return expression.args["true"] 806 if always_false(expression.this): 807 return expression.args.get("false") or exp.null() 808 809 return expression 810 811 812def simplify_startswith(expression: exp.Expression) -> exp.Expression: 813 """ 814 Reduces a prefix check to either TRUE or FALSE if both the string and the 815 prefix are statically known. 816 817 Example: 818 >>> from sqlglot import parse_one 819 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 820 'TRUE' 821 """ 822 if ( 823 isinstance(expression, exp.StartsWith) 824 and expression.this.is_string 825 and expression.expression.is_string 826 ): 827 return exp.convert(expression.name.startswith(expression.expression.name)) 828 829 return expression 830 831 832DateRange = t.Tuple[datetime.date, datetime.date] 833 834 835def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 836 """ 837 Get the date range for a DATE_TRUNC equality comparison: 838 839 Example: 840 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 841 Returns: 842 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 843 """ 844 floor = date_floor(date, unit, dialect) 845 846 if date != floor: 847 # This will always be False, except for NULL values. 848 return None 849 850 return floor, floor + interval(unit) 851 852 853def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 854 """Get the logical expression for a date range""" 855 return exp.and_( 856 left >= date_literal(drange[0]), 857 left < date_literal(drange[1]), 858 copy=False, 859 ) 860 861 862def _datetrunc_eq( 863 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 864) -> t.Optional[exp.Expression]: 865 drange = _datetrunc_range(date, unit, dialect) 866 if not drange: 867 return None 868 869 return _datetrunc_eq_expression(left, drange) 870 871 872def _datetrunc_neq( 873 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 874) -> t.Optional[exp.Expression]: 875 drange = _datetrunc_range(date, unit, dialect) 876 if not drange: 877 return None 878 879 return exp.and_( 880 left < date_literal(drange[0]), 881 left >= date_literal(drange[1]), 882 copy=False, 883 ) 884 885 886DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 887 exp.LT: lambda l, dt, u, d: l 888 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 889 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 890 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 891 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 892 exp.EQ: _datetrunc_eq, 893 exp.NEQ: _datetrunc_neq, 894} 895DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 896DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 897 898 899def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 900 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 901 902 903@catch(ModuleNotFoundError, UnsupportedUnit) 904def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 905 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 906 comparison = expression.__class__ 907 908 if isinstance(expression, DATETRUNCS): 909 date = extract_date(expression.this) 910 if date and expression.unit: 911 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 912 elif comparison not in DATETRUNC_COMPARISONS: 913 return expression 914 915 if isinstance(expression, exp.Binary): 916 l, r = expression.left, expression.right 917 918 if not _is_datetrunc_predicate(l, r): 919 return expression 920 921 l = t.cast(exp.DateTrunc, l) 922 unit = l.unit.name.lower() 923 date = extract_date(r) 924 925 if not date: 926 return expression 927 928 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 929 elif isinstance(expression, exp.In): 930 l = expression.this 931 rs = expression.expressions 932 933 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 934 l = t.cast(exp.DateTrunc, l) 935 unit = l.unit.name.lower() 936 937 ranges = [] 938 for r in rs: 939 date = extract_date(r) 940 if not date: 941 return expression 942 drange = _datetrunc_range(date, unit, dialect) 943 if drange: 944 ranges.append(drange) 945 946 if not ranges: 947 return expression 948 949 ranges = merge_ranges(ranges) 950 951 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 952 953 return expression 954 955 956def sort_comparison(expression: exp.Expression) -> exp.Expression: 957 if expression.__class__ in COMPLEMENT_COMPARISONS: 958 l, r = expression.this, expression.expression 959 l_column = isinstance(l, exp.Column) 960 r_column = isinstance(r, exp.Column) 961 l_const = _is_constant(l) 962 r_const = _is_constant(r) 963 964 if (l_column and not r_column) or (r_const and not l_const): 965 return expression 966 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 967 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 968 this=r, expression=l 969 ) 970 return expression 971 972 973# CROSS joins result in an empty table if the right table is empty. 974# So we can only simplify certain types of joins to CROSS. 975# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 976JOINS = { 977 ("", ""), 978 ("", "INNER"), 979 ("RIGHT", ""), 980 ("RIGHT", "OUTER"), 981} 982 983 984def remove_where_true(expression): 985 for where in expression.find_all(exp.Where): 986 if always_true(where.this): 987 where.pop() 988 for join in expression.find_all(exp.Join): 989 if ( 990 always_true(join.args.get("on")) 991 and not join.args.get("using") 992 and not join.args.get("method") 993 and (join.side, join.kind) in JOINS 994 ): 995 join.args["on"].pop() 996 join.set("side", None) 997 join.set("kind", "CROSS") 998 999 1000def always_true(expression): 1001 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 1002 expression, exp.Literal 1003 ) 1004 1005 1006def always_false(expression): 1007 return is_false(expression) or is_null(expression) 1008 1009 1010def is_complement(a, b): 1011 return isinstance(b, exp.Not) and b.this == a 1012 1013 1014def is_false(a: exp.Expression) -> bool: 1015 return type(a) is exp.Boolean and not a.this 1016 1017 1018def is_null(a: exp.Expression) -> bool: 1019 return type(a) is exp.Null 1020 1021 1022def eval_boolean(expression, a, b): 1023 if isinstance(expression, (exp.EQ, exp.Is)): 1024 return boolean_literal(a == b) 1025 if isinstance(expression, exp.NEQ): 1026 return boolean_literal(a != b) 1027 if isinstance(expression, exp.GT): 1028 return boolean_literal(a > b) 1029 if isinstance(expression, exp.GTE): 1030 return boolean_literal(a >= b) 1031 if isinstance(expression, exp.LT): 1032 return boolean_literal(a < b) 1033 if isinstance(expression, exp.LTE): 1034 return boolean_literal(a <= b) 1035 return None 1036 1037 1038def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1039 if isinstance(value, datetime.datetime): 1040 return value.date() 1041 if isinstance(value, datetime.date): 1042 return value 1043 try: 1044 return datetime.datetime.fromisoformat(value).date() 1045 except ValueError: 1046 return None 1047 1048 1049def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1050 if isinstance(value, datetime.datetime): 1051 return value 1052 if isinstance(value, datetime.date): 1053 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1054 try: 1055 return datetime.datetime.fromisoformat(value) 1056 except ValueError: 1057 return None 1058 1059 1060def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1061 if not value: 1062 return None 1063 if to.is_type(exp.DataType.Type.DATE): 1064 return cast_as_date(value) 1065 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1066 return cast_as_datetime(value) 1067 return None 1068 1069 1070def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1071 if isinstance(cast, exp.Cast): 1072 to = cast.to 1073 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1074 to = exp.DataType.build(exp.DataType.Type.DATE) 1075 else: 1076 return None 1077 1078 if isinstance(cast.this, exp.Literal): 1079 value: t.Any = cast.this.name 1080 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1081 value = extract_date(cast.this) 1082 else: 1083 return None 1084 return cast_value(value, to) 1085 1086 1087def _is_date_literal(expression: exp.Expression) -> bool: 1088 return extract_date(expression) is not None 1089 1090 1091def extract_interval(expression): 1092 try: 1093 n = int(expression.name) 1094 unit = expression.text("unit").lower() 1095 return interval(unit, n) 1096 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1097 return None 1098 1099 1100def date_literal(date): 1101 return exp.cast( 1102 exp.Literal.string(date), 1103 ( 1104 exp.DataType.Type.DATETIME 1105 if isinstance(date, datetime.datetime) 1106 else exp.DataType.Type.DATE 1107 ), 1108 ) 1109 1110 1111def interval(unit: str, n: int = 1): 1112 from dateutil.relativedelta import relativedelta 1113 1114 if unit == "year": 1115 return relativedelta(years=1 * n) 1116 if unit == "quarter": 1117 return relativedelta(months=3 * n) 1118 if unit == "month": 1119 return relativedelta(months=1 * n) 1120 if unit == "week": 1121 return relativedelta(weeks=1 * n) 1122 if unit == "day": 1123 return relativedelta(days=1 * n) 1124 if unit == "hour": 1125 return relativedelta(hours=1 * n) 1126 if unit == "minute": 1127 return relativedelta(minutes=1 * n) 1128 if unit == "second": 1129 return relativedelta(seconds=1 * n) 1130 1131 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1132 1133 1134def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1135 if unit == "year": 1136 return d.replace(month=1, day=1) 1137 if unit == "quarter": 1138 if d.month <= 3: 1139 return d.replace(month=1, day=1) 1140 elif d.month <= 6: 1141 return d.replace(month=4, day=1) 1142 elif d.month <= 9: 1143 return d.replace(month=7, day=1) 1144 else: 1145 return d.replace(month=10, day=1) 1146 if unit == "month": 1147 return d.replace(month=d.month, day=1) 1148 if unit == "week": 1149 # Assuming week starts on Monday (0) and ends on Sunday (6) 1150 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1151 if unit == "day": 1152 return d 1153 1154 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1155 1156 1157def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1158 floor = date_floor(d, unit, dialect) 1159 1160 if floor == d: 1161 return d 1162 1163 return floor + interval(unit) 1164 1165 1166def boolean_literal(condition): 1167 return exp.true() if condition else exp.false() 1168 1169 1170def _flat_simplify(expression, simplifier, root=True): 1171 if root or not expression.same_parent: 1172 operands = [] 1173 queue = deque(expression.flatten(unnest=False)) 1174 size = len(queue) 1175 1176 while queue: 1177 a = queue.popleft() 1178 1179 for b in queue: 1180 result = simplifier(expression, a, b) 1181 1182 if result and result is not expression: 1183 queue.remove(b) 1184 queue.appendleft(result) 1185 break 1186 else: 1187 operands.append(a) 1188 1189 if len(operands) < size: 1190 return functools.reduce( 1191 lambda a, b: expression.__class__(this=a, expression=b), operands 1192 ) 1193 return expression 1194 1195 1196def gen(expression: t.Any) -> str: 1197 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1198 1199 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1200 generator is expensive so we have a bare minimum sql generator here. 1201 """ 1202 return Gen().gen(expression) 1203 1204 1205class Gen: 1206 def __init__(self): 1207 self.stack = [] 1208 self.sqls = [] 1209 1210 def gen(self, expression: exp.Expression) -> str: 1211 self.stack = [expression] 1212 self.sqls.clear() 1213 1214 while self.stack: 1215 node = self.stack.pop() 1216 1217 if isinstance(node, exp.Expression): 1218 exp_handler_name = f"{node.key}_sql" 1219 1220 if hasattr(self, exp_handler_name): 1221 getattr(self, exp_handler_name)(node) 1222 elif isinstance(node, exp.Func): 1223 self._function(node) 1224 else: 1225 key = node.key.upper() 1226 self.stack.append(f"{key} " if self._args(node) else key) 1227 elif type(node) is list: 1228 for n in reversed(node): 1229 if n is not None: 1230 self.stack.extend((n, ",")) 1231 if node: 1232 self.stack.pop() 1233 else: 1234 if node is not None: 1235 self.sqls.append(str(node)) 1236 1237 return "".join(self.sqls) 1238 1239 def add_sql(self, e: exp.Add) -> None: 1240 self._binary(e, " + ") 1241 1242 def alias_sql(self, e: exp.Alias) -> None: 1243 self.stack.extend( 1244 ( 1245 e.args.get("alias"), 1246 " AS ", 1247 e.args.get("this"), 1248 ) 1249 ) 1250 1251 def and_sql(self, e: exp.And) -> None: 1252 self._binary(e, " AND ") 1253 1254 def anonymous_sql(self, e: exp.Anonymous) -> None: 1255 this = e.this 1256 if isinstance(this, str): 1257 name = this.upper() 1258 elif isinstance(this, exp.Identifier): 1259 name = this.this 1260 name = f'"{name}"' if this.quoted else name.upper() 1261 else: 1262 raise ValueError( 1263 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1264 ) 1265 1266 self.stack.extend( 1267 ( 1268 ")", 1269 e.expressions, 1270 "(", 1271 name, 1272 ) 1273 ) 1274 1275 def between_sql(self, e: exp.Between) -> None: 1276 self.stack.extend( 1277 ( 1278 e.args.get("high"), 1279 " AND ", 1280 e.args.get("low"), 1281 " BETWEEN ", 1282 e.this, 1283 ) 1284 ) 1285 1286 def boolean_sql(self, e: exp.Boolean) -> None: 1287 self.stack.append("TRUE" if e.this else "FALSE") 1288 1289 def bracket_sql(self, e: exp.Bracket) -> None: 1290 self.stack.extend( 1291 ( 1292 "]", 1293 e.expressions, 1294 "[", 1295 e.this, 1296 ) 1297 ) 1298 1299 def column_sql(self, e: exp.Column) -> None: 1300 for p in reversed(e.parts): 1301 self.stack.extend((p, ".")) 1302 self.stack.pop() 1303 1304 def datatype_sql(self, e: exp.DataType) -> None: 1305 self._args(e, 1) 1306 self.stack.append(f"{e.this.name} ") 1307 1308 def div_sql(self, e: exp.Div) -> None: 1309 self._binary(e, " / ") 1310 1311 def dot_sql(self, e: exp.Dot) -> None: 1312 self._binary(e, ".") 1313 1314 def eq_sql(self, e: exp.EQ) -> None: 1315 self._binary(e, " = ") 1316 1317 def from_sql(self, e: exp.From) -> None: 1318 self.stack.extend((e.this, "FROM ")) 1319 1320 def gt_sql(self, e: exp.GT) -> None: 1321 self._binary(e, " > ") 1322 1323 def gte_sql(self, e: exp.GTE) -> None: 1324 self._binary(e, " >= ") 1325 1326 def identifier_sql(self, e: exp.Identifier) -> None: 1327 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1328 1329 def ilike_sql(self, e: exp.ILike) -> None: 1330 self._binary(e, " ILIKE ") 1331 1332 def in_sql(self, e: exp.In) -> None: 1333 self.stack.append(")") 1334 self._args(e, 1) 1335 self.stack.extend( 1336 ( 1337 "(", 1338 " IN ", 1339 e.this, 1340 ) 1341 ) 1342 1343 def intdiv_sql(self, e: exp.IntDiv) -> None: 1344 self._binary(e, " DIV ") 1345 1346 def is_sql(self, e: exp.Is) -> None: 1347 self._binary(e, " IS ") 1348 1349 def like_sql(self, e: exp.Like) -> None: 1350 self._binary(e, " Like ") 1351 1352 def literal_sql(self, e: exp.Literal) -> None: 1353 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1354 1355 def lt_sql(self, e: exp.LT) -> None: 1356 self._binary(e, " < ") 1357 1358 def lte_sql(self, e: exp.LTE) -> None: 1359 self._binary(e, " <= ") 1360 1361 def mod_sql(self, e: exp.Mod) -> None: 1362 self._binary(e, " % ") 1363 1364 def mul_sql(self, e: exp.Mul) -> None: 1365 self._binary(e, " * ") 1366 1367 def neg_sql(self, e: exp.Neg) -> None: 1368 self._unary(e, "-") 1369 1370 def neq_sql(self, e: exp.NEQ) -> None: 1371 self._binary(e, " <> ") 1372 1373 def not_sql(self, e: exp.Not) -> None: 1374 self._unary(e, "NOT ") 1375 1376 def null_sql(self, e: exp.Null) -> None: 1377 self.stack.append("NULL") 1378 1379 def or_sql(self, e: exp.Or) -> None: 1380 self._binary(e, " OR ") 1381 1382 def paren_sql(self, e: exp.Paren) -> None: 1383 self.stack.extend( 1384 ( 1385 ")", 1386 e.this, 1387 "(", 1388 ) 1389 ) 1390 1391 def sub_sql(self, e: exp.Sub) -> None: 1392 self._binary(e, " - ") 1393 1394 def subquery_sql(self, e: exp.Subquery) -> None: 1395 self._args(e, 2) 1396 alias = e.args.get("alias") 1397 if alias: 1398 self.stack.append(alias) 1399 self.stack.extend((")", e.this, "(")) 1400 1401 def table_sql(self, e: exp.Table) -> None: 1402 self._args(e, 4) 1403 alias = e.args.get("alias") 1404 if alias: 1405 self.stack.append(alias) 1406 for p in reversed(e.parts): 1407 self.stack.extend((p, ".")) 1408 self.stack.pop() 1409 1410 def tablealias_sql(self, e: exp.TableAlias) -> None: 1411 columns = e.columns 1412 1413 if columns: 1414 self.stack.extend((")", columns, "(")) 1415 1416 self.stack.extend((e.this, " AS ")) 1417 1418 def var_sql(self, e: exp.Var) -> None: 1419 self.stack.append(e.this) 1420 1421 def _binary(self, e: exp.Binary, op: str) -> None: 1422 self.stack.extend((e.expression, op, e.this)) 1423 1424 def _unary(self, e: exp.Unary, op: str) -> None: 1425 self.stack.extend((e.this, op)) 1426 1427 def _function(self, e: exp.Func) -> None: 1428 self.stack.extend( 1429 ( 1430 ")", 1431 list(e.args.values()), 1432 "(", 1433 e.sql_name(), 1434 ) 1435 ) 1436 1437 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1438 kvs = [] 1439 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1440 1441 for k in arg_types or arg_types: 1442 v = node.args.get(k) 1443 1444 if v is not None: 1445 kvs.append([f":{k}", v]) 1446 if kvs: 1447 self.stack.append(kvs) 1448 return True 1449 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
37def simplify( 38 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 39): 40 """ 41 Rewrite sqlglot AST to simplify expressions. 42 43 Example: 44 >>> import sqlglot 45 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 46 >>> simplify(expression).sql() 47 'TRUE' 48 49 Args: 50 expression (sqlglot.Expression): expression to simplify 51 constant_propagation: whether the constant propagation rule should be used 52 53 Returns: 54 sqlglot.Expression: simplified expression 55 """ 56 57 dialect = Dialect.get_or_raise(dialect) 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # group by expressions cannot be simplified, for example 64 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 65 # the projection must exactly match the group by key 66 group = expression.args.get("group") 67 68 if group and hasattr(expression, "selects"): 69 groups = set(group.expressions) 70 group.meta[FINAL] = True 71 72 for e in expression.selects: 73 for node in e.walk(): 74 if node in groups: 75 e.meta[FINAL] = True 76 break 77 78 having = expression.args.get("having") 79 if having: 80 for node in having.walk(): 81 if node in groups: 82 having.meta[FINAL] = True 83 break 84 85 # Pre-order transformations 86 node = expression 87 node = rewrite_between(node) 88 node = uniq_sort(node, root) 89 node = absorb_and_eliminate(node, root) 90 node = simplify_concat(node) 91 node = simplify_conditionals(node) 92 93 if constant_propagation: 94 node = propagate_constants(node, root) 95 96 exp.replace_children(node, lambda e: _simplify(e, False)) 97 98 # Post-order transformations 99 node = simplify_not(node) 100 node = flatten(node) 101 node = simplify_connectors(node, root) 102 node = remove_complements(node, root) 103 node = simplify_coalesce(node) 104 node.parent = expression.parent 105 node = simplify_literals(node, root) 106 node = simplify_equality(node) 107 node = simplify_parens(node) 108 node = simplify_datetrunc(node, dialect) 109 node = sort_comparison(node) 110 node = simplify_startswith(node) 111 112 if root: 113 expression.replace(node) 114 return node 115 116 expression = while_changing(expression, _simplify) 117 remove_where_true(expression) 118 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
121def catch(*exceptions): 122 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 123 124 def decorator(func): 125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression 130 131 return wrapped 132 133 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
136def rewrite_between(expression: exp.Expression) -> exp.Expression: 137 """Rewrite x between y and z to x >= y AND x <= z. 138 139 This is done because comparison simplification is only done on lt/lte/gt/gte. 140 """ 141 if isinstance(expression, exp.Between): 142 negate = isinstance(expression.parent, exp.Not) 143 144 expression = exp.and_( 145 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 146 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 147 copy=False, 148 ) 149 150 if negate: 151 expression = exp.paren(expression, copy=False) 152 153 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
166def simplify_not(expression): 167 """ 168 Demorgan's Law 169 NOT (x OR y) -> NOT x AND NOT y 170 NOT (x AND y) -> NOT x OR NOT y 171 """ 172 if isinstance(expression, exp.Not): 173 this = expression.this 174 if is_null(this): 175 return exp.null() 176 if this.__class__ in COMPLEMENT_COMPARISONS: 177 return COMPLEMENT_COMPARISONS[this.__class__]( 178 this=this.this, expression=this.expression 179 ) 180 if isinstance(this, exp.Paren): 181 condition = this.unnest() 182 if isinstance(condition, exp.And): 183 return exp.paren( 184 exp.or_( 185 exp.not_(condition.left, copy=False), 186 exp.not_(condition.right, copy=False), 187 copy=False, 188 ) 189 ) 190 if isinstance(condition, exp.Or): 191 return exp.paren( 192 exp.and_( 193 exp.not_(condition.left, copy=False), 194 exp.not_(condition.right, copy=False), 195 copy=False, 196 ) 197 ) 198 if is_null(condition): 199 return exp.null() 200 if always_true(this): 201 return exp.false() 202 if is_false(this): 203 return exp.true() 204 if isinstance(this, exp.Not): 205 # double negation 206 # NOT NOT x -> x 207 return this.this 208 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
211def flatten(expression): 212 """ 213 A AND (B AND C) -> A AND B AND C 214 A OR (B OR C) -> A OR B OR C 215 """ 216 if isinstance(expression, exp.Connector): 217 for node in expression.args.values(): 218 child = node.unnest() 219 if isinstance(child, expression.__class__): 220 node.replace(child) 221 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
224def simplify_connectors(expression, root=True): 225 def _simplify_connectors(expression, left, right): 226 if left == right: 227 return left 228 if isinstance(expression, exp.And): 229 if is_false(left) or is_false(right): 230 return exp.false() 231 if is_null(left) or is_null(right): 232 return exp.null() 233 if always_true(left) and always_true(right): 234 return exp.true() 235 if always_true(left): 236 return right 237 if always_true(right): 238 return left 239 return _simplify_comparison(expression, left, right) 240 elif isinstance(expression, exp.Or): 241 if always_true(left) or always_true(right): 242 return exp.true() 243 if is_false(left) and is_false(right): 244 return exp.false() 245 if ( 246 (is_null(left) and is_null(right)) 247 or (is_null(left) and is_false(right)) 248 or (is_false(left) and is_null(right)) 249 ): 250 return exp.null() 251 if is_false(left): 252 return right 253 if is_false(right): 254 return left 255 return _simplify_comparison(expression, left, right, or_=True) 256 257 if isinstance(expression, exp.Connector): 258 return _flat_simplify(expression, _simplify_connectors, root) 259 return expression
345def remove_complements(expression, root=True): 346 """ 347 Removing complements. 348 349 A AND NOT A -> FALSE 350 A OR NOT A -> TRUE 351 """ 352 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 353 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 354 355 for a, b in itertools.permutations(expression.flatten(), 2): 356 if is_complement(a, b): 357 return complement 358 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
361def uniq_sort(expression, root=True): 362 """ 363 Uniq and sort a connector. 364 365 C AND A AND B AND B -> A AND B AND C 366 """ 367 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 368 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 369 flattened = tuple(expression.flatten()) 370 deduped = {gen(e): e for e in flattened} 371 arr = tuple(deduped.items()) 372 373 # check if the operands are already sorted, if not sort them 374 # A AND C AND B -> A AND B AND C 375 for i, (sql, e) in enumerate(arr[1:]): 376 if sql < arr[i][0]: 377 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 378 break 379 else: 380 # we didn't have to sort but maybe we need to dedup 381 if len(deduped) < len(flattened): 382 expression = result_func(*deduped.values(), copy=False) 383 384 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
387def absorb_and_eliminate(expression, root=True): 388 """ 389 absorption: 390 A AND (A OR B) -> A 391 A OR (A AND B) -> A 392 A AND (NOT A OR B) -> A AND B 393 A OR (NOT A AND B) -> A OR B 394 elimination: 395 (A AND B) OR (A AND NOT B) -> A 396 (A OR B) AND (A OR NOT B) -> A 397 """ 398 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 399 kind = exp.Or if isinstance(expression, exp.And) else exp.And 400 401 for a, b in itertools.permutations(expression.flatten(), 2): 402 if isinstance(a, kind): 403 aa, ab = a.unnest_operands() 404 405 # absorb 406 if is_complement(b, aa): 407 aa.replace(exp.true() if kind == exp.And else exp.false()) 408 elif is_complement(b, ab): 409 ab.replace(exp.true() if kind == exp.And else exp.false()) 410 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 411 a.replace(exp.false() if kind == exp.And else exp.true()) 412 elif isinstance(b, kind): 413 # eliminate 414 rhs = b.unnest_operands() 415 ba, bb = rhs 416 417 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 418 a.replace(aa) 419 b.replace(aa) 420 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 421 a.replace(ab) 422 b.replace(ab) 423 424 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
427def propagate_constants(expression, root=True): 428 """ 429 Propagate constants for conjunctions in DNF: 430 431 SELECT * FROM t WHERE a = b AND b = 5 becomes 432 SELECT * FROM t WHERE a = 5 AND b = 5 433 434 Reference: https://www.sqlite.org/optoverview.html 435 """ 436 437 if ( 438 isinstance(expression, exp.And) 439 and (root or not expression.same_parent) 440 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 441 ): 442 constant_mapping = {} 443 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 444 if isinstance(expr, exp.EQ): 445 l, r = expr.left, expr.right 446 447 # TODO: create a helper that can be used to detect nested literal expressions such 448 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 449 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 450 constant_mapping[l] = (id(l), r) 451 452 if constant_mapping: 453 for column in find_all_in_scope(expression, exp.Column): 454 parent = column.parent 455 column_id, constant = constant_mapping.get(column) or (None, None) 456 if ( 457 column_id is not None 458 and id(column) != column_id 459 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 460 ): 461 column.replace(constant.copy()) 462 463 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
538def simplify_literals(expression, root=True): 539 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 540 return _flat_simplify(expression, _simplify_binary, root) 541 542 if isinstance(expression, exp.Neg): 543 this = expression.this 544 if this.is_number: 545 value = this.name 546 if value[0] == "-": 547 return exp.Literal.number(value[1:]) 548 return exp.Literal.number(f"-{value}") 549 550 if type(expression) in INVERSE_DATE_OPS: 551 return _simplify_binary(expression, expression.this, expression.interval()) or expression 552 553 return expression
654def simplify_parens(expression): 655 if not isinstance(expression, exp.Paren): 656 return expression 657 658 this = expression.this 659 parent = expression.parent 660 parent_is_predicate = isinstance(parent, exp.Predicate) 661 662 if not isinstance(this, exp.Select) and ( 663 not isinstance(parent, (exp.Condition, exp.Binary)) 664 or isinstance(parent, exp.Paren) 665 or ( 666 not isinstance(this, exp.Binary) 667 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 668 ) 669 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 670 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 671 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 672 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 673 ): 674 return this 675 return expression
686def simplify_coalesce(expression): 687 # COALESCE(x) -> x 688 if ( 689 isinstance(expression, exp.Coalesce) 690 and (not expression.expressions or _is_nonnull_constant(expression.this)) 691 # COALESCE is also used as a Spark partitioning hint 692 and not isinstance(expression.parent, exp.Hint) 693 ): 694 return expression.this 695 696 if not isinstance(expression, COMPARISONS): 697 return expression 698 699 if isinstance(expression.left, exp.Coalesce): 700 coalesce = expression.left 701 other = expression.right 702 elif isinstance(expression.right, exp.Coalesce): 703 coalesce = expression.right 704 other = expression.left 705 else: 706 return expression 707 708 # This transformation is valid for non-constants, 709 # but it really only does anything if they are both constants. 710 if not _is_constant(other): 711 return expression 712 713 # Find the first constant arg 714 for arg_index, arg in enumerate(coalesce.expressions): 715 if _is_constant(arg): 716 break 717 else: 718 return expression 719 720 coalesce.set("expressions", coalesce.expressions[:arg_index]) 721 722 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 723 # since we already remove COALESCE at the top of this function. 724 coalesce = coalesce if coalesce.expressions else coalesce.this 725 726 # This expression is more complex than when we started, but it will get simplified further 727 return exp.paren( 728 exp.or_( 729 exp.and_( 730 coalesce.is_(exp.null()).not_(copy=False), 731 expression.copy(), 732 copy=False, 733 ), 734 exp.and_( 735 coalesce.is_(exp.null()), 736 type(expression)(this=arg.copy(), expression=other.copy()), 737 copy=False, 738 ), 739 copy=False, 740 ) 741 )
747def simplify_concat(expression): 748 """Reduces all groups that contain string literals by concatenating them.""" 749 if not isinstance(expression, CONCATS) or ( 750 # We can't reduce a CONCAT_WS call if we don't statically know the separator 751 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 752 ): 753 return expression 754 755 if isinstance(expression, exp.ConcatWs): 756 sep_expr, *expressions = expression.expressions 757 sep = sep_expr.name 758 concat_type = exp.ConcatWs 759 args = {} 760 else: 761 expressions = expression.expressions 762 sep = "" 763 concat_type = exp.Concat 764 args = { 765 "safe": expression.args.get("safe"), 766 "coalesce": expression.args.get("coalesce"), 767 } 768 769 new_args = [] 770 for is_string_group, group in itertools.groupby( 771 expressions or expression.flatten(), lambda e: e.is_string 772 ): 773 if is_string_group: 774 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 775 else: 776 new_args.extend(group) 777 778 if len(new_args) == 1 and new_args[0].is_string: 779 return new_args[0] 780 781 if concat_type is exp.ConcatWs: 782 new_args = [sep_expr] + new_args 783 784 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
787def simplify_conditionals(expression): 788 """Simplifies expressions like IF, CASE if their condition is statically known.""" 789 if isinstance(expression, exp.Case): 790 this = expression.this 791 for case in expression.args["ifs"]: 792 cond = case.this 793 if this: 794 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 795 cond = cond.replace(this.pop().eq(cond)) 796 797 if always_true(cond): 798 return case.args["true"] 799 800 if always_false(cond): 801 case.pop() 802 if not expression.args["ifs"]: 803 return expression.args.get("default") or exp.null() 804 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 805 if always_true(expression.this): 806 return expression.args["true"] 807 if always_false(expression.this): 808 return expression.args.get("false") or exp.null() 809 810 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
813def simplify_startswith(expression: exp.Expression) -> exp.Expression: 814 """ 815 Reduces a prefix check to either TRUE or FALSE if both the string and the 816 prefix are statically known. 817 818 Example: 819 >>> from sqlglot import parse_one 820 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 821 'TRUE' 822 """ 823 if ( 824 isinstance(expression, exp.StartsWith) 825 and expression.this.is_string 826 and expression.expression.is_string 827 ): 828 return exp.convert(expression.name.startswith(expression.expression.name)) 829 830 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
957def sort_comparison(expression: exp.Expression) -> exp.Expression: 958 if expression.__class__ in COMPLEMENT_COMPARISONS: 959 l, r = expression.this, expression.expression 960 l_column = isinstance(l, exp.Column) 961 r_column = isinstance(r, exp.Column) 962 l_const = _is_constant(l) 963 r_const = _is_constant(r) 964 965 if (l_column and not r_column) or (r_const and not l_const): 966 return expression 967 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 968 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 969 this=r, expression=l 970 ) 971 return expression
985def remove_where_true(expression): 986 for where in expression.find_all(exp.Where): 987 if always_true(where.this): 988 where.pop() 989 for join in expression.find_all(exp.Join): 990 if ( 991 always_true(join.args.get("on")) 992 and not join.args.get("using") 993 and not join.args.get("method") 994 and (join.side, join.kind) in JOINS 995 ): 996 join.args["on"].pop() 997 join.set("side", None) 998 join.set("kind", "CROSS")
1023def eval_boolean(expression, a, b): 1024 if isinstance(expression, (exp.EQ, exp.Is)): 1025 return boolean_literal(a == b) 1026 if isinstance(expression, exp.NEQ): 1027 return boolean_literal(a != b) 1028 if isinstance(expression, exp.GT): 1029 return boolean_literal(a > b) 1030 if isinstance(expression, exp.GTE): 1031 return boolean_literal(a >= b) 1032 if isinstance(expression, exp.LT): 1033 return boolean_literal(a < b) 1034 if isinstance(expression, exp.LTE): 1035 return boolean_literal(a <= b) 1036 return None
1039def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1040 if isinstance(value, datetime.datetime): 1041 return value.date() 1042 if isinstance(value, datetime.date): 1043 return value 1044 try: 1045 return datetime.datetime.fromisoformat(value).date() 1046 except ValueError: 1047 return None
1050def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1051 if isinstance(value, datetime.datetime): 1052 return value 1053 if isinstance(value, datetime.date): 1054 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1055 try: 1056 return datetime.datetime.fromisoformat(value) 1057 except ValueError: 1058 return None
1061def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1062 if not value: 1063 return None 1064 if to.is_type(exp.DataType.Type.DATE): 1065 return cast_as_date(value) 1066 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1067 return cast_as_datetime(value) 1068 return None
1071def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1072 if isinstance(cast, exp.Cast): 1073 to = cast.to 1074 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1075 to = exp.DataType.build(exp.DataType.Type.DATE) 1076 else: 1077 return None 1078 1079 if isinstance(cast.this, exp.Literal): 1080 value: t.Any = cast.this.name 1081 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1082 value = extract_date(cast.this) 1083 else: 1084 return None 1085 return cast_value(value, to)
1112def interval(unit: str, n: int = 1): 1113 from dateutil.relativedelta import relativedelta 1114 1115 if unit == "year": 1116 return relativedelta(years=1 * n) 1117 if unit == "quarter": 1118 return relativedelta(months=3 * n) 1119 if unit == "month": 1120 return relativedelta(months=1 * n) 1121 if unit == "week": 1122 return relativedelta(weeks=1 * n) 1123 if unit == "day": 1124 return relativedelta(days=1 * n) 1125 if unit == "hour": 1126 return relativedelta(hours=1 * n) 1127 if unit == "minute": 1128 return relativedelta(minutes=1 * n) 1129 if unit == "second": 1130 return relativedelta(seconds=1 * n) 1131 1132 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1135def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1136 if unit == "year": 1137 return d.replace(month=1, day=1) 1138 if unit == "quarter": 1139 if d.month <= 3: 1140 return d.replace(month=1, day=1) 1141 elif d.month <= 6: 1142 return d.replace(month=4, day=1) 1143 elif d.month <= 9: 1144 return d.replace(month=7, day=1) 1145 else: 1146 return d.replace(month=10, day=1) 1147 if unit == "month": 1148 return d.replace(month=d.month, day=1) 1149 if unit == "week": 1150 # Assuming week starts on Monday (0) and ends on Sunday (6) 1151 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1152 if unit == "day": 1153 return d 1154 1155 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1197def gen(expression: t.Any) -> str: 1198 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1199 1200 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1201 generator is expensive so we have a bare minimum sql generator here. 1202 """ 1203 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1206class Gen: 1207 def __init__(self): 1208 self.stack = [] 1209 self.sqls = [] 1210 1211 def gen(self, expression: exp.Expression) -> str: 1212 self.stack = [expression] 1213 self.sqls.clear() 1214 1215 while self.stack: 1216 node = self.stack.pop() 1217 1218 if isinstance(node, exp.Expression): 1219 exp_handler_name = f"{node.key}_sql" 1220 1221 if hasattr(self, exp_handler_name): 1222 getattr(self, exp_handler_name)(node) 1223 elif isinstance(node, exp.Func): 1224 self._function(node) 1225 else: 1226 key = node.key.upper() 1227 self.stack.append(f"{key} " if self._args(node) else key) 1228 elif type(node) is list: 1229 for n in reversed(node): 1230 if n is not None: 1231 self.stack.extend((n, ",")) 1232 if node: 1233 self.stack.pop() 1234 else: 1235 if node is not None: 1236 self.sqls.append(str(node)) 1237 1238 return "".join(self.sqls) 1239 1240 def add_sql(self, e: exp.Add) -> None: 1241 self._binary(e, " + ") 1242 1243 def alias_sql(self, e: exp.Alias) -> None: 1244 self.stack.extend( 1245 ( 1246 e.args.get("alias"), 1247 " AS ", 1248 e.args.get("this"), 1249 ) 1250 ) 1251 1252 def and_sql(self, e: exp.And) -> None: 1253 self._binary(e, " AND ") 1254 1255 def anonymous_sql(self, e: exp.Anonymous) -> None: 1256 this = e.this 1257 if isinstance(this, str): 1258 name = this.upper() 1259 elif isinstance(this, exp.Identifier): 1260 name = this.this 1261 name = f'"{name}"' if this.quoted else name.upper() 1262 else: 1263 raise ValueError( 1264 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1265 ) 1266 1267 self.stack.extend( 1268 ( 1269 ")", 1270 e.expressions, 1271 "(", 1272 name, 1273 ) 1274 ) 1275 1276 def between_sql(self, e: exp.Between) -> None: 1277 self.stack.extend( 1278 ( 1279 e.args.get("high"), 1280 " AND ", 1281 e.args.get("low"), 1282 " BETWEEN ", 1283 e.this, 1284 ) 1285 ) 1286 1287 def boolean_sql(self, e: exp.Boolean) -> None: 1288 self.stack.append("TRUE" if e.this else "FALSE") 1289 1290 def bracket_sql(self, e: exp.Bracket) -> None: 1291 self.stack.extend( 1292 ( 1293 "]", 1294 e.expressions, 1295 "[", 1296 e.this, 1297 ) 1298 ) 1299 1300 def column_sql(self, e: exp.Column) -> None: 1301 for p in reversed(e.parts): 1302 self.stack.extend((p, ".")) 1303 self.stack.pop() 1304 1305 def datatype_sql(self, e: exp.DataType) -> None: 1306 self._args(e, 1) 1307 self.stack.append(f"{e.this.name} ") 1308 1309 def div_sql(self, e: exp.Div) -> None: 1310 self._binary(e, " / ") 1311 1312 def dot_sql(self, e: exp.Dot) -> None: 1313 self._binary(e, ".") 1314 1315 def eq_sql(self, e: exp.EQ) -> None: 1316 self._binary(e, " = ") 1317 1318 def from_sql(self, e: exp.From) -> None: 1319 self.stack.extend((e.this, "FROM ")) 1320 1321 def gt_sql(self, e: exp.GT) -> None: 1322 self._binary(e, " > ") 1323 1324 def gte_sql(self, e: exp.GTE) -> None: 1325 self._binary(e, " >= ") 1326 1327 def identifier_sql(self, e: exp.Identifier) -> None: 1328 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1329 1330 def ilike_sql(self, e: exp.ILike) -> None: 1331 self._binary(e, " ILIKE ") 1332 1333 def in_sql(self, e: exp.In) -> None: 1334 self.stack.append(")") 1335 self._args(e, 1) 1336 self.stack.extend( 1337 ( 1338 "(", 1339 " IN ", 1340 e.this, 1341 ) 1342 ) 1343 1344 def intdiv_sql(self, e: exp.IntDiv) -> None: 1345 self._binary(e, " DIV ") 1346 1347 def is_sql(self, e: exp.Is) -> None: 1348 self._binary(e, " IS ") 1349 1350 def like_sql(self, e: exp.Like) -> None: 1351 self._binary(e, " Like ") 1352 1353 def literal_sql(self, e: exp.Literal) -> None: 1354 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1355 1356 def lt_sql(self, e: exp.LT) -> None: 1357 self._binary(e, " < ") 1358 1359 def lte_sql(self, e: exp.LTE) -> None: 1360 self._binary(e, " <= ") 1361 1362 def mod_sql(self, e: exp.Mod) -> None: 1363 self._binary(e, " % ") 1364 1365 def mul_sql(self, e: exp.Mul) -> None: 1366 self._binary(e, " * ") 1367 1368 def neg_sql(self, e: exp.Neg) -> None: 1369 self._unary(e, "-") 1370 1371 def neq_sql(self, e: exp.NEQ) -> None: 1372 self._binary(e, " <> ") 1373 1374 def not_sql(self, e: exp.Not) -> None: 1375 self._unary(e, "NOT ") 1376 1377 def null_sql(self, e: exp.Null) -> None: 1378 self.stack.append("NULL") 1379 1380 def or_sql(self, e: exp.Or) -> None: 1381 self._binary(e, " OR ") 1382 1383 def paren_sql(self, e: exp.Paren) -> None: 1384 self.stack.extend( 1385 ( 1386 ")", 1387 e.this, 1388 "(", 1389 ) 1390 ) 1391 1392 def sub_sql(self, e: exp.Sub) -> None: 1393 self._binary(e, " - ") 1394 1395 def subquery_sql(self, e: exp.Subquery) -> None: 1396 self._args(e, 2) 1397 alias = e.args.get("alias") 1398 if alias: 1399 self.stack.append(alias) 1400 self.stack.extend((")", e.this, "(")) 1401 1402 def table_sql(self, e: exp.Table) -> None: 1403 self._args(e, 4) 1404 alias = e.args.get("alias") 1405 if alias: 1406 self.stack.append(alias) 1407 for p in reversed(e.parts): 1408 self.stack.extend((p, ".")) 1409 self.stack.pop() 1410 1411 def tablealias_sql(self, e: exp.TableAlias) -> None: 1412 columns = e.columns 1413 1414 if columns: 1415 self.stack.extend((")", columns, "(")) 1416 1417 self.stack.extend((e.this, " AS ")) 1418 1419 def var_sql(self, e: exp.Var) -> None: 1420 self.stack.append(e.this) 1421 1422 def _binary(self, e: exp.Binary, op: str) -> None: 1423 self.stack.extend((e.expression, op, e.this)) 1424 1425 def _unary(self, e: exp.Unary, op: str) -> None: 1426 self.stack.extend((e.this, op)) 1427 1428 def _function(self, e: exp.Func) -> None: 1429 self.stack.extend( 1430 ( 1431 ")", 1432 list(e.args.values()), 1433 "(", 1434 e.sql_name(), 1435 ) 1436 ) 1437 1438 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1439 kvs = [] 1440 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1441 1442 for k in arg_types or arg_types: 1443 v = node.args.get(k) 1444 1445 if v is not None: 1446 kvs.append([f":{k}", v]) 1447 if kvs: 1448 self.stack.append(kvs) 1449 return True 1450 return False
1211 def gen(self, expression: exp.Expression) -> str: 1212 self.stack = [expression] 1213 self.sqls.clear() 1214 1215 while self.stack: 1216 node = self.stack.pop() 1217 1218 if isinstance(node, exp.Expression): 1219 exp_handler_name = f"{node.key}_sql" 1220 1221 if hasattr(self, exp_handler_name): 1222 getattr(self, exp_handler_name)(node) 1223 elif isinstance(node, exp.Func): 1224 self._function(node) 1225 else: 1226 key = node.key.upper() 1227 self.stack.append(f"{key} " if self._args(node) else key) 1228 elif type(node) is list: 1229 for n in reversed(node): 1230 if n is not None: 1231 self.stack.extend((n, ",")) 1232 if node: 1233 self.stack.pop() 1234 else: 1235 if node is not None: 1236 self.sqls.append(str(node)) 1237 1238 return "".join(self.sqls)
1255 def anonymous_sql(self, e: exp.Anonymous) -> None: 1256 this = e.this 1257 if isinstance(this, str): 1258 name = this.upper() 1259 elif isinstance(this, exp.Identifier): 1260 name = this.this 1261 name = f'"{name}"' if this.quoted else name.upper() 1262 else: 1263 raise ValueError( 1264 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1265 ) 1266 1267 self.stack.extend( 1268 ( 1269 ")", 1270 e.expressions, 1271 "(", 1272 name, 1273 ) 1274 )