sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8import sqlglot 9from sqlglot import exp 10from sqlglot.generator import cached_generator 11from sqlglot.helper import first, merge_ranges, while_changing 12from sqlglot.optimizer.scope import find_all_in_scope 13 14# Final means that an expression should not be simplified 15FINAL = "final" 16 17 18class UnsupportedUnit(Exception): 19 pass 20 21 22def simplify(expression, constant_propagation=False): 23 """ 24 Rewrite sqlglot AST to simplify expressions. 25 26 Example: 27 >>> import sqlglot 28 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 29 >>> simplify(expression).sql() 30 'TRUE' 31 32 Args: 33 expression (sqlglot.Expression): expression to simplify 34 constant_propagation: whether or not the constant propagation rule should be used 35 36 Returns: 37 sqlglot.Expression: simplified expression 38 """ 39 40 generate = cached_generator() 41 42 # group by expressions cannot be simplified, for example 43 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 44 # the projection must exactly match the group by key 45 for group in expression.find_all(exp.Group): 46 select = group.parent 47 groups = set(group.expressions) 48 group.meta[FINAL] = True 49 50 for e in select.selects: 51 for node, *_ in e.walk(): 52 if node in groups: 53 e.meta[FINAL] = True 54 break 55 56 having = select.args.get("having") 57 if having: 58 for node, *_ in having.walk(): 59 if node in groups: 60 having.meta[FINAL] = True 61 break 62 63 def _simplify(expression, root=True): 64 if expression.meta.get(FINAL): 65 return expression 66 67 # Pre-order transformations 68 node = expression 69 node = rewrite_between(node) 70 node = uniq_sort(node, generate, root) 71 node = absorb_and_eliminate(node, root) 72 node = simplify_concat(node) 73 74 if constant_propagation: 75 node = propagate_constants(node, root) 76 77 exp.replace_children(node, lambda e: _simplify(e, False)) 78 79 # Post-order transformations 80 node = simplify_not(node) 81 node = flatten(node) 82 node = simplify_connectors(node, root) 83 node = remove_complements(node, root) 84 node = simplify_coalesce(node) 85 node.parent = expression.parent 86 node = simplify_literals(node, root) 87 node = simplify_equality(node) 88 node = simplify_parens(node) 89 node = simplify_datetrunc_predicate(node) 90 91 if root: 92 expression.replace(node) 93 94 return node 95 96 expression = while_changing(expression, _simplify) 97 remove_where_true(expression) 98 return expression 99 100 101def catch(*exceptions): 102 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 103 104 def decorator(func): 105 def wrapped(expression, *args, **kwargs): 106 try: 107 return func(expression, *args, **kwargs) 108 except exceptions: 109 return expression 110 111 return wrapped 112 113 return decorator 114 115 116def rewrite_between(expression: exp.Expression) -> exp.Expression: 117 """Rewrite x between y and z to x >= y AND x <= z. 118 119 This is done because comparison simplification is only done on lt/lte/gt/gte. 120 """ 121 if isinstance(expression, exp.Between): 122 return exp.and_( 123 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 124 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 125 copy=False, 126 ) 127 return expression 128 129 130def simplify_not(expression): 131 """ 132 Demorgan's Law 133 NOT (x OR y) -> NOT x AND NOT y 134 NOT (x AND y) -> NOT x OR NOT y 135 """ 136 if isinstance(expression, exp.Not): 137 if is_null(expression.this): 138 return exp.null() 139 if isinstance(expression.this, exp.Paren): 140 condition = expression.this.unnest() 141 if isinstance(condition, exp.And): 142 return exp.or_( 143 exp.not_(condition.left, copy=False), 144 exp.not_(condition.right, copy=False), 145 copy=False, 146 ) 147 if isinstance(condition, exp.Or): 148 return exp.and_( 149 exp.not_(condition.left, copy=False), 150 exp.not_(condition.right, copy=False), 151 copy=False, 152 ) 153 if is_null(condition): 154 return exp.null() 155 if always_true(expression.this): 156 return exp.false() 157 if is_false(expression.this): 158 return exp.true() 159 if isinstance(expression.this, exp.Not): 160 # double negation 161 # NOT NOT x -> x 162 return expression.this.this 163 return expression 164 165 166def flatten(expression): 167 """ 168 A AND (B AND C) -> A AND B AND C 169 A OR (B OR C) -> A OR B OR C 170 """ 171 if isinstance(expression, exp.Connector): 172 for node in expression.args.values(): 173 child = node.unnest() 174 if isinstance(child, expression.__class__): 175 node.replace(child) 176 return expression 177 178 179def simplify_connectors(expression, root=True): 180 def _simplify_connectors(expression, left, right): 181 if left == right: 182 return left 183 if isinstance(expression, exp.And): 184 if is_false(left) or is_false(right): 185 return exp.false() 186 if is_null(left) or is_null(right): 187 return exp.null() 188 if always_true(left) and always_true(right): 189 return exp.true() 190 if always_true(left): 191 return right 192 if always_true(right): 193 return left 194 return _simplify_comparison(expression, left, right) 195 elif isinstance(expression, exp.Or): 196 if always_true(left) or always_true(right): 197 return exp.true() 198 if is_false(left) and is_false(right): 199 return exp.false() 200 if ( 201 (is_null(left) and is_null(right)) 202 or (is_null(left) and is_false(right)) 203 or (is_false(left) and is_null(right)) 204 ): 205 return exp.null() 206 if is_false(left): 207 return right 208 if is_false(right): 209 return left 210 return _simplify_comparison(expression, left, right, or_=True) 211 212 if isinstance(expression, exp.Connector): 213 return _flat_simplify(expression, _simplify_connectors, root) 214 return expression 215 216 217LT_LTE = (exp.LT, exp.LTE) 218GT_GTE = (exp.GT, exp.GTE) 219 220COMPARISONS = ( 221 *LT_LTE, 222 *GT_GTE, 223 exp.EQ, 224 exp.NEQ, 225 exp.Is, 226) 227 228INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 229 exp.LT: exp.GT, 230 exp.GT: exp.LT, 231 exp.LTE: exp.GTE, 232 exp.GTE: exp.LTE, 233} 234 235 236def _simplify_comparison(expression, left, right, or_=False): 237 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 238 ll, lr = left.args.values() 239 rl, rr = right.args.values() 240 241 largs = {ll, lr} 242 rargs = {rl, rr} 243 244 matching = largs & rargs 245 columns = {m for m in matching if isinstance(m, exp.Column)} 246 247 if matching and columns: 248 try: 249 l = first(largs - columns) 250 r = first(rargs - columns) 251 except StopIteration: 252 return expression 253 254 # make sure the comparison is always of the form x > 1 instead of 1 < x 255 if left.__class__ in INVERSE_COMPARISONS and l == ll: 256 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 257 if right.__class__ in INVERSE_COMPARISONS and r == rl: 258 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 259 260 if l.is_number and r.is_number: 261 l = float(l.name) 262 r = float(r.name) 263 elif l.is_string and r.is_string: 264 l = l.name 265 r = r.name 266 else: 267 return None 268 269 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 270 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 271 return left if (av > bv if or_ else av <= bv) else right 272 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 273 return left if (av < bv if or_ else av >= bv) else right 274 275 # we can't ever shortcut to true because the column could be null 276 if not or_: 277 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 278 if av <= bv: 279 return exp.false() 280 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 281 if av >= bv: 282 return exp.false() 283 elif isinstance(a, exp.EQ): 284 if isinstance(b, exp.LT): 285 return exp.false() if av >= bv else a 286 if isinstance(b, exp.LTE): 287 return exp.false() if av > bv else a 288 if isinstance(b, exp.GT): 289 return exp.false() if av <= bv else a 290 if isinstance(b, exp.GTE): 291 return exp.false() if av < bv else a 292 if isinstance(b, exp.NEQ): 293 return exp.false() if av == bv else a 294 return None 295 296 297def remove_complements(expression, root=True): 298 """ 299 Removing complements. 300 301 A AND NOT A -> FALSE 302 A OR NOT A -> TRUE 303 """ 304 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 305 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 306 307 for a, b in itertools.permutations(expression.flatten(), 2): 308 if is_complement(a, b): 309 return complement 310 return expression 311 312 313def uniq_sort(expression, generate, root=True): 314 """ 315 Uniq and sort a connector. 316 317 C AND A AND B AND B -> A AND B AND C 318 """ 319 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 320 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 321 flattened = tuple(expression.flatten()) 322 deduped = {generate(e): e for e in flattened} 323 arr = tuple(deduped.items()) 324 325 # check if the operands are already sorted, if not sort them 326 # A AND C AND B -> A AND B AND C 327 for i, (sql, e) in enumerate(arr[1:]): 328 if sql < arr[i][0]: 329 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 330 break 331 else: 332 # we didn't have to sort but maybe we need to dedup 333 if len(deduped) < len(flattened): 334 expression = result_func(*deduped.values(), copy=False) 335 336 return expression 337 338 339def absorb_and_eliminate(expression, root=True): 340 """ 341 absorption: 342 A AND (A OR B) -> A 343 A OR (A AND B) -> A 344 A AND (NOT A OR B) -> A AND B 345 A OR (NOT A AND B) -> A OR B 346 elimination: 347 (A AND B) OR (A AND NOT B) -> A 348 (A OR B) AND (A OR NOT B) -> A 349 """ 350 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 351 kind = exp.Or if isinstance(expression, exp.And) else exp.And 352 353 for a, b in itertools.permutations(expression.flatten(), 2): 354 if isinstance(a, kind): 355 aa, ab = a.unnest_operands() 356 357 # absorb 358 if is_complement(b, aa): 359 aa.replace(exp.true() if kind == exp.And else exp.false()) 360 elif is_complement(b, ab): 361 ab.replace(exp.true() if kind == exp.And else exp.false()) 362 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 363 a.replace(exp.false() if kind == exp.And else exp.true()) 364 elif isinstance(b, kind): 365 # eliminate 366 rhs = b.unnest_operands() 367 ba, bb = rhs 368 369 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 370 a.replace(aa) 371 b.replace(aa) 372 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 373 a.replace(ab) 374 b.replace(ab) 375 376 return expression 377 378 379def propagate_constants(expression, root=True): 380 """ 381 Propagate constants for conjunctions in DNF: 382 383 SELECT * FROM t WHERE a = b AND b = 5 becomes 384 SELECT * FROM t WHERE a = 5 AND b = 5 385 386 Reference: https://www.sqlite.org/optoverview.html 387 """ 388 389 if ( 390 isinstance(expression, exp.And) 391 and (root or not expression.same_parent) 392 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 393 ): 394 constant_mapping = {} 395 for eq in find_all_in_scope(expression, exp.EQ): 396 l, r = eq.left, eq.right 397 398 # TODO: create a helper that can be used to detect nested literal expressions such 399 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 400 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 401 pass 402 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 403 l, r = r, l 404 else: 405 continue 406 407 constant_mapping[l] = (id(l), r) 408 409 if constant_mapping: 410 for column in find_all_in_scope(expression, exp.Column): 411 parent = column.parent 412 column_id, constant = constant_mapping.get(column) or (None, None) 413 if ( 414 column_id is not None 415 and id(column) != column_id 416 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 417 ): 418 column.replace(constant.copy()) 419 420 return expression 421 422 423INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 424 exp.DateAdd: exp.Sub, 425 exp.DateSub: exp.Add, 426 exp.DatetimeAdd: exp.Sub, 427 exp.DatetimeSub: exp.Add, 428} 429 430INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 431 **INVERSE_DATE_OPS, 432 exp.Add: exp.Sub, 433 exp.Sub: exp.Add, 434} 435 436 437def _is_number(expression: exp.Expression) -> bool: 438 return expression.is_number 439 440 441def _is_interval(expression: exp.Expression) -> bool: 442 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 443 444 445@catch(ModuleNotFoundError, UnsupportedUnit) 446def simplify_equality(expression: exp.Expression) -> exp.Expression: 447 """ 448 Use the subtraction and addition properties of equality to simplify expressions: 449 450 x + 1 = 3 becomes x = 2 451 452 There are two binary operations in the above expression: + and = 453 Here's how we reference all the operands in the code below: 454 455 l r 456 x + 1 = 3 457 a b 458 """ 459 if isinstance(expression, COMPARISONS): 460 l, r = expression.left, expression.right 461 462 if l.__class__ in INVERSE_OPS: 463 pass 464 elif r.__class__ in INVERSE_OPS: 465 l, r = r, l 466 else: 467 return expression 468 469 if r.is_number: 470 a_predicate = _is_number 471 b_predicate = _is_number 472 elif _is_date_literal(r): 473 a_predicate = _is_date_literal 474 b_predicate = _is_interval 475 else: 476 return expression 477 478 if l.__class__ in INVERSE_DATE_OPS: 479 a = l.this 480 b = l.interval() 481 else: 482 a, b = l.left, l.right 483 484 if not a_predicate(a) and b_predicate(b): 485 pass 486 elif not a_predicate(b) and b_predicate(a): 487 a, b = b, a 488 else: 489 return expression 490 491 return expression.__class__( 492 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 493 ) 494 return expression 495 496 497def simplify_literals(expression, root=True): 498 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 499 return _flat_simplify(expression, _simplify_binary, root) 500 501 if isinstance(expression, exp.Neg): 502 this = expression.this 503 if this.is_number: 504 value = this.name 505 if value[0] == "-": 506 return exp.Literal.number(value[1:]) 507 return exp.Literal.number(f"-{value}") 508 509 return expression 510 511 512def _simplify_binary(expression, a, b): 513 if isinstance(expression, exp.Is): 514 if isinstance(b, exp.Not): 515 c = b.this 516 not_ = True 517 else: 518 c = b 519 not_ = False 520 521 if is_null(c): 522 if isinstance(a, exp.Literal): 523 return exp.true() if not_ else exp.false() 524 if is_null(a): 525 return exp.false() if not_ else exp.true() 526 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 527 return None 528 elif is_null(a) or is_null(b): 529 return exp.null() 530 531 if a.is_number and b.is_number: 532 a = int(a.name) if a.is_int else Decimal(a.name) 533 b = int(b.name) if b.is_int else Decimal(b.name) 534 535 if isinstance(expression, exp.Add): 536 return exp.Literal.number(a + b) 537 if isinstance(expression, exp.Sub): 538 return exp.Literal.number(a - b) 539 if isinstance(expression, exp.Mul): 540 return exp.Literal.number(a * b) 541 if isinstance(expression, exp.Div): 542 # engines have differing int div behavior so intdiv is not safe 543 if isinstance(a, int) and isinstance(b, int): 544 return None 545 return exp.Literal.number(a / b) 546 547 boolean = eval_boolean(expression, a, b) 548 549 if boolean: 550 return boolean 551 elif a.is_string and b.is_string: 552 boolean = eval_boolean(expression, a.this, b.this) 553 554 if boolean: 555 return boolean 556 elif _is_date_literal(a) and isinstance(b, exp.Interval): 557 a, b = extract_date(a), extract_interval(b) 558 if a and b: 559 if isinstance(expression, exp.Add): 560 return date_literal(a + b) 561 if isinstance(expression, exp.Sub): 562 return date_literal(a - b) 563 elif isinstance(a, exp.Interval) and _is_date_literal(b): 564 a, b = extract_interval(a), extract_date(b) 565 # you cannot subtract a date from an interval 566 if a and b and isinstance(expression, exp.Add): 567 return date_literal(a + b) 568 569 return None 570 571 572def simplify_parens(expression): 573 if not isinstance(expression, exp.Paren): 574 return expression 575 576 this = expression.this 577 parent = expression.parent 578 579 if not isinstance(this, exp.Select) and ( 580 not isinstance(parent, (exp.Condition, exp.Binary)) 581 or isinstance(parent, exp.Paren) 582 or not isinstance(this, exp.Binary) 583 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 584 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 585 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 586 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 587 ): 588 return this 589 return expression 590 591 592CONSTANTS = ( 593 exp.Literal, 594 exp.Boolean, 595 exp.Null, 596) 597 598 599def simplify_coalesce(expression): 600 # COALESCE(x) -> x 601 if ( 602 isinstance(expression, exp.Coalesce) 603 and not expression.expressions 604 # COALESCE is also used as a Spark partitioning hint 605 and not isinstance(expression.parent, exp.Hint) 606 ): 607 return expression.this 608 609 if not isinstance(expression, COMPARISONS): 610 return expression 611 612 if isinstance(expression.left, exp.Coalesce): 613 coalesce = expression.left 614 other = expression.right 615 elif isinstance(expression.right, exp.Coalesce): 616 coalesce = expression.right 617 other = expression.left 618 else: 619 return expression 620 621 # This transformation is valid for non-constants, 622 # but it really only does anything if they are both constants. 623 if not isinstance(other, CONSTANTS): 624 return expression 625 626 # Find the first constant arg 627 for arg_index, arg in enumerate(coalesce.expressions): 628 if isinstance(arg, CONSTANTS): 629 break 630 else: 631 return expression 632 633 coalesce.set("expressions", coalesce.expressions[:arg_index]) 634 635 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 636 # since we already remove COALESCE at the top of this function. 637 coalesce = coalesce if coalesce.expressions else coalesce.this 638 639 # This expression is more complex than when we started, but it will get simplified further 640 return exp.paren( 641 exp.or_( 642 exp.and_( 643 coalesce.is_(exp.null()).not_(copy=False), 644 expression.copy(), 645 copy=False, 646 ), 647 exp.and_( 648 coalesce.is_(exp.null()), 649 type(expression)(this=arg.copy(), expression=other.copy()), 650 copy=False, 651 ), 652 copy=False, 653 ) 654 ) 655 656 657CONCATS = (exp.Concat, exp.DPipe) 658SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 659 660 661def simplify_concat(expression): 662 """Reduces all groups that contain string literals by concatenating them.""" 663 if not isinstance(expression, CONCATS) or ( 664 # We can't reduce a CONCAT_WS call if we don't statically know the separator 665 isinstance(expression, exp.ConcatWs) 666 and not expression.expressions[0].is_string 667 ): 668 return expression 669 670 if isinstance(expression, exp.ConcatWs): 671 sep_expr, *expressions = expression.expressions 672 sep = sep_expr.name 673 concat_type = exp.ConcatWs 674 else: 675 expressions = expression.expressions 676 sep = "" 677 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 678 679 new_args = [] 680 for is_string_group, group in itertools.groupby( 681 expressions or expression.flatten(), lambda e: e.is_string 682 ): 683 if is_string_group: 684 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 685 else: 686 new_args.extend(group) 687 688 if len(new_args) == 1 and new_args[0].is_string: 689 return new_args[0] 690 691 if concat_type is exp.ConcatWs: 692 new_args = [sep_expr] + new_args 693 694 return concat_type(expressions=new_args) 695 696 697DateRange = t.Tuple[datetime.date, datetime.date] 698 699 700def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 701 """ 702 Get the date range for a DATE_TRUNC equality comparison: 703 704 Example: 705 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 706 Returns: 707 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 708 """ 709 floor = date_floor(date, unit) 710 711 if date != floor: 712 # This will always be False, except for NULL values. 713 return None 714 715 return floor, floor + interval(unit) 716 717 718def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 719 """Get the logical expression for a date range""" 720 return exp.and_( 721 left >= date_literal(drange[0]), 722 left < date_literal(drange[1]), 723 copy=False, 724 ) 725 726 727def _datetrunc_eq( 728 left: exp.Expression, date: datetime.date, unit: str 729) -> t.Optional[exp.Expression]: 730 drange = _datetrunc_range(date, unit) 731 if not drange: 732 return None 733 734 return _datetrunc_eq_expression(left, drange) 735 736 737def _datetrunc_neq( 738 left: exp.Expression, date: datetime.date, unit: str 739) -> t.Optional[exp.Expression]: 740 drange = _datetrunc_range(date, unit) 741 if not drange: 742 return None 743 744 return exp.and_( 745 left < date_literal(drange[0]), 746 left >= date_literal(drange[1]), 747 copy=False, 748 ) 749 750 751DateTruncBinaryTransform = t.Callable[ 752 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 753] 754DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 755 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 756 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 757 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 758 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 759 exp.EQ: _datetrunc_eq, 760 exp.NEQ: _datetrunc_neq, 761} 762DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 763 764 765def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 766 return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) 767 768 769@catch(ModuleNotFoundError, UnsupportedUnit) 770def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 771 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 772 comparison = expression.__class__ 773 774 if comparison not in DATETRUNC_COMPARISONS: 775 return expression 776 777 if isinstance(expression, exp.Binary): 778 l, r = expression.left, expression.right 779 780 if _is_datetrunc_predicate(l, r): 781 pass 782 elif _is_datetrunc_predicate(r, l): 783 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 784 l, r = r, l 785 else: 786 return expression 787 788 unit = l.unit.name.lower() 789 date = extract_date(r) 790 791 if not date: 792 return expression 793 794 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 795 elif isinstance(expression, exp.In): 796 l = expression.this 797 rs = expression.expressions 798 799 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 800 unit = l.unit.name.lower() 801 802 ranges = [] 803 for r in rs: 804 date = extract_date(r) 805 if not date: 806 return expression 807 drange = _datetrunc_range(date, unit) 808 if drange: 809 ranges.append(drange) 810 811 if not ranges: 812 return expression 813 814 ranges = merge_ranges(ranges) 815 816 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 817 818 return expression 819 820 821# CROSS joins result in an empty table if the right table is empty. 822# So we can only simplify certain types of joins to CROSS. 823# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 824JOINS = { 825 ("", ""), 826 ("", "INNER"), 827 ("RIGHT", ""), 828 ("RIGHT", "OUTER"), 829} 830 831 832def remove_where_true(expression): 833 for where in expression.find_all(exp.Where): 834 if always_true(where.this): 835 where.parent.set("where", None) 836 for join in expression.find_all(exp.Join): 837 if ( 838 always_true(join.args.get("on")) 839 and not join.args.get("using") 840 and not join.args.get("method") 841 and (join.side, join.kind) in JOINS 842 ): 843 join.set("on", None) 844 join.set("side", None) 845 join.set("kind", "CROSS") 846 847 848def always_true(expression): 849 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 850 expression, exp.Literal 851 ) 852 853 854def is_complement(a, b): 855 return isinstance(b, exp.Not) and b.this == a 856 857 858def is_false(a: exp.Expression) -> bool: 859 return type(a) is exp.Boolean and not a.this 860 861 862def is_null(a: exp.Expression) -> bool: 863 return type(a) is exp.Null 864 865 866def eval_boolean(expression, a, b): 867 if isinstance(expression, (exp.EQ, exp.Is)): 868 return boolean_literal(a == b) 869 if isinstance(expression, exp.NEQ): 870 return boolean_literal(a != b) 871 if isinstance(expression, exp.GT): 872 return boolean_literal(a > b) 873 if isinstance(expression, exp.GTE): 874 return boolean_literal(a >= b) 875 if isinstance(expression, exp.LT): 876 return boolean_literal(a < b) 877 if isinstance(expression, exp.LTE): 878 return boolean_literal(a <= b) 879 return None 880 881 882def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 883 if isinstance(value, datetime.datetime): 884 return value.date() 885 if isinstance(value, datetime.date): 886 return value 887 try: 888 return datetime.datetime.fromisoformat(value).date() 889 except ValueError: 890 return None 891 892 893def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 894 if isinstance(value, datetime.datetime): 895 return value 896 if isinstance(value, datetime.date): 897 return datetime.datetime(year=value.year, month=value.month, day=value.day) 898 try: 899 return datetime.datetime.fromisoformat(value) 900 except ValueError: 901 return None 902 903 904def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 905 if not value: 906 return None 907 if to.is_type(exp.DataType.Type.DATE): 908 return cast_as_date(value) 909 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 910 return cast_as_datetime(value) 911 return None 912 913 914def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 915 if isinstance(cast, exp.Cast): 916 to = cast.to 917 elif isinstance(cast, exp.TsOrDsToDate): 918 to = exp.DataType.build(exp.DataType.Type.DATE) 919 else: 920 return None 921 922 if isinstance(cast.this, exp.Literal): 923 value: t.Any = cast.this.name 924 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 925 value = extract_date(cast.this) 926 else: 927 return None 928 return cast_value(value, to) 929 930 931def _is_date_literal(expression: exp.Expression) -> bool: 932 return extract_date(expression) is not None 933 934 935def extract_interval(expression): 936 n = int(expression.name) 937 unit = expression.text("unit").lower() 938 939 try: 940 return interval(unit, n) 941 except (UnsupportedUnit, ModuleNotFoundError): 942 return None 943 944 945def date_literal(date): 946 return exp.cast( 947 exp.Literal.string(date), 948 exp.DataType.Type.DATETIME 949 if isinstance(date, datetime.datetime) 950 else exp.DataType.Type.DATE, 951 ) 952 953 954def interval(unit: str, n: int = 1): 955 from dateutil.relativedelta import relativedelta 956 957 if unit == "year": 958 return relativedelta(years=1 * n) 959 if unit == "quarter": 960 return relativedelta(months=3 * n) 961 if unit == "month": 962 return relativedelta(months=1 * n) 963 if unit == "week": 964 return relativedelta(weeks=1 * n) 965 if unit == "day": 966 return relativedelta(days=1 * n) 967 if unit == "hour": 968 return relativedelta(hours=1 * n) 969 if unit == "minute": 970 return relativedelta(minutes=1 * n) 971 if unit == "second": 972 return relativedelta(seconds=1 * n) 973 974 raise UnsupportedUnit(f"Unsupported unit: {unit}") 975 976 977def date_floor(d: datetime.date, unit: str) -> datetime.date: 978 if unit == "year": 979 return d.replace(month=1, day=1) 980 if unit == "quarter": 981 if d.month <= 3: 982 return d.replace(month=1, day=1) 983 elif d.month <= 6: 984 return d.replace(month=4, day=1) 985 elif d.month <= 9: 986 return d.replace(month=7, day=1) 987 else: 988 return d.replace(month=10, day=1) 989 if unit == "month": 990 return d.replace(month=d.month, day=1) 991 if unit == "week": 992 # Assuming week starts on Monday (0) and ends on Sunday (6) 993 return d - datetime.timedelta(days=d.weekday()) 994 if unit == "day": 995 return d 996 997 raise UnsupportedUnit(f"Unsupported unit: {unit}") 998 999 1000def date_ceil(d: datetime.date, unit: str) -> datetime.date: 1001 floor = date_floor(d, unit) 1002 1003 if floor == d: 1004 return d 1005 1006 return floor + interval(unit) 1007 1008 1009def boolean_literal(condition): 1010 return exp.true() if condition else exp.false() 1011 1012 1013def _flat_simplify(expression, simplifier, root=True): 1014 if root or not expression.same_parent: 1015 operands = [] 1016 queue = deque(expression.flatten(unnest=False)) 1017 size = len(queue) 1018 1019 while queue: 1020 a = queue.popleft() 1021 1022 for b in queue: 1023 result = simplifier(expression, a, b) 1024 1025 if result and result is not expression: 1026 queue.remove(b) 1027 queue.appendleft(result) 1028 break 1029 else: 1030 operands.append(a) 1031 1032 if len(operands) < size: 1033 return functools.reduce( 1034 lambda a, b: expression.__class__(this=a, expression=b), operands 1035 ) 1036 return expression
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
23def simplify(expression, constant_propagation=False): 24 """ 25 Rewrite sqlglot AST to simplify expressions. 26 27 Example: 28 >>> import sqlglot 29 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 30 >>> simplify(expression).sql() 31 'TRUE' 32 33 Args: 34 expression (sqlglot.Expression): expression to simplify 35 constant_propagation: whether or not the constant propagation rule should be used 36 37 Returns: 38 sqlglot.Expression: simplified expression 39 """ 40 41 generate = cached_generator() 42 43 # group by expressions cannot be simplified, for example 44 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 45 # the projection must exactly match the group by key 46 for group in expression.find_all(exp.Group): 47 select = group.parent 48 groups = set(group.expressions) 49 group.meta[FINAL] = True 50 51 for e in select.selects: 52 for node, *_ in e.walk(): 53 if node in groups: 54 e.meta[FINAL] = True 55 break 56 57 having = select.args.get("having") 58 if having: 59 for node, *_ in having.walk(): 60 if node in groups: 61 having.meta[FINAL] = True 62 break 63 64 def _simplify(expression, root=True): 65 if expression.meta.get(FINAL): 66 return expression 67 68 # Pre-order transformations 69 node = expression 70 node = rewrite_between(node) 71 node = uniq_sort(node, generate, root) 72 node = absorb_and_eliminate(node, root) 73 node = simplify_concat(node) 74 75 if constant_propagation: 76 node = propagate_constants(node, root) 77 78 exp.replace_children(node, lambda e: _simplify(e, False)) 79 80 # Post-order transformations 81 node = simplify_not(node) 82 node = flatten(node) 83 node = simplify_connectors(node, root) 84 node = remove_complements(node, root) 85 node = simplify_coalesce(node) 86 node.parent = expression.parent 87 node = simplify_literals(node, root) 88 node = simplify_equality(node) 89 node = simplify_parens(node) 90 node = simplify_datetrunc_predicate(node) 91 92 if root: 93 expression.replace(node) 94 95 return node 96 97 expression = while_changing(expression, _simplify) 98 remove_where_true(expression) 99 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
102def catch(*exceptions): 103 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 104 105 def decorator(func): 106 def wrapped(expression, *args, **kwargs): 107 try: 108 return func(expression, *args, **kwargs) 109 except exceptions: 110 return expression 111 112 return wrapped 113 114 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
117def rewrite_between(expression: exp.Expression) -> exp.Expression: 118 """Rewrite x between y and z to x >= y AND x <= z. 119 120 This is done because comparison simplification is only done on lt/lte/gt/gte. 121 """ 122 if isinstance(expression, exp.Between): 123 return exp.and_( 124 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 125 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 126 copy=False, 127 ) 128 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
131def simplify_not(expression): 132 """ 133 Demorgan's Law 134 NOT (x OR y) -> NOT x AND NOT y 135 NOT (x AND y) -> NOT x OR NOT y 136 """ 137 if isinstance(expression, exp.Not): 138 if is_null(expression.this): 139 return exp.null() 140 if isinstance(expression.this, exp.Paren): 141 condition = expression.this.unnest() 142 if isinstance(condition, exp.And): 143 return exp.or_( 144 exp.not_(condition.left, copy=False), 145 exp.not_(condition.right, copy=False), 146 copy=False, 147 ) 148 if isinstance(condition, exp.Or): 149 return exp.and_( 150 exp.not_(condition.left, copy=False), 151 exp.not_(condition.right, copy=False), 152 copy=False, 153 ) 154 if is_null(condition): 155 return exp.null() 156 if always_true(expression.this): 157 return exp.false() 158 if is_false(expression.this): 159 return exp.true() 160 if isinstance(expression.this, exp.Not): 161 # double negation 162 # NOT NOT x -> x 163 return expression.this.this 164 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
167def flatten(expression): 168 """ 169 A AND (B AND C) -> A AND B AND C 170 A OR (B OR C) -> A OR B OR C 171 """ 172 if isinstance(expression, exp.Connector): 173 for node in expression.args.values(): 174 child = node.unnest() 175 if isinstance(child, expression.__class__): 176 node.replace(child) 177 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
180def simplify_connectors(expression, root=True): 181 def _simplify_connectors(expression, left, right): 182 if left == right: 183 return left 184 if isinstance(expression, exp.And): 185 if is_false(left) or is_false(right): 186 return exp.false() 187 if is_null(left) or is_null(right): 188 return exp.null() 189 if always_true(left) and always_true(right): 190 return exp.true() 191 if always_true(left): 192 return right 193 if always_true(right): 194 return left 195 return _simplify_comparison(expression, left, right) 196 elif isinstance(expression, exp.Or): 197 if always_true(left) or always_true(right): 198 return exp.true() 199 if is_false(left) and is_false(right): 200 return exp.false() 201 if ( 202 (is_null(left) and is_null(right)) 203 or (is_null(left) and is_false(right)) 204 or (is_false(left) and is_null(right)) 205 ): 206 return exp.null() 207 if is_false(left): 208 return right 209 if is_false(right): 210 return left 211 return _simplify_comparison(expression, left, right, or_=True) 212 213 if isinstance(expression, exp.Connector): 214 return _flat_simplify(expression, _simplify_connectors, root) 215 return expression
298def remove_complements(expression, root=True): 299 """ 300 Removing complements. 301 302 A AND NOT A -> FALSE 303 A OR NOT A -> TRUE 304 """ 305 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 306 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 307 308 for a, b in itertools.permutations(expression.flatten(), 2): 309 if is_complement(a, b): 310 return complement 311 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
314def uniq_sort(expression, generate, root=True): 315 """ 316 Uniq and sort a connector. 317 318 C AND A AND B AND B -> A AND B AND C 319 """ 320 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 321 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 322 flattened = tuple(expression.flatten()) 323 deduped = {generate(e): e for e in flattened} 324 arr = tuple(deduped.items()) 325 326 # check if the operands are already sorted, if not sort them 327 # A AND C AND B -> A AND B AND C 328 for i, (sql, e) in enumerate(arr[1:]): 329 if sql < arr[i][0]: 330 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 331 break 332 else: 333 # we didn't have to sort but maybe we need to dedup 334 if len(deduped) < len(flattened): 335 expression = result_func(*deduped.values(), copy=False) 336 337 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
340def absorb_and_eliminate(expression, root=True): 341 """ 342 absorption: 343 A AND (A OR B) -> A 344 A OR (A AND B) -> A 345 A AND (NOT A OR B) -> A AND B 346 A OR (NOT A AND B) -> A OR B 347 elimination: 348 (A AND B) OR (A AND NOT B) -> A 349 (A OR B) AND (A OR NOT B) -> A 350 """ 351 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 352 kind = exp.Or if isinstance(expression, exp.And) else exp.And 353 354 for a, b in itertools.permutations(expression.flatten(), 2): 355 if isinstance(a, kind): 356 aa, ab = a.unnest_operands() 357 358 # absorb 359 if is_complement(b, aa): 360 aa.replace(exp.true() if kind == exp.And else exp.false()) 361 elif is_complement(b, ab): 362 ab.replace(exp.true() if kind == exp.And else exp.false()) 363 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 364 a.replace(exp.false() if kind == exp.And else exp.true()) 365 elif isinstance(b, kind): 366 # eliminate 367 rhs = b.unnest_operands() 368 ba, bb = rhs 369 370 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 371 a.replace(aa) 372 b.replace(aa) 373 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 374 a.replace(ab) 375 b.replace(ab) 376 377 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
380def propagate_constants(expression, root=True): 381 """ 382 Propagate constants for conjunctions in DNF: 383 384 SELECT * FROM t WHERE a = b AND b = 5 becomes 385 SELECT * FROM t WHERE a = 5 AND b = 5 386 387 Reference: https://www.sqlite.org/optoverview.html 388 """ 389 390 if ( 391 isinstance(expression, exp.And) 392 and (root or not expression.same_parent) 393 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 394 ): 395 constant_mapping = {} 396 for eq in find_all_in_scope(expression, exp.EQ): 397 l, r = eq.left, eq.right 398 399 # TODO: create a helper that can be used to detect nested literal expressions such 400 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 401 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 402 pass 403 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 404 l, r = r, l 405 else: 406 continue 407 408 constant_mapping[l] = (id(l), r) 409 410 if constant_mapping: 411 for column in find_all_in_scope(expression, exp.Column): 412 parent = column.parent 413 column_id, constant = constant_mapping.get(column) or (None, None) 414 if ( 415 column_id is not None 416 and id(column) != column_id 417 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 418 ): 419 column.replace(constant.copy()) 420 421 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
106 def wrapped(expression, *args, **kwargs): 107 try: 108 return func(expression, *args, **kwargs) 109 except exceptions: 110 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
498def simplify_literals(expression, root=True): 499 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 500 return _flat_simplify(expression, _simplify_binary, root) 501 502 if isinstance(expression, exp.Neg): 503 this = expression.this 504 if this.is_number: 505 value = this.name 506 if value[0] == "-": 507 return exp.Literal.number(value[1:]) 508 return exp.Literal.number(f"-{value}") 509 510 return expression
573def simplify_parens(expression): 574 if not isinstance(expression, exp.Paren): 575 return expression 576 577 this = expression.this 578 parent = expression.parent 579 580 if not isinstance(this, exp.Select) and ( 581 not isinstance(parent, (exp.Condition, exp.Binary)) 582 or isinstance(parent, exp.Paren) 583 or not isinstance(this, exp.Binary) 584 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 585 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 586 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 587 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 588 ): 589 return this 590 return expression
600def simplify_coalesce(expression): 601 # COALESCE(x) -> x 602 if ( 603 isinstance(expression, exp.Coalesce) 604 and not expression.expressions 605 # COALESCE is also used as a Spark partitioning hint 606 and not isinstance(expression.parent, exp.Hint) 607 ): 608 return expression.this 609 610 if not isinstance(expression, COMPARISONS): 611 return expression 612 613 if isinstance(expression.left, exp.Coalesce): 614 coalesce = expression.left 615 other = expression.right 616 elif isinstance(expression.right, exp.Coalesce): 617 coalesce = expression.right 618 other = expression.left 619 else: 620 return expression 621 622 # This transformation is valid for non-constants, 623 # but it really only does anything if they are both constants. 624 if not isinstance(other, CONSTANTS): 625 return expression 626 627 # Find the first constant arg 628 for arg_index, arg in enumerate(coalesce.expressions): 629 if isinstance(arg, CONSTANTS): 630 break 631 else: 632 return expression 633 634 coalesce.set("expressions", coalesce.expressions[:arg_index]) 635 636 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 637 # since we already remove COALESCE at the top of this function. 638 coalesce = coalesce if coalesce.expressions else coalesce.this 639 640 # This expression is more complex than when we started, but it will get simplified further 641 return exp.paren( 642 exp.or_( 643 exp.and_( 644 coalesce.is_(exp.null()).not_(copy=False), 645 expression.copy(), 646 copy=False, 647 ), 648 exp.and_( 649 coalesce.is_(exp.null()), 650 type(expression)(this=arg.copy(), expression=other.copy()), 651 copy=False, 652 ), 653 copy=False, 654 ) 655 )
662def simplify_concat(expression): 663 """Reduces all groups that contain string literals by concatenating them.""" 664 if not isinstance(expression, CONCATS) or ( 665 # We can't reduce a CONCAT_WS call if we don't statically know the separator 666 isinstance(expression, exp.ConcatWs) 667 and not expression.expressions[0].is_string 668 ): 669 return expression 670 671 if isinstance(expression, exp.ConcatWs): 672 sep_expr, *expressions = expression.expressions 673 sep = sep_expr.name 674 concat_type = exp.ConcatWs 675 else: 676 expressions = expression.expressions 677 sep = "" 678 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 679 680 new_args = [] 681 for is_string_group, group in itertools.groupby( 682 expressions or expression.flatten(), lambda e: e.is_string 683 ): 684 if is_string_group: 685 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 686 else: 687 new_args.extend(group) 688 689 if len(new_args) == 1 and new_args[0].is_string: 690 return new_args[0] 691 692 if concat_type is exp.ConcatWs: 693 new_args = [sep_expr] + new_args 694 695 return concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
106 def wrapped(expression, *args, **kwargs): 107 try: 108 return func(expression, *args, **kwargs) 109 except exceptions: 110 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
833def remove_where_true(expression): 834 for where in expression.find_all(exp.Where): 835 if always_true(where.this): 836 where.parent.set("where", None) 837 for join in expression.find_all(exp.Join): 838 if ( 839 always_true(join.args.get("on")) 840 and not join.args.get("using") 841 and not join.args.get("method") 842 and (join.side, join.kind) in JOINS 843 ): 844 join.set("on", None) 845 join.set("side", None) 846 join.set("kind", "CROSS")
867def eval_boolean(expression, a, b): 868 if isinstance(expression, (exp.EQ, exp.Is)): 869 return boolean_literal(a == b) 870 if isinstance(expression, exp.NEQ): 871 return boolean_literal(a != b) 872 if isinstance(expression, exp.GT): 873 return boolean_literal(a > b) 874 if isinstance(expression, exp.GTE): 875 return boolean_literal(a >= b) 876 if isinstance(expression, exp.LT): 877 return boolean_literal(a < b) 878 if isinstance(expression, exp.LTE): 879 return boolean_literal(a <= b) 880 return None
894def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 895 if isinstance(value, datetime.datetime): 896 return value 897 if isinstance(value, datetime.date): 898 return datetime.datetime(year=value.year, month=value.month, day=value.day) 899 try: 900 return datetime.datetime.fromisoformat(value) 901 except ValueError: 902 return None
905def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 906 if not value: 907 return None 908 if to.is_type(exp.DataType.Type.DATE): 909 return cast_as_date(value) 910 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 911 return cast_as_datetime(value) 912 return None
915def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 916 if isinstance(cast, exp.Cast): 917 to = cast.to 918 elif isinstance(cast, exp.TsOrDsToDate): 919 to = exp.DataType.build(exp.DataType.Type.DATE) 920 else: 921 return None 922 923 if isinstance(cast.this, exp.Literal): 924 value: t.Any = cast.this.name 925 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 926 value = extract_date(cast.this) 927 else: 928 return None 929 return cast_value(value, to)
955def interval(unit: str, n: int = 1): 956 from dateutil.relativedelta import relativedelta 957 958 if unit == "year": 959 return relativedelta(years=1 * n) 960 if unit == "quarter": 961 return relativedelta(months=3 * n) 962 if unit == "month": 963 return relativedelta(months=1 * n) 964 if unit == "week": 965 return relativedelta(weeks=1 * n) 966 if unit == "day": 967 return relativedelta(days=1 * n) 968 if unit == "hour": 969 return relativedelta(hours=1 * n) 970 if unit == "minute": 971 return relativedelta(minutes=1 * n) 972 if unit == "second": 973 return relativedelta(seconds=1 * n) 974 975 raise UnsupportedUnit(f"Unsupported unit: {unit}")
978def date_floor(d: datetime.date, unit: str) -> datetime.date: 979 if unit == "year": 980 return d.replace(month=1, day=1) 981 if unit == "quarter": 982 if d.month <= 3: 983 return d.replace(month=1, day=1) 984 elif d.month <= 6: 985 return d.replace(month=4, day=1) 986 elif d.month <= 9: 987 return d.replace(month=7, day=1) 988 else: 989 return d.replace(month=10, day=1) 990 if unit == "month": 991 return d.replace(month=d.month, day=1) 992 if unit == "week": 993 # Assuming week starts on Monday (0) and ends on Sunday (6) 994 return d - datetime.timedelta(days=d.weekday()) 995 if unit == "day": 996 return d 997 998 raise UnsupportedUnit(f"Unsupported unit: {unit}")