sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23logger = logging.getLogger("sqlglot") 24 25# Final means that an expression should not be simplified 26FINAL = "final" 27 28# Value ranges for byte-sized signed/unsigned integers 29TINYINT_MIN = -128 30TINYINT_MAX = 127 31UTINYINT_MIN = 0 32UTINYINT_MAX = 255 33 34 35class UnsupportedUnit(Exception): 36 pass 37 38 39def simplify( 40 expression: exp.Expression, 41 constant_propagation: bool = False, 42 dialect: DialectType = None, 43 max_depth: t.Optional[int] = None, 44): 45 """ 46 Rewrite sqlglot AST to simplify expressions. 47 48 Example: 49 >>> import sqlglot 50 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 51 >>> simplify(expression).sql() 52 'TRUE' 53 54 Args: 55 expression: expression to simplify 56 constant_propagation: whether the constant propagation rule should be used 57 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 58 Returns: 59 sqlglot.Expression: simplified expression 60 """ 61 62 dialect = Dialect.get_or_raise(dialect) 63 64 def _simplify(expression, root=True): 65 if ( 66 max_depth 67 and isinstance(expression, exp.Connector) 68 and not isinstance(expression.parent, exp.Connector) 69 ): 70 depth = connector_depth(expression) 71 if depth > max_depth: 72 logger.info( 73 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 74 ) 75 return expression 76 77 if expression.meta.get(FINAL): 78 return expression 79 80 # group by expressions cannot be simplified, for example 81 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 82 # the projection must exactly match the group by key 83 group = expression.args.get("group") 84 85 if group and hasattr(expression, "selects"): 86 groups = set(group.expressions) 87 group.meta[FINAL] = True 88 89 for e in expression.selects: 90 for node in e.walk(): 91 if node in groups: 92 e.meta[FINAL] = True 93 break 94 95 having = expression.args.get("having") 96 if having: 97 for node in having.walk(): 98 if node in groups: 99 having.meta[FINAL] = True 100 break 101 102 # Pre-order transformations 103 node = expression 104 node = rewrite_between(node) 105 node = uniq_sort(node, root) 106 node = absorb_and_eliminate(node, root) 107 node = simplify_concat(node) 108 node = simplify_conditionals(node) 109 110 if constant_propagation: 111 node = propagate_constants(node, root) 112 113 exp.replace_children(node, lambda e: _simplify(e, False)) 114 115 # Post-order transformations 116 node = simplify_not(node) 117 node = flatten(node) 118 node = simplify_connectors(node, root) 119 node = remove_complements(node, root) 120 node = simplify_coalesce(node) 121 node.parent = expression.parent 122 node = simplify_literals(node, root) 123 node = simplify_equality(node) 124 node = simplify_parens(node) 125 node = simplify_datetrunc(node, dialect) 126 node = sort_comparison(node) 127 node = simplify_startswith(node) 128 129 if root: 130 expression.replace(node) 131 return node 132 133 expression = while_changing(expression, _simplify) 134 remove_where_true(expression) 135 return expression 136 137 138def connector_depth(expression: exp.Expression) -> int: 139 """ 140 Determine the maximum depth of a tree of Connectors. 141 142 For example: 143 >>> from sqlglot import parse_one 144 >>> connector_depth(parse_one("a AND b AND c AND d")) 145 3 146 """ 147 stack = deque([(expression, 0)]) 148 max_depth = 0 149 150 while stack: 151 expression, depth = stack.pop() 152 153 if not isinstance(expression, exp.Connector): 154 continue 155 156 depth += 1 157 max_depth = max(depth, max_depth) 158 159 stack.append((expression.left, depth)) 160 stack.append((expression.right, depth)) 161 162 return max_depth 163 164 165def catch(*exceptions): 166 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 167 168 def decorator(func): 169 def wrapped(expression, *args, **kwargs): 170 try: 171 return func(expression, *args, **kwargs) 172 except exceptions: 173 return expression 174 175 return wrapped 176 177 return decorator 178 179 180def rewrite_between(expression: exp.Expression) -> exp.Expression: 181 """Rewrite x between y and z to x >= y AND x <= z. 182 183 This is done because comparison simplification is only done on lt/lte/gt/gte. 184 """ 185 if isinstance(expression, exp.Between): 186 negate = isinstance(expression.parent, exp.Not) 187 188 expression = exp.and_( 189 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 190 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 191 copy=False, 192 ) 193 194 if negate: 195 expression = exp.paren(expression, copy=False) 196 197 return expression 198 199 200COMPLEMENT_COMPARISONS = { 201 exp.LT: exp.GTE, 202 exp.GT: exp.LTE, 203 exp.LTE: exp.GT, 204 exp.GTE: exp.LT, 205 exp.EQ: exp.NEQ, 206 exp.NEQ: exp.EQ, 207} 208 209 210def simplify_not(expression): 211 """ 212 Demorgan's Law 213 NOT (x OR y) -> NOT x AND NOT y 214 NOT (x AND y) -> NOT x OR NOT y 215 """ 216 if isinstance(expression, exp.Not): 217 this = expression.this 218 if is_null(this): 219 return exp.null() 220 if this.__class__ in COMPLEMENT_COMPARISONS: 221 return COMPLEMENT_COMPARISONS[this.__class__]( 222 this=this.this, expression=this.expression 223 ) 224 if isinstance(this, exp.Paren): 225 condition = this.unnest() 226 if isinstance(condition, exp.And): 227 return exp.paren( 228 exp.or_( 229 exp.not_(condition.left, copy=False), 230 exp.not_(condition.right, copy=False), 231 copy=False, 232 ) 233 ) 234 if isinstance(condition, exp.Or): 235 return exp.paren( 236 exp.and_( 237 exp.not_(condition.left, copy=False), 238 exp.not_(condition.right, copy=False), 239 copy=False, 240 ) 241 ) 242 if is_null(condition): 243 return exp.null() 244 if always_true(this): 245 return exp.false() 246 if is_false(this): 247 return exp.true() 248 if isinstance(this, exp.Not): 249 # double negation 250 # NOT NOT x -> x 251 return this.this 252 return expression 253 254 255def flatten(expression): 256 """ 257 A AND (B AND C) -> A AND B AND C 258 A OR (B OR C) -> A OR B OR C 259 """ 260 if isinstance(expression, exp.Connector): 261 for node in expression.args.values(): 262 child = node.unnest() 263 if isinstance(child, expression.__class__): 264 node.replace(child) 265 return expression 266 267 268def simplify_connectors(expression, root=True): 269 def _simplify_connectors(expression, left, right): 270 if isinstance(expression, exp.And): 271 if is_false(left) or is_false(right): 272 return exp.false() 273 if is_zero(left) or is_zero(right): 274 return exp.false() 275 if is_null(left) or is_null(right): 276 return exp.null() 277 if always_true(left) and always_true(right): 278 return exp.true() 279 if always_true(left): 280 return right 281 if always_true(right): 282 return left 283 return _simplify_comparison(expression, left, right) 284 elif isinstance(expression, exp.Or): 285 if always_true(left) or always_true(right): 286 return exp.true() 287 if ( 288 (is_null(left) and is_null(right)) 289 or (is_null(left) and always_false(right)) 290 or (always_false(left) and is_null(right)) 291 ): 292 return exp.null() 293 if is_false(left): 294 return right 295 if is_false(right): 296 return left 297 return _simplify_comparison(expression, left, right, or_=True) 298 elif isinstance(expression, exp.Xor): 299 if left == right: 300 return exp.false() 301 302 if isinstance(expression, exp.Connector): 303 return _flat_simplify(expression, _simplify_connectors, root) 304 return expression 305 306 307LT_LTE = (exp.LT, exp.LTE) 308GT_GTE = (exp.GT, exp.GTE) 309 310COMPARISONS = ( 311 *LT_LTE, 312 *GT_GTE, 313 exp.EQ, 314 exp.NEQ, 315 exp.Is, 316) 317 318INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 319 exp.LT: exp.GT, 320 exp.GT: exp.LT, 321 exp.LTE: exp.GTE, 322 exp.GTE: exp.LTE, 323} 324 325NONDETERMINISTIC = (exp.Rand, exp.Randn) 326AND_OR = (exp.And, exp.Or) 327 328 329def _simplify_comparison(expression, left, right, or_=False): 330 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 331 ll, lr = left.args.values() 332 rl, rr = right.args.values() 333 334 largs = {ll, lr} 335 rargs = {rl, rr} 336 337 matching = largs & rargs 338 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 339 340 if matching and columns: 341 try: 342 l = first(largs - columns) 343 r = first(rargs - columns) 344 except StopIteration: 345 return expression 346 347 if l.is_number and r.is_number: 348 l = l.to_py() 349 r = r.to_py() 350 elif l.is_string and r.is_string: 351 l = l.name 352 r = r.name 353 else: 354 l = extract_date(l) 355 if not l: 356 return None 357 r = extract_date(r) 358 if not r: 359 return None 360 # python won't compare date and datetime, but many engines will upcast 361 l, r = cast_as_datetime(l), cast_as_datetime(r) 362 363 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 364 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 365 return left if (av > bv if or_ else av <= bv) else right 366 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 367 return left if (av < bv if or_ else av >= bv) else right 368 369 # we can't ever shortcut to true because the column could be null 370 if not or_: 371 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 372 if av <= bv: 373 return exp.false() 374 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 375 if av >= bv: 376 return exp.false() 377 elif isinstance(a, exp.EQ): 378 if isinstance(b, exp.LT): 379 return exp.false() if av >= bv else a 380 if isinstance(b, exp.LTE): 381 return exp.false() if av > bv else a 382 if isinstance(b, exp.GT): 383 return exp.false() if av <= bv else a 384 if isinstance(b, exp.GTE): 385 return exp.false() if av < bv else a 386 if isinstance(b, exp.NEQ): 387 return exp.false() if av == bv else a 388 return None 389 390 391def remove_complements(expression, root=True): 392 """ 393 Removing complements. 394 395 A AND NOT A -> FALSE 396 A OR NOT A -> TRUE 397 """ 398 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 399 ops = set(expression.flatten()) 400 for op in ops: 401 if isinstance(op, exp.Not) and op.this in ops: 402 return exp.false() if isinstance(expression, exp.And) else exp.true() 403 404 return expression 405 406 407def uniq_sort(expression, root=True): 408 """ 409 Uniq and sort a connector. 410 411 C AND A AND B AND B -> A AND B AND C 412 """ 413 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 414 flattened = tuple(expression.flatten()) 415 416 if isinstance(expression, exp.Xor): 417 result_func = exp.xor 418 # Do not deduplicate XOR as A XOR A != A if A == True 419 deduped = None 420 arr = tuple((gen(e), e) for e in flattened) 421 else: 422 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 423 deduped = {gen(e): e for e in flattened} 424 arr = tuple(deduped.items()) 425 426 # check if the operands are already sorted, if not sort them 427 # A AND C AND B -> A AND B AND C 428 for i, (sql, e) in enumerate(arr[1:]): 429 if sql < arr[i][0]: 430 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 431 break 432 else: 433 # we didn't have to sort but maybe we need to dedup 434 if deduped and len(deduped) < len(flattened): 435 expression = result_func(*deduped.values(), copy=False) 436 437 return expression 438 439 440def absorb_and_eliminate(expression, root=True): 441 """ 442 absorption: 443 A AND (A OR B) -> A 444 A OR (A AND B) -> A 445 A AND (NOT A OR B) -> A AND B 446 A OR (NOT A AND B) -> A OR B 447 elimination: 448 (A AND B) OR (A AND NOT B) -> A 449 (A OR B) AND (A OR NOT B) -> A 450 """ 451 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 452 kind = exp.Or if isinstance(expression, exp.And) else exp.And 453 454 ops = tuple(expression.flatten()) 455 456 # Initialize lookup tables: 457 # Set of all operands, used to find complements for absorption. 458 op_set = set() 459 # Sub-operands, used to find subsets for absorption. 460 subops = defaultdict(list) 461 # Pairs of complements, used for elimination. 462 pairs = defaultdict(list) 463 464 # Populate the lookup tables 465 for op in ops: 466 op_set.add(op) 467 468 if not isinstance(op, kind): 469 # In cases like: A OR (A AND B) 470 # Subop will be: ^ 471 subops[op].append({op}) 472 continue 473 474 # In cases like: (A AND B) OR (A AND B AND C) 475 # Subops will be: ^ ^ 476 subset = set(op.flatten()) 477 for i in subset: 478 subops[i].append(subset) 479 480 a, b = op.unnest_operands() 481 if isinstance(a, exp.Not): 482 pairs[frozenset((a.this, b))].append((op, b)) 483 if isinstance(b, exp.Not): 484 pairs[frozenset((a, b.this))].append((op, a)) 485 486 for op in ops: 487 if not isinstance(op, kind): 488 continue 489 490 a, b = op.unnest_operands() 491 492 # Absorb 493 if isinstance(a, exp.Not) and a.this in op_set: 494 a.replace(exp.true() if kind == exp.And else exp.false()) 495 continue 496 if isinstance(b, exp.Not) and b.this in op_set: 497 b.replace(exp.true() if kind == exp.And else exp.false()) 498 continue 499 superset = set(op.flatten()) 500 if any(any(subset < superset for subset in subops[i]) for i in superset): 501 op.replace(exp.false() if kind == exp.And else exp.true()) 502 continue 503 504 # Eliminate 505 for other, complement in pairs[frozenset((a, b))]: 506 op.replace(complement) 507 other.replace(complement) 508 509 return expression 510 511 512def propagate_constants(expression, root=True): 513 """ 514 Propagate constants for conjunctions in DNF: 515 516 SELECT * FROM t WHERE a = b AND b = 5 becomes 517 SELECT * FROM t WHERE a = 5 AND b = 5 518 519 Reference: https://www.sqlite.org/optoverview.html 520 """ 521 522 if ( 523 isinstance(expression, exp.And) 524 and (root or not expression.same_parent) 525 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 526 ): 527 constant_mapping = {} 528 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 529 if isinstance(expr, exp.EQ): 530 l, r = expr.left, expr.right 531 532 # TODO: create a helper that can be used to detect nested literal expressions such 533 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 534 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 535 constant_mapping[l] = (id(l), r) 536 537 if constant_mapping: 538 for column in find_all_in_scope(expression, exp.Column): 539 parent = column.parent 540 column_id, constant = constant_mapping.get(column) or (None, None) 541 if ( 542 column_id is not None 543 and id(column) != column_id 544 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 545 ): 546 column.replace(constant.copy()) 547 548 return expression 549 550 551INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 552 exp.DateAdd: exp.Sub, 553 exp.DateSub: exp.Add, 554 exp.DatetimeAdd: exp.Sub, 555 exp.DatetimeSub: exp.Add, 556} 557 558INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 559 **INVERSE_DATE_OPS, 560 exp.Add: exp.Sub, 561 exp.Sub: exp.Add, 562} 563 564 565def _is_number(expression: exp.Expression) -> bool: 566 return expression.is_number 567 568 569def _is_interval(expression: exp.Expression) -> bool: 570 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 571 572 573@catch(ModuleNotFoundError, UnsupportedUnit) 574def simplify_equality(expression: exp.Expression) -> exp.Expression: 575 """ 576 Use the subtraction and addition properties of equality to simplify expressions: 577 578 x + 1 = 3 becomes x = 2 579 580 There are two binary operations in the above expression: + and = 581 Here's how we reference all the operands in the code below: 582 583 l r 584 x + 1 = 3 585 a b 586 """ 587 if isinstance(expression, COMPARISONS): 588 l, r = expression.left, expression.right 589 590 if l.__class__ not in INVERSE_OPS: 591 return expression 592 593 if r.is_number: 594 a_predicate = _is_number 595 b_predicate = _is_number 596 elif _is_date_literal(r): 597 a_predicate = _is_date_literal 598 b_predicate = _is_interval 599 else: 600 return expression 601 602 if l.__class__ in INVERSE_DATE_OPS: 603 l = t.cast(exp.IntervalOp, l) 604 a = l.this 605 b = l.interval() 606 else: 607 l = t.cast(exp.Binary, l) 608 a, b = l.left, l.right 609 610 if not a_predicate(a) and b_predicate(b): 611 pass 612 elif not a_predicate(b) and b_predicate(a): 613 a, b = b, a 614 else: 615 return expression 616 617 return expression.__class__( 618 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 619 ) 620 return expression 621 622 623def simplify_literals(expression, root=True): 624 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 625 return _flat_simplify(expression, _simplify_binary, root) 626 627 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 628 return expression.this.this 629 630 if type(expression) in INVERSE_DATE_OPS: 631 return _simplify_binary(expression, expression.this, expression.interval()) or expression 632 633 return expression 634 635 636NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 637 638 639def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 640 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 641 this = _simplify_integer_cast(expr.this) 642 else: 643 this = expr.this 644 645 if isinstance(expr, exp.Cast) and this.is_int: 646 num = this.to_py() 647 648 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 649 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 650 # engine-dependent 651 if ( 652 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 653 ) or ( 654 UTINYINT_MIN <= num <= UTINYINT_MAX 655 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 656 ): 657 return this 658 659 return expr 660 661 662def _simplify_binary(expression, a, b): 663 if isinstance(expression, COMPARISONS): 664 a = _simplify_integer_cast(a) 665 b = _simplify_integer_cast(b) 666 667 if isinstance(expression, exp.Is): 668 if isinstance(b, exp.Not): 669 c = b.this 670 not_ = True 671 else: 672 c = b 673 not_ = False 674 675 if is_null(c): 676 if isinstance(a, exp.Literal): 677 return exp.true() if not_ else exp.false() 678 if is_null(a): 679 return exp.false() if not_ else exp.true() 680 elif isinstance(expression, NULL_OK): 681 return None 682 elif is_null(a) or is_null(b): 683 return exp.null() 684 685 if a.is_number and b.is_number: 686 num_a = a.to_py() 687 num_b = b.to_py() 688 689 if isinstance(expression, exp.Add): 690 return exp.Literal.number(num_a + num_b) 691 if isinstance(expression, exp.Mul): 692 return exp.Literal.number(num_a * num_b) 693 694 # We only simplify Sub, Div if a and b have the same parent because they're not associative 695 if isinstance(expression, exp.Sub): 696 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 697 if isinstance(expression, exp.Div): 698 # engines have differing int div behavior so intdiv is not safe 699 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 700 return None 701 return exp.Literal.number(num_a / num_b) 702 703 boolean = eval_boolean(expression, num_a, num_b) 704 705 if boolean: 706 return boolean 707 elif a.is_string and b.is_string: 708 boolean = eval_boolean(expression, a.this, b.this) 709 710 if boolean: 711 return boolean 712 elif _is_date_literal(a) and isinstance(b, exp.Interval): 713 date, b = extract_date(a), extract_interval(b) 714 if date and b: 715 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 716 return date_literal(date + b, extract_type(a)) 717 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 718 return date_literal(date - b, extract_type(a)) 719 elif isinstance(a, exp.Interval) and _is_date_literal(b): 720 a, date = extract_interval(a), extract_date(b) 721 # you cannot subtract a date from an interval 722 if a and b and isinstance(expression, exp.Add): 723 return date_literal(a + date, extract_type(b)) 724 elif _is_date_literal(a) and _is_date_literal(b): 725 if isinstance(expression, exp.Predicate): 726 a, b = extract_date(a), extract_date(b) 727 boolean = eval_boolean(expression, a, b) 728 if boolean: 729 return boolean 730 731 return None 732 733 734def simplify_parens(expression): 735 if not isinstance(expression, exp.Paren): 736 return expression 737 738 this = expression.this 739 parent = expression.parent 740 parent_is_predicate = isinstance(parent, exp.Predicate) 741 742 if ( 743 not isinstance(this, exp.Select) 744 and not isinstance(parent, exp.SubqueryPredicate) 745 and ( 746 not isinstance(parent, (exp.Condition, exp.Binary)) 747 or isinstance(parent, exp.Paren) 748 or ( 749 not isinstance(this, exp.Binary) 750 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 751 ) 752 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 753 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 754 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 755 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 756 ) 757 ): 758 return this 759 return expression 760 761 762def _is_nonnull_constant(expression: exp.Expression) -> bool: 763 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 764 765 766def _is_constant(expression: exp.Expression) -> bool: 767 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 768 769 770def simplify_coalesce(expression): 771 # COALESCE(x) -> x 772 if ( 773 isinstance(expression, exp.Coalesce) 774 and (not expression.expressions or _is_nonnull_constant(expression.this)) 775 # COALESCE is also used as a Spark partitioning hint 776 and not isinstance(expression.parent, exp.Hint) 777 ): 778 return expression.this 779 780 if not isinstance(expression, COMPARISONS): 781 return expression 782 783 if isinstance(expression.left, exp.Coalesce): 784 coalesce = expression.left 785 other = expression.right 786 elif isinstance(expression.right, exp.Coalesce): 787 coalesce = expression.right 788 other = expression.left 789 else: 790 return expression 791 792 # This transformation is valid for non-constants, 793 # but it really only does anything if they are both constants. 794 if not _is_constant(other): 795 return expression 796 797 # Find the first constant arg 798 for arg_index, arg in enumerate(coalesce.expressions): 799 if _is_constant(arg): 800 break 801 else: 802 return expression 803 804 coalesce.set("expressions", coalesce.expressions[:arg_index]) 805 806 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 807 # since we already remove COALESCE at the top of this function. 808 coalesce = coalesce if coalesce.expressions else coalesce.this 809 810 # This expression is more complex than when we started, but it will get simplified further 811 return exp.paren( 812 exp.or_( 813 exp.and_( 814 coalesce.is_(exp.null()).not_(copy=False), 815 expression.copy(), 816 copy=False, 817 ), 818 exp.and_( 819 coalesce.is_(exp.null()), 820 type(expression)(this=arg.copy(), expression=other.copy()), 821 copy=False, 822 ), 823 copy=False, 824 ) 825 ) 826 827 828CONCATS = (exp.Concat, exp.DPipe) 829 830 831def simplify_concat(expression): 832 """Reduces all groups that contain string literals by concatenating them.""" 833 if not isinstance(expression, CONCATS) or ( 834 # We can't reduce a CONCAT_WS call if we don't statically know the separator 835 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 836 ): 837 return expression 838 839 if isinstance(expression, exp.ConcatWs): 840 sep_expr, *expressions = expression.expressions 841 sep = sep_expr.name 842 concat_type = exp.ConcatWs 843 args = {} 844 else: 845 expressions = expression.expressions 846 sep = "" 847 concat_type = exp.Concat 848 args = { 849 "safe": expression.args.get("safe"), 850 "coalesce": expression.args.get("coalesce"), 851 } 852 853 new_args = [] 854 for is_string_group, group in itertools.groupby( 855 expressions or expression.flatten(), lambda e: e.is_string 856 ): 857 if is_string_group: 858 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 859 else: 860 new_args.extend(group) 861 862 if len(new_args) == 1 and new_args[0].is_string: 863 return new_args[0] 864 865 if concat_type is exp.ConcatWs: 866 new_args = [sep_expr] + new_args 867 elif isinstance(expression, exp.DPipe): 868 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 869 870 return concat_type(expressions=new_args, **args) 871 872 873def simplify_conditionals(expression): 874 """Simplifies expressions like IF, CASE if their condition is statically known.""" 875 if isinstance(expression, exp.Case): 876 this = expression.this 877 for case in expression.args["ifs"]: 878 cond = case.this 879 if this: 880 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 881 cond = cond.replace(this.pop().eq(cond)) 882 883 if always_true(cond): 884 return case.args["true"] 885 886 if always_false(cond): 887 case.pop() 888 if not expression.args["ifs"]: 889 return expression.args.get("default") or exp.null() 890 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 891 if always_true(expression.this): 892 return expression.args["true"] 893 if always_false(expression.this): 894 return expression.args.get("false") or exp.null() 895 896 return expression 897 898 899def simplify_startswith(expression: exp.Expression) -> exp.Expression: 900 """ 901 Reduces a prefix check to either TRUE or FALSE if both the string and the 902 prefix are statically known. 903 904 Example: 905 >>> from sqlglot import parse_one 906 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 907 'TRUE' 908 """ 909 if ( 910 isinstance(expression, exp.StartsWith) 911 and expression.this.is_string 912 and expression.expression.is_string 913 ): 914 return exp.convert(expression.name.startswith(expression.expression.name)) 915 916 return expression 917 918 919DateRange = t.Tuple[datetime.date, datetime.date] 920 921 922def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 923 """ 924 Get the date range for a DATE_TRUNC equality comparison: 925 926 Example: 927 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 928 Returns: 929 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 930 """ 931 floor = date_floor(date, unit, dialect) 932 933 if date != floor: 934 # This will always be False, except for NULL values. 935 return None 936 937 return floor, floor + interval(unit) 938 939 940def _datetrunc_eq_expression( 941 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 942) -> exp.Expression: 943 """Get the logical expression for a date range""" 944 return exp.and_( 945 left >= date_literal(drange[0], target_type), 946 left < date_literal(drange[1], target_type), 947 copy=False, 948 ) 949 950 951def _datetrunc_eq( 952 left: exp.Expression, 953 date: datetime.date, 954 unit: str, 955 dialect: Dialect, 956 target_type: t.Optional[exp.DataType], 957) -> t.Optional[exp.Expression]: 958 drange = _datetrunc_range(date, unit, dialect) 959 if not drange: 960 return None 961 962 return _datetrunc_eq_expression(left, drange, target_type) 963 964 965def _datetrunc_neq( 966 left: exp.Expression, 967 date: datetime.date, 968 unit: str, 969 dialect: Dialect, 970 target_type: t.Optional[exp.DataType], 971) -> t.Optional[exp.Expression]: 972 drange = _datetrunc_range(date, unit, dialect) 973 if not drange: 974 return None 975 976 return exp.and_( 977 left < date_literal(drange[0], target_type), 978 left >= date_literal(drange[1], target_type), 979 copy=False, 980 ) 981 982 983DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 984 exp.LT: lambda l, dt, u, d, t: l 985 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 986 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 987 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 988 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 989 exp.EQ: _datetrunc_eq, 990 exp.NEQ: _datetrunc_neq, 991} 992DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 993DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 994 995 996def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 997 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 998 999 1000@catch(ModuleNotFoundError, UnsupportedUnit) 1001def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1002 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1003 comparison = expression.__class__ 1004 1005 if isinstance(expression, DATETRUNCS): 1006 this = expression.this 1007 trunc_type = extract_type(this) 1008 date = extract_date(this) 1009 if date and expression.unit: 1010 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1011 elif comparison not in DATETRUNC_COMPARISONS: 1012 return expression 1013 1014 if isinstance(expression, exp.Binary): 1015 l, r = expression.left, expression.right 1016 1017 if not _is_datetrunc_predicate(l, r): 1018 return expression 1019 1020 l = t.cast(exp.DateTrunc, l) 1021 trunc_arg = l.this 1022 unit = l.unit.name.lower() 1023 date = extract_date(r) 1024 1025 if not date: 1026 return expression 1027 1028 return ( 1029 DATETRUNC_BINARY_COMPARISONS[comparison]( 1030 trunc_arg, date, unit, dialect, extract_type(r) 1031 ) 1032 or expression 1033 ) 1034 1035 if isinstance(expression, exp.In): 1036 l = expression.this 1037 rs = expression.expressions 1038 1039 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1040 l = t.cast(exp.DateTrunc, l) 1041 unit = l.unit.name.lower() 1042 1043 ranges = [] 1044 for r in rs: 1045 date = extract_date(r) 1046 if not date: 1047 return expression 1048 drange = _datetrunc_range(date, unit, dialect) 1049 if drange: 1050 ranges.append(drange) 1051 1052 if not ranges: 1053 return expression 1054 1055 ranges = merge_ranges(ranges) 1056 target_type = extract_type(*rs) 1057 1058 return exp.or_( 1059 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1060 ) 1061 1062 return expression 1063 1064 1065def sort_comparison(expression: exp.Expression) -> exp.Expression: 1066 if expression.__class__ in COMPLEMENT_COMPARISONS: 1067 l, r = expression.this, expression.expression 1068 l_column = isinstance(l, exp.Column) 1069 r_column = isinstance(r, exp.Column) 1070 l_const = _is_constant(l) 1071 r_const = _is_constant(r) 1072 1073 if (l_column and not r_column) or (r_const and not l_const): 1074 return expression 1075 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1076 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1077 this=r, expression=l 1078 ) 1079 return expression 1080 1081 1082# CROSS joins result in an empty table if the right table is empty. 1083# So we can only simplify certain types of joins to CROSS. 1084# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1085JOINS = { 1086 ("", ""), 1087 ("", "INNER"), 1088 ("RIGHT", ""), 1089 ("RIGHT", "OUTER"), 1090} 1091 1092 1093def remove_where_true(expression): 1094 for where in expression.find_all(exp.Where): 1095 if always_true(where.this): 1096 where.pop() 1097 for join in expression.find_all(exp.Join): 1098 if ( 1099 always_true(join.args.get("on")) 1100 and not join.args.get("using") 1101 and not join.args.get("method") 1102 and (join.side, join.kind) in JOINS 1103 ): 1104 join.args["on"].pop() 1105 join.set("side", None) 1106 join.set("kind", "CROSS") 1107 1108 1109def always_true(expression): 1110 return (isinstance(expression, exp.Boolean) and expression.this) or ( 1111 isinstance(expression, exp.Literal) and not is_zero(expression) 1112 ) 1113 1114 1115def always_false(expression): 1116 return is_false(expression) or is_null(expression) or is_zero(expression) 1117 1118 1119def is_zero(expression): 1120 return isinstance(expression, exp.Literal) and expression.to_py() == 0 1121 1122 1123def is_complement(a, b): 1124 return isinstance(b, exp.Not) and b.this == a 1125 1126 1127def is_false(a: exp.Expression) -> bool: 1128 return type(a) is exp.Boolean and not a.this 1129 1130 1131def is_null(a: exp.Expression) -> bool: 1132 return type(a) is exp.Null 1133 1134 1135def eval_boolean(expression, a, b): 1136 if isinstance(expression, (exp.EQ, exp.Is)): 1137 return boolean_literal(a == b) 1138 if isinstance(expression, exp.NEQ): 1139 return boolean_literal(a != b) 1140 if isinstance(expression, exp.GT): 1141 return boolean_literal(a > b) 1142 if isinstance(expression, exp.GTE): 1143 return boolean_literal(a >= b) 1144 if isinstance(expression, exp.LT): 1145 return boolean_literal(a < b) 1146 if isinstance(expression, exp.LTE): 1147 return boolean_literal(a <= b) 1148 return None 1149 1150 1151def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1152 if isinstance(value, datetime.datetime): 1153 return value.date() 1154 if isinstance(value, datetime.date): 1155 return value 1156 try: 1157 return datetime.datetime.fromisoformat(value).date() 1158 except ValueError: 1159 return None 1160 1161 1162def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1163 if isinstance(value, datetime.datetime): 1164 return value 1165 if isinstance(value, datetime.date): 1166 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1167 try: 1168 return datetime.datetime.fromisoformat(value) 1169 except ValueError: 1170 return None 1171 1172 1173def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1174 if not value: 1175 return None 1176 if to.is_type(exp.DataType.Type.DATE): 1177 return cast_as_date(value) 1178 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1179 return cast_as_datetime(value) 1180 return None 1181 1182 1183def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1184 if isinstance(cast, exp.Cast): 1185 to = cast.to 1186 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1187 to = exp.DataType.build(exp.DataType.Type.DATE) 1188 else: 1189 return None 1190 1191 if isinstance(cast.this, exp.Literal): 1192 value: t.Any = cast.this.name 1193 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1194 value = extract_date(cast.this) 1195 else: 1196 return None 1197 return cast_value(value, to) 1198 1199 1200def _is_date_literal(expression: exp.Expression) -> bool: 1201 return extract_date(expression) is not None 1202 1203 1204def extract_interval(expression): 1205 try: 1206 n = int(expression.this.to_py()) 1207 unit = expression.text("unit").lower() 1208 return interval(unit, n) 1209 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1210 return None 1211 1212 1213def extract_type(*expressions): 1214 target_type = None 1215 for expression in expressions: 1216 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1217 if target_type: 1218 break 1219 1220 return target_type 1221 1222 1223def date_literal(date, target_type=None): 1224 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1225 target_type = ( 1226 exp.DataType.Type.DATETIME 1227 if isinstance(date, datetime.datetime) 1228 else exp.DataType.Type.DATE 1229 ) 1230 1231 return exp.cast(exp.Literal.string(date), target_type) 1232 1233 1234def interval(unit: str, n: int = 1): 1235 from dateutil.relativedelta import relativedelta 1236 1237 if unit == "year": 1238 return relativedelta(years=1 * n) 1239 if unit == "quarter": 1240 return relativedelta(months=3 * n) 1241 if unit == "month": 1242 return relativedelta(months=1 * n) 1243 if unit == "week": 1244 return relativedelta(weeks=1 * n) 1245 if unit == "day": 1246 return relativedelta(days=1 * n) 1247 if unit == "hour": 1248 return relativedelta(hours=1 * n) 1249 if unit == "minute": 1250 return relativedelta(minutes=1 * n) 1251 if unit == "second": 1252 return relativedelta(seconds=1 * n) 1253 1254 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1255 1256 1257def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1258 if unit == "year": 1259 return d.replace(month=1, day=1) 1260 if unit == "quarter": 1261 if d.month <= 3: 1262 return d.replace(month=1, day=1) 1263 elif d.month <= 6: 1264 return d.replace(month=4, day=1) 1265 elif d.month <= 9: 1266 return d.replace(month=7, day=1) 1267 else: 1268 return d.replace(month=10, day=1) 1269 if unit == "month": 1270 return d.replace(month=d.month, day=1) 1271 if unit == "week": 1272 # Assuming week starts on Monday (0) and ends on Sunday (6) 1273 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1274 if unit == "day": 1275 return d 1276 1277 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1278 1279 1280def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1281 floor = date_floor(d, unit, dialect) 1282 1283 if floor == d: 1284 return d 1285 1286 return floor + interval(unit) 1287 1288 1289def boolean_literal(condition): 1290 return exp.true() if condition else exp.false() 1291 1292 1293def _flat_simplify(expression, simplifier, root=True): 1294 if root or not expression.same_parent: 1295 operands = [] 1296 queue = deque(expression.flatten(unnest=False)) 1297 size = len(queue) 1298 1299 while queue: 1300 a = queue.popleft() 1301 1302 for b in queue: 1303 result = simplifier(expression, a, b) 1304 1305 if result and result is not expression: 1306 queue.remove(b) 1307 queue.appendleft(result) 1308 break 1309 else: 1310 operands.append(a) 1311 1312 if len(operands) < size: 1313 return functools.reduce( 1314 lambda a, b: expression.__class__(this=a, expression=b), operands 1315 ) 1316 return expression 1317 1318 1319def gen(expression: t.Any) -> str: 1320 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1321 1322 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1323 generator is expensive so we have a bare minimum sql generator here. 1324 """ 1325 return Gen().gen(expression) 1326 1327 1328class Gen: 1329 def __init__(self): 1330 self.stack = [] 1331 self.sqls = [] 1332 1333 def gen(self, expression: exp.Expression) -> str: 1334 self.stack = [expression] 1335 self.sqls.clear() 1336 1337 while self.stack: 1338 node = self.stack.pop() 1339 1340 if isinstance(node, exp.Expression): 1341 exp_handler_name = f"{node.key}_sql" 1342 1343 if hasattr(self, exp_handler_name): 1344 getattr(self, exp_handler_name)(node) 1345 elif isinstance(node, exp.Func): 1346 self._function(node) 1347 else: 1348 key = node.key.upper() 1349 self.stack.append(f"{key} " if self._args(node) else key) 1350 elif type(node) is list: 1351 for n in reversed(node): 1352 if n is not None: 1353 self.stack.extend((n, ",")) 1354 if node: 1355 self.stack.pop() 1356 else: 1357 if node is not None: 1358 self.sqls.append(str(node)) 1359 1360 return "".join(self.sqls) 1361 1362 def add_sql(self, e: exp.Add) -> None: 1363 self._binary(e, " + ") 1364 1365 def alias_sql(self, e: exp.Alias) -> None: 1366 self.stack.extend( 1367 ( 1368 e.args.get("alias"), 1369 " AS ", 1370 e.args.get("this"), 1371 ) 1372 ) 1373 1374 def and_sql(self, e: exp.And) -> None: 1375 self._binary(e, " AND ") 1376 1377 def anonymous_sql(self, e: exp.Anonymous) -> None: 1378 this = e.this 1379 if isinstance(this, str): 1380 name = this.upper() 1381 elif isinstance(this, exp.Identifier): 1382 name = this.this 1383 name = f'"{name}"' if this.quoted else name.upper() 1384 else: 1385 raise ValueError( 1386 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1387 ) 1388 1389 self.stack.extend( 1390 ( 1391 ")", 1392 e.expressions, 1393 "(", 1394 name, 1395 ) 1396 ) 1397 1398 def between_sql(self, e: exp.Between) -> None: 1399 self.stack.extend( 1400 ( 1401 e.args.get("high"), 1402 " AND ", 1403 e.args.get("low"), 1404 " BETWEEN ", 1405 e.this, 1406 ) 1407 ) 1408 1409 def boolean_sql(self, e: exp.Boolean) -> None: 1410 self.stack.append("TRUE" if e.this else "FALSE") 1411 1412 def bracket_sql(self, e: exp.Bracket) -> None: 1413 self.stack.extend( 1414 ( 1415 "]", 1416 e.expressions, 1417 "[", 1418 e.this, 1419 ) 1420 ) 1421 1422 def column_sql(self, e: exp.Column) -> None: 1423 for p in reversed(e.parts): 1424 self.stack.extend((p, ".")) 1425 self.stack.pop() 1426 1427 def datatype_sql(self, e: exp.DataType) -> None: 1428 self._args(e, 1) 1429 self.stack.append(f"{e.this.name} ") 1430 1431 def div_sql(self, e: exp.Div) -> None: 1432 self._binary(e, " / ") 1433 1434 def dot_sql(self, e: exp.Dot) -> None: 1435 self._binary(e, ".") 1436 1437 def eq_sql(self, e: exp.EQ) -> None: 1438 self._binary(e, " = ") 1439 1440 def from_sql(self, e: exp.From) -> None: 1441 self.stack.extend((e.this, "FROM ")) 1442 1443 def gt_sql(self, e: exp.GT) -> None: 1444 self._binary(e, " > ") 1445 1446 def gte_sql(self, e: exp.GTE) -> None: 1447 self._binary(e, " >= ") 1448 1449 def identifier_sql(self, e: exp.Identifier) -> None: 1450 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1451 1452 def ilike_sql(self, e: exp.ILike) -> None: 1453 self._binary(e, " ILIKE ") 1454 1455 def in_sql(self, e: exp.In) -> None: 1456 self.stack.append(")") 1457 self._args(e, 1) 1458 self.stack.extend( 1459 ( 1460 "(", 1461 " IN ", 1462 e.this, 1463 ) 1464 ) 1465 1466 def intdiv_sql(self, e: exp.IntDiv) -> None: 1467 self._binary(e, " DIV ") 1468 1469 def is_sql(self, e: exp.Is) -> None: 1470 self._binary(e, " IS ") 1471 1472 def like_sql(self, e: exp.Like) -> None: 1473 self._binary(e, " Like ") 1474 1475 def literal_sql(self, e: exp.Literal) -> None: 1476 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1477 1478 def lt_sql(self, e: exp.LT) -> None: 1479 self._binary(e, " < ") 1480 1481 def lte_sql(self, e: exp.LTE) -> None: 1482 self._binary(e, " <= ") 1483 1484 def mod_sql(self, e: exp.Mod) -> None: 1485 self._binary(e, " % ") 1486 1487 def mul_sql(self, e: exp.Mul) -> None: 1488 self._binary(e, " * ") 1489 1490 def neg_sql(self, e: exp.Neg) -> None: 1491 self._unary(e, "-") 1492 1493 def neq_sql(self, e: exp.NEQ) -> None: 1494 self._binary(e, " <> ") 1495 1496 def not_sql(self, e: exp.Not) -> None: 1497 self._unary(e, "NOT ") 1498 1499 def null_sql(self, e: exp.Null) -> None: 1500 self.stack.append("NULL") 1501 1502 def or_sql(self, e: exp.Or) -> None: 1503 self._binary(e, " OR ") 1504 1505 def paren_sql(self, e: exp.Paren) -> None: 1506 self.stack.extend( 1507 ( 1508 ")", 1509 e.this, 1510 "(", 1511 ) 1512 ) 1513 1514 def sub_sql(self, e: exp.Sub) -> None: 1515 self._binary(e, " - ") 1516 1517 def subquery_sql(self, e: exp.Subquery) -> None: 1518 self._args(e, 2) 1519 alias = e.args.get("alias") 1520 if alias: 1521 self.stack.append(alias) 1522 self.stack.extend((")", e.this, "(")) 1523 1524 def table_sql(self, e: exp.Table) -> None: 1525 self._args(e, 4) 1526 alias = e.args.get("alias") 1527 if alias: 1528 self.stack.append(alias) 1529 for p in reversed(e.parts): 1530 self.stack.extend((p, ".")) 1531 self.stack.pop() 1532 1533 def tablealias_sql(self, e: exp.TableAlias) -> None: 1534 columns = e.columns 1535 1536 if columns: 1537 self.stack.extend((")", columns, "(")) 1538 1539 self.stack.extend((e.this, " AS ")) 1540 1541 def var_sql(self, e: exp.Var) -> None: 1542 self.stack.append(e.this) 1543 1544 def _binary(self, e: exp.Binary, op: str) -> None: 1545 self.stack.extend((e.expression, op, e.this)) 1546 1547 def _unary(self, e: exp.Unary, op: str) -> None: 1548 self.stack.extend((e.this, op)) 1549 1550 def _function(self, e: exp.Func) -> None: 1551 self.stack.extend( 1552 ( 1553 ")", 1554 list(e.args.values()), 1555 "(", 1556 e.sql_name(), 1557 ) 1558 ) 1559 1560 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1561 kvs = [] 1562 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1563 1564 for k in arg_types or arg_types: 1565 v = node.args.get(k) 1566 1567 if v is not None: 1568 kvs.append([f":{k}", v]) 1569 if kvs: 1570 self.stack.append(kvs) 1571 return True 1572 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 dialect: DialectType = None, 44 max_depth: t.Optional[int] = None, 45): 46 """ 47 Rewrite sqlglot AST to simplify expressions. 48 49 Example: 50 >>> import sqlglot 51 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 52 >>> simplify(expression).sql() 53 'TRUE' 54 55 Args: 56 expression: expression to simplify 57 constant_propagation: whether the constant propagation rule should be used 58 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 59 Returns: 60 sqlglot.Expression: simplified expression 61 """ 62 63 dialect = Dialect.get_or_raise(dialect) 64 65 def _simplify(expression, root=True): 66 if ( 67 max_depth 68 and isinstance(expression, exp.Connector) 69 and not isinstance(expression.parent, exp.Connector) 70 ): 71 depth = connector_depth(expression) 72 if depth > max_depth: 73 logger.info( 74 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 75 ) 76 return expression 77 78 if expression.meta.get(FINAL): 79 return expression 80 81 # group by expressions cannot be simplified, for example 82 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 83 # the projection must exactly match the group by key 84 group = expression.args.get("group") 85 86 if group and hasattr(expression, "selects"): 87 groups = set(group.expressions) 88 group.meta[FINAL] = True 89 90 for e in expression.selects: 91 for node in e.walk(): 92 if node in groups: 93 e.meta[FINAL] = True 94 break 95 96 having = expression.args.get("having") 97 if having: 98 for node in having.walk(): 99 if node in groups: 100 having.meta[FINAL] = True 101 break 102 103 # Pre-order transformations 104 node = expression 105 node = rewrite_between(node) 106 node = uniq_sort(node, root) 107 node = absorb_and_eliminate(node, root) 108 node = simplify_concat(node) 109 node = simplify_conditionals(node) 110 111 if constant_propagation: 112 node = propagate_constants(node, root) 113 114 exp.replace_children(node, lambda e: _simplify(e, False)) 115 116 # Post-order transformations 117 node = simplify_not(node) 118 node = flatten(node) 119 node = simplify_connectors(node, root) 120 node = remove_complements(node, root) 121 node = simplify_coalesce(node) 122 node.parent = expression.parent 123 node = simplify_literals(node, root) 124 node = simplify_equality(node) 125 node = simplify_parens(node) 126 node = simplify_datetrunc(node, dialect) 127 node = sort_comparison(node) 128 node = simplify_startswith(node) 129 130 if root: 131 expression.replace(node) 132 return node 133 134 expression = while_changing(expression, _simplify) 135 remove_where_true(expression) 136 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- max_depth: Chains of Connectors (AND, OR, etc) exceeding
max_depth
will be skipped
Returns:
sqlglot.Expression: simplified expression
139def connector_depth(expression: exp.Expression) -> int: 140 """ 141 Determine the maximum depth of a tree of Connectors. 142 143 For example: 144 >>> from sqlglot import parse_one 145 >>> connector_depth(parse_one("a AND b AND c AND d")) 146 3 147 """ 148 stack = deque([(expression, 0)]) 149 max_depth = 0 150 151 while stack: 152 expression, depth = stack.pop() 153 154 if not isinstance(expression, exp.Connector): 155 continue 156 157 depth += 1 158 max_depth = max(depth, max_depth) 159 160 stack.append((expression.left, depth)) 161 stack.append((expression.right, depth)) 162 163 return max_depth
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one >>> connector_depth(parse_one("a AND b AND c AND d")) 3
166def catch(*exceptions): 167 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 168 169 def decorator(func): 170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression 175 176 return wrapped 177 178 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
181def rewrite_between(expression: exp.Expression) -> exp.Expression: 182 """Rewrite x between y and z to x >= y AND x <= z. 183 184 This is done because comparison simplification is only done on lt/lte/gt/gte. 185 """ 186 if isinstance(expression, exp.Between): 187 negate = isinstance(expression.parent, exp.Not) 188 189 expression = exp.and_( 190 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 191 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 192 copy=False, 193 ) 194 195 if negate: 196 expression = exp.paren(expression, copy=False) 197 198 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
211def simplify_not(expression): 212 """ 213 Demorgan's Law 214 NOT (x OR y) -> NOT x AND NOT y 215 NOT (x AND y) -> NOT x OR NOT y 216 """ 217 if isinstance(expression, exp.Not): 218 this = expression.this 219 if is_null(this): 220 return exp.null() 221 if this.__class__ in COMPLEMENT_COMPARISONS: 222 return COMPLEMENT_COMPARISONS[this.__class__]( 223 this=this.this, expression=this.expression 224 ) 225 if isinstance(this, exp.Paren): 226 condition = this.unnest() 227 if isinstance(condition, exp.And): 228 return exp.paren( 229 exp.or_( 230 exp.not_(condition.left, copy=False), 231 exp.not_(condition.right, copy=False), 232 copy=False, 233 ) 234 ) 235 if isinstance(condition, exp.Or): 236 return exp.paren( 237 exp.and_( 238 exp.not_(condition.left, copy=False), 239 exp.not_(condition.right, copy=False), 240 copy=False, 241 ) 242 ) 243 if is_null(condition): 244 return exp.null() 245 if always_true(this): 246 return exp.false() 247 if is_false(this): 248 return exp.true() 249 if isinstance(this, exp.Not): 250 # double negation 251 # NOT NOT x -> x 252 return this.this 253 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
256def flatten(expression): 257 """ 258 A AND (B AND C) -> A AND B AND C 259 A OR (B OR C) -> A OR B OR C 260 """ 261 if isinstance(expression, exp.Connector): 262 for node in expression.args.values(): 263 child = node.unnest() 264 if isinstance(child, expression.__class__): 265 node.replace(child) 266 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
269def simplify_connectors(expression, root=True): 270 def _simplify_connectors(expression, left, right): 271 if isinstance(expression, exp.And): 272 if is_false(left) or is_false(right): 273 return exp.false() 274 if is_zero(left) or is_zero(right): 275 return exp.false() 276 if is_null(left) or is_null(right): 277 return exp.null() 278 if always_true(left) and always_true(right): 279 return exp.true() 280 if always_true(left): 281 return right 282 if always_true(right): 283 return left 284 return _simplify_comparison(expression, left, right) 285 elif isinstance(expression, exp.Or): 286 if always_true(left) or always_true(right): 287 return exp.true() 288 if ( 289 (is_null(left) and is_null(right)) 290 or (is_null(left) and always_false(right)) 291 or (always_false(left) and is_null(right)) 292 ): 293 return exp.null() 294 if is_false(left): 295 return right 296 if is_false(right): 297 return left 298 return _simplify_comparison(expression, left, right, or_=True) 299 elif isinstance(expression, exp.Xor): 300 if left == right: 301 return exp.false() 302 303 if isinstance(expression, exp.Connector): 304 return _flat_simplify(expression, _simplify_connectors, root) 305 return expression
392def remove_complements(expression, root=True): 393 """ 394 Removing complements. 395 396 A AND NOT A -> FALSE 397 A OR NOT A -> TRUE 398 """ 399 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 400 ops = set(expression.flatten()) 401 for op in ops: 402 if isinstance(op, exp.Not) and op.this in ops: 403 return exp.false() if isinstance(expression, exp.And) else exp.true() 404 405 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
408def uniq_sort(expression, root=True): 409 """ 410 Uniq and sort a connector. 411 412 C AND A AND B AND B -> A AND B AND C 413 """ 414 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 415 flattened = tuple(expression.flatten()) 416 417 if isinstance(expression, exp.Xor): 418 result_func = exp.xor 419 # Do not deduplicate XOR as A XOR A != A if A == True 420 deduped = None 421 arr = tuple((gen(e), e) for e in flattened) 422 else: 423 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 424 deduped = {gen(e): e for e in flattened} 425 arr = tuple(deduped.items()) 426 427 # check if the operands are already sorted, if not sort them 428 # A AND C AND B -> A AND B AND C 429 for i, (sql, e) in enumerate(arr[1:]): 430 if sql < arr[i][0]: 431 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 432 break 433 else: 434 # we didn't have to sort but maybe we need to dedup 435 if deduped and len(deduped) < len(flattened): 436 expression = result_func(*deduped.values(), copy=False) 437 438 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
441def absorb_and_eliminate(expression, root=True): 442 """ 443 absorption: 444 A AND (A OR B) -> A 445 A OR (A AND B) -> A 446 A AND (NOT A OR B) -> A AND B 447 A OR (NOT A AND B) -> A OR B 448 elimination: 449 (A AND B) OR (A AND NOT B) -> A 450 (A OR B) AND (A OR NOT B) -> A 451 """ 452 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 453 kind = exp.Or if isinstance(expression, exp.And) else exp.And 454 455 ops = tuple(expression.flatten()) 456 457 # Initialize lookup tables: 458 # Set of all operands, used to find complements for absorption. 459 op_set = set() 460 # Sub-operands, used to find subsets for absorption. 461 subops = defaultdict(list) 462 # Pairs of complements, used for elimination. 463 pairs = defaultdict(list) 464 465 # Populate the lookup tables 466 for op in ops: 467 op_set.add(op) 468 469 if not isinstance(op, kind): 470 # In cases like: A OR (A AND B) 471 # Subop will be: ^ 472 subops[op].append({op}) 473 continue 474 475 # In cases like: (A AND B) OR (A AND B AND C) 476 # Subops will be: ^ ^ 477 subset = set(op.flatten()) 478 for i in subset: 479 subops[i].append(subset) 480 481 a, b = op.unnest_operands() 482 if isinstance(a, exp.Not): 483 pairs[frozenset((a.this, b))].append((op, b)) 484 if isinstance(b, exp.Not): 485 pairs[frozenset((a, b.this))].append((op, a)) 486 487 for op in ops: 488 if not isinstance(op, kind): 489 continue 490 491 a, b = op.unnest_operands() 492 493 # Absorb 494 if isinstance(a, exp.Not) and a.this in op_set: 495 a.replace(exp.true() if kind == exp.And else exp.false()) 496 continue 497 if isinstance(b, exp.Not) and b.this in op_set: 498 b.replace(exp.true() if kind == exp.And else exp.false()) 499 continue 500 superset = set(op.flatten()) 501 if any(any(subset < superset for subset in subops[i]) for i in superset): 502 op.replace(exp.false() if kind == exp.And else exp.true()) 503 continue 504 505 # Eliminate 506 for other, complement in pairs[frozenset((a, b))]: 507 op.replace(complement) 508 other.replace(complement) 509 510 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
513def propagate_constants(expression, root=True): 514 """ 515 Propagate constants for conjunctions in DNF: 516 517 SELECT * FROM t WHERE a = b AND b = 5 becomes 518 SELECT * FROM t WHERE a = 5 AND b = 5 519 520 Reference: https://www.sqlite.org/optoverview.html 521 """ 522 523 if ( 524 isinstance(expression, exp.And) 525 and (root or not expression.same_parent) 526 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 527 ): 528 constant_mapping = {} 529 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 530 if isinstance(expr, exp.EQ): 531 l, r = expr.left, expr.right 532 533 # TODO: create a helper that can be used to detect nested literal expressions such 534 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 535 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 536 constant_mapping[l] = (id(l), r) 537 538 if constant_mapping: 539 for column in find_all_in_scope(expression, exp.Column): 540 parent = column.parent 541 column_id, constant = constant_mapping.get(column) or (None, None) 542 if ( 543 column_id is not None 544 and id(column) != column_id 545 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 546 ): 547 column.replace(constant.copy()) 548 549 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
624def simplify_literals(expression, root=True): 625 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 626 return _flat_simplify(expression, _simplify_binary, root) 627 628 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 629 return expression.this.this 630 631 if type(expression) in INVERSE_DATE_OPS: 632 return _simplify_binary(expression, expression.this, expression.interval()) or expression 633 634 return expression
735def simplify_parens(expression): 736 if not isinstance(expression, exp.Paren): 737 return expression 738 739 this = expression.this 740 parent = expression.parent 741 parent_is_predicate = isinstance(parent, exp.Predicate) 742 743 if ( 744 not isinstance(this, exp.Select) 745 and not isinstance(parent, exp.SubqueryPredicate) 746 and ( 747 not isinstance(parent, (exp.Condition, exp.Binary)) 748 or isinstance(parent, exp.Paren) 749 or ( 750 not isinstance(this, exp.Binary) 751 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 752 ) 753 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 754 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 755 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 756 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 757 ) 758 ): 759 return this 760 return expression
771def simplify_coalesce(expression): 772 # COALESCE(x) -> x 773 if ( 774 isinstance(expression, exp.Coalesce) 775 and (not expression.expressions or _is_nonnull_constant(expression.this)) 776 # COALESCE is also used as a Spark partitioning hint 777 and not isinstance(expression.parent, exp.Hint) 778 ): 779 return expression.this 780 781 if not isinstance(expression, COMPARISONS): 782 return expression 783 784 if isinstance(expression.left, exp.Coalesce): 785 coalesce = expression.left 786 other = expression.right 787 elif isinstance(expression.right, exp.Coalesce): 788 coalesce = expression.right 789 other = expression.left 790 else: 791 return expression 792 793 # This transformation is valid for non-constants, 794 # but it really only does anything if they are both constants. 795 if not _is_constant(other): 796 return expression 797 798 # Find the first constant arg 799 for arg_index, arg in enumerate(coalesce.expressions): 800 if _is_constant(arg): 801 break 802 else: 803 return expression 804 805 coalesce.set("expressions", coalesce.expressions[:arg_index]) 806 807 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 808 # since we already remove COALESCE at the top of this function. 809 coalesce = coalesce if coalesce.expressions else coalesce.this 810 811 # This expression is more complex than when we started, but it will get simplified further 812 return exp.paren( 813 exp.or_( 814 exp.and_( 815 coalesce.is_(exp.null()).not_(copy=False), 816 expression.copy(), 817 copy=False, 818 ), 819 exp.and_( 820 coalesce.is_(exp.null()), 821 type(expression)(this=arg.copy(), expression=other.copy()), 822 copy=False, 823 ), 824 copy=False, 825 ) 826 )
832def simplify_concat(expression): 833 """Reduces all groups that contain string literals by concatenating them.""" 834 if not isinstance(expression, CONCATS) or ( 835 # We can't reduce a CONCAT_WS call if we don't statically know the separator 836 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 837 ): 838 return expression 839 840 if isinstance(expression, exp.ConcatWs): 841 sep_expr, *expressions = expression.expressions 842 sep = sep_expr.name 843 concat_type = exp.ConcatWs 844 args = {} 845 else: 846 expressions = expression.expressions 847 sep = "" 848 concat_type = exp.Concat 849 args = { 850 "safe": expression.args.get("safe"), 851 "coalesce": expression.args.get("coalesce"), 852 } 853 854 new_args = [] 855 for is_string_group, group in itertools.groupby( 856 expressions or expression.flatten(), lambda e: e.is_string 857 ): 858 if is_string_group: 859 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 860 else: 861 new_args.extend(group) 862 863 if len(new_args) == 1 and new_args[0].is_string: 864 return new_args[0] 865 866 if concat_type is exp.ConcatWs: 867 new_args = [sep_expr] + new_args 868 elif isinstance(expression, exp.DPipe): 869 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 870 871 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
874def simplify_conditionals(expression): 875 """Simplifies expressions like IF, CASE if their condition is statically known.""" 876 if isinstance(expression, exp.Case): 877 this = expression.this 878 for case in expression.args["ifs"]: 879 cond = case.this 880 if this: 881 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 882 cond = cond.replace(this.pop().eq(cond)) 883 884 if always_true(cond): 885 return case.args["true"] 886 887 if always_false(cond): 888 case.pop() 889 if not expression.args["ifs"]: 890 return expression.args.get("default") or exp.null() 891 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 892 if always_true(expression.this): 893 return expression.args["true"] 894 if always_false(expression.this): 895 return expression.args.get("false") or exp.null() 896 897 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
900def simplify_startswith(expression: exp.Expression) -> exp.Expression: 901 """ 902 Reduces a prefix check to either TRUE or FALSE if both the string and the 903 prefix are statically known. 904 905 Example: 906 >>> from sqlglot import parse_one 907 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 908 'TRUE' 909 """ 910 if ( 911 isinstance(expression, exp.StartsWith) 912 and expression.this.is_string 913 and expression.expression.is_string 914 ): 915 return exp.convert(expression.name.startswith(expression.expression.name)) 916 917 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1066def sort_comparison(expression: exp.Expression) -> exp.Expression: 1067 if expression.__class__ in COMPLEMENT_COMPARISONS: 1068 l, r = expression.this, expression.expression 1069 l_column = isinstance(l, exp.Column) 1070 r_column = isinstance(r, exp.Column) 1071 l_const = _is_constant(l) 1072 r_const = _is_constant(r) 1073 1074 if (l_column and not r_column) or (r_const and not l_const): 1075 return expression 1076 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1077 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1078 this=r, expression=l 1079 ) 1080 return expression
1094def remove_where_true(expression): 1095 for where in expression.find_all(exp.Where): 1096 if always_true(where.this): 1097 where.pop() 1098 for join in expression.find_all(exp.Join): 1099 if ( 1100 always_true(join.args.get("on")) 1101 and not join.args.get("using") 1102 and not join.args.get("method") 1103 and (join.side, join.kind) in JOINS 1104 ): 1105 join.args["on"].pop() 1106 join.set("side", None) 1107 join.set("kind", "CROSS")
1136def eval_boolean(expression, a, b): 1137 if isinstance(expression, (exp.EQ, exp.Is)): 1138 return boolean_literal(a == b) 1139 if isinstance(expression, exp.NEQ): 1140 return boolean_literal(a != b) 1141 if isinstance(expression, exp.GT): 1142 return boolean_literal(a > b) 1143 if isinstance(expression, exp.GTE): 1144 return boolean_literal(a >= b) 1145 if isinstance(expression, exp.LT): 1146 return boolean_literal(a < b) 1147 if isinstance(expression, exp.LTE): 1148 return boolean_literal(a <= b) 1149 return None
1152def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1153 if isinstance(value, datetime.datetime): 1154 return value.date() 1155 if isinstance(value, datetime.date): 1156 return value 1157 try: 1158 return datetime.datetime.fromisoformat(value).date() 1159 except ValueError: 1160 return None
1163def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1164 if isinstance(value, datetime.datetime): 1165 return value 1166 if isinstance(value, datetime.date): 1167 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1168 try: 1169 return datetime.datetime.fromisoformat(value) 1170 except ValueError: 1171 return None
1174def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1175 if not value: 1176 return None 1177 if to.is_type(exp.DataType.Type.DATE): 1178 return cast_as_date(value) 1179 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1180 return cast_as_datetime(value) 1181 return None
1184def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1185 if isinstance(cast, exp.Cast): 1186 to = cast.to 1187 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1188 to = exp.DataType.build(exp.DataType.Type.DATE) 1189 else: 1190 return None 1191 1192 if isinstance(cast.this, exp.Literal): 1193 value: t.Any = cast.this.name 1194 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1195 value = extract_date(cast.this) 1196 else: 1197 return None 1198 return cast_value(value, to)
1224def date_literal(date, target_type=None): 1225 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1226 target_type = ( 1227 exp.DataType.Type.DATETIME 1228 if isinstance(date, datetime.datetime) 1229 else exp.DataType.Type.DATE 1230 ) 1231 1232 return exp.cast(exp.Literal.string(date), target_type)
1235def interval(unit: str, n: int = 1): 1236 from dateutil.relativedelta import relativedelta 1237 1238 if unit == "year": 1239 return relativedelta(years=1 * n) 1240 if unit == "quarter": 1241 return relativedelta(months=3 * n) 1242 if unit == "month": 1243 return relativedelta(months=1 * n) 1244 if unit == "week": 1245 return relativedelta(weeks=1 * n) 1246 if unit == "day": 1247 return relativedelta(days=1 * n) 1248 if unit == "hour": 1249 return relativedelta(hours=1 * n) 1250 if unit == "minute": 1251 return relativedelta(minutes=1 * n) 1252 if unit == "second": 1253 return relativedelta(seconds=1 * n) 1254 1255 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1258def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1259 if unit == "year": 1260 return d.replace(month=1, day=1) 1261 if unit == "quarter": 1262 if d.month <= 3: 1263 return d.replace(month=1, day=1) 1264 elif d.month <= 6: 1265 return d.replace(month=4, day=1) 1266 elif d.month <= 9: 1267 return d.replace(month=7, day=1) 1268 else: 1269 return d.replace(month=10, day=1) 1270 if unit == "month": 1271 return d.replace(month=d.month, day=1) 1272 if unit == "week": 1273 # Assuming week starts on Monday (0) and ends on Sunday (6) 1274 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1275 if unit == "day": 1276 return d 1277 1278 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1320def gen(expression: t.Any) -> str: 1321 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1322 1323 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1324 generator is expensive so we have a bare minimum sql generator here. 1325 """ 1326 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1329class Gen: 1330 def __init__(self): 1331 self.stack = [] 1332 self.sqls = [] 1333 1334 def gen(self, expression: exp.Expression) -> str: 1335 self.stack = [expression] 1336 self.sqls.clear() 1337 1338 while self.stack: 1339 node = self.stack.pop() 1340 1341 if isinstance(node, exp.Expression): 1342 exp_handler_name = f"{node.key}_sql" 1343 1344 if hasattr(self, exp_handler_name): 1345 getattr(self, exp_handler_name)(node) 1346 elif isinstance(node, exp.Func): 1347 self._function(node) 1348 else: 1349 key = node.key.upper() 1350 self.stack.append(f"{key} " if self._args(node) else key) 1351 elif type(node) is list: 1352 for n in reversed(node): 1353 if n is not None: 1354 self.stack.extend((n, ",")) 1355 if node: 1356 self.stack.pop() 1357 else: 1358 if node is not None: 1359 self.sqls.append(str(node)) 1360 1361 return "".join(self.sqls) 1362 1363 def add_sql(self, e: exp.Add) -> None: 1364 self._binary(e, " + ") 1365 1366 def alias_sql(self, e: exp.Alias) -> None: 1367 self.stack.extend( 1368 ( 1369 e.args.get("alias"), 1370 " AS ", 1371 e.args.get("this"), 1372 ) 1373 ) 1374 1375 def and_sql(self, e: exp.And) -> None: 1376 self._binary(e, " AND ") 1377 1378 def anonymous_sql(self, e: exp.Anonymous) -> None: 1379 this = e.this 1380 if isinstance(this, str): 1381 name = this.upper() 1382 elif isinstance(this, exp.Identifier): 1383 name = this.this 1384 name = f'"{name}"' if this.quoted else name.upper() 1385 else: 1386 raise ValueError( 1387 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1388 ) 1389 1390 self.stack.extend( 1391 ( 1392 ")", 1393 e.expressions, 1394 "(", 1395 name, 1396 ) 1397 ) 1398 1399 def between_sql(self, e: exp.Between) -> None: 1400 self.stack.extend( 1401 ( 1402 e.args.get("high"), 1403 " AND ", 1404 e.args.get("low"), 1405 " BETWEEN ", 1406 e.this, 1407 ) 1408 ) 1409 1410 def boolean_sql(self, e: exp.Boolean) -> None: 1411 self.stack.append("TRUE" if e.this else "FALSE") 1412 1413 def bracket_sql(self, e: exp.Bracket) -> None: 1414 self.stack.extend( 1415 ( 1416 "]", 1417 e.expressions, 1418 "[", 1419 e.this, 1420 ) 1421 ) 1422 1423 def column_sql(self, e: exp.Column) -> None: 1424 for p in reversed(e.parts): 1425 self.stack.extend((p, ".")) 1426 self.stack.pop() 1427 1428 def datatype_sql(self, e: exp.DataType) -> None: 1429 self._args(e, 1) 1430 self.stack.append(f"{e.this.name} ") 1431 1432 def div_sql(self, e: exp.Div) -> None: 1433 self._binary(e, " / ") 1434 1435 def dot_sql(self, e: exp.Dot) -> None: 1436 self._binary(e, ".") 1437 1438 def eq_sql(self, e: exp.EQ) -> None: 1439 self._binary(e, " = ") 1440 1441 def from_sql(self, e: exp.From) -> None: 1442 self.stack.extend((e.this, "FROM ")) 1443 1444 def gt_sql(self, e: exp.GT) -> None: 1445 self._binary(e, " > ") 1446 1447 def gte_sql(self, e: exp.GTE) -> None: 1448 self._binary(e, " >= ") 1449 1450 def identifier_sql(self, e: exp.Identifier) -> None: 1451 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1452 1453 def ilike_sql(self, e: exp.ILike) -> None: 1454 self._binary(e, " ILIKE ") 1455 1456 def in_sql(self, e: exp.In) -> None: 1457 self.stack.append(")") 1458 self._args(e, 1) 1459 self.stack.extend( 1460 ( 1461 "(", 1462 " IN ", 1463 e.this, 1464 ) 1465 ) 1466 1467 def intdiv_sql(self, e: exp.IntDiv) -> None: 1468 self._binary(e, " DIV ") 1469 1470 def is_sql(self, e: exp.Is) -> None: 1471 self._binary(e, " IS ") 1472 1473 def like_sql(self, e: exp.Like) -> None: 1474 self._binary(e, " Like ") 1475 1476 def literal_sql(self, e: exp.Literal) -> None: 1477 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1478 1479 def lt_sql(self, e: exp.LT) -> None: 1480 self._binary(e, " < ") 1481 1482 def lte_sql(self, e: exp.LTE) -> None: 1483 self._binary(e, " <= ") 1484 1485 def mod_sql(self, e: exp.Mod) -> None: 1486 self._binary(e, " % ") 1487 1488 def mul_sql(self, e: exp.Mul) -> None: 1489 self._binary(e, " * ") 1490 1491 def neg_sql(self, e: exp.Neg) -> None: 1492 self._unary(e, "-") 1493 1494 def neq_sql(self, e: exp.NEQ) -> None: 1495 self._binary(e, " <> ") 1496 1497 def not_sql(self, e: exp.Not) -> None: 1498 self._unary(e, "NOT ") 1499 1500 def null_sql(self, e: exp.Null) -> None: 1501 self.stack.append("NULL") 1502 1503 def or_sql(self, e: exp.Or) -> None: 1504 self._binary(e, " OR ") 1505 1506 def paren_sql(self, e: exp.Paren) -> None: 1507 self.stack.extend( 1508 ( 1509 ")", 1510 e.this, 1511 "(", 1512 ) 1513 ) 1514 1515 def sub_sql(self, e: exp.Sub) -> None: 1516 self._binary(e, " - ") 1517 1518 def subquery_sql(self, e: exp.Subquery) -> None: 1519 self._args(e, 2) 1520 alias = e.args.get("alias") 1521 if alias: 1522 self.stack.append(alias) 1523 self.stack.extend((")", e.this, "(")) 1524 1525 def table_sql(self, e: exp.Table) -> None: 1526 self._args(e, 4) 1527 alias = e.args.get("alias") 1528 if alias: 1529 self.stack.append(alias) 1530 for p in reversed(e.parts): 1531 self.stack.extend((p, ".")) 1532 self.stack.pop() 1533 1534 def tablealias_sql(self, e: exp.TableAlias) -> None: 1535 columns = e.columns 1536 1537 if columns: 1538 self.stack.extend((")", columns, "(")) 1539 1540 self.stack.extend((e.this, " AS ")) 1541 1542 def var_sql(self, e: exp.Var) -> None: 1543 self.stack.append(e.this) 1544 1545 def _binary(self, e: exp.Binary, op: str) -> None: 1546 self.stack.extend((e.expression, op, e.this)) 1547 1548 def _unary(self, e: exp.Unary, op: str) -> None: 1549 self.stack.extend((e.this, op)) 1550 1551 def _function(self, e: exp.Func) -> None: 1552 self.stack.extend( 1553 ( 1554 ")", 1555 list(e.args.values()), 1556 "(", 1557 e.sql_name(), 1558 ) 1559 ) 1560 1561 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1562 kvs = [] 1563 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1564 1565 for k in arg_types or arg_types: 1566 v = node.args.get(k) 1567 1568 if v is not None: 1569 kvs.append([f":{k}", v]) 1570 if kvs: 1571 self.stack.append(kvs) 1572 return True 1573 return False
1334 def gen(self, expression: exp.Expression) -> str: 1335 self.stack = [expression] 1336 self.sqls.clear() 1337 1338 while self.stack: 1339 node = self.stack.pop() 1340 1341 if isinstance(node, exp.Expression): 1342 exp_handler_name = f"{node.key}_sql" 1343 1344 if hasattr(self, exp_handler_name): 1345 getattr(self, exp_handler_name)(node) 1346 elif isinstance(node, exp.Func): 1347 self._function(node) 1348 else: 1349 key = node.key.upper() 1350 self.stack.append(f"{key} " if self._args(node) else key) 1351 elif type(node) is list: 1352 for n in reversed(node): 1353 if n is not None: 1354 self.stack.extend((n, ",")) 1355 if node: 1356 self.stack.pop() 1357 else: 1358 if node is not None: 1359 self.sqls.append(str(node)) 1360 1361 return "".join(self.sqls)
1378 def anonymous_sql(self, e: exp.Anonymous) -> None: 1379 this = e.this 1380 if isinstance(this, str): 1381 name = this.upper() 1382 elif isinstance(this, exp.Identifier): 1383 name = this.this 1384 name = f'"{name}"' if this.quoted else name.upper() 1385 else: 1386 raise ValueError( 1387 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1388 ) 1389 1390 self.stack.extend( 1391 ( 1392 ")", 1393 e.expressions, 1394 "(", 1395 name, 1396 ) 1397 )