sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8import sqlglot 9from sqlglot import exp 10from sqlglot.helper import first, is_iterable, merge_ranges, while_changing 11from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 12 13# Final means that an expression should not be simplified 14FINAL = "final" 15 16 17class UnsupportedUnit(Exception): 18 pass 19 20 21def simplify(expression, constant_propagation=False): 22 """ 23 Rewrite sqlglot AST to simplify expressions. 24 25 Example: 26 >>> import sqlglot 27 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 28 >>> simplify(expression).sql() 29 'TRUE' 30 31 Args: 32 expression (sqlglot.Expression): expression to simplify 33 constant_propagation: whether or not the constant propagation rule should be used 34 35 Returns: 36 sqlglot.Expression: simplified expression 37 """ 38 39 # group by expressions cannot be simplified, for example 40 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 41 # the projection must exactly match the group by key 42 for group in expression.find_all(exp.Group): 43 select = group.parent 44 groups = set(group.expressions) 45 group.meta[FINAL] = True 46 47 for e in select.selects: 48 for node, *_ in e.walk(): 49 if node in groups: 50 e.meta[FINAL] = True 51 break 52 53 having = select.args.get("having") 54 if having: 55 for node, *_ in having.walk(): 56 if node in groups: 57 having.meta[FINAL] = True 58 break 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # Pre-order transformations 65 node = expression 66 node = rewrite_between(node) 67 node = uniq_sort(node, root) 68 node = absorb_and_eliminate(node, root) 69 node = simplify_concat(node) 70 node = simplify_conditionals(node) 71 72 if constant_propagation: 73 node = propagate_constants(node, root) 74 75 exp.replace_children(node, lambda e: _simplify(e, False)) 76 77 # Post-order transformations 78 node = simplify_not(node) 79 node = flatten(node) 80 node = simplify_connectors(node, root) 81 node = remove_complements(node, root) 82 node = simplify_coalesce(node) 83 node.parent = expression.parent 84 node = simplify_literals(node, root) 85 node = simplify_equality(node) 86 node = simplify_parens(node) 87 node = simplify_datetrunc_predicate(node) 88 89 if root: 90 expression.replace(node) 91 92 return node 93 94 expression = while_changing(expression, _simplify) 95 remove_where_true(expression) 96 return expression 97 98 99def catch(*exceptions): 100 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 101 102 def decorator(func): 103 def wrapped(expression, *args, **kwargs): 104 try: 105 return func(expression, *args, **kwargs) 106 except exceptions: 107 return expression 108 109 return wrapped 110 111 return decorator 112 113 114def rewrite_between(expression: exp.Expression) -> exp.Expression: 115 """Rewrite x between y and z to x >= y AND x <= z. 116 117 This is done because comparison simplification is only done on lt/lte/gt/gte. 118 """ 119 if isinstance(expression, exp.Between): 120 return exp.and_( 121 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 122 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 123 copy=False, 124 ) 125 return expression 126 127 128def simplify_not(expression): 129 """ 130 Demorgan's Law 131 NOT (x OR y) -> NOT x AND NOT y 132 NOT (x AND y) -> NOT x OR NOT y 133 """ 134 if isinstance(expression, exp.Not): 135 if is_null(expression.this): 136 return exp.null() 137 if isinstance(expression.this, exp.Paren): 138 condition = expression.this.unnest() 139 if isinstance(condition, exp.And): 140 return exp.or_( 141 exp.not_(condition.left, copy=False), 142 exp.not_(condition.right, copy=False), 143 copy=False, 144 ) 145 if isinstance(condition, exp.Or): 146 return exp.and_( 147 exp.not_(condition.left, copy=False), 148 exp.not_(condition.right, copy=False), 149 copy=False, 150 ) 151 if is_null(condition): 152 return exp.null() 153 if always_true(expression.this): 154 return exp.false() 155 if is_false(expression.this): 156 return exp.true() 157 if isinstance(expression.this, exp.Not): 158 # double negation 159 # NOT NOT x -> x 160 return expression.this.this 161 return expression 162 163 164def flatten(expression): 165 """ 166 A AND (B AND C) -> A AND B AND C 167 A OR (B OR C) -> A OR B OR C 168 """ 169 if isinstance(expression, exp.Connector): 170 for node in expression.args.values(): 171 child = node.unnest() 172 if isinstance(child, expression.__class__): 173 node.replace(child) 174 return expression 175 176 177def simplify_connectors(expression, root=True): 178 def _simplify_connectors(expression, left, right): 179 if left == right: 180 return left 181 if isinstance(expression, exp.And): 182 if is_false(left) or is_false(right): 183 return exp.false() 184 if is_null(left) or is_null(right): 185 return exp.null() 186 if always_true(left) and always_true(right): 187 return exp.true() 188 if always_true(left): 189 return right 190 if always_true(right): 191 return left 192 return _simplify_comparison(expression, left, right) 193 elif isinstance(expression, exp.Or): 194 if always_true(left) or always_true(right): 195 return exp.true() 196 if is_false(left) and is_false(right): 197 return exp.false() 198 if ( 199 (is_null(left) and is_null(right)) 200 or (is_null(left) and is_false(right)) 201 or (is_false(left) and is_null(right)) 202 ): 203 return exp.null() 204 if is_false(left): 205 return right 206 if is_false(right): 207 return left 208 return _simplify_comparison(expression, left, right, or_=True) 209 210 if isinstance(expression, exp.Connector): 211 return _flat_simplify(expression, _simplify_connectors, root) 212 return expression 213 214 215LT_LTE = (exp.LT, exp.LTE) 216GT_GTE = (exp.GT, exp.GTE) 217 218COMPARISONS = ( 219 *LT_LTE, 220 *GT_GTE, 221 exp.EQ, 222 exp.NEQ, 223 exp.Is, 224) 225 226INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 227 exp.LT: exp.GT, 228 exp.GT: exp.LT, 229 exp.LTE: exp.GTE, 230 exp.GTE: exp.LTE, 231} 232 233 234def _simplify_comparison(expression, left, right, or_=False): 235 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 236 ll, lr = left.args.values() 237 rl, rr = right.args.values() 238 239 largs = {ll, lr} 240 rargs = {rl, rr} 241 242 matching = largs & rargs 243 columns = {m for m in matching if isinstance(m, exp.Column)} 244 245 if matching and columns: 246 try: 247 l = first(largs - columns) 248 r = first(rargs - columns) 249 except StopIteration: 250 return expression 251 252 # make sure the comparison is always of the form x > 1 instead of 1 < x 253 if left.__class__ in INVERSE_COMPARISONS and l == ll: 254 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 255 if right.__class__ in INVERSE_COMPARISONS and r == rl: 256 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 257 258 if l.is_number and r.is_number: 259 l = float(l.name) 260 r = float(r.name) 261 elif l.is_string and r.is_string: 262 l = l.name 263 r = r.name 264 else: 265 return None 266 267 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 268 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 269 return left if (av > bv if or_ else av <= bv) else right 270 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 271 return left if (av < bv if or_ else av >= bv) else right 272 273 # we can't ever shortcut to true because the column could be null 274 if not or_: 275 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 276 if av <= bv: 277 return exp.false() 278 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 279 if av >= bv: 280 return exp.false() 281 elif isinstance(a, exp.EQ): 282 if isinstance(b, exp.LT): 283 return exp.false() if av >= bv else a 284 if isinstance(b, exp.LTE): 285 return exp.false() if av > bv else a 286 if isinstance(b, exp.GT): 287 return exp.false() if av <= bv else a 288 if isinstance(b, exp.GTE): 289 return exp.false() if av < bv else a 290 if isinstance(b, exp.NEQ): 291 return exp.false() if av == bv else a 292 return None 293 294 295def remove_complements(expression, root=True): 296 """ 297 Removing complements. 298 299 A AND NOT A -> FALSE 300 A OR NOT A -> TRUE 301 """ 302 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 303 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 304 305 for a, b in itertools.permutations(expression.flatten(), 2): 306 if is_complement(a, b): 307 return complement 308 return expression 309 310 311def uniq_sort(expression, root=True): 312 """ 313 Uniq and sort a connector. 314 315 C AND A AND B AND B -> A AND B AND C 316 """ 317 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 318 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 319 flattened = tuple(expression.flatten()) 320 deduped = {gen(e): e for e in flattened} 321 arr = tuple(deduped.items()) 322 323 # check if the operands are already sorted, if not sort them 324 # A AND C AND B -> A AND B AND C 325 for i, (sql, e) in enumerate(arr[1:]): 326 if sql < arr[i][0]: 327 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 328 break 329 else: 330 # we didn't have to sort but maybe we need to dedup 331 if len(deduped) < len(flattened): 332 expression = result_func(*deduped.values(), copy=False) 333 334 return expression 335 336 337def absorb_and_eliminate(expression, root=True): 338 """ 339 absorption: 340 A AND (A OR B) -> A 341 A OR (A AND B) -> A 342 A AND (NOT A OR B) -> A AND B 343 A OR (NOT A AND B) -> A OR B 344 elimination: 345 (A AND B) OR (A AND NOT B) -> A 346 (A OR B) AND (A OR NOT B) -> A 347 """ 348 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 349 kind = exp.Or if isinstance(expression, exp.And) else exp.And 350 351 for a, b in itertools.permutations(expression.flatten(), 2): 352 if isinstance(a, kind): 353 aa, ab = a.unnest_operands() 354 355 # absorb 356 if is_complement(b, aa): 357 aa.replace(exp.true() if kind == exp.And else exp.false()) 358 elif is_complement(b, ab): 359 ab.replace(exp.true() if kind == exp.And else exp.false()) 360 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 361 a.replace(exp.false() if kind == exp.And else exp.true()) 362 elif isinstance(b, kind): 363 # eliminate 364 rhs = b.unnest_operands() 365 ba, bb = rhs 366 367 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 368 a.replace(aa) 369 b.replace(aa) 370 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 371 a.replace(ab) 372 b.replace(ab) 373 374 return expression 375 376 377def propagate_constants(expression, root=True): 378 """ 379 Propagate constants for conjunctions in DNF: 380 381 SELECT * FROM t WHERE a = b AND b = 5 becomes 382 SELECT * FROM t WHERE a = 5 AND b = 5 383 384 Reference: https://www.sqlite.org/optoverview.html 385 """ 386 387 if ( 388 isinstance(expression, exp.And) 389 and (root or not expression.same_parent) 390 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 391 ): 392 constant_mapping = {} 393 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 394 if isinstance(expr, exp.EQ): 395 l, r = expr.left, expr.right 396 397 # TODO: create a helper that can be used to detect nested literal expressions such 398 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 399 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 400 pass 401 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 402 l, r = r, l 403 else: 404 continue 405 406 constant_mapping[l] = (id(l), r) 407 408 if constant_mapping: 409 for column in find_all_in_scope(expression, exp.Column): 410 parent = column.parent 411 column_id, constant = constant_mapping.get(column) or (None, None) 412 if ( 413 column_id is not None 414 and id(column) != column_id 415 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 416 ): 417 column.replace(constant.copy()) 418 419 return expression 420 421 422INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 423 exp.DateAdd: exp.Sub, 424 exp.DateSub: exp.Add, 425 exp.DatetimeAdd: exp.Sub, 426 exp.DatetimeSub: exp.Add, 427} 428 429INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 430 **INVERSE_DATE_OPS, 431 exp.Add: exp.Sub, 432 exp.Sub: exp.Add, 433} 434 435 436def _is_number(expression: exp.Expression) -> bool: 437 return expression.is_number 438 439 440def _is_interval(expression: exp.Expression) -> bool: 441 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 442 443 444@catch(ModuleNotFoundError, UnsupportedUnit) 445def simplify_equality(expression: exp.Expression) -> exp.Expression: 446 """ 447 Use the subtraction and addition properties of equality to simplify expressions: 448 449 x + 1 = 3 becomes x = 2 450 451 There are two binary operations in the above expression: + and = 452 Here's how we reference all the operands in the code below: 453 454 l r 455 x + 1 = 3 456 a b 457 """ 458 if isinstance(expression, COMPARISONS): 459 l, r = expression.left, expression.right 460 461 if l.__class__ in INVERSE_OPS: 462 pass 463 elif r.__class__ in INVERSE_OPS: 464 l, r = r, l 465 else: 466 return expression 467 468 if r.is_number: 469 a_predicate = _is_number 470 b_predicate = _is_number 471 elif _is_date_literal(r): 472 a_predicate = _is_date_literal 473 b_predicate = _is_interval 474 else: 475 return expression 476 477 if l.__class__ in INVERSE_DATE_OPS: 478 l = t.cast(exp.IntervalOp, l) 479 a = l.this 480 b = l.interval() 481 else: 482 l = t.cast(exp.Binary, l) 483 a, b = l.left, l.right 484 485 if not a_predicate(a) and b_predicate(b): 486 pass 487 elif not a_predicate(b) and b_predicate(a): 488 a, b = b, a 489 else: 490 return expression 491 492 return expression.__class__( 493 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 494 ) 495 return expression 496 497 498def simplify_literals(expression, root=True): 499 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 500 return _flat_simplify(expression, _simplify_binary, root) 501 502 if isinstance(expression, exp.Neg): 503 this = expression.this 504 if this.is_number: 505 value = this.name 506 if value[0] == "-": 507 return exp.Literal.number(value[1:]) 508 return exp.Literal.number(f"-{value}") 509 510 if type(expression) in INVERSE_DATE_OPS: 511 return _simplify_binary(expression, expression.this, expression.interval()) or expression 512 513 return expression 514 515 516def _simplify_binary(expression, a, b): 517 if isinstance(expression, exp.Is): 518 if isinstance(b, exp.Not): 519 c = b.this 520 not_ = True 521 else: 522 c = b 523 not_ = False 524 525 if is_null(c): 526 if isinstance(a, exp.Literal): 527 return exp.true() if not_ else exp.false() 528 if is_null(a): 529 return exp.false() if not_ else exp.true() 530 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 531 return None 532 elif is_null(a) or is_null(b): 533 return exp.null() 534 535 if a.is_number and b.is_number: 536 num_a = int(a.name) if a.is_int else Decimal(a.name) 537 num_b = int(b.name) if b.is_int else Decimal(b.name) 538 539 if isinstance(expression, exp.Add): 540 return exp.Literal.number(num_a + num_b) 541 if isinstance(expression, exp.Mul): 542 return exp.Literal.number(num_a * num_b) 543 544 # We only simplify Sub, Div if a and b have the same parent because they're not associative 545 if isinstance(expression, exp.Sub): 546 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 547 if isinstance(expression, exp.Div): 548 # engines have differing int div behavior so intdiv is not safe 549 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 550 return None 551 return exp.Literal.number(num_a / num_b) 552 553 boolean = eval_boolean(expression, num_a, num_b) 554 555 if boolean: 556 return boolean 557 elif a.is_string and b.is_string: 558 boolean = eval_boolean(expression, a.this, b.this) 559 560 if boolean: 561 return boolean 562 elif _is_date_literal(a) and isinstance(b, exp.Interval): 563 a, b = extract_date(a), extract_interval(b) 564 if a and b: 565 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 566 return date_literal(a + b) 567 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 568 return date_literal(a - b) 569 elif isinstance(a, exp.Interval) and _is_date_literal(b): 570 a, b = extract_interval(a), extract_date(b) 571 # you cannot subtract a date from an interval 572 if a and b and isinstance(expression, exp.Add): 573 return date_literal(a + b) 574 elif _is_date_literal(a) and _is_date_literal(b): 575 if isinstance(expression, exp.Predicate): 576 a, b = extract_date(a), extract_date(b) 577 boolean = eval_boolean(expression, a, b) 578 if boolean: 579 return boolean 580 581 return None 582 583 584def simplify_parens(expression): 585 if not isinstance(expression, exp.Paren): 586 return expression 587 588 this = expression.this 589 parent = expression.parent 590 591 if not isinstance(this, exp.Select) and ( 592 not isinstance(parent, (exp.Condition, exp.Binary)) 593 or isinstance(parent, exp.Paren) 594 or not isinstance(this, exp.Binary) 595 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 596 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 597 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 598 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 599 ): 600 return this 601 return expression 602 603 604NONNULL_CONSTANTS = ( 605 exp.Literal, 606 exp.Boolean, 607) 608 609CONSTANTS = ( 610 exp.Literal, 611 exp.Boolean, 612 exp.Null, 613) 614 615 616def _is_nonnull_constant(expression: exp.Expression) -> bool: 617 return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) 618 619 620def _is_constant(expression: exp.Expression) -> bool: 621 return isinstance(expression, CONSTANTS) or _is_date_literal(expression) 622 623 624def simplify_coalesce(expression): 625 # COALESCE(x) -> x 626 if ( 627 isinstance(expression, exp.Coalesce) 628 and (not expression.expressions or _is_nonnull_constant(expression.this)) 629 # COALESCE is also used as a Spark partitioning hint 630 and not isinstance(expression.parent, exp.Hint) 631 ): 632 return expression.this 633 634 if not isinstance(expression, COMPARISONS): 635 return expression 636 637 if isinstance(expression.left, exp.Coalesce): 638 coalesce = expression.left 639 other = expression.right 640 elif isinstance(expression.right, exp.Coalesce): 641 coalesce = expression.right 642 other = expression.left 643 else: 644 return expression 645 646 # This transformation is valid for non-constants, 647 # but it really only does anything if they are both constants. 648 if not _is_constant(other): 649 return expression 650 651 # Find the first constant arg 652 for arg_index, arg in enumerate(coalesce.expressions): 653 if _is_constant(other): 654 break 655 else: 656 return expression 657 658 coalesce.set("expressions", coalesce.expressions[:arg_index]) 659 660 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 661 # since we already remove COALESCE at the top of this function. 662 coalesce = coalesce if coalesce.expressions else coalesce.this 663 664 # This expression is more complex than when we started, but it will get simplified further 665 return exp.paren( 666 exp.or_( 667 exp.and_( 668 coalesce.is_(exp.null()).not_(copy=False), 669 expression.copy(), 670 copy=False, 671 ), 672 exp.and_( 673 coalesce.is_(exp.null()), 674 type(expression)(this=arg.copy(), expression=other.copy()), 675 copy=False, 676 ), 677 copy=False, 678 ) 679 ) 680 681 682CONCATS = (exp.Concat, exp.DPipe) 683 684 685def simplify_concat(expression): 686 """Reduces all groups that contain string literals by concatenating them.""" 687 if not isinstance(expression, CONCATS) or ( 688 # We can't reduce a CONCAT_WS call if we don't statically know the separator 689 isinstance(expression, exp.ConcatWs) 690 and not expression.expressions[0].is_string 691 ): 692 return expression 693 694 if isinstance(expression, exp.ConcatWs): 695 sep_expr, *expressions = expression.expressions 696 sep = sep_expr.name 697 concat_type = exp.ConcatWs 698 args = {} 699 else: 700 expressions = expression.expressions 701 sep = "" 702 concat_type = exp.Concat 703 args = { 704 "safe": expression.args.get("safe"), 705 "coalesce": expression.args.get("coalesce"), 706 } 707 708 new_args = [] 709 for is_string_group, group in itertools.groupby( 710 expressions or expression.flatten(), lambda e: e.is_string 711 ): 712 if is_string_group: 713 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 714 else: 715 new_args.extend(group) 716 717 if len(new_args) == 1 and new_args[0].is_string: 718 return new_args[0] 719 720 if concat_type is exp.ConcatWs: 721 new_args = [sep_expr] + new_args 722 723 return concat_type(expressions=new_args, **args) 724 725 726def simplify_conditionals(expression): 727 """Simplifies expressions like IF, CASE if their condition is statically known.""" 728 if isinstance(expression, exp.Case): 729 this = expression.this 730 for case in expression.args["ifs"]: 731 cond = case.this 732 if this: 733 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 734 cond = cond.replace(this.pop().eq(cond)) 735 736 if always_true(cond): 737 return case.args["true"] 738 739 if always_false(cond): 740 case.pop() 741 if not expression.args["ifs"]: 742 return expression.args.get("default") or exp.null() 743 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 744 if always_true(expression.this): 745 return expression.args["true"] 746 if always_false(expression.this): 747 return expression.args.get("false") or exp.null() 748 749 return expression 750 751 752DateRange = t.Tuple[datetime.date, datetime.date] 753 754 755def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 756 """ 757 Get the date range for a DATE_TRUNC equality comparison: 758 759 Example: 760 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 761 Returns: 762 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 763 """ 764 floor = date_floor(date, unit) 765 766 if date != floor: 767 # This will always be False, except for NULL values. 768 return None 769 770 return floor, floor + interval(unit) 771 772 773def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 774 """Get the logical expression for a date range""" 775 return exp.and_( 776 left >= date_literal(drange[0]), 777 left < date_literal(drange[1]), 778 copy=False, 779 ) 780 781 782def _datetrunc_eq( 783 left: exp.Expression, date: datetime.date, unit: str 784) -> t.Optional[exp.Expression]: 785 drange = _datetrunc_range(date, unit) 786 if not drange: 787 return None 788 789 return _datetrunc_eq_expression(left, drange) 790 791 792def _datetrunc_neq( 793 left: exp.Expression, date: datetime.date, unit: str 794) -> t.Optional[exp.Expression]: 795 drange = _datetrunc_range(date, unit) 796 if not drange: 797 return None 798 799 return exp.and_( 800 left < date_literal(drange[0]), 801 left >= date_literal(drange[1]), 802 copy=False, 803 ) 804 805 806DateTruncBinaryTransform = t.Callable[ 807 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 808] 809DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 810 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 811 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 812 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 813 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 814 exp.EQ: _datetrunc_eq, 815 exp.NEQ: _datetrunc_neq, 816} 817DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 818 819 820def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 821 return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) 822 823 824@catch(ModuleNotFoundError, UnsupportedUnit) 825def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 826 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 827 comparison = expression.__class__ 828 829 if comparison not in DATETRUNC_COMPARISONS: 830 return expression 831 832 if isinstance(expression, exp.Binary): 833 l, r = expression.left, expression.right 834 835 if _is_datetrunc_predicate(l, r): 836 pass 837 elif _is_datetrunc_predicate(r, l): 838 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 839 l, r = r, l 840 else: 841 return expression 842 843 l = t.cast(exp.DateTrunc, l) 844 unit = l.unit.name.lower() 845 date = extract_date(r) 846 847 if not date: 848 return expression 849 850 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 851 elif isinstance(expression, exp.In): 852 l = expression.this 853 rs = expression.expressions 854 855 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 856 l = t.cast(exp.DateTrunc, l) 857 unit = l.unit.name.lower() 858 859 ranges = [] 860 for r in rs: 861 date = extract_date(r) 862 if not date: 863 return expression 864 drange = _datetrunc_range(date, unit) 865 if drange: 866 ranges.append(drange) 867 868 if not ranges: 869 return expression 870 871 ranges = merge_ranges(ranges) 872 873 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 874 875 return expression 876 877 878# CROSS joins result in an empty table if the right table is empty. 879# So we can only simplify certain types of joins to CROSS. 880# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 881JOINS = { 882 ("", ""), 883 ("", "INNER"), 884 ("RIGHT", ""), 885 ("RIGHT", "OUTER"), 886} 887 888 889def remove_where_true(expression): 890 for where in expression.find_all(exp.Where): 891 if always_true(where.this): 892 where.parent.set("where", None) 893 for join in expression.find_all(exp.Join): 894 if ( 895 always_true(join.args.get("on")) 896 and not join.args.get("using") 897 and not join.args.get("method") 898 and (join.side, join.kind) in JOINS 899 ): 900 join.set("on", None) 901 join.set("side", None) 902 join.set("kind", "CROSS") 903 904 905def always_true(expression): 906 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 907 expression, exp.Literal 908 ) 909 910 911def always_false(expression): 912 return is_false(expression) or is_null(expression) 913 914 915def is_complement(a, b): 916 return isinstance(b, exp.Not) and b.this == a 917 918 919def is_false(a: exp.Expression) -> bool: 920 return type(a) is exp.Boolean and not a.this 921 922 923def is_null(a: exp.Expression) -> bool: 924 return type(a) is exp.Null 925 926 927def eval_boolean(expression, a, b): 928 if isinstance(expression, (exp.EQ, exp.Is)): 929 return boolean_literal(a == b) 930 if isinstance(expression, exp.NEQ): 931 return boolean_literal(a != b) 932 if isinstance(expression, exp.GT): 933 return boolean_literal(a > b) 934 if isinstance(expression, exp.GTE): 935 return boolean_literal(a >= b) 936 if isinstance(expression, exp.LT): 937 return boolean_literal(a < b) 938 if isinstance(expression, exp.LTE): 939 return boolean_literal(a <= b) 940 return None 941 942 943def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 944 if isinstance(value, datetime.datetime): 945 return value.date() 946 if isinstance(value, datetime.date): 947 return value 948 try: 949 return datetime.datetime.fromisoformat(value).date() 950 except ValueError: 951 return None 952 953 954def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 955 if isinstance(value, datetime.datetime): 956 return value 957 if isinstance(value, datetime.date): 958 return datetime.datetime(year=value.year, month=value.month, day=value.day) 959 try: 960 return datetime.datetime.fromisoformat(value) 961 except ValueError: 962 return None 963 964 965def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 966 if not value: 967 return None 968 if to.is_type(exp.DataType.Type.DATE): 969 return cast_as_date(value) 970 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 971 return cast_as_datetime(value) 972 return None 973 974 975def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 976 if isinstance(cast, exp.Cast): 977 to = cast.to 978 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 979 to = exp.DataType.build(exp.DataType.Type.DATE) 980 else: 981 return None 982 983 if isinstance(cast.this, exp.Literal): 984 value: t.Any = cast.this.name 985 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 986 value = extract_date(cast.this) 987 else: 988 return None 989 return cast_value(value, to) 990 991 992def _is_date_literal(expression: exp.Expression) -> bool: 993 return extract_date(expression) is not None 994 995 996def extract_interval(expression): 997 try: 998 n = int(expression.name) 999 unit = expression.text("unit").lower() 1000 return interval(unit, n) 1001 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1002 return None 1003 1004 1005def date_literal(date): 1006 return exp.cast( 1007 exp.Literal.string(date), 1008 exp.DataType.Type.DATETIME 1009 if isinstance(date, datetime.datetime) 1010 else exp.DataType.Type.DATE, 1011 ) 1012 1013 1014def interval(unit: str, n: int = 1): 1015 from dateutil.relativedelta import relativedelta 1016 1017 if unit == "year": 1018 return relativedelta(years=1 * n) 1019 if unit == "quarter": 1020 return relativedelta(months=3 * n) 1021 if unit == "month": 1022 return relativedelta(months=1 * n) 1023 if unit == "week": 1024 return relativedelta(weeks=1 * n) 1025 if unit == "day": 1026 return relativedelta(days=1 * n) 1027 if unit == "hour": 1028 return relativedelta(hours=1 * n) 1029 if unit == "minute": 1030 return relativedelta(minutes=1 * n) 1031 if unit == "second": 1032 return relativedelta(seconds=1 * n) 1033 1034 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1035 1036 1037def date_floor(d: datetime.date, unit: str) -> datetime.date: 1038 if unit == "year": 1039 return d.replace(month=1, day=1) 1040 if unit == "quarter": 1041 if d.month <= 3: 1042 return d.replace(month=1, day=1) 1043 elif d.month <= 6: 1044 return d.replace(month=4, day=1) 1045 elif d.month <= 9: 1046 return d.replace(month=7, day=1) 1047 else: 1048 return d.replace(month=10, day=1) 1049 if unit == "month": 1050 return d.replace(month=d.month, day=1) 1051 if unit == "week": 1052 # Assuming week starts on Monday (0) and ends on Sunday (6) 1053 return d - datetime.timedelta(days=d.weekday()) 1054 if unit == "day": 1055 return d 1056 1057 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1058 1059 1060def date_ceil(d: datetime.date, unit: str) -> datetime.date: 1061 floor = date_floor(d, unit) 1062 1063 if floor == d: 1064 return d 1065 1066 return floor + interval(unit) 1067 1068 1069def boolean_literal(condition): 1070 return exp.true() if condition else exp.false() 1071 1072 1073def _flat_simplify(expression, simplifier, root=True): 1074 if root or not expression.same_parent: 1075 operands = [] 1076 queue = deque(expression.flatten(unnest=False)) 1077 size = len(queue) 1078 1079 while queue: 1080 a = queue.popleft() 1081 1082 for b in queue: 1083 result = simplifier(expression, a, b) 1084 1085 if result and result is not expression: 1086 queue.remove(b) 1087 queue.appendleft(result) 1088 break 1089 else: 1090 operands.append(a) 1091 1092 if len(operands) < size: 1093 return functools.reduce( 1094 lambda a, b: expression.__class__(this=a, expression=b), operands 1095 ) 1096 return expression 1097 1098 1099def gen(expression: t.Any) -> str: 1100 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1101 1102 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1103 generator is expensive so we have a bare minimum sql generator here. 1104 """ 1105 if expression is None: 1106 return "_" 1107 if is_iterable(expression): 1108 return ",".join(gen(e) for e in expression) 1109 if not isinstance(expression, exp.Expression): 1110 return str(expression) 1111 1112 etype = type(expression) 1113 if etype in GEN_MAP: 1114 return GEN_MAP[etype](expression) 1115 return f"{expression.key} {gen(expression.args.values())}" 1116 1117 1118GEN_MAP = { 1119 exp.Add: lambda e: _binary(e, "+"), 1120 exp.And: lambda e: _binary(e, "AND"), 1121 exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}", 1122 exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", 1123 exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", 1124 exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", 1125 exp.Column: lambda e: ".".join(gen(p) for p in e.parts), 1126 exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", 1127 exp.Div: lambda e: _binary(e, "/"), 1128 exp.Dot: lambda e: _binary(e, "."), 1129 exp.EQ: lambda e: _binary(e, "="), 1130 exp.GT: lambda e: _binary(e, ">"), 1131 exp.GTE: lambda e: _binary(e, ">="), 1132 exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, 1133 exp.ILike: lambda e: _binary(e, "ILIKE"), 1134 exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", 1135 exp.Is: lambda e: _binary(e, "IS"), 1136 exp.Like: lambda e: _binary(e, "LIKE"), 1137 exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, 1138 exp.LT: lambda e: _binary(e, "<"), 1139 exp.LTE: lambda e: _binary(e, "<="), 1140 exp.Mod: lambda e: _binary(e, "%"), 1141 exp.Mul: lambda e: _binary(e, "*"), 1142 exp.Neg: lambda e: _unary(e, "-"), 1143 exp.NEQ: lambda e: _binary(e, "<>"), 1144 exp.Not: lambda e: _unary(e, "NOT"), 1145 exp.Null: lambda e: "NULL", 1146 exp.Or: lambda e: _binary(e, "OR"), 1147 exp.Paren: lambda e: f"({gen(e.this)})", 1148 exp.Sub: lambda e: _binary(e, "-"), 1149 exp.Subquery: lambda e: f"({gen(e.args.values())})", 1150 exp.Table: lambda e: gen(e.args.values()), 1151 exp.Var: lambda e: e.name, 1152} 1153 1154 1155def _binary(e: exp.Binary, op: str) -> str: 1156 return f"{gen(e.left)} {op} {gen(e.right)}" 1157 1158 1159def _unary(e: exp.Unary, op: str) -> str: 1160 return f"{op} {gen(e.this)}"
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
22def simplify(expression, constant_propagation=False): 23 """ 24 Rewrite sqlglot AST to simplify expressions. 25 26 Example: 27 >>> import sqlglot 28 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 29 >>> simplify(expression).sql() 30 'TRUE' 31 32 Args: 33 expression (sqlglot.Expression): expression to simplify 34 constant_propagation: whether or not the constant propagation rule should be used 35 36 Returns: 37 sqlglot.Expression: simplified expression 38 """ 39 40 # group by expressions cannot be simplified, for example 41 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 42 # the projection must exactly match the group by key 43 for group in expression.find_all(exp.Group): 44 select = group.parent 45 groups = set(group.expressions) 46 group.meta[FINAL] = True 47 48 for e in select.selects: 49 for node, *_ in e.walk(): 50 if node in groups: 51 e.meta[FINAL] = True 52 break 53 54 having = select.args.get("having") 55 if having: 56 for node, *_ in having.walk(): 57 if node in groups: 58 having.meta[FINAL] = True 59 break 60 61 def _simplify(expression, root=True): 62 if expression.meta.get(FINAL): 63 return expression 64 65 # Pre-order transformations 66 node = expression 67 node = rewrite_between(node) 68 node = uniq_sort(node, root) 69 node = absorb_and_eliminate(node, root) 70 node = simplify_concat(node) 71 node = simplify_conditionals(node) 72 73 if constant_propagation: 74 node = propagate_constants(node, root) 75 76 exp.replace_children(node, lambda e: _simplify(e, False)) 77 78 # Post-order transformations 79 node = simplify_not(node) 80 node = flatten(node) 81 node = simplify_connectors(node, root) 82 node = remove_complements(node, root) 83 node = simplify_coalesce(node) 84 node.parent = expression.parent 85 node = simplify_literals(node, root) 86 node = simplify_equality(node) 87 node = simplify_parens(node) 88 node = simplify_datetrunc_predicate(node) 89 90 if root: 91 expression.replace(node) 92 93 return node 94 95 expression = while_changing(expression, _simplify) 96 remove_where_true(expression) 97 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
100def catch(*exceptions): 101 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 102 103 def decorator(func): 104 def wrapped(expression, *args, **kwargs): 105 try: 106 return func(expression, *args, **kwargs) 107 except exceptions: 108 return expression 109 110 return wrapped 111 112 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
115def rewrite_between(expression: exp.Expression) -> exp.Expression: 116 """Rewrite x between y and z to x >= y AND x <= z. 117 118 This is done because comparison simplification is only done on lt/lte/gt/gte. 119 """ 120 if isinstance(expression, exp.Between): 121 return exp.and_( 122 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 123 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 124 copy=False, 125 ) 126 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
129def simplify_not(expression): 130 """ 131 Demorgan's Law 132 NOT (x OR y) -> NOT x AND NOT y 133 NOT (x AND y) -> NOT x OR NOT y 134 """ 135 if isinstance(expression, exp.Not): 136 if is_null(expression.this): 137 return exp.null() 138 if isinstance(expression.this, exp.Paren): 139 condition = expression.this.unnest() 140 if isinstance(condition, exp.And): 141 return exp.or_( 142 exp.not_(condition.left, copy=False), 143 exp.not_(condition.right, copy=False), 144 copy=False, 145 ) 146 if isinstance(condition, exp.Or): 147 return exp.and_( 148 exp.not_(condition.left, copy=False), 149 exp.not_(condition.right, copy=False), 150 copy=False, 151 ) 152 if is_null(condition): 153 return exp.null() 154 if always_true(expression.this): 155 return exp.false() 156 if is_false(expression.this): 157 return exp.true() 158 if isinstance(expression.this, exp.Not): 159 # double negation 160 # NOT NOT x -> x 161 return expression.this.this 162 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
165def flatten(expression): 166 """ 167 A AND (B AND C) -> A AND B AND C 168 A OR (B OR C) -> A OR B OR C 169 """ 170 if isinstance(expression, exp.Connector): 171 for node in expression.args.values(): 172 child = node.unnest() 173 if isinstance(child, expression.__class__): 174 node.replace(child) 175 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
178def simplify_connectors(expression, root=True): 179 def _simplify_connectors(expression, left, right): 180 if left == right: 181 return left 182 if isinstance(expression, exp.And): 183 if is_false(left) or is_false(right): 184 return exp.false() 185 if is_null(left) or is_null(right): 186 return exp.null() 187 if always_true(left) and always_true(right): 188 return exp.true() 189 if always_true(left): 190 return right 191 if always_true(right): 192 return left 193 return _simplify_comparison(expression, left, right) 194 elif isinstance(expression, exp.Or): 195 if always_true(left) or always_true(right): 196 return exp.true() 197 if is_false(left) and is_false(right): 198 return exp.false() 199 if ( 200 (is_null(left) and is_null(right)) 201 or (is_null(left) and is_false(right)) 202 or (is_false(left) and is_null(right)) 203 ): 204 return exp.null() 205 if is_false(left): 206 return right 207 if is_false(right): 208 return left 209 return _simplify_comparison(expression, left, right, or_=True) 210 211 if isinstance(expression, exp.Connector): 212 return _flat_simplify(expression, _simplify_connectors, root) 213 return expression
296def remove_complements(expression, root=True): 297 """ 298 Removing complements. 299 300 A AND NOT A -> FALSE 301 A OR NOT A -> TRUE 302 """ 303 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 304 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 305 306 for a, b in itertools.permutations(expression.flatten(), 2): 307 if is_complement(a, b): 308 return complement 309 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
312def uniq_sort(expression, root=True): 313 """ 314 Uniq and sort a connector. 315 316 C AND A AND B AND B -> A AND B AND C 317 """ 318 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 319 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 320 flattened = tuple(expression.flatten()) 321 deduped = {gen(e): e for e in flattened} 322 arr = tuple(deduped.items()) 323 324 # check if the operands are already sorted, if not sort them 325 # A AND C AND B -> A AND B AND C 326 for i, (sql, e) in enumerate(arr[1:]): 327 if sql < arr[i][0]: 328 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 329 break 330 else: 331 # we didn't have to sort but maybe we need to dedup 332 if len(deduped) < len(flattened): 333 expression = result_func(*deduped.values(), copy=False) 334 335 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
338def absorb_and_eliminate(expression, root=True): 339 """ 340 absorption: 341 A AND (A OR B) -> A 342 A OR (A AND B) -> A 343 A AND (NOT A OR B) -> A AND B 344 A OR (NOT A AND B) -> A OR B 345 elimination: 346 (A AND B) OR (A AND NOT B) -> A 347 (A OR B) AND (A OR NOT B) -> A 348 """ 349 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 350 kind = exp.Or if isinstance(expression, exp.And) else exp.And 351 352 for a, b in itertools.permutations(expression.flatten(), 2): 353 if isinstance(a, kind): 354 aa, ab = a.unnest_operands() 355 356 # absorb 357 if is_complement(b, aa): 358 aa.replace(exp.true() if kind == exp.And else exp.false()) 359 elif is_complement(b, ab): 360 ab.replace(exp.true() if kind == exp.And else exp.false()) 361 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 362 a.replace(exp.false() if kind == exp.And else exp.true()) 363 elif isinstance(b, kind): 364 # eliminate 365 rhs = b.unnest_operands() 366 ba, bb = rhs 367 368 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 369 a.replace(aa) 370 b.replace(aa) 371 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 372 a.replace(ab) 373 b.replace(ab) 374 375 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
378def propagate_constants(expression, root=True): 379 """ 380 Propagate constants for conjunctions in DNF: 381 382 SELECT * FROM t WHERE a = b AND b = 5 becomes 383 SELECT * FROM t WHERE a = 5 AND b = 5 384 385 Reference: https://www.sqlite.org/optoverview.html 386 """ 387 388 if ( 389 isinstance(expression, exp.And) 390 and (root or not expression.same_parent) 391 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 392 ): 393 constant_mapping = {} 394 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 395 if isinstance(expr, exp.EQ): 396 l, r = expr.left, expr.right 397 398 # TODO: create a helper that can be used to detect nested literal expressions such 399 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 400 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 401 pass 402 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 403 l, r = r, l 404 else: 405 continue 406 407 constant_mapping[l] = (id(l), r) 408 409 if constant_mapping: 410 for column in find_all_in_scope(expression, exp.Column): 411 parent = column.parent 412 column_id, constant = constant_mapping.get(column) or (None, None) 413 if ( 414 column_id is not None 415 and id(column) != column_id 416 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 417 ): 418 column.replace(constant.copy()) 419 420 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
104 def wrapped(expression, *args, **kwargs): 105 try: 106 return func(expression, *args, **kwargs) 107 except exceptions: 108 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
499def simplify_literals(expression, root=True): 500 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 501 return _flat_simplify(expression, _simplify_binary, root) 502 503 if isinstance(expression, exp.Neg): 504 this = expression.this 505 if this.is_number: 506 value = this.name 507 if value[0] == "-": 508 return exp.Literal.number(value[1:]) 509 return exp.Literal.number(f"-{value}") 510 511 if type(expression) in INVERSE_DATE_OPS: 512 return _simplify_binary(expression, expression.this, expression.interval()) or expression 513 514 return expression
585def simplify_parens(expression): 586 if not isinstance(expression, exp.Paren): 587 return expression 588 589 this = expression.this 590 parent = expression.parent 591 592 if not isinstance(this, exp.Select) and ( 593 not isinstance(parent, (exp.Condition, exp.Binary)) 594 or isinstance(parent, exp.Paren) 595 or not isinstance(this, exp.Binary) 596 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 597 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 598 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 599 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 600 ): 601 return this 602 return expression
625def simplify_coalesce(expression): 626 # COALESCE(x) -> x 627 if ( 628 isinstance(expression, exp.Coalesce) 629 and (not expression.expressions or _is_nonnull_constant(expression.this)) 630 # COALESCE is also used as a Spark partitioning hint 631 and not isinstance(expression.parent, exp.Hint) 632 ): 633 return expression.this 634 635 if not isinstance(expression, COMPARISONS): 636 return expression 637 638 if isinstance(expression.left, exp.Coalesce): 639 coalesce = expression.left 640 other = expression.right 641 elif isinstance(expression.right, exp.Coalesce): 642 coalesce = expression.right 643 other = expression.left 644 else: 645 return expression 646 647 # This transformation is valid for non-constants, 648 # but it really only does anything if they are both constants. 649 if not _is_constant(other): 650 return expression 651 652 # Find the first constant arg 653 for arg_index, arg in enumerate(coalesce.expressions): 654 if _is_constant(other): 655 break 656 else: 657 return expression 658 659 coalesce.set("expressions", coalesce.expressions[:arg_index]) 660 661 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 662 # since we already remove COALESCE at the top of this function. 663 coalesce = coalesce if coalesce.expressions else coalesce.this 664 665 # This expression is more complex than when we started, but it will get simplified further 666 return exp.paren( 667 exp.or_( 668 exp.and_( 669 coalesce.is_(exp.null()).not_(copy=False), 670 expression.copy(), 671 copy=False, 672 ), 673 exp.and_( 674 coalesce.is_(exp.null()), 675 type(expression)(this=arg.copy(), expression=other.copy()), 676 copy=False, 677 ), 678 copy=False, 679 ) 680 )
686def simplify_concat(expression): 687 """Reduces all groups that contain string literals by concatenating them.""" 688 if not isinstance(expression, CONCATS) or ( 689 # We can't reduce a CONCAT_WS call if we don't statically know the separator 690 isinstance(expression, exp.ConcatWs) 691 and not expression.expressions[0].is_string 692 ): 693 return expression 694 695 if isinstance(expression, exp.ConcatWs): 696 sep_expr, *expressions = expression.expressions 697 sep = sep_expr.name 698 concat_type = exp.ConcatWs 699 args = {} 700 else: 701 expressions = expression.expressions 702 sep = "" 703 concat_type = exp.Concat 704 args = { 705 "safe": expression.args.get("safe"), 706 "coalesce": expression.args.get("coalesce"), 707 } 708 709 new_args = [] 710 for is_string_group, group in itertools.groupby( 711 expressions or expression.flatten(), lambda e: e.is_string 712 ): 713 if is_string_group: 714 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 715 else: 716 new_args.extend(group) 717 718 if len(new_args) == 1 and new_args[0].is_string: 719 return new_args[0] 720 721 if concat_type is exp.ConcatWs: 722 new_args = [sep_expr] + new_args 723 724 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
727def simplify_conditionals(expression): 728 """Simplifies expressions like IF, CASE if their condition is statically known.""" 729 if isinstance(expression, exp.Case): 730 this = expression.this 731 for case in expression.args["ifs"]: 732 cond = case.this 733 if this: 734 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 735 cond = cond.replace(this.pop().eq(cond)) 736 737 if always_true(cond): 738 return case.args["true"] 739 740 if always_false(cond): 741 case.pop() 742 if not expression.args["ifs"]: 743 return expression.args.get("default") or exp.null() 744 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 745 if always_true(expression.this): 746 return expression.args["true"] 747 if always_false(expression.this): 748 return expression.args.get("false") or exp.null() 749 750 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
104 def wrapped(expression, *args, **kwargs): 105 try: 106 return func(expression, *args, **kwargs) 107 except exceptions: 108 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
890def remove_where_true(expression): 891 for where in expression.find_all(exp.Where): 892 if always_true(where.this): 893 where.parent.set("where", None) 894 for join in expression.find_all(exp.Join): 895 if ( 896 always_true(join.args.get("on")) 897 and not join.args.get("using") 898 and not join.args.get("method") 899 and (join.side, join.kind) in JOINS 900 ): 901 join.set("on", None) 902 join.set("side", None) 903 join.set("kind", "CROSS")
928def eval_boolean(expression, a, b): 929 if isinstance(expression, (exp.EQ, exp.Is)): 930 return boolean_literal(a == b) 931 if isinstance(expression, exp.NEQ): 932 return boolean_literal(a != b) 933 if isinstance(expression, exp.GT): 934 return boolean_literal(a > b) 935 if isinstance(expression, exp.GTE): 936 return boolean_literal(a >= b) 937 if isinstance(expression, exp.LT): 938 return boolean_literal(a < b) 939 if isinstance(expression, exp.LTE): 940 return boolean_literal(a <= b) 941 return None
955def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 956 if isinstance(value, datetime.datetime): 957 return value 958 if isinstance(value, datetime.date): 959 return datetime.datetime(year=value.year, month=value.month, day=value.day) 960 try: 961 return datetime.datetime.fromisoformat(value) 962 except ValueError: 963 return None
966def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 967 if not value: 968 return None 969 if to.is_type(exp.DataType.Type.DATE): 970 return cast_as_date(value) 971 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 972 return cast_as_datetime(value) 973 return None
976def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 977 if isinstance(cast, exp.Cast): 978 to = cast.to 979 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 980 to = exp.DataType.build(exp.DataType.Type.DATE) 981 else: 982 return None 983 984 if isinstance(cast.this, exp.Literal): 985 value: t.Any = cast.this.name 986 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 987 value = extract_date(cast.this) 988 else: 989 return None 990 return cast_value(value, to)
1015def interval(unit: str, n: int = 1): 1016 from dateutil.relativedelta import relativedelta 1017 1018 if unit == "year": 1019 return relativedelta(years=1 * n) 1020 if unit == "quarter": 1021 return relativedelta(months=3 * n) 1022 if unit == "month": 1023 return relativedelta(months=1 * n) 1024 if unit == "week": 1025 return relativedelta(weeks=1 * n) 1026 if unit == "day": 1027 return relativedelta(days=1 * n) 1028 if unit == "hour": 1029 return relativedelta(hours=1 * n) 1030 if unit == "minute": 1031 return relativedelta(minutes=1 * n) 1032 if unit == "second": 1033 return relativedelta(seconds=1 * n) 1034 1035 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1038def date_floor(d: datetime.date, unit: str) -> datetime.date: 1039 if unit == "year": 1040 return d.replace(month=1, day=1) 1041 if unit == "quarter": 1042 if d.month <= 3: 1043 return d.replace(month=1, day=1) 1044 elif d.month <= 6: 1045 return d.replace(month=4, day=1) 1046 elif d.month <= 9: 1047 return d.replace(month=7, day=1) 1048 else: 1049 return d.replace(month=10, day=1) 1050 if unit == "month": 1051 return d.replace(month=d.month, day=1) 1052 if unit == "week": 1053 # Assuming week starts on Monday (0) and ends on Sunday (6) 1054 return d - datetime.timedelta(days=d.weekday()) 1055 if unit == "day": 1056 return d 1057 1058 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1100def gen(expression: t.Any) -> str: 1101 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1102 1103 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1104 generator is expensive so we have a bare minimum sql generator here. 1105 """ 1106 if expression is None: 1107 return "_" 1108 if is_iterable(expression): 1109 return ",".join(gen(e) for e in expression) 1110 if not isinstance(expression, exp.Expression): 1111 return str(expression) 1112 1113 etype = type(expression) 1114 if etype in GEN_MAP: 1115 return GEN_MAP[etype](expression) 1116 return f"{expression.key} {gen(expression.args.values())}"
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.