sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8from sqlglot import exp 9from sqlglot.generator import cached_generator 10from sqlglot.helper import first, merge_ranges, while_changing 11 12# Final means that an expression should not be simplified 13FINAL = "final" 14 15 16class UnsupportedUnit(Exception): 17 pass 18 19 20def simplify(expression): 21 """ 22 Rewrite sqlglot AST to simplify expressions. 23 24 Example: 25 >>> import sqlglot 26 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 27 >>> simplify(expression).sql() 28 'TRUE' 29 30 Args: 31 expression (sqlglot.Expression): expression to simplify 32 Returns: 33 sqlglot.Expression: simplified expression 34 """ 35 36 generate = cached_generator() 37 38 # group by expressions cannot be simplified, for example 39 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 40 # the projection must exactly match the group by key 41 for group in expression.find_all(exp.Group): 42 select = group.parent 43 groups = set(group.expressions) 44 group.meta[FINAL] = True 45 46 for e in select.selects: 47 for node, *_ in e.walk(): 48 if node in groups: 49 e.meta[FINAL] = True 50 break 51 52 having = select.args.get("having") 53 if having: 54 for node, *_ in having.walk(): 55 if node in groups: 56 having.meta[FINAL] = True 57 break 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # Pre-order transformations 64 node = expression 65 node = rewrite_between(node) 66 node = uniq_sort(node, generate, root) 67 node = absorb_and_eliminate(node, root) 68 node = simplify_concat(node) 69 70 exp.replace_children(node, lambda e: _simplify(e, False)) 71 72 # Post-order transformations 73 node = simplify_not(node) 74 node = flatten(node) 75 node = simplify_connectors(node, root) 76 node = remove_compliments(node, root) 77 node = simplify_coalesce(node) 78 node.parent = expression.parent 79 node = simplify_literals(node, root) 80 node = simplify_equality(node) 81 node = simplify_parens(node) 82 node = simplify_datetrunc_predicate(node) 83 84 if root: 85 expression.replace(node) 86 87 return node 88 89 expression = while_changing(expression, _simplify) 90 remove_where_true(expression) 91 return expression 92 93 94def catch(*exceptions): 95 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 96 97 def decorator(func): 98 def wrapped(expression, *args, **kwargs): 99 try: 100 return func(expression, *args, **kwargs) 101 except exceptions: 102 return expression 103 104 return wrapped 105 106 return decorator 107 108 109def rewrite_between(expression: exp.Expression) -> exp.Expression: 110 """Rewrite x between y and z to x >= y AND x <= z. 111 112 This is done because comparison simplification is only done on lt/lte/gt/gte. 113 """ 114 if isinstance(expression, exp.Between): 115 return exp.and_( 116 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 117 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 118 copy=False, 119 ) 120 return expression 121 122 123def simplify_not(expression): 124 """ 125 Demorgan's Law 126 NOT (x OR y) -> NOT x AND NOT y 127 NOT (x AND y) -> NOT x OR NOT y 128 """ 129 if isinstance(expression, exp.Not): 130 if is_null(expression.this): 131 return exp.null() 132 if isinstance(expression.this, exp.Paren): 133 condition = expression.this.unnest() 134 if isinstance(condition, exp.And): 135 return exp.or_( 136 exp.not_(condition.left, copy=False), 137 exp.not_(condition.right, copy=False), 138 copy=False, 139 ) 140 if isinstance(condition, exp.Or): 141 return exp.and_( 142 exp.not_(condition.left, copy=False), 143 exp.not_(condition.right, copy=False), 144 copy=False, 145 ) 146 if is_null(condition): 147 return exp.null() 148 if always_true(expression.this): 149 return exp.false() 150 if is_false(expression.this): 151 return exp.true() 152 if isinstance(expression.this, exp.Not): 153 # double negation 154 # NOT NOT x -> x 155 return expression.this.this 156 return expression 157 158 159def flatten(expression): 160 """ 161 A AND (B AND C) -> A AND B AND C 162 A OR (B OR C) -> A OR B OR C 163 """ 164 if isinstance(expression, exp.Connector): 165 for node in expression.args.values(): 166 child = node.unnest() 167 if isinstance(child, expression.__class__): 168 node.replace(child) 169 return expression 170 171 172def simplify_connectors(expression, root=True): 173 def _simplify_connectors(expression, left, right): 174 if left == right: 175 return left 176 if isinstance(expression, exp.And): 177 if is_false(left) or is_false(right): 178 return exp.false() 179 if is_null(left) or is_null(right): 180 return exp.null() 181 if always_true(left) and always_true(right): 182 return exp.true() 183 if always_true(left): 184 return right 185 if always_true(right): 186 return left 187 return _simplify_comparison(expression, left, right) 188 elif isinstance(expression, exp.Or): 189 if always_true(left) or always_true(right): 190 return exp.true() 191 if is_false(left) and is_false(right): 192 return exp.false() 193 if ( 194 (is_null(left) and is_null(right)) 195 or (is_null(left) and is_false(right)) 196 or (is_false(left) and is_null(right)) 197 ): 198 return exp.null() 199 if is_false(left): 200 return right 201 if is_false(right): 202 return left 203 return _simplify_comparison(expression, left, right, or_=True) 204 205 if isinstance(expression, exp.Connector): 206 return _flat_simplify(expression, _simplify_connectors, root) 207 return expression 208 209 210LT_LTE = (exp.LT, exp.LTE) 211GT_GTE = (exp.GT, exp.GTE) 212 213COMPARISONS = ( 214 *LT_LTE, 215 *GT_GTE, 216 exp.EQ, 217 exp.NEQ, 218 exp.Is, 219) 220 221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 222 exp.LT: exp.GT, 223 exp.GT: exp.LT, 224 exp.LTE: exp.GTE, 225 exp.GTE: exp.LTE, 226} 227 228 229def _simplify_comparison(expression, left, right, or_=False): 230 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 231 ll, lr = left.args.values() 232 rl, rr = right.args.values() 233 234 largs = {ll, lr} 235 rargs = {rl, rr} 236 237 matching = largs & rargs 238 columns = {m for m in matching if isinstance(m, exp.Column)} 239 240 if matching and columns: 241 try: 242 l = first(largs - columns) 243 r = first(rargs - columns) 244 except StopIteration: 245 return expression 246 247 # make sure the comparison is always of the form x > 1 instead of 1 < x 248 if left.__class__ in INVERSE_COMPARISONS and l == ll: 249 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 250 if right.__class__ in INVERSE_COMPARISONS and r == rl: 251 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 252 253 if l.is_number and r.is_number: 254 l = float(l.name) 255 r = float(r.name) 256 elif l.is_string and r.is_string: 257 l = l.name 258 r = r.name 259 else: 260 return None 261 262 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 263 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 264 return left if (av > bv if or_ else av <= bv) else right 265 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 266 return left if (av < bv if or_ else av >= bv) else right 267 268 # we can't ever shortcut to true because the column could be null 269 if not or_: 270 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 271 if av <= bv: 272 return exp.false() 273 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 274 if av >= bv: 275 return exp.false() 276 elif isinstance(a, exp.EQ): 277 if isinstance(b, exp.LT): 278 return exp.false() if av >= bv else a 279 if isinstance(b, exp.LTE): 280 return exp.false() if av > bv else a 281 if isinstance(b, exp.GT): 282 return exp.false() if av <= bv else a 283 if isinstance(b, exp.GTE): 284 return exp.false() if av < bv else a 285 if isinstance(b, exp.NEQ): 286 return exp.false() if av == bv else a 287 return None 288 289 290def remove_compliments(expression, root=True): 291 """ 292 Removing compliments. 293 294 A AND NOT A -> FALSE 295 A OR NOT A -> TRUE 296 """ 297 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 298 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 299 300 for a, b in itertools.permutations(expression.flatten(), 2): 301 if is_complement(a, b): 302 return compliment 303 return expression 304 305 306def uniq_sort(expression, generate, root=True): 307 """ 308 Uniq and sort a connector. 309 310 C AND A AND B AND B -> A AND B AND C 311 """ 312 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 313 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 314 flattened = tuple(expression.flatten()) 315 deduped = {generate(e): e for e in flattened} 316 arr = tuple(deduped.items()) 317 318 # check if the operands are already sorted, if not sort them 319 # A AND C AND B -> A AND B AND C 320 for i, (sql, e) in enumerate(arr[1:]): 321 if sql < arr[i][0]: 322 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 323 break 324 else: 325 # we didn't have to sort but maybe we need to dedup 326 if len(deduped) < len(flattened): 327 expression = result_func(*deduped.values(), copy=False) 328 329 return expression 330 331 332def absorb_and_eliminate(expression, root=True): 333 """ 334 absorption: 335 A AND (A OR B) -> A 336 A OR (A AND B) -> A 337 A AND (NOT A OR B) -> A AND B 338 A OR (NOT A AND B) -> A OR B 339 elimination: 340 (A AND B) OR (A AND NOT B) -> A 341 (A OR B) AND (A OR NOT B) -> A 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 kind = exp.Or if isinstance(expression, exp.And) else exp.And 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if isinstance(a, kind): 348 aa, ab = a.unnest_operands() 349 350 # absorb 351 if is_complement(b, aa): 352 aa.replace(exp.true() if kind == exp.And else exp.false()) 353 elif is_complement(b, ab): 354 ab.replace(exp.true() if kind == exp.And else exp.false()) 355 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 356 a.replace(exp.false() if kind == exp.And else exp.true()) 357 elif isinstance(b, kind): 358 # eliminate 359 rhs = b.unnest_operands() 360 ba, bb = rhs 361 362 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 363 a.replace(aa) 364 b.replace(aa) 365 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 366 a.replace(ab) 367 b.replace(ab) 368 369 return expression 370 371 372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 373 exp.DateAdd: exp.Sub, 374 exp.DateSub: exp.Add, 375 exp.DatetimeAdd: exp.Sub, 376 exp.DatetimeSub: exp.Add, 377} 378 379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 380 **INVERSE_DATE_OPS, 381 exp.Add: exp.Sub, 382 exp.Sub: exp.Add, 383} 384 385 386def _is_number(expression: exp.Expression) -> bool: 387 return expression.is_number 388 389 390def _is_date(expression: exp.Expression) -> bool: 391 return isinstance(expression, exp.Cast) and extract_date(expression) is not None 392 393 394def _is_interval(expression: exp.Expression) -> bool: 395 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 396 397 398@catch(ModuleNotFoundError, UnsupportedUnit) 399def simplify_equality(expression: exp.Expression) -> exp.Expression: 400 """ 401 Use the subtraction and addition properties of equality to simplify expressions: 402 403 x + 1 = 3 becomes x = 2 404 405 There are two binary operations in the above expression: + and = 406 Here's how we reference all the operands in the code below: 407 408 l r 409 x + 1 = 3 410 a b 411 """ 412 if isinstance(expression, COMPARISONS): 413 l, r = expression.left, expression.right 414 415 if l.__class__ in INVERSE_OPS: 416 pass 417 elif r.__class__ in INVERSE_OPS: 418 l, r = r, l 419 else: 420 return expression 421 422 if r.is_number: 423 a_predicate = _is_number 424 b_predicate = _is_number 425 elif _is_date(r): 426 a_predicate = _is_date 427 b_predicate = _is_interval 428 else: 429 return expression 430 431 if l.__class__ in INVERSE_DATE_OPS: 432 a = l.this 433 b = exp.Interval( 434 this=l.expression.copy(), 435 unit=l.unit.copy(), 436 ) 437 else: 438 a, b = l.left, l.right 439 440 if not a_predicate(a) and b_predicate(b): 441 pass 442 elif not a_predicate(b) and b_predicate(a): 443 a, b = b, a 444 else: 445 return expression 446 447 return expression.__class__( 448 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 449 ) 450 return expression 451 452 453def simplify_literals(expression, root=True): 454 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 455 return _flat_simplify(expression, _simplify_binary, root) 456 457 if isinstance(expression, exp.Neg): 458 this = expression.this 459 if this.is_number: 460 value = this.name 461 if value[0] == "-": 462 return exp.Literal.number(value[1:]) 463 return exp.Literal.number(f"-{value}") 464 465 return expression 466 467 468def _simplify_binary(expression, a, b): 469 if isinstance(expression, exp.Is): 470 if isinstance(b, exp.Not): 471 c = b.this 472 not_ = True 473 else: 474 c = b 475 not_ = False 476 477 if is_null(c): 478 if isinstance(a, exp.Literal): 479 return exp.true() if not_ else exp.false() 480 if is_null(a): 481 return exp.false() if not_ else exp.true() 482 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 483 return None 484 elif is_null(a) or is_null(b): 485 return exp.null() 486 487 if a.is_number and b.is_number: 488 a = int(a.name) if a.is_int else Decimal(a.name) 489 b = int(b.name) if b.is_int else Decimal(b.name) 490 491 if isinstance(expression, exp.Add): 492 return exp.Literal.number(a + b) 493 if isinstance(expression, exp.Sub): 494 return exp.Literal.number(a - b) 495 if isinstance(expression, exp.Mul): 496 return exp.Literal.number(a * b) 497 if isinstance(expression, exp.Div): 498 # engines have differing int div behavior so intdiv is not safe 499 if isinstance(a, int) and isinstance(b, int): 500 return None 501 return exp.Literal.number(a / b) 502 503 boolean = eval_boolean(expression, a, b) 504 505 if boolean: 506 return boolean 507 elif a.is_string and b.is_string: 508 boolean = eval_boolean(expression, a.this, b.this) 509 510 if boolean: 511 return boolean 512 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 513 a, b = extract_date(a), extract_interval(b) 514 if a and b: 515 if isinstance(expression, exp.Add): 516 return date_literal(a + b) 517 if isinstance(expression, exp.Sub): 518 return date_literal(a - b) 519 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 520 a, b = extract_interval(a), extract_date(b) 521 # you cannot subtract a date from an interval 522 if a and b and isinstance(expression, exp.Add): 523 return date_literal(a + b) 524 525 return None 526 527 528def simplify_parens(expression): 529 if not isinstance(expression, exp.Paren): 530 return expression 531 532 this = expression.this 533 parent = expression.parent 534 535 if not isinstance(this, exp.Select) and ( 536 not isinstance(parent, (exp.Condition, exp.Binary)) 537 or isinstance(parent, exp.Paren) 538 or not isinstance(this, exp.Binary) 539 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 540 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 541 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 542 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 543 ): 544 return this 545 return expression 546 547 548CONSTANTS = ( 549 exp.Literal, 550 exp.Boolean, 551 exp.Null, 552) 553 554 555def simplify_coalesce(expression): 556 # COALESCE(x) -> x 557 if ( 558 isinstance(expression, exp.Coalesce) 559 and not expression.expressions 560 # COALESCE is also used as a Spark partitioning hint 561 and not isinstance(expression.parent, exp.Hint) 562 ): 563 return expression.this 564 565 if not isinstance(expression, COMPARISONS): 566 return expression 567 568 if isinstance(expression.left, exp.Coalesce): 569 coalesce = expression.left 570 other = expression.right 571 elif isinstance(expression.right, exp.Coalesce): 572 coalesce = expression.right 573 other = expression.left 574 else: 575 return expression 576 577 # This transformation is valid for non-constants, 578 # but it really only does anything if they are both constants. 579 if not isinstance(other, CONSTANTS): 580 return expression 581 582 # Find the first constant arg 583 for arg_index, arg in enumerate(coalesce.expressions): 584 if isinstance(arg, CONSTANTS): 585 break 586 else: 587 return expression 588 589 coalesce.set("expressions", coalesce.expressions[:arg_index]) 590 591 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 592 # since we already remove COALESCE at the top of this function. 593 coalesce = coalesce if coalesce.expressions else coalesce.this 594 595 # This expression is more complex than when we started, but it will get simplified further 596 return exp.paren( 597 exp.or_( 598 exp.and_( 599 coalesce.is_(exp.null()).not_(copy=False), 600 expression.copy(), 601 copy=False, 602 ), 603 exp.and_( 604 coalesce.is_(exp.null()), 605 type(expression)(this=arg.copy(), expression=other.copy()), 606 copy=False, 607 ), 608 copy=False, 609 ) 610 ) 611 612 613CONCATS = (exp.Concat, exp.DPipe) 614SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 615 616 617def simplify_concat(expression): 618 """Reduces all groups that contain string literals by concatenating them.""" 619 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 620 return expression 621 622 new_args = [] 623 for is_string_group, group in itertools.groupby( 624 expression.expressions or expression.flatten(), lambda e: e.is_string 625 ): 626 if is_string_group: 627 new_args.append(exp.Literal.string("".join(string.name for string in group))) 628 else: 629 new_args.extend(group) 630 631 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 632 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 633 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) 634 635 636DateRange = t.Tuple[datetime.date, datetime.date] 637 638 639def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 640 """ 641 Get the date range for a DATE_TRUNC equality comparison: 642 643 Example: 644 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 645 Returns: 646 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 647 """ 648 floor = date_floor(date, unit) 649 650 if date != floor: 651 # This will always be False, except for NULL values. 652 return None 653 654 return floor, floor + interval(unit) 655 656 657def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 658 """Get the logical expression for a date range""" 659 return exp.and_( 660 left >= date_literal(drange[0]), 661 left < date_literal(drange[1]), 662 copy=False, 663 ) 664 665 666def _datetrunc_eq( 667 left: exp.Expression, date: datetime.date, unit: str 668) -> t.Optional[exp.Expression]: 669 drange = _datetrunc_range(date, unit) 670 if not drange: 671 return None 672 673 return _datetrunc_eq_expression(left, drange) 674 675 676def _datetrunc_neq( 677 left: exp.Expression, date: datetime.date, unit: str 678) -> t.Optional[exp.Expression]: 679 drange = _datetrunc_range(date, unit) 680 if not drange: 681 return None 682 683 return exp.and_( 684 left < date_literal(drange[0]), 685 left >= date_literal(drange[1]), 686 copy=False, 687 ) 688 689 690DateTruncBinaryTransform = t.Callable[ 691 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 692] 693DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 694 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 695 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 696 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 697 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 698 exp.EQ: _datetrunc_eq, 699 exp.NEQ: _datetrunc_neq, 700} 701DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 702 703 704def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 705 return ( 706 isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) 707 and isinstance(right, exp.Cast) 708 and right.is_type(*exp.DataType.TEMPORAL_TYPES) 709 ) 710 711 712@catch(ModuleNotFoundError, UnsupportedUnit) 713def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 714 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 715 comparison = expression.__class__ 716 717 if comparison not in DATETRUNC_COMPARISONS: 718 return expression 719 720 if isinstance(expression, exp.Binary): 721 l, r = expression.left, expression.right 722 723 if _is_datetrunc_predicate(l, r): 724 pass 725 elif _is_datetrunc_predicate(r, l): 726 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 727 l, r = r, l 728 else: 729 return expression 730 731 unit = l.unit.name.lower() 732 date = extract_date(r) 733 734 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 735 elif isinstance(expression, exp.In): 736 l = expression.this 737 rs = expression.expressions 738 739 if all(_is_datetrunc_predicate(l, r) for r in rs): 740 unit = l.unit.name.lower() 741 742 ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r] 743 if not ranges: 744 return expression 745 746 ranges = merge_ranges(ranges) 747 748 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 749 750 return expression 751 752 753# CROSS joins result in an empty table if the right table is empty. 754# So we can only simplify certain types of joins to CROSS. 755# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 756JOINS = { 757 ("", ""), 758 ("", "INNER"), 759 ("RIGHT", ""), 760 ("RIGHT", "OUTER"), 761} 762 763 764def remove_where_true(expression): 765 for where in expression.find_all(exp.Where): 766 if always_true(where.this): 767 where.parent.set("where", None) 768 for join in expression.find_all(exp.Join): 769 if ( 770 always_true(join.args.get("on")) 771 and not join.args.get("using") 772 and not join.args.get("method") 773 and (join.side, join.kind) in JOINS 774 ): 775 join.set("on", None) 776 join.set("side", None) 777 join.set("kind", "CROSS") 778 779 780def always_true(expression): 781 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 782 expression, exp.Literal 783 ) 784 785 786def is_complement(a, b): 787 return isinstance(b, exp.Not) and b.this == a 788 789 790def is_false(a: exp.Expression) -> bool: 791 return type(a) is exp.Boolean and not a.this 792 793 794def is_null(a: exp.Expression) -> bool: 795 return type(a) is exp.Null 796 797 798def eval_boolean(expression, a, b): 799 if isinstance(expression, (exp.EQ, exp.Is)): 800 return boolean_literal(a == b) 801 if isinstance(expression, exp.NEQ): 802 return boolean_literal(a != b) 803 if isinstance(expression, exp.GT): 804 return boolean_literal(a > b) 805 if isinstance(expression, exp.GTE): 806 return boolean_literal(a >= b) 807 if isinstance(expression, exp.LT): 808 return boolean_literal(a < b) 809 if isinstance(expression, exp.LTE): 810 return boolean_literal(a <= b) 811 return None 812 813 814def extract_date(cast): 815 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 816 # so in that case we can't extract the date. 817 try: 818 if cast.args["to"].this == exp.DataType.Type.DATE: 819 return datetime.date.fromisoformat(cast.name) 820 if cast.args["to"].this == exp.DataType.Type.DATETIME: 821 return datetime.datetime.fromisoformat(cast.name) 822 except ValueError: 823 return None 824 825 826def extract_interval(expression): 827 n = int(expression.name) 828 unit = expression.text("unit").lower() 829 830 try: 831 return interval(unit, n) 832 except (UnsupportedUnit, ModuleNotFoundError): 833 return None 834 835 836def date_literal(date): 837 return exp.cast( 838 exp.Literal.string(date), 839 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 840 ) 841 842 843def interval(unit: str, n: int = 1): 844 from dateutil.relativedelta import relativedelta 845 846 if unit == "year": 847 return relativedelta(years=1 * n) 848 if unit == "quarter": 849 return relativedelta(months=3 * n) 850 if unit == "month": 851 return relativedelta(months=1 * n) 852 if unit == "week": 853 return relativedelta(weeks=1 * n) 854 if unit == "day": 855 return relativedelta(days=1 * n) 856 if unit == "hour": 857 return relativedelta(hours=1 * n) 858 if unit == "minute": 859 return relativedelta(minutes=1 * n) 860 if unit == "second": 861 return relativedelta(seconds=1 * n) 862 863 raise UnsupportedUnit(f"Unsupported unit: {unit}") 864 865 866def date_floor(d: datetime.date, unit: str) -> datetime.date: 867 if unit == "year": 868 return d.replace(month=1, day=1) 869 if unit == "quarter": 870 if d.month <= 3: 871 return d.replace(month=1, day=1) 872 elif d.month <= 6: 873 return d.replace(month=4, day=1) 874 elif d.month <= 9: 875 return d.replace(month=7, day=1) 876 else: 877 return d.replace(month=10, day=1) 878 if unit == "month": 879 return d.replace(month=d.month, day=1) 880 if unit == "week": 881 # Assuming week starts on Monday (0) and ends on Sunday (6) 882 return d - datetime.timedelta(days=d.weekday()) 883 if unit == "day": 884 return d 885 886 raise UnsupportedUnit(f"Unsupported unit: {unit}") 887 888 889def date_ceil(d: datetime.date, unit: str) -> datetime.date: 890 floor = date_floor(d, unit) 891 892 if floor == d: 893 return d 894 895 return floor + interval(unit) 896 897 898def boolean_literal(condition): 899 return exp.true() if condition else exp.false() 900 901 902def _flat_simplify(expression, simplifier, root=True): 903 if root or not expression.same_parent: 904 operands = [] 905 queue = deque(expression.flatten(unnest=False)) 906 size = len(queue) 907 908 while queue: 909 a = queue.popleft() 910 911 for b in queue: 912 result = simplifier(expression, a, b) 913 914 if result and result is not expression: 915 queue.remove(b) 916 queue.appendleft(result) 917 break 918 else: 919 operands.append(a) 920 921 if len(operands) < size: 922 return functools.reduce( 923 lambda a, b: expression.__class__(this=a, expression=b), operands 924 ) 925 return expression
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
21def simplify(expression): 22 """ 23 Rewrite sqlglot AST to simplify expressions. 24 25 Example: 26 >>> import sqlglot 27 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 28 >>> simplify(expression).sql() 29 'TRUE' 30 31 Args: 32 expression (sqlglot.Expression): expression to simplify 33 Returns: 34 sqlglot.Expression: simplified expression 35 """ 36 37 generate = cached_generator() 38 39 # group by expressions cannot be simplified, for example 40 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 41 # the projection must exactly match the group by key 42 for group in expression.find_all(exp.Group): 43 select = group.parent 44 groups = set(group.expressions) 45 group.meta[FINAL] = True 46 47 for e in select.selects: 48 for node, *_ in e.walk(): 49 if node in groups: 50 e.meta[FINAL] = True 51 break 52 53 having = select.args.get("having") 54 if having: 55 for node, *_ in having.walk(): 56 if node in groups: 57 having.meta[FINAL] = True 58 break 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # Pre-order transformations 65 node = expression 66 node = rewrite_between(node) 67 node = uniq_sort(node, generate, root) 68 node = absorb_and_eliminate(node, root) 69 node = simplify_concat(node) 70 71 exp.replace_children(node, lambda e: _simplify(e, False)) 72 73 # Post-order transformations 74 node = simplify_not(node) 75 node = flatten(node) 76 node = simplify_connectors(node, root) 77 node = remove_compliments(node, root) 78 node = simplify_coalesce(node) 79 node.parent = expression.parent 80 node = simplify_literals(node, root) 81 node = simplify_equality(node) 82 node = simplify_parens(node) 83 node = simplify_datetrunc_predicate(node) 84 85 if root: 86 expression.replace(node) 87 88 return node 89 90 expression = while_changing(expression, _simplify) 91 remove_where_true(expression) 92 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
95def catch(*exceptions): 96 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 97 98 def decorator(func): 99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression 104 105 return wrapped 106 107 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
110def rewrite_between(expression: exp.Expression) -> exp.Expression: 111 """Rewrite x between y and z to x >= y AND x <= z. 112 113 This is done because comparison simplification is only done on lt/lte/gt/gte. 114 """ 115 if isinstance(expression, exp.Between): 116 return exp.and_( 117 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 118 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 119 copy=False, 120 ) 121 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
124def simplify_not(expression): 125 """ 126 Demorgan's Law 127 NOT (x OR y) -> NOT x AND NOT y 128 NOT (x AND y) -> NOT x OR NOT y 129 """ 130 if isinstance(expression, exp.Not): 131 if is_null(expression.this): 132 return exp.null() 133 if isinstance(expression.this, exp.Paren): 134 condition = expression.this.unnest() 135 if isinstance(condition, exp.And): 136 return exp.or_( 137 exp.not_(condition.left, copy=False), 138 exp.not_(condition.right, copy=False), 139 copy=False, 140 ) 141 if isinstance(condition, exp.Or): 142 return exp.and_( 143 exp.not_(condition.left, copy=False), 144 exp.not_(condition.right, copy=False), 145 copy=False, 146 ) 147 if is_null(condition): 148 return exp.null() 149 if always_true(expression.this): 150 return exp.false() 151 if is_false(expression.this): 152 return exp.true() 153 if isinstance(expression.this, exp.Not): 154 # double negation 155 # NOT NOT x -> x 156 return expression.this.this 157 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
160def flatten(expression): 161 """ 162 A AND (B AND C) -> A AND B AND C 163 A OR (B OR C) -> A OR B OR C 164 """ 165 if isinstance(expression, exp.Connector): 166 for node in expression.args.values(): 167 child = node.unnest() 168 if isinstance(child, expression.__class__): 169 node.replace(child) 170 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
173def simplify_connectors(expression, root=True): 174 def _simplify_connectors(expression, left, right): 175 if left == right: 176 return left 177 if isinstance(expression, exp.And): 178 if is_false(left) or is_false(right): 179 return exp.false() 180 if is_null(left) or is_null(right): 181 return exp.null() 182 if always_true(left) and always_true(right): 183 return exp.true() 184 if always_true(left): 185 return right 186 if always_true(right): 187 return left 188 return _simplify_comparison(expression, left, right) 189 elif isinstance(expression, exp.Or): 190 if always_true(left) or always_true(right): 191 return exp.true() 192 if is_false(left) and is_false(right): 193 return exp.false() 194 if ( 195 (is_null(left) and is_null(right)) 196 or (is_null(left) and is_false(right)) 197 or (is_false(left) and is_null(right)) 198 ): 199 return exp.null() 200 if is_false(left): 201 return right 202 if is_false(right): 203 return left 204 return _simplify_comparison(expression, left, right, or_=True) 205 206 if isinstance(expression, exp.Connector): 207 return _flat_simplify(expression, _simplify_connectors, root) 208 return expression
291def remove_compliments(expression, root=True): 292 """ 293 Removing compliments. 294 295 A AND NOT A -> FALSE 296 A OR NOT A -> TRUE 297 """ 298 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 299 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 300 301 for a, b in itertools.permutations(expression.flatten(), 2): 302 if is_complement(a, b): 303 return compliment 304 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
307def uniq_sort(expression, generate, root=True): 308 """ 309 Uniq and sort a connector. 310 311 C AND A AND B AND B -> A AND B AND C 312 """ 313 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 314 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 315 flattened = tuple(expression.flatten()) 316 deduped = {generate(e): e for e in flattened} 317 arr = tuple(deduped.items()) 318 319 # check if the operands are already sorted, if not sort them 320 # A AND C AND B -> A AND B AND C 321 for i, (sql, e) in enumerate(arr[1:]): 322 if sql < arr[i][0]: 323 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 324 break 325 else: 326 # we didn't have to sort but maybe we need to dedup 327 if len(deduped) < len(flattened): 328 expression = result_func(*deduped.values(), copy=False) 329 330 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
333def absorb_and_eliminate(expression, root=True): 334 """ 335 absorption: 336 A AND (A OR B) -> A 337 A OR (A AND B) -> A 338 A AND (NOT A OR B) -> A AND B 339 A OR (NOT A AND B) -> A OR B 340 elimination: 341 (A AND B) OR (A AND NOT B) -> A 342 (A OR B) AND (A OR NOT B) -> A 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 kind = exp.Or if isinstance(expression, exp.And) else exp.And 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if isinstance(a, kind): 349 aa, ab = a.unnest_operands() 350 351 # absorb 352 if is_complement(b, aa): 353 aa.replace(exp.true() if kind == exp.And else exp.false()) 354 elif is_complement(b, ab): 355 ab.replace(exp.true() if kind == exp.And else exp.false()) 356 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 357 a.replace(exp.false() if kind == exp.And else exp.true()) 358 elif isinstance(b, kind): 359 # eliminate 360 rhs = b.unnest_operands() 361 ba, bb = rhs 362 363 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 364 a.replace(aa) 365 b.replace(aa) 366 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 367 a.replace(ab) 368 b.replace(ab) 369 370 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
454def simplify_literals(expression, root=True): 455 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 456 return _flat_simplify(expression, _simplify_binary, root) 457 458 if isinstance(expression, exp.Neg): 459 this = expression.this 460 if this.is_number: 461 value = this.name 462 if value[0] == "-": 463 return exp.Literal.number(value[1:]) 464 return exp.Literal.number(f"-{value}") 465 466 return expression
529def simplify_parens(expression): 530 if not isinstance(expression, exp.Paren): 531 return expression 532 533 this = expression.this 534 parent = expression.parent 535 536 if not isinstance(this, exp.Select) and ( 537 not isinstance(parent, (exp.Condition, exp.Binary)) 538 or isinstance(parent, exp.Paren) 539 or not isinstance(this, exp.Binary) 540 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 541 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 542 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 543 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 544 ): 545 return this 546 return expression
556def simplify_coalesce(expression): 557 # COALESCE(x) -> x 558 if ( 559 isinstance(expression, exp.Coalesce) 560 and not expression.expressions 561 # COALESCE is also used as a Spark partitioning hint 562 and not isinstance(expression.parent, exp.Hint) 563 ): 564 return expression.this 565 566 if not isinstance(expression, COMPARISONS): 567 return expression 568 569 if isinstance(expression.left, exp.Coalesce): 570 coalesce = expression.left 571 other = expression.right 572 elif isinstance(expression.right, exp.Coalesce): 573 coalesce = expression.right 574 other = expression.left 575 else: 576 return expression 577 578 # This transformation is valid for non-constants, 579 # but it really only does anything if they are both constants. 580 if not isinstance(other, CONSTANTS): 581 return expression 582 583 # Find the first constant arg 584 for arg_index, arg in enumerate(coalesce.expressions): 585 if isinstance(arg, CONSTANTS): 586 break 587 else: 588 return expression 589 590 coalesce.set("expressions", coalesce.expressions[:arg_index]) 591 592 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 593 # since we already remove COALESCE at the top of this function. 594 coalesce = coalesce if coalesce.expressions else coalesce.this 595 596 # This expression is more complex than when we started, but it will get simplified further 597 return exp.paren( 598 exp.or_( 599 exp.and_( 600 coalesce.is_(exp.null()).not_(copy=False), 601 expression.copy(), 602 copy=False, 603 ), 604 exp.and_( 605 coalesce.is_(exp.null()), 606 type(expression)(this=arg.copy(), expression=other.copy()), 607 copy=False, 608 ), 609 copy=False, 610 ) 611 )
618def simplify_concat(expression): 619 """Reduces all groups that contain string literals by concatenating them.""" 620 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 621 return expression 622 623 new_args = [] 624 for is_string_group, group in itertools.groupby( 625 expression.expressions or expression.flatten(), lambda e: e.is_string 626 ): 627 if is_string_group: 628 new_args.append(exp.Literal.string("".join(string.name for string in group))) 629 else: 630 new_args.extend(group) 631 632 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 633 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 634 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
765def remove_where_true(expression): 766 for where in expression.find_all(exp.Where): 767 if always_true(where.this): 768 where.parent.set("where", None) 769 for join in expression.find_all(exp.Join): 770 if ( 771 always_true(join.args.get("on")) 772 and not join.args.get("using") 773 and not join.args.get("method") 774 and (join.side, join.kind) in JOINS 775 ): 776 join.set("on", None) 777 join.set("side", None) 778 join.set("kind", "CROSS")
799def eval_boolean(expression, a, b): 800 if isinstance(expression, (exp.EQ, exp.Is)): 801 return boolean_literal(a == b) 802 if isinstance(expression, exp.NEQ): 803 return boolean_literal(a != b) 804 if isinstance(expression, exp.GT): 805 return boolean_literal(a > b) 806 if isinstance(expression, exp.GTE): 807 return boolean_literal(a >= b) 808 if isinstance(expression, exp.LT): 809 return boolean_literal(a < b) 810 if isinstance(expression, exp.LTE): 811 return boolean_literal(a <= b) 812 return None
815def extract_date(cast): 816 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 817 # so in that case we can't extract the date. 818 try: 819 if cast.args["to"].this == exp.DataType.Type.DATE: 820 return datetime.date.fromisoformat(cast.name) 821 if cast.args["to"].this == exp.DataType.Type.DATETIME: 822 return datetime.datetime.fromisoformat(cast.name) 823 except ValueError: 824 return None
844def interval(unit: str, n: int = 1): 845 from dateutil.relativedelta import relativedelta 846 847 if unit == "year": 848 return relativedelta(years=1 * n) 849 if unit == "quarter": 850 return relativedelta(months=3 * n) 851 if unit == "month": 852 return relativedelta(months=1 * n) 853 if unit == "week": 854 return relativedelta(weeks=1 * n) 855 if unit == "day": 856 return relativedelta(days=1 * n) 857 if unit == "hour": 858 return relativedelta(hours=1 * n) 859 if unit == "minute": 860 return relativedelta(minutes=1 * n) 861 if unit == "second": 862 return relativedelta(seconds=1 * n) 863 864 raise UnsupportedUnit(f"Unsupported unit: {unit}")
867def date_floor(d: datetime.date, unit: str) -> datetime.date: 868 if unit == "year": 869 return d.replace(month=1, day=1) 870 if unit == "quarter": 871 if d.month <= 3: 872 return d.replace(month=1, day=1) 873 elif d.month <= 6: 874 return d.replace(month=4, day=1) 875 elif d.month <= 9: 876 return d.replace(month=7, day=1) 877 else: 878 return d.replace(month=10, day=1) 879 if unit == "month": 880 return d.replace(month=d.month, day=1) 881 if unit == "week": 882 # Assuming week starts on Monday (0) and ends on Sunday (6) 883 return d - datetime.timedelta(days=d.weekday()) 884 if unit == "day": 885 return d 886 887 raise UnsupportedUnit(f"Unsupported unit: {unit}")