sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, is_iterable, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25 26class UnsupportedUnit(Exception): 27 pass 28 29 30def simplify( 31 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 32): 33 """ 34 Rewrite sqlglot AST to simplify expressions. 35 36 Example: 37 >>> import sqlglot 38 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 39 >>> simplify(expression).sql() 40 'TRUE' 41 42 Args: 43 expression (sqlglot.Expression): expression to simplify 44 constant_propagation: whether or not the constant propagation rule should be used 45 46 Returns: 47 sqlglot.Expression: simplified expression 48 """ 49 50 dialect = Dialect.get_or_raise(dialect) 51 52 def _simplify(expression, root=True): 53 if expression.meta.get(FINAL): 54 return expression 55 56 # group by expressions cannot be simplified, for example 57 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 58 # the projection must exactly match the group by key 59 group = expression.args.get("group") 60 61 if group and hasattr(expression, "selects"): 62 groups = set(group.expressions) 63 group.meta[FINAL] = True 64 65 for e in expression.selects: 66 for node, *_ in e.walk(): 67 if node in groups: 68 e.meta[FINAL] = True 69 break 70 71 having = expression.args.get("having") 72 if having: 73 for node, *_ in having.walk(): 74 if node in groups: 75 having.meta[FINAL] = True 76 break 77 78 # Pre-order transformations 79 node = expression 80 node = rewrite_between(node) 81 node = uniq_sort(node, root) 82 node = absorb_and_eliminate(node, root) 83 node = simplify_concat(node) 84 node = simplify_conditionals(node) 85 86 if constant_propagation: 87 node = propagate_constants(node, root) 88 89 exp.replace_children(node, lambda e: _simplify(e, False)) 90 91 # Post-order transformations 92 node = simplify_not(node) 93 node = flatten(node) 94 node = simplify_connectors(node, root) 95 node = remove_complements(node, root) 96 node = simplify_coalesce(node) 97 node.parent = expression.parent 98 node = simplify_literals(node, root) 99 node = simplify_equality(node) 100 node = simplify_parens(node) 101 node = simplify_datetrunc(node, dialect) 102 node = sort_comparison(node) 103 node = simplify_startswith(node) 104 105 if root: 106 expression.replace(node) 107 108 return node 109 110 expression = while_changing(expression, _simplify) 111 remove_where_true(expression) 112 return expression 113 114 115def catch(*exceptions): 116 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 117 118 def decorator(func): 119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression 124 125 return wrapped 126 127 return decorator 128 129 130def rewrite_between(expression: exp.Expression) -> exp.Expression: 131 """Rewrite x between y and z to x >= y AND x <= z. 132 133 This is done because comparison simplification is only done on lt/lte/gt/gte. 134 """ 135 if isinstance(expression, exp.Between): 136 negate = isinstance(expression.parent, exp.Not) 137 138 expression = exp.and_( 139 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 140 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 141 copy=False, 142 ) 143 144 if negate: 145 expression = exp.paren(expression, copy=False) 146 147 return expression 148 149 150COMPLEMENT_COMPARISONS = { 151 exp.LT: exp.GTE, 152 exp.GT: exp.LTE, 153 exp.LTE: exp.GT, 154 exp.GTE: exp.LT, 155 exp.EQ: exp.NEQ, 156 exp.NEQ: exp.EQ, 157} 158 159 160def simplify_not(expression): 161 """ 162 Demorgan's Law 163 NOT (x OR y) -> NOT x AND NOT y 164 NOT (x AND y) -> NOT x OR NOT y 165 """ 166 if isinstance(expression, exp.Not): 167 this = expression.this 168 if is_null(this): 169 return exp.null() 170 if this.__class__ in COMPLEMENT_COMPARISONS: 171 return COMPLEMENT_COMPARISONS[this.__class__]( 172 this=this.this, expression=this.expression 173 ) 174 if isinstance(this, exp.Paren): 175 condition = this.unnest() 176 if isinstance(condition, exp.And): 177 return exp.or_( 178 exp.not_(condition.left, copy=False), 179 exp.not_(condition.right, copy=False), 180 copy=False, 181 ) 182 if isinstance(condition, exp.Or): 183 return exp.and_( 184 exp.not_(condition.left, copy=False), 185 exp.not_(condition.right, copy=False), 186 copy=False, 187 ) 188 if is_null(condition): 189 return exp.null() 190 if always_true(this): 191 return exp.false() 192 if is_false(this): 193 return exp.true() 194 if isinstance(this, exp.Not): 195 # double negation 196 # NOT NOT x -> x 197 return this.this 198 return expression 199 200 201def flatten(expression): 202 """ 203 A AND (B AND C) -> A AND B AND C 204 A OR (B OR C) -> A OR B OR C 205 """ 206 if isinstance(expression, exp.Connector): 207 for node in expression.args.values(): 208 child = node.unnest() 209 if isinstance(child, expression.__class__): 210 node.replace(child) 211 return expression 212 213 214def simplify_connectors(expression, root=True): 215 def _simplify_connectors(expression, left, right): 216 if left == right: 217 return left 218 if isinstance(expression, exp.And): 219 if is_false(left) or is_false(right): 220 return exp.false() 221 if is_null(left) or is_null(right): 222 return exp.null() 223 if always_true(left) and always_true(right): 224 return exp.true() 225 if always_true(left): 226 return right 227 if always_true(right): 228 return left 229 return _simplify_comparison(expression, left, right) 230 elif isinstance(expression, exp.Or): 231 if always_true(left) or always_true(right): 232 return exp.true() 233 if is_false(left) and is_false(right): 234 return exp.false() 235 if ( 236 (is_null(left) and is_null(right)) 237 or (is_null(left) and is_false(right)) 238 or (is_false(left) and is_null(right)) 239 ): 240 return exp.null() 241 if is_false(left): 242 return right 243 if is_false(right): 244 return left 245 return _simplify_comparison(expression, left, right, or_=True) 246 247 if isinstance(expression, exp.Connector): 248 return _flat_simplify(expression, _simplify_connectors, root) 249 return expression 250 251 252LT_LTE = (exp.LT, exp.LTE) 253GT_GTE = (exp.GT, exp.GTE) 254 255COMPARISONS = ( 256 *LT_LTE, 257 *GT_GTE, 258 exp.EQ, 259 exp.NEQ, 260 exp.Is, 261) 262 263INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 264 exp.LT: exp.GT, 265 exp.GT: exp.LT, 266 exp.LTE: exp.GTE, 267 exp.GTE: exp.LTE, 268} 269 270NONDETERMINISTIC = (exp.Rand, exp.Randn) 271 272 273def _simplify_comparison(expression, left, right, or_=False): 274 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 275 ll, lr = left.args.values() 276 rl, rr = right.args.values() 277 278 largs = {ll, lr} 279 rargs = {rl, rr} 280 281 matching = largs & rargs 282 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 283 284 if matching and columns: 285 try: 286 l = first(largs - columns) 287 r = first(rargs - columns) 288 except StopIteration: 289 return expression 290 291 if l.is_number and r.is_number: 292 l = float(l.name) 293 r = float(r.name) 294 elif l.is_string and r.is_string: 295 l = l.name 296 r = r.name 297 else: 298 l = extract_date(l) 299 if not l: 300 return None 301 r = extract_date(r) 302 if not r: 303 return None 304 305 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 306 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 307 return left if (av > bv if or_ else av <= bv) else right 308 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 309 return left if (av < bv if or_ else av >= bv) else right 310 311 # we can't ever shortcut to true because the column could be null 312 if not or_: 313 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 314 if av <= bv: 315 return exp.false() 316 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 317 if av >= bv: 318 return exp.false() 319 elif isinstance(a, exp.EQ): 320 if isinstance(b, exp.LT): 321 return exp.false() if av >= bv else a 322 if isinstance(b, exp.LTE): 323 return exp.false() if av > bv else a 324 if isinstance(b, exp.GT): 325 return exp.false() if av <= bv else a 326 if isinstance(b, exp.GTE): 327 return exp.false() if av < bv else a 328 if isinstance(b, exp.NEQ): 329 return exp.false() if av == bv else a 330 return None 331 332 333def remove_complements(expression, root=True): 334 """ 335 Removing complements. 336 337 A AND NOT A -> FALSE 338 A OR NOT A -> TRUE 339 """ 340 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 341 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 342 343 for a, b in itertools.permutations(expression.flatten(), 2): 344 if is_complement(a, b): 345 return complement 346 return expression 347 348 349def uniq_sort(expression, root=True): 350 """ 351 Uniq and sort a connector. 352 353 C AND A AND B AND B -> A AND B AND C 354 """ 355 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 356 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 357 flattened = tuple(expression.flatten()) 358 deduped = {gen(e): e for e in flattened} 359 arr = tuple(deduped.items()) 360 361 # check if the operands are already sorted, if not sort them 362 # A AND C AND B -> A AND B AND C 363 for i, (sql, e) in enumerate(arr[1:]): 364 if sql < arr[i][0]: 365 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 366 break 367 else: 368 # we didn't have to sort but maybe we need to dedup 369 if len(deduped) < len(flattened): 370 expression = result_func(*deduped.values(), copy=False) 371 372 return expression 373 374 375def absorb_and_eliminate(expression, root=True): 376 """ 377 absorption: 378 A AND (A OR B) -> A 379 A OR (A AND B) -> A 380 A AND (NOT A OR B) -> A AND B 381 A OR (NOT A AND B) -> A OR B 382 elimination: 383 (A AND B) OR (A AND NOT B) -> A 384 (A OR B) AND (A OR NOT B) -> A 385 """ 386 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 387 kind = exp.Or if isinstance(expression, exp.And) else exp.And 388 389 for a, b in itertools.permutations(expression.flatten(), 2): 390 if isinstance(a, kind): 391 aa, ab = a.unnest_operands() 392 393 # absorb 394 if is_complement(b, aa): 395 aa.replace(exp.true() if kind == exp.And else exp.false()) 396 elif is_complement(b, ab): 397 ab.replace(exp.true() if kind == exp.And else exp.false()) 398 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 399 a.replace(exp.false() if kind == exp.And else exp.true()) 400 elif isinstance(b, kind): 401 # eliminate 402 rhs = b.unnest_operands() 403 ba, bb = rhs 404 405 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 406 a.replace(aa) 407 b.replace(aa) 408 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 409 a.replace(ab) 410 b.replace(ab) 411 412 return expression 413 414 415def propagate_constants(expression, root=True): 416 """ 417 Propagate constants for conjunctions in DNF: 418 419 SELECT * FROM t WHERE a = b AND b = 5 becomes 420 SELECT * FROM t WHERE a = 5 AND b = 5 421 422 Reference: https://www.sqlite.org/optoverview.html 423 """ 424 425 if ( 426 isinstance(expression, exp.And) 427 and (root or not expression.same_parent) 428 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 429 ): 430 constant_mapping = {} 431 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 432 if isinstance(expr, exp.EQ): 433 l, r = expr.left, expr.right 434 435 # TODO: create a helper that can be used to detect nested literal expressions such 436 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 437 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 438 constant_mapping[l] = (id(l), r) 439 440 if constant_mapping: 441 for column in find_all_in_scope(expression, exp.Column): 442 parent = column.parent 443 column_id, constant = constant_mapping.get(column) or (None, None) 444 if ( 445 column_id is not None 446 and id(column) != column_id 447 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 448 ): 449 column.replace(constant.copy()) 450 451 return expression 452 453 454INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 455 exp.DateAdd: exp.Sub, 456 exp.DateSub: exp.Add, 457 exp.DatetimeAdd: exp.Sub, 458 exp.DatetimeSub: exp.Add, 459} 460 461INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 462 **INVERSE_DATE_OPS, 463 exp.Add: exp.Sub, 464 exp.Sub: exp.Add, 465} 466 467 468def _is_number(expression: exp.Expression) -> bool: 469 return expression.is_number 470 471 472def _is_interval(expression: exp.Expression) -> bool: 473 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 474 475 476@catch(ModuleNotFoundError, UnsupportedUnit) 477def simplify_equality(expression: exp.Expression) -> exp.Expression: 478 """ 479 Use the subtraction and addition properties of equality to simplify expressions: 480 481 x + 1 = 3 becomes x = 2 482 483 There are two binary operations in the above expression: + and = 484 Here's how we reference all the operands in the code below: 485 486 l r 487 x + 1 = 3 488 a b 489 """ 490 if isinstance(expression, COMPARISONS): 491 l, r = expression.left, expression.right 492 493 if not l.__class__ in INVERSE_OPS: 494 return expression 495 496 if r.is_number: 497 a_predicate = _is_number 498 b_predicate = _is_number 499 elif _is_date_literal(r): 500 a_predicate = _is_date_literal 501 b_predicate = _is_interval 502 else: 503 return expression 504 505 if l.__class__ in INVERSE_DATE_OPS: 506 l = t.cast(exp.IntervalOp, l) 507 a = l.this 508 b = l.interval() 509 else: 510 l = t.cast(exp.Binary, l) 511 a, b = l.left, l.right 512 513 if not a_predicate(a) and b_predicate(b): 514 pass 515 elif not a_predicate(b) and b_predicate(a): 516 a, b = b, a 517 else: 518 return expression 519 520 return expression.__class__( 521 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 522 ) 523 return expression 524 525 526def simplify_literals(expression, root=True): 527 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 528 return _flat_simplify(expression, _simplify_binary, root) 529 530 if isinstance(expression, exp.Neg): 531 this = expression.this 532 if this.is_number: 533 value = this.name 534 if value[0] == "-": 535 return exp.Literal.number(value[1:]) 536 return exp.Literal.number(f"-{value}") 537 538 if type(expression) in INVERSE_DATE_OPS: 539 return _simplify_binary(expression, expression.this, expression.interval()) or expression 540 541 return expression 542 543 544def _simplify_binary(expression, a, b): 545 if isinstance(expression, exp.Is): 546 if isinstance(b, exp.Not): 547 c = b.this 548 not_ = True 549 else: 550 c = b 551 not_ = False 552 553 if is_null(c): 554 if isinstance(a, exp.Literal): 555 return exp.true() if not_ else exp.false() 556 if is_null(a): 557 return exp.false() if not_ else exp.true() 558 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 559 return None 560 elif is_null(a) or is_null(b): 561 return exp.null() 562 563 if a.is_number and b.is_number: 564 num_a = int(a.name) if a.is_int else Decimal(a.name) 565 num_b = int(b.name) if b.is_int else Decimal(b.name) 566 567 if isinstance(expression, exp.Add): 568 return exp.Literal.number(num_a + num_b) 569 if isinstance(expression, exp.Mul): 570 return exp.Literal.number(num_a * num_b) 571 572 # We only simplify Sub, Div if a and b have the same parent because they're not associative 573 if isinstance(expression, exp.Sub): 574 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 575 if isinstance(expression, exp.Div): 576 # engines have differing int div behavior so intdiv is not safe 577 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 578 return None 579 return exp.Literal.number(num_a / num_b) 580 581 boolean = eval_boolean(expression, num_a, num_b) 582 583 if boolean: 584 return boolean 585 elif a.is_string and b.is_string: 586 boolean = eval_boolean(expression, a.this, b.this) 587 588 if boolean: 589 return boolean 590 elif _is_date_literal(a) and isinstance(b, exp.Interval): 591 a, b = extract_date(a), extract_interval(b) 592 if a and b: 593 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 594 return date_literal(a + b) 595 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 596 return date_literal(a - b) 597 elif isinstance(a, exp.Interval) and _is_date_literal(b): 598 a, b = extract_interval(a), extract_date(b) 599 # you cannot subtract a date from an interval 600 if a and b and isinstance(expression, exp.Add): 601 return date_literal(a + b) 602 elif _is_date_literal(a) and _is_date_literal(b): 603 if isinstance(expression, exp.Predicate): 604 a, b = extract_date(a), extract_date(b) 605 boolean = eval_boolean(expression, a, b) 606 if boolean: 607 return boolean 608 609 return None 610 611 612def simplify_parens(expression): 613 if not isinstance(expression, exp.Paren): 614 return expression 615 616 this = expression.this 617 parent = expression.parent 618 619 if not isinstance(this, exp.Select) and ( 620 not isinstance(parent, (exp.Condition, exp.Binary)) 621 or isinstance(parent, exp.Paren) 622 or not isinstance(this, exp.Binary) 623 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 624 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 625 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 626 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 627 ): 628 return this 629 return expression 630 631 632NONNULL_CONSTANTS = ( 633 exp.Literal, 634 exp.Boolean, 635) 636 637CONSTANTS = ( 638 exp.Literal, 639 exp.Boolean, 640 exp.Null, 641) 642 643 644def _is_nonnull_constant(expression: exp.Expression) -> bool: 645 return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) 646 647 648def _is_constant(expression: exp.Expression) -> bool: 649 return isinstance(expression, CONSTANTS) or _is_date_literal(expression) 650 651 652def simplify_coalesce(expression): 653 # COALESCE(x) -> x 654 if ( 655 isinstance(expression, exp.Coalesce) 656 and (not expression.expressions or _is_nonnull_constant(expression.this)) 657 # COALESCE is also used as a Spark partitioning hint 658 and not isinstance(expression.parent, exp.Hint) 659 ): 660 return expression.this 661 662 if not isinstance(expression, COMPARISONS): 663 return expression 664 665 if isinstance(expression.left, exp.Coalesce): 666 coalesce = expression.left 667 other = expression.right 668 elif isinstance(expression.right, exp.Coalesce): 669 coalesce = expression.right 670 other = expression.left 671 else: 672 return expression 673 674 # This transformation is valid for non-constants, 675 # but it really only does anything if they are both constants. 676 if not _is_constant(other): 677 return expression 678 679 # Find the first constant arg 680 for arg_index, arg in enumerate(coalesce.expressions): 681 if _is_constant(arg): 682 break 683 else: 684 return expression 685 686 coalesce.set("expressions", coalesce.expressions[:arg_index]) 687 688 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 689 # since we already remove COALESCE at the top of this function. 690 coalesce = coalesce if coalesce.expressions else coalesce.this 691 692 # This expression is more complex than when we started, but it will get simplified further 693 return exp.paren( 694 exp.or_( 695 exp.and_( 696 coalesce.is_(exp.null()).not_(copy=False), 697 expression.copy(), 698 copy=False, 699 ), 700 exp.and_( 701 coalesce.is_(exp.null()), 702 type(expression)(this=arg.copy(), expression=other.copy()), 703 copy=False, 704 ), 705 copy=False, 706 ) 707 ) 708 709 710CONCATS = (exp.Concat, exp.DPipe) 711 712 713def simplify_concat(expression): 714 """Reduces all groups that contain string literals by concatenating them.""" 715 if not isinstance(expression, CONCATS) or ( 716 # We can't reduce a CONCAT_WS call if we don't statically know the separator 717 isinstance(expression, exp.ConcatWs) 718 and not expression.expressions[0].is_string 719 ): 720 return expression 721 722 if isinstance(expression, exp.ConcatWs): 723 sep_expr, *expressions = expression.expressions 724 sep = sep_expr.name 725 concat_type = exp.ConcatWs 726 args = {} 727 else: 728 expressions = expression.expressions 729 sep = "" 730 concat_type = exp.Concat 731 args = { 732 "safe": expression.args.get("safe"), 733 "coalesce": expression.args.get("coalesce"), 734 } 735 736 new_args = [] 737 for is_string_group, group in itertools.groupby( 738 expressions or expression.flatten(), lambda e: e.is_string 739 ): 740 if is_string_group: 741 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 742 else: 743 new_args.extend(group) 744 745 if len(new_args) == 1 and new_args[0].is_string: 746 return new_args[0] 747 748 if concat_type is exp.ConcatWs: 749 new_args = [sep_expr] + new_args 750 751 return concat_type(expressions=new_args, **args) 752 753 754def simplify_conditionals(expression): 755 """Simplifies expressions like IF, CASE if their condition is statically known.""" 756 if isinstance(expression, exp.Case): 757 this = expression.this 758 for case in expression.args["ifs"]: 759 cond = case.this 760 if this: 761 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 762 cond = cond.replace(this.pop().eq(cond)) 763 764 if always_true(cond): 765 return case.args["true"] 766 767 if always_false(cond): 768 case.pop() 769 if not expression.args["ifs"]: 770 return expression.args.get("default") or exp.null() 771 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 772 if always_true(expression.this): 773 return expression.args["true"] 774 if always_false(expression.this): 775 return expression.args.get("false") or exp.null() 776 777 return expression 778 779 780def simplify_startswith(expression: exp.Expression) -> exp.Expression: 781 """ 782 Reduces a prefix check to either TRUE or FALSE if both the string and the 783 prefix are statically known. 784 785 Example: 786 >>> from sqlglot import parse_one 787 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 788 'TRUE' 789 """ 790 if ( 791 isinstance(expression, exp.StartsWith) 792 and expression.this.is_string 793 and expression.expression.is_string 794 ): 795 return exp.convert(expression.name.startswith(expression.expression.name)) 796 797 return expression 798 799 800DateRange = t.Tuple[datetime.date, datetime.date] 801 802 803def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 804 """ 805 Get the date range for a DATE_TRUNC equality comparison: 806 807 Example: 808 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 809 Returns: 810 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 811 """ 812 floor = date_floor(date, unit, dialect) 813 814 if date != floor: 815 # This will always be False, except for NULL values. 816 return None 817 818 return floor, floor + interval(unit) 819 820 821def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 822 """Get the logical expression for a date range""" 823 return exp.and_( 824 left >= date_literal(drange[0]), 825 left < date_literal(drange[1]), 826 copy=False, 827 ) 828 829 830def _datetrunc_eq( 831 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 832) -> t.Optional[exp.Expression]: 833 drange = _datetrunc_range(date, unit, dialect) 834 if not drange: 835 return None 836 837 return _datetrunc_eq_expression(left, drange) 838 839 840def _datetrunc_neq( 841 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 842) -> t.Optional[exp.Expression]: 843 drange = _datetrunc_range(date, unit, dialect) 844 if not drange: 845 return None 846 847 return exp.and_( 848 left < date_literal(drange[0]), 849 left >= date_literal(drange[1]), 850 copy=False, 851 ) 852 853 854DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 855 exp.LT: lambda l, dt, u, d: l 856 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 857 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 858 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 859 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 860 exp.EQ: _datetrunc_eq, 861 exp.NEQ: _datetrunc_neq, 862} 863DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 864DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 865 866 867def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 868 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 869 870 871@catch(ModuleNotFoundError, UnsupportedUnit) 872def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 873 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 874 comparison = expression.__class__ 875 876 if isinstance(expression, DATETRUNCS): 877 date = extract_date(expression.this) 878 if date and expression.unit: 879 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 880 elif comparison not in DATETRUNC_COMPARISONS: 881 return expression 882 883 if isinstance(expression, exp.Binary): 884 l, r = expression.left, expression.right 885 886 if not _is_datetrunc_predicate(l, r): 887 return expression 888 889 l = t.cast(exp.DateTrunc, l) 890 unit = l.unit.name.lower() 891 date = extract_date(r) 892 893 if not date: 894 return expression 895 896 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 897 elif isinstance(expression, exp.In): 898 l = expression.this 899 rs = expression.expressions 900 901 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 902 l = t.cast(exp.DateTrunc, l) 903 unit = l.unit.name.lower() 904 905 ranges = [] 906 for r in rs: 907 date = extract_date(r) 908 if not date: 909 return expression 910 drange = _datetrunc_range(date, unit, dialect) 911 if drange: 912 ranges.append(drange) 913 914 if not ranges: 915 return expression 916 917 ranges = merge_ranges(ranges) 918 919 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 920 921 return expression 922 923 924def sort_comparison(expression: exp.Expression) -> exp.Expression: 925 if expression.__class__ in COMPLEMENT_COMPARISONS: 926 l, r = expression.this, expression.expression 927 l_column = isinstance(l, exp.Column) 928 r_column = isinstance(r, exp.Column) 929 l_const = _is_constant(l) 930 r_const = _is_constant(r) 931 932 if (l_column and not r_column) or (r_const and not l_const): 933 return expression 934 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 935 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 936 this=r, expression=l 937 ) 938 return expression 939 940 941# CROSS joins result in an empty table if the right table is empty. 942# So we can only simplify certain types of joins to CROSS. 943# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 944JOINS = { 945 ("", ""), 946 ("", "INNER"), 947 ("RIGHT", ""), 948 ("RIGHT", "OUTER"), 949} 950 951 952def remove_where_true(expression): 953 for where in expression.find_all(exp.Where): 954 if always_true(where.this): 955 where.parent.set("where", None) 956 for join in expression.find_all(exp.Join): 957 if ( 958 always_true(join.args.get("on")) 959 and not join.args.get("using") 960 and not join.args.get("method") 961 and (join.side, join.kind) in JOINS 962 ): 963 join.set("on", None) 964 join.set("side", None) 965 join.set("kind", "CROSS") 966 967 968def always_true(expression): 969 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 970 expression, exp.Literal 971 ) 972 973 974def always_false(expression): 975 return is_false(expression) or is_null(expression) 976 977 978def is_complement(a, b): 979 return isinstance(b, exp.Not) and b.this == a 980 981 982def is_false(a: exp.Expression) -> bool: 983 return type(a) is exp.Boolean and not a.this 984 985 986def is_null(a: exp.Expression) -> bool: 987 return type(a) is exp.Null 988 989 990def eval_boolean(expression, a, b): 991 if isinstance(expression, (exp.EQ, exp.Is)): 992 return boolean_literal(a == b) 993 if isinstance(expression, exp.NEQ): 994 return boolean_literal(a != b) 995 if isinstance(expression, exp.GT): 996 return boolean_literal(a > b) 997 if isinstance(expression, exp.GTE): 998 return boolean_literal(a >= b) 999 if isinstance(expression, exp.LT): 1000 return boolean_literal(a < b) 1001 if isinstance(expression, exp.LTE): 1002 return boolean_literal(a <= b) 1003 return None 1004 1005 1006def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1007 if isinstance(value, datetime.datetime): 1008 return value.date() 1009 if isinstance(value, datetime.date): 1010 return value 1011 try: 1012 return datetime.datetime.fromisoformat(value).date() 1013 except ValueError: 1014 return None 1015 1016 1017def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1018 if isinstance(value, datetime.datetime): 1019 return value 1020 if isinstance(value, datetime.date): 1021 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1022 try: 1023 return datetime.datetime.fromisoformat(value) 1024 except ValueError: 1025 return None 1026 1027 1028def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1029 if not value: 1030 return None 1031 if to.is_type(exp.DataType.Type.DATE): 1032 return cast_as_date(value) 1033 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1034 return cast_as_datetime(value) 1035 return None 1036 1037 1038def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1039 if isinstance(cast, exp.Cast): 1040 to = cast.to 1041 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1042 to = exp.DataType.build(exp.DataType.Type.DATE) 1043 else: 1044 return None 1045 1046 if isinstance(cast.this, exp.Literal): 1047 value: t.Any = cast.this.name 1048 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1049 value = extract_date(cast.this) 1050 else: 1051 return None 1052 return cast_value(value, to) 1053 1054 1055def _is_date_literal(expression: exp.Expression) -> bool: 1056 return extract_date(expression) is not None 1057 1058 1059def extract_interval(expression): 1060 try: 1061 n = int(expression.name) 1062 unit = expression.text("unit").lower() 1063 return interval(unit, n) 1064 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1065 return None 1066 1067 1068def date_literal(date): 1069 return exp.cast( 1070 exp.Literal.string(date), 1071 exp.DataType.Type.DATETIME 1072 if isinstance(date, datetime.datetime) 1073 else exp.DataType.Type.DATE, 1074 ) 1075 1076 1077def interval(unit: str, n: int = 1): 1078 from dateutil.relativedelta import relativedelta 1079 1080 if unit == "year": 1081 return relativedelta(years=1 * n) 1082 if unit == "quarter": 1083 return relativedelta(months=3 * n) 1084 if unit == "month": 1085 return relativedelta(months=1 * n) 1086 if unit == "week": 1087 return relativedelta(weeks=1 * n) 1088 if unit == "day": 1089 return relativedelta(days=1 * n) 1090 if unit == "hour": 1091 return relativedelta(hours=1 * n) 1092 if unit == "minute": 1093 return relativedelta(minutes=1 * n) 1094 if unit == "second": 1095 return relativedelta(seconds=1 * n) 1096 1097 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1098 1099 1100def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1101 if unit == "year": 1102 return d.replace(month=1, day=1) 1103 if unit == "quarter": 1104 if d.month <= 3: 1105 return d.replace(month=1, day=1) 1106 elif d.month <= 6: 1107 return d.replace(month=4, day=1) 1108 elif d.month <= 9: 1109 return d.replace(month=7, day=1) 1110 else: 1111 return d.replace(month=10, day=1) 1112 if unit == "month": 1113 return d.replace(month=d.month, day=1) 1114 if unit == "week": 1115 # Assuming week starts on Monday (0) and ends on Sunday (6) 1116 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1117 if unit == "day": 1118 return d 1119 1120 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1121 1122 1123def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1124 floor = date_floor(d, unit, dialect) 1125 1126 if floor == d: 1127 return d 1128 1129 return floor + interval(unit) 1130 1131 1132def boolean_literal(condition): 1133 return exp.true() if condition else exp.false() 1134 1135 1136def _flat_simplify(expression, simplifier, root=True): 1137 if root or not expression.same_parent: 1138 operands = [] 1139 queue = deque(expression.flatten(unnest=False)) 1140 size = len(queue) 1141 1142 while queue: 1143 a = queue.popleft() 1144 1145 for b in queue: 1146 result = simplifier(expression, a, b) 1147 1148 if result and result is not expression: 1149 queue.remove(b) 1150 queue.appendleft(result) 1151 break 1152 else: 1153 operands.append(a) 1154 1155 if len(operands) < size: 1156 return functools.reduce( 1157 lambda a, b: expression.__class__(this=a, expression=b), operands 1158 ) 1159 return expression 1160 1161 1162def gen(expression: t.Any) -> str: 1163 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1164 1165 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1166 generator is expensive so we have a bare minimum sql generator here. 1167 """ 1168 if expression is None: 1169 return "_" 1170 if is_iterable(expression): 1171 return ",".join(gen(e) for e in expression) 1172 if not isinstance(expression, exp.Expression): 1173 return str(expression) 1174 1175 etype = type(expression) 1176 if etype in GEN_MAP: 1177 return GEN_MAP[etype](expression) 1178 return f"{expression.key} {gen(expression.args.values())}" 1179 1180 1181GEN_MAP = { 1182 exp.Add: lambda e: _binary(e, "+"), 1183 exp.And: lambda e: _binary(e, "AND"), 1184 exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}", 1185 exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", 1186 exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", 1187 exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", 1188 exp.Column: lambda e: ".".join(gen(p) for p in e.parts), 1189 exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", 1190 exp.Div: lambda e: _binary(e, "/"), 1191 exp.Dot: lambda e: _binary(e, "."), 1192 exp.EQ: lambda e: _binary(e, "="), 1193 exp.GT: lambda e: _binary(e, ">"), 1194 exp.GTE: lambda e: _binary(e, ">="), 1195 exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, 1196 exp.ILike: lambda e: _binary(e, "ILIKE"), 1197 exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", 1198 exp.Is: lambda e: _binary(e, "IS"), 1199 exp.Like: lambda e: _binary(e, "LIKE"), 1200 exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, 1201 exp.LT: lambda e: _binary(e, "<"), 1202 exp.LTE: lambda e: _binary(e, "<="), 1203 exp.Mod: lambda e: _binary(e, "%"), 1204 exp.Mul: lambda e: _binary(e, "*"), 1205 exp.Neg: lambda e: _unary(e, "-"), 1206 exp.NEQ: lambda e: _binary(e, "<>"), 1207 exp.Not: lambda e: _unary(e, "NOT"), 1208 exp.Null: lambda e: "NULL", 1209 exp.Or: lambda e: _binary(e, "OR"), 1210 exp.Paren: lambda e: f"({gen(e.this)})", 1211 exp.Sub: lambda e: _binary(e, "-"), 1212 exp.Subquery: lambda e: f"({gen(e.args.values())})", 1213 exp.Table: lambda e: gen(e.args.values()), 1214 exp.Var: lambda e: e.name, 1215} 1216 1217 1218def _binary(e: exp.Binary, op: str) -> str: 1219 return f"{gen(e.left)} {op} {gen(e.right)}" 1220 1221 1222def _unary(e: exp.Unary, op: str) -> str: 1223 return f"{op} {gen(e.this)}"
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
31def simplify( 32 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 33): 34 """ 35 Rewrite sqlglot AST to simplify expressions. 36 37 Example: 38 >>> import sqlglot 39 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 40 >>> simplify(expression).sql() 41 'TRUE' 42 43 Args: 44 expression (sqlglot.Expression): expression to simplify 45 constant_propagation: whether or not the constant propagation rule should be used 46 47 Returns: 48 sqlglot.Expression: simplified expression 49 """ 50 51 dialect = Dialect.get_or_raise(dialect) 52 53 def _simplify(expression, root=True): 54 if expression.meta.get(FINAL): 55 return expression 56 57 # group by expressions cannot be simplified, for example 58 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 59 # the projection must exactly match the group by key 60 group = expression.args.get("group") 61 62 if group and hasattr(expression, "selects"): 63 groups = set(group.expressions) 64 group.meta[FINAL] = True 65 66 for e in expression.selects: 67 for node, *_ in e.walk(): 68 if node in groups: 69 e.meta[FINAL] = True 70 break 71 72 having = expression.args.get("having") 73 if having: 74 for node, *_ in having.walk(): 75 if node in groups: 76 having.meta[FINAL] = True 77 break 78 79 # Pre-order transformations 80 node = expression 81 node = rewrite_between(node) 82 node = uniq_sort(node, root) 83 node = absorb_and_eliminate(node, root) 84 node = simplify_concat(node) 85 node = simplify_conditionals(node) 86 87 if constant_propagation: 88 node = propagate_constants(node, root) 89 90 exp.replace_children(node, lambda e: _simplify(e, False)) 91 92 # Post-order transformations 93 node = simplify_not(node) 94 node = flatten(node) 95 node = simplify_connectors(node, root) 96 node = remove_complements(node, root) 97 node = simplify_coalesce(node) 98 node.parent = expression.parent 99 node = simplify_literals(node, root) 100 node = simplify_equality(node) 101 node = simplify_parens(node) 102 node = simplify_datetrunc(node, dialect) 103 node = sort_comparison(node) 104 node = simplify_startswith(node) 105 106 if root: 107 expression.replace(node) 108 109 return node 110 111 expression = while_changing(expression, _simplify) 112 remove_where_true(expression) 113 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
116def catch(*exceptions): 117 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 118 119 def decorator(func): 120 def wrapped(expression, *args, **kwargs): 121 try: 122 return func(expression, *args, **kwargs) 123 except exceptions: 124 return expression 125 126 return wrapped 127 128 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
131def rewrite_between(expression: exp.Expression) -> exp.Expression: 132 """Rewrite x between y and z to x >= y AND x <= z. 133 134 This is done because comparison simplification is only done on lt/lte/gt/gte. 135 """ 136 if isinstance(expression, exp.Between): 137 negate = isinstance(expression.parent, exp.Not) 138 139 expression = exp.and_( 140 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 141 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 142 copy=False, 143 ) 144 145 if negate: 146 expression = exp.paren(expression, copy=False) 147 148 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
161def simplify_not(expression): 162 """ 163 Demorgan's Law 164 NOT (x OR y) -> NOT x AND NOT y 165 NOT (x AND y) -> NOT x OR NOT y 166 """ 167 if isinstance(expression, exp.Not): 168 this = expression.this 169 if is_null(this): 170 return exp.null() 171 if this.__class__ in COMPLEMENT_COMPARISONS: 172 return COMPLEMENT_COMPARISONS[this.__class__]( 173 this=this.this, expression=this.expression 174 ) 175 if isinstance(this, exp.Paren): 176 condition = this.unnest() 177 if isinstance(condition, exp.And): 178 return exp.or_( 179 exp.not_(condition.left, copy=False), 180 exp.not_(condition.right, copy=False), 181 copy=False, 182 ) 183 if isinstance(condition, exp.Or): 184 return exp.and_( 185 exp.not_(condition.left, copy=False), 186 exp.not_(condition.right, copy=False), 187 copy=False, 188 ) 189 if is_null(condition): 190 return exp.null() 191 if always_true(this): 192 return exp.false() 193 if is_false(this): 194 return exp.true() 195 if isinstance(this, exp.Not): 196 # double negation 197 # NOT NOT x -> x 198 return this.this 199 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
202def flatten(expression): 203 """ 204 A AND (B AND C) -> A AND B AND C 205 A OR (B OR C) -> A OR B OR C 206 """ 207 if isinstance(expression, exp.Connector): 208 for node in expression.args.values(): 209 child = node.unnest() 210 if isinstance(child, expression.__class__): 211 node.replace(child) 212 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
215def simplify_connectors(expression, root=True): 216 def _simplify_connectors(expression, left, right): 217 if left == right: 218 return left 219 if isinstance(expression, exp.And): 220 if is_false(left) or is_false(right): 221 return exp.false() 222 if is_null(left) or is_null(right): 223 return exp.null() 224 if always_true(left) and always_true(right): 225 return exp.true() 226 if always_true(left): 227 return right 228 if always_true(right): 229 return left 230 return _simplify_comparison(expression, left, right) 231 elif isinstance(expression, exp.Or): 232 if always_true(left) or always_true(right): 233 return exp.true() 234 if is_false(left) and is_false(right): 235 return exp.false() 236 if ( 237 (is_null(left) and is_null(right)) 238 or (is_null(left) and is_false(right)) 239 or (is_false(left) and is_null(right)) 240 ): 241 return exp.null() 242 if is_false(left): 243 return right 244 if is_false(right): 245 return left 246 return _simplify_comparison(expression, left, right, or_=True) 247 248 if isinstance(expression, exp.Connector): 249 return _flat_simplify(expression, _simplify_connectors, root) 250 return expression
334def remove_complements(expression, root=True): 335 """ 336 Removing complements. 337 338 A AND NOT A -> FALSE 339 A OR NOT A -> TRUE 340 """ 341 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 342 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 343 344 for a, b in itertools.permutations(expression.flatten(), 2): 345 if is_complement(a, b): 346 return complement 347 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
350def uniq_sort(expression, root=True): 351 """ 352 Uniq and sort a connector. 353 354 C AND A AND B AND B -> A AND B AND C 355 """ 356 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 357 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 358 flattened = tuple(expression.flatten()) 359 deduped = {gen(e): e for e in flattened} 360 arr = tuple(deduped.items()) 361 362 # check if the operands are already sorted, if not sort them 363 # A AND C AND B -> A AND B AND C 364 for i, (sql, e) in enumerate(arr[1:]): 365 if sql < arr[i][0]: 366 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 367 break 368 else: 369 # we didn't have to sort but maybe we need to dedup 370 if len(deduped) < len(flattened): 371 expression = result_func(*deduped.values(), copy=False) 372 373 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
376def absorb_and_eliminate(expression, root=True): 377 """ 378 absorption: 379 A AND (A OR B) -> A 380 A OR (A AND B) -> A 381 A AND (NOT A OR B) -> A AND B 382 A OR (NOT A AND B) -> A OR B 383 elimination: 384 (A AND B) OR (A AND NOT B) -> A 385 (A OR B) AND (A OR NOT B) -> A 386 """ 387 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 388 kind = exp.Or if isinstance(expression, exp.And) else exp.And 389 390 for a, b in itertools.permutations(expression.flatten(), 2): 391 if isinstance(a, kind): 392 aa, ab = a.unnest_operands() 393 394 # absorb 395 if is_complement(b, aa): 396 aa.replace(exp.true() if kind == exp.And else exp.false()) 397 elif is_complement(b, ab): 398 ab.replace(exp.true() if kind == exp.And else exp.false()) 399 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 400 a.replace(exp.false() if kind == exp.And else exp.true()) 401 elif isinstance(b, kind): 402 # eliminate 403 rhs = b.unnest_operands() 404 ba, bb = rhs 405 406 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 407 a.replace(aa) 408 b.replace(aa) 409 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 410 a.replace(ab) 411 b.replace(ab) 412 413 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
416def propagate_constants(expression, root=True): 417 """ 418 Propagate constants for conjunctions in DNF: 419 420 SELECT * FROM t WHERE a = b AND b = 5 becomes 421 SELECT * FROM t WHERE a = 5 AND b = 5 422 423 Reference: https://www.sqlite.org/optoverview.html 424 """ 425 426 if ( 427 isinstance(expression, exp.And) 428 and (root or not expression.same_parent) 429 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 430 ): 431 constant_mapping = {} 432 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 433 if isinstance(expr, exp.EQ): 434 l, r = expr.left, expr.right 435 436 # TODO: create a helper that can be used to detect nested literal expressions such 437 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 438 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 439 constant_mapping[l] = (id(l), r) 440 441 if constant_mapping: 442 for column in find_all_in_scope(expression, exp.Column): 443 parent = column.parent 444 column_id, constant = constant_mapping.get(column) or (None, None) 445 if ( 446 column_id is not None 447 and id(column) != column_id 448 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 449 ): 450 column.replace(constant.copy()) 451 452 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
120 def wrapped(expression, *args, **kwargs): 121 try: 122 return func(expression, *args, **kwargs) 123 except exceptions: 124 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
527def simplify_literals(expression, root=True): 528 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 529 return _flat_simplify(expression, _simplify_binary, root) 530 531 if isinstance(expression, exp.Neg): 532 this = expression.this 533 if this.is_number: 534 value = this.name 535 if value[0] == "-": 536 return exp.Literal.number(value[1:]) 537 return exp.Literal.number(f"-{value}") 538 539 if type(expression) in INVERSE_DATE_OPS: 540 return _simplify_binary(expression, expression.this, expression.interval()) or expression 541 542 return expression
613def simplify_parens(expression): 614 if not isinstance(expression, exp.Paren): 615 return expression 616 617 this = expression.this 618 parent = expression.parent 619 620 if not isinstance(this, exp.Select) and ( 621 not isinstance(parent, (exp.Condition, exp.Binary)) 622 or isinstance(parent, exp.Paren) 623 or not isinstance(this, exp.Binary) 624 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 625 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 626 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 627 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 628 ): 629 return this 630 return expression
653def simplify_coalesce(expression): 654 # COALESCE(x) -> x 655 if ( 656 isinstance(expression, exp.Coalesce) 657 and (not expression.expressions or _is_nonnull_constant(expression.this)) 658 # COALESCE is also used as a Spark partitioning hint 659 and not isinstance(expression.parent, exp.Hint) 660 ): 661 return expression.this 662 663 if not isinstance(expression, COMPARISONS): 664 return expression 665 666 if isinstance(expression.left, exp.Coalesce): 667 coalesce = expression.left 668 other = expression.right 669 elif isinstance(expression.right, exp.Coalesce): 670 coalesce = expression.right 671 other = expression.left 672 else: 673 return expression 674 675 # This transformation is valid for non-constants, 676 # but it really only does anything if they are both constants. 677 if not _is_constant(other): 678 return expression 679 680 # Find the first constant arg 681 for arg_index, arg in enumerate(coalesce.expressions): 682 if _is_constant(arg): 683 break 684 else: 685 return expression 686 687 coalesce.set("expressions", coalesce.expressions[:arg_index]) 688 689 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 690 # since we already remove COALESCE at the top of this function. 691 coalesce = coalesce if coalesce.expressions else coalesce.this 692 693 # This expression is more complex than when we started, but it will get simplified further 694 return exp.paren( 695 exp.or_( 696 exp.and_( 697 coalesce.is_(exp.null()).not_(copy=False), 698 expression.copy(), 699 copy=False, 700 ), 701 exp.and_( 702 coalesce.is_(exp.null()), 703 type(expression)(this=arg.copy(), expression=other.copy()), 704 copy=False, 705 ), 706 copy=False, 707 ) 708 )
714def simplify_concat(expression): 715 """Reduces all groups that contain string literals by concatenating them.""" 716 if not isinstance(expression, CONCATS) or ( 717 # We can't reduce a CONCAT_WS call if we don't statically know the separator 718 isinstance(expression, exp.ConcatWs) 719 and not expression.expressions[0].is_string 720 ): 721 return expression 722 723 if isinstance(expression, exp.ConcatWs): 724 sep_expr, *expressions = expression.expressions 725 sep = sep_expr.name 726 concat_type = exp.ConcatWs 727 args = {} 728 else: 729 expressions = expression.expressions 730 sep = "" 731 concat_type = exp.Concat 732 args = { 733 "safe": expression.args.get("safe"), 734 "coalesce": expression.args.get("coalesce"), 735 } 736 737 new_args = [] 738 for is_string_group, group in itertools.groupby( 739 expressions or expression.flatten(), lambda e: e.is_string 740 ): 741 if is_string_group: 742 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 743 else: 744 new_args.extend(group) 745 746 if len(new_args) == 1 and new_args[0].is_string: 747 return new_args[0] 748 749 if concat_type is exp.ConcatWs: 750 new_args = [sep_expr] + new_args 751 752 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
755def simplify_conditionals(expression): 756 """Simplifies expressions like IF, CASE if their condition is statically known.""" 757 if isinstance(expression, exp.Case): 758 this = expression.this 759 for case in expression.args["ifs"]: 760 cond = case.this 761 if this: 762 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 763 cond = cond.replace(this.pop().eq(cond)) 764 765 if always_true(cond): 766 return case.args["true"] 767 768 if always_false(cond): 769 case.pop() 770 if not expression.args["ifs"]: 771 return expression.args.get("default") or exp.null() 772 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 773 if always_true(expression.this): 774 return expression.args["true"] 775 if always_false(expression.this): 776 return expression.args.get("false") or exp.null() 777 778 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
781def simplify_startswith(expression: exp.Expression) -> exp.Expression: 782 """ 783 Reduces a prefix check to either TRUE or FALSE if both the string and the 784 prefix are statically known. 785 786 Example: 787 >>> from sqlglot import parse_one 788 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 789 'TRUE' 790 """ 791 if ( 792 isinstance(expression, exp.StartsWith) 793 and expression.this.is_string 794 and expression.expression.is_string 795 ): 796 return exp.convert(expression.name.startswith(expression.expression.name)) 797 798 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
120 def wrapped(expression, *args, **kwargs): 121 try: 122 return func(expression, *args, **kwargs) 123 except exceptions: 124 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
925def sort_comparison(expression: exp.Expression) -> exp.Expression: 926 if expression.__class__ in COMPLEMENT_COMPARISONS: 927 l, r = expression.this, expression.expression 928 l_column = isinstance(l, exp.Column) 929 r_column = isinstance(r, exp.Column) 930 l_const = _is_constant(l) 931 r_const = _is_constant(r) 932 933 if (l_column and not r_column) or (r_const and not l_const): 934 return expression 935 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 936 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 937 this=r, expression=l 938 ) 939 return expression
953def remove_where_true(expression): 954 for where in expression.find_all(exp.Where): 955 if always_true(where.this): 956 where.parent.set("where", None) 957 for join in expression.find_all(exp.Join): 958 if ( 959 always_true(join.args.get("on")) 960 and not join.args.get("using") 961 and not join.args.get("method") 962 and (join.side, join.kind) in JOINS 963 ): 964 join.set("on", None) 965 join.set("side", None) 966 join.set("kind", "CROSS")
991def eval_boolean(expression, a, b): 992 if isinstance(expression, (exp.EQ, exp.Is)): 993 return boolean_literal(a == b) 994 if isinstance(expression, exp.NEQ): 995 return boolean_literal(a != b) 996 if isinstance(expression, exp.GT): 997 return boolean_literal(a > b) 998 if isinstance(expression, exp.GTE): 999 return boolean_literal(a >= b) 1000 if isinstance(expression, exp.LT): 1001 return boolean_literal(a < b) 1002 if isinstance(expression, exp.LTE): 1003 return boolean_literal(a <= b) 1004 return None
1007def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1008 if isinstance(value, datetime.datetime): 1009 return value.date() 1010 if isinstance(value, datetime.date): 1011 return value 1012 try: 1013 return datetime.datetime.fromisoformat(value).date() 1014 except ValueError: 1015 return None
1018def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1019 if isinstance(value, datetime.datetime): 1020 return value 1021 if isinstance(value, datetime.date): 1022 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1023 try: 1024 return datetime.datetime.fromisoformat(value) 1025 except ValueError: 1026 return None
1029def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1030 if not value: 1031 return None 1032 if to.is_type(exp.DataType.Type.DATE): 1033 return cast_as_date(value) 1034 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1035 return cast_as_datetime(value) 1036 return None
1039def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1040 if isinstance(cast, exp.Cast): 1041 to = cast.to 1042 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1043 to = exp.DataType.build(exp.DataType.Type.DATE) 1044 else: 1045 return None 1046 1047 if isinstance(cast.this, exp.Literal): 1048 value: t.Any = cast.this.name 1049 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1050 value = extract_date(cast.this) 1051 else: 1052 return None 1053 return cast_value(value, to)
1078def interval(unit: str, n: int = 1): 1079 from dateutil.relativedelta import relativedelta 1080 1081 if unit == "year": 1082 return relativedelta(years=1 * n) 1083 if unit == "quarter": 1084 return relativedelta(months=3 * n) 1085 if unit == "month": 1086 return relativedelta(months=1 * n) 1087 if unit == "week": 1088 return relativedelta(weeks=1 * n) 1089 if unit == "day": 1090 return relativedelta(days=1 * n) 1091 if unit == "hour": 1092 return relativedelta(hours=1 * n) 1093 if unit == "minute": 1094 return relativedelta(minutes=1 * n) 1095 if unit == "second": 1096 return relativedelta(seconds=1 * n) 1097 1098 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1101def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1102 if unit == "year": 1103 return d.replace(month=1, day=1) 1104 if unit == "quarter": 1105 if d.month <= 3: 1106 return d.replace(month=1, day=1) 1107 elif d.month <= 6: 1108 return d.replace(month=4, day=1) 1109 elif d.month <= 9: 1110 return d.replace(month=7, day=1) 1111 else: 1112 return d.replace(month=10, day=1) 1113 if unit == "month": 1114 return d.replace(month=d.month, day=1) 1115 if unit == "week": 1116 # Assuming week starts on Monday (0) and ends on Sunday (6) 1117 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1118 if unit == "day": 1119 return d 1120 1121 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1163def gen(expression: t.Any) -> str: 1164 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1165 1166 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1167 generator is expensive so we have a bare minimum sql generator here. 1168 """ 1169 if expression is None: 1170 return "_" 1171 if is_iterable(expression): 1172 return ",".join(gen(e) for e in expression) 1173 if not isinstance(expression, exp.Expression): 1174 return str(expression) 1175 1176 etype = type(expression) 1177 if etype in GEN_MAP: 1178 return GEN_MAP[etype](expression) 1179 return f"{expression.key} {gen(expression.args.values())}"
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.