sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from decimal import Decimal 10from functools import reduce 11 12import sqlglot 13from sqlglot import Dialect, exp 14from sqlglot.helper import first, merge_ranges, while_changing 15from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 16 17if t.TYPE_CHECKING: 18 from sqlglot.dialects.dialect import DialectType 19 20 DateTruncBinaryTransform = t.Callable[ 21 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 22 ] 23 24logger = logging.getLogger("sqlglot") 25 26# Final means that an expression should not be simplified 27FINAL = "final" 28 29# Value ranges for byte-sized signed/unsigned integers 30TINYINT_MIN = -128 31TINYINT_MAX = 127 32UTINYINT_MIN = 0 33UTINYINT_MAX = 255 34 35 36class UnsupportedUnit(Exception): 37 pass 38 39 40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 dialect: DialectType = None, 44 max_depth: t.Optional[int] = None, 45): 46 """ 47 Rewrite sqlglot AST to simplify expressions. 48 49 Example: 50 >>> import sqlglot 51 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 52 >>> simplify(expression).sql() 53 'TRUE' 54 55 Args: 56 expression: expression to simplify 57 constant_propagation: whether the constant propagation rule should be used 58 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 59 Returns: 60 sqlglot.Expression: simplified expression 61 """ 62 63 dialect = Dialect.get_or_raise(dialect) 64 65 def _simplify(expression, root=True): 66 if ( 67 max_depth 68 and isinstance(expression, exp.Connector) 69 and not isinstance(expression.parent, exp.Connector) 70 ): 71 depth = connector_depth(expression) 72 if depth > max_depth: 73 logger.info( 74 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 75 ) 76 return expression 77 78 if expression.meta.get(FINAL): 79 return expression 80 81 # group by expressions cannot be simplified, for example 82 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 83 # the projection must exactly match the group by key 84 group = expression.args.get("group") 85 86 if group and hasattr(expression, "selects"): 87 groups = set(group.expressions) 88 group.meta[FINAL] = True 89 90 for e in expression.selects: 91 for node in e.walk(): 92 if node in groups: 93 e.meta[FINAL] = True 94 break 95 96 having = expression.args.get("having") 97 if having: 98 for node in having.walk(): 99 if node in groups: 100 having.meta[FINAL] = True 101 break 102 103 # Pre-order transformations 104 node = expression 105 node = rewrite_between(node) 106 node = uniq_sort(node, root) 107 node = absorb_and_eliminate(node, root) 108 node = simplify_concat(node) 109 node = simplify_conditionals(node) 110 111 if constant_propagation: 112 node = propagate_constants(node, root) 113 114 exp.replace_children(node, lambda e: _simplify(e, False)) 115 116 # Post-order transformations 117 node = simplify_not(node) 118 node = flatten(node) 119 node = simplify_connectors(node, root) 120 node = remove_complements(node, root) 121 node = simplify_coalesce(node) 122 node.parent = expression.parent 123 node = simplify_literals(node, root) 124 node = simplify_equality(node) 125 node = simplify_parens(node) 126 node = simplify_datetrunc(node, dialect) 127 node = sort_comparison(node) 128 node = simplify_startswith(node) 129 130 if root: 131 expression.replace(node) 132 return node 133 134 expression = while_changing(expression, _simplify) 135 remove_where_true(expression) 136 return expression 137 138 139def connector_depth(expression: exp.Expression) -> int: 140 """ 141 Determine the maximum depth of a tree of Connectors. 142 143 For example: 144 >>> from sqlglot import parse_one 145 >>> connector_depth(parse_one("a AND b AND c AND d")) 146 3 147 """ 148 stack = deque([(expression, 0)]) 149 max_depth = 0 150 151 while stack: 152 expression, depth = stack.pop() 153 154 if not isinstance(expression, exp.Connector): 155 continue 156 157 depth += 1 158 max_depth = max(depth, max_depth) 159 160 stack.append((expression.left, depth)) 161 stack.append((expression.right, depth)) 162 163 return max_depth 164 165 166def catch(*exceptions): 167 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 168 169 def decorator(func): 170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression 175 176 return wrapped 177 178 return decorator 179 180 181def rewrite_between(expression: exp.Expression) -> exp.Expression: 182 """Rewrite x between y and z to x >= y AND x <= z. 183 184 This is done because comparison simplification is only done on lt/lte/gt/gte. 185 """ 186 if isinstance(expression, exp.Between): 187 negate = isinstance(expression.parent, exp.Not) 188 189 expression = exp.and_( 190 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 191 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 192 copy=False, 193 ) 194 195 if negate: 196 expression = exp.paren(expression, copy=False) 197 198 return expression 199 200 201COMPLEMENT_COMPARISONS = { 202 exp.LT: exp.GTE, 203 exp.GT: exp.LTE, 204 exp.LTE: exp.GT, 205 exp.GTE: exp.LT, 206 exp.EQ: exp.NEQ, 207 exp.NEQ: exp.EQ, 208} 209 210 211def simplify_not(expression): 212 """ 213 Demorgan's Law 214 NOT (x OR y) -> NOT x AND NOT y 215 NOT (x AND y) -> NOT x OR NOT y 216 """ 217 if isinstance(expression, exp.Not): 218 this = expression.this 219 if is_null(this): 220 return exp.null() 221 if this.__class__ in COMPLEMENT_COMPARISONS: 222 return COMPLEMENT_COMPARISONS[this.__class__]( 223 this=this.this, expression=this.expression 224 ) 225 if isinstance(this, exp.Paren): 226 condition = this.unnest() 227 if isinstance(condition, exp.And): 228 return exp.paren( 229 exp.or_( 230 exp.not_(condition.left, copy=False), 231 exp.not_(condition.right, copy=False), 232 copy=False, 233 ) 234 ) 235 if isinstance(condition, exp.Or): 236 return exp.paren( 237 exp.and_( 238 exp.not_(condition.left, copy=False), 239 exp.not_(condition.right, copy=False), 240 copy=False, 241 ) 242 ) 243 if is_null(condition): 244 return exp.null() 245 if always_true(this): 246 return exp.false() 247 if is_false(this): 248 return exp.true() 249 if isinstance(this, exp.Not): 250 # double negation 251 # NOT NOT x -> x 252 return this.this 253 return expression 254 255 256def flatten(expression): 257 """ 258 A AND (B AND C) -> A AND B AND C 259 A OR (B OR C) -> A OR B OR C 260 """ 261 if isinstance(expression, exp.Connector): 262 for node in expression.args.values(): 263 child = node.unnest() 264 if isinstance(child, expression.__class__): 265 node.replace(child) 266 return expression 267 268 269def simplify_connectors(expression, root=True): 270 def _simplify_connectors(expression, left, right): 271 if left == right: 272 if isinstance(expression, exp.Xor): 273 return exp.false() 274 return left 275 if isinstance(expression, exp.And): 276 if is_false(left) or is_false(right): 277 return exp.false() 278 if is_null(left) or is_null(right): 279 return exp.null() 280 if always_true(left) and always_true(right): 281 return exp.true() 282 if always_true(left): 283 return right 284 if always_true(right): 285 return left 286 return _simplify_comparison(expression, left, right) 287 elif isinstance(expression, exp.Or): 288 if always_true(left) or always_true(right): 289 return exp.true() 290 if is_false(left) and is_false(right): 291 return exp.false() 292 if ( 293 (is_null(left) and is_null(right)) 294 or (is_null(left) and is_false(right)) 295 or (is_false(left) and is_null(right)) 296 ): 297 return exp.null() 298 if is_false(left): 299 return right 300 if is_false(right): 301 return left 302 return _simplify_comparison(expression, left, right, or_=True) 303 304 if isinstance(expression, exp.Connector): 305 return _flat_simplify(expression, _simplify_connectors, root) 306 return expression 307 308 309LT_LTE = (exp.LT, exp.LTE) 310GT_GTE = (exp.GT, exp.GTE) 311 312COMPARISONS = ( 313 *LT_LTE, 314 *GT_GTE, 315 exp.EQ, 316 exp.NEQ, 317 exp.Is, 318) 319 320INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 321 exp.LT: exp.GT, 322 exp.GT: exp.LT, 323 exp.LTE: exp.GTE, 324 exp.GTE: exp.LTE, 325} 326 327NONDETERMINISTIC = (exp.Rand, exp.Randn) 328AND_OR = (exp.And, exp.Or) 329 330 331def _simplify_comparison(expression, left, right, or_=False): 332 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 333 ll, lr = left.args.values() 334 rl, rr = right.args.values() 335 336 largs = {ll, lr} 337 rargs = {rl, rr} 338 339 matching = largs & rargs 340 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 341 342 if matching and columns: 343 try: 344 l = first(largs - columns) 345 r = first(rargs - columns) 346 except StopIteration: 347 return expression 348 349 if l.is_number and r.is_number: 350 l = float(l.name) 351 r = float(r.name) 352 elif l.is_string and r.is_string: 353 l = l.name 354 r = r.name 355 else: 356 l = extract_date(l) 357 if not l: 358 return None 359 r = extract_date(r) 360 if not r: 361 return None 362 # python won't compare date and datetime, but many engines will upcast 363 l, r = cast_as_datetime(l), cast_as_datetime(r) 364 365 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 366 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 367 return left if (av > bv if or_ else av <= bv) else right 368 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 369 return left if (av < bv if or_ else av >= bv) else right 370 371 # we can't ever shortcut to true because the column could be null 372 if not or_: 373 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 374 if av <= bv: 375 return exp.false() 376 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 377 if av >= bv: 378 return exp.false() 379 elif isinstance(a, exp.EQ): 380 if isinstance(b, exp.LT): 381 return exp.false() if av >= bv else a 382 if isinstance(b, exp.LTE): 383 return exp.false() if av > bv else a 384 if isinstance(b, exp.GT): 385 return exp.false() if av <= bv else a 386 if isinstance(b, exp.GTE): 387 return exp.false() if av < bv else a 388 if isinstance(b, exp.NEQ): 389 return exp.false() if av == bv else a 390 return None 391 392 393def remove_complements(expression, root=True): 394 """ 395 Removing complements. 396 397 A AND NOT A -> FALSE 398 A OR NOT A -> TRUE 399 """ 400 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 401 ops = set(expression.flatten()) 402 for op in ops: 403 if isinstance(op, exp.Not) and op.this in ops: 404 return exp.false() if isinstance(expression, exp.And) else exp.true() 405 406 return expression 407 408 409def uniq_sort(expression, root=True): 410 """ 411 Uniq and sort a connector. 412 413 C AND A AND B AND B -> A AND B AND C 414 """ 415 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 416 flattened = tuple(expression.flatten()) 417 418 if isinstance(expression, exp.Xor): 419 result_func = exp.xor 420 # Do not deduplicate XOR as A XOR A != A if A == True 421 deduped = None 422 arr = tuple((gen(e), e) for e in flattened) 423 else: 424 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 425 deduped = {gen(e): e for e in flattened} 426 arr = tuple(deduped.items()) 427 428 # check if the operands are already sorted, if not sort them 429 # A AND C AND B -> A AND B AND C 430 for i, (sql, e) in enumerate(arr[1:]): 431 if sql < arr[i][0]: 432 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 433 break 434 else: 435 # we didn't have to sort but maybe we need to dedup 436 if deduped and len(deduped) < len(flattened): 437 expression = result_func(*deduped.values(), copy=False) 438 439 return expression 440 441 442def absorb_and_eliminate(expression, root=True): 443 """ 444 absorption: 445 A AND (A OR B) -> A 446 A OR (A AND B) -> A 447 A AND (NOT A OR B) -> A AND B 448 A OR (NOT A AND B) -> A OR B 449 elimination: 450 (A AND B) OR (A AND NOT B) -> A 451 (A OR B) AND (A OR NOT B) -> A 452 """ 453 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 454 kind = exp.Or if isinstance(expression, exp.And) else exp.And 455 456 ops = tuple(expression.flatten()) 457 458 # Initialize lookup tables: 459 # Set of all operands, used to find complements for absorption. 460 op_set = set() 461 # Sub-operands, used to find subsets for absorption. 462 subops = defaultdict(list) 463 # Pairs of complements, used for elimination. 464 pairs = defaultdict(list) 465 466 # Populate the lookup tables 467 for op in ops: 468 op_set.add(op) 469 470 if not isinstance(op, kind): 471 # In cases like: A OR (A AND B) 472 # Subop will be: ^ 473 subops[op].append({op}) 474 continue 475 476 # In cases like: (A AND B) OR (A AND B AND C) 477 # Subops will be: ^ ^ 478 subset = set(op.flatten()) 479 for i in subset: 480 subops[i].append(subset) 481 482 a, b = op.unnest_operands() 483 if isinstance(a, exp.Not): 484 pairs[frozenset((a.this, b))].append((op, b)) 485 if isinstance(b, exp.Not): 486 pairs[frozenset((a, b.this))].append((op, a)) 487 488 for op in ops: 489 if not isinstance(op, kind): 490 continue 491 492 a, b = op.unnest_operands() 493 494 # Absorb 495 if isinstance(a, exp.Not) and a.this in op_set: 496 a.replace(exp.true() if kind == exp.And else exp.false()) 497 continue 498 if isinstance(b, exp.Not) and b.this in op_set: 499 b.replace(exp.true() if kind == exp.And else exp.false()) 500 continue 501 superset = set(op.flatten()) 502 if any(any(subset < superset for subset in subops[i]) for i in superset): 503 op.replace(exp.false() if kind == exp.And else exp.true()) 504 continue 505 506 # Eliminate 507 for other, complement in pairs[frozenset((a, b))]: 508 op.replace(complement) 509 other.replace(complement) 510 511 return expression 512 513 514def propagate_constants(expression, root=True): 515 """ 516 Propagate constants for conjunctions in DNF: 517 518 SELECT * FROM t WHERE a = b AND b = 5 becomes 519 SELECT * FROM t WHERE a = 5 AND b = 5 520 521 Reference: https://www.sqlite.org/optoverview.html 522 """ 523 524 if ( 525 isinstance(expression, exp.And) 526 and (root or not expression.same_parent) 527 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 528 ): 529 constant_mapping = {} 530 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 531 if isinstance(expr, exp.EQ): 532 l, r = expr.left, expr.right 533 534 # TODO: create a helper that can be used to detect nested literal expressions such 535 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 536 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 537 constant_mapping[l] = (id(l), r) 538 539 if constant_mapping: 540 for column in find_all_in_scope(expression, exp.Column): 541 parent = column.parent 542 column_id, constant = constant_mapping.get(column) or (None, None) 543 if ( 544 column_id is not None 545 and id(column) != column_id 546 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 547 ): 548 column.replace(constant.copy()) 549 550 return expression 551 552 553INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 554 exp.DateAdd: exp.Sub, 555 exp.DateSub: exp.Add, 556 exp.DatetimeAdd: exp.Sub, 557 exp.DatetimeSub: exp.Add, 558} 559 560INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 561 **INVERSE_DATE_OPS, 562 exp.Add: exp.Sub, 563 exp.Sub: exp.Add, 564} 565 566 567def _is_number(expression: exp.Expression) -> bool: 568 return expression.is_number 569 570 571def _is_interval(expression: exp.Expression) -> bool: 572 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 573 574 575@catch(ModuleNotFoundError, UnsupportedUnit) 576def simplify_equality(expression: exp.Expression) -> exp.Expression: 577 """ 578 Use the subtraction and addition properties of equality to simplify expressions: 579 580 x + 1 = 3 becomes x = 2 581 582 There are two binary operations in the above expression: + and = 583 Here's how we reference all the operands in the code below: 584 585 l r 586 x + 1 = 3 587 a b 588 """ 589 if isinstance(expression, COMPARISONS): 590 l, r = expression.left, expression.right 591 592 if l.__class__ not in INVERSE_OPS: 593 return expression 594 595 if r.is_number: 596 a_predicate = _is_number 597 b_predicate = _is_number 598 elif _is_date_literal(r): 599 a_predicate = _is_date_literal 600 b_predicate = _is_interval 601 else: 602 return expression 603 604 if l.__class__ in INVERSE_DATE_OPS: 605 l = t.cast(exp.IntervalOp, l) 606 a = l.this 607 b = l.interval() 608 else: 609 l = t.cast(exp.Binary, l) 610 a, b = l.left, l.right 611 612 if not a_predicate(a) and b_predicate(b): 613 pass 614 elif not a_predicate(b) and b_predicate(a): 615 a, b = b, a 616 else: 617 return expression 618 619 return expression.__class__( 620 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 621 ) 622 return expression 623 624 625def simplify_literals(expression, root=True): 626 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 627 return _flat_simplify(expression, _simplify_binary, root) 628 629 if isinstance(expression, exp.Neg): 630 this = expression.this 631 if this.is_number: 632 value = this.name 633 if value[0] == "-": 634 return exp.Literal.number(value[1:]) 635 return exp.Literal.number(f"-{value}") 636 637 if type(expression) in INVERSE_DATE_OPS: 638 return _simplify_binary(expression, expression.this, expression.interval()) or expression 639 640 return expression 641 642 643NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 644 645 646def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 647 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 648 this = _simplify_integer_cast(expr.this) 649 else: 650 this = expr.this 651 652 if isinstance(expr, exp.Cast) and this.is_int: 653 num = int(this.name) 654 655 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 656 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 657 # engine-dependent 658 if ( 659 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 660 ) or ( 661 UTINYINT_MIN <= num <= UTINYINT_MAX 662 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 663 ): 664 return this 665 666 return expr 667 668 669def _simplify_binary(expression, a, b): 670 if isinstance(expression, COMPARISONS): 671 a = _simplify_integer_cast(a) 672 b = _simplify_integer_cast(b) 673 674 if isinstance(expression, exp.Is): 675 if isinstance(b, exp.Not): 676 c = b.this 677 not_ = True 678 else: 679 c = b 680 not_ = False 681 682 if is_null(c): 683 if isinstance(a, exp.Literal): 684 return exp.true() if not_ else exp.false() 685 if is_null(a): 686 return exp.false() if not_ else exp.true() 687 elif isinstance(expression, NULL_OK): 688 return None 689 elif is_null(a) or is_null(b): 690 return exp.null() 691 692 if a.is_number and b.is_number: 693 num_a = int(a.name) if a.is_int else Decimal(a.name) 694 num_b = int(b.name) if b.is_int else Decimal(b.name) 695 696 if isinstance(expression, exp.Add): 697 return exp.Literal.number(num_a + num_b) 698 if isinstance(expression, exp.Mul): 699 return exp.Literal.number(num_a * num_b) 700 701 # We only simplify Sub, Div if a and b have the same parent because they're not associative 702 if isinstance(expression, exp.Sub): 703 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 704 if isinstance(expression, exp.Div): 705 # engines have differing int div behavior so intdiv is not safe 706 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 707 return None 708 return exp.Literal.number(num_a / num_b) 709 710 boolean = eval_boolean(expression, num_a, num_b) 711 712 if boolean: 713 return boolean 714 elif a.is_string and b.is_string: 715 boolean = eval_boolean(expression, a.this, b.this) 716 717 if boolean: 718 return boolean 719 elif _is_date_literal(a) and isinstance(b, exp.Interval): 720 date, b = extract_date(a), extract_interval(b) 721 if date and b: 722 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 723 return date_literal(date + b, extract_type(a)) 724 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 725 return date_literal(date - b, extract_type(a)) 726 elif isinstance(a, exp.Interval) and _is_date_literal(b): 727 a, date = extract_interval(a), extract_date(b) 728 # you cannot subtract a date from an interval 729 if a and b and isinstance(expression, exp.Add): 730 return date_literal(a + date, extract_type(b)) 731 elif _is_date_literal(a) and _is_date_literal(b): 732 if isinstance(expression, exp.Predicate): 733 a, b = extract_date(a), extract_date(b) 734 boolean = eval_boolean(expression, a, b) 735 if boolean: 736 return boolean 737 738 return None 739 740 741def simplify_parens(expression): 742 if not isinstance(expression, exp.Paren): 743 return expression 744 745 this = expression.this 746 parent = expression.parent 747 parent_is_predicate = isinstance(parent, exp.Predicate) 748 749 if ( 750 not isinstance(this, exp.Select) 751 and not isinstance(parent, exp.SubqueryPredicate) 752 and ( 753 not isinstance(parent, (exp.Condition, exp.Binary)) 754 or isinstance(parent, exp.Paren) 755 or ( 756 not isinstance(this, exp.Binary) 757 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 758 ) 759 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 760 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 761 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 762 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 763 ) 764 ): 765 return this 766 return expression 767 768 769def _is_nonnull_constant(expression: exp.Expression) -> bool: 770 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 771 772 773def _is_constant(expression: exp.Expression) -> bool: 774 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 775 776 777def simplify_coalesce(expression): 778 # COALESCE(x) -> x 779 if ( 780 isinstance(expression, exp.Coalesce) 781 and (not expression.expressions or _is_nonnull_constant(expression.this)) 782 # COALESCE is also used as a Spark partitioning hint 783 and not isinstance(expression.parent, exp.Hint) 784 ): 785 return expression.this 786 787 if not isinstance(expression, COMPARISONS): 788 return expression 789 790 if isinstance(expression.left, exp.Coalesce): 791 coalesce = expression.left 792 other = expression.right 793 elif isinstance(expression.right, exp.Coalesce): 794 coalesce = expression.right 795 other = expression.left 796 else: 797 return expression 798 799 # This transformation is valid for non-constants, 800 # but it really only does anything if they are both constants. 801 if not _is_constant(other): 802 return expression 803 804 # Find the first constant arg 805 for arg_index, arg in enumerate(coalesce.expressions): 806 if _is_constant(arg): 807 break 808 else: 809 return expression 810 811 coalesce.set("expressions", coalesce.expressions[:arg_index]) 812 813 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 814 # since we already remove COALESCE at the top of this function. 815 coalesce = coalesce if coalesce.expressions else coalesce.this 816 817 # This expression is more complex than when we started, but it will get simplified further 818 return exp.paren( 819 exp.or_( 820 exp.and_( 821 coalesce.is_(exp.null()).not_(copy=False), 822 expression.copy(), 823 copy=False, 824 ), 825 exp.and_( 826 coalesce.is_(exp.null()), 827 type(expression)(this=arg.copy(), expression=other.copy()), 828 copy=False, 829 ), 830 copy=False, 831 ) 832 ) 833 834 835CONCATS = (exp.Concat, exp.DPipe) 836 837 838def simplify_concat(expression): 839 """Reduces all groups that contain string literals by concatenating them.""" 840 if not isinstance(expression, CONCATS) or ( 841 # We can't reduce a CONCAT_WS call if we don't statically know the separator 842 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 843 ): 844 return expression 845 846 if isinstance(expression, exp.ConcatWs): 847 sep_expr, *expressions = expression.expressions 848 sep = sep_expr.name 849 concat_type = exp.ConcatWs 850 args = {} 851 else: 852 expressions = expression.expressions 853 sep = "" 854 concat_type = exp.Concat 855 args = { 856 "safe": expression.args.get("safe"), 857 "coalesce": expression.args.get("coalesce"), 858 } 859 860 new_args = [] 861 for is_string_group, group in itertools.groupby( 862 expressions or expression.flatten(), lambda e: e.is_string 863 ): 864 if is_string_group: 865 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 866 else: 867 new_args.extend(group) 868 869 if len(new_args) == 1 and new_args[0].is_string: 870 return new_args[0] 871 872 if concat_type is exp.ConcatWs: 873 new_args = [sep_expr] + new_args 874 elif isinstance(expression, exp.DPipe): 875 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 876 877 return concat_type(expressions=new_args, **args) 878 879 880def simplify_conditionals(expression): 881 """Simplifies expressions like IF, CASE if their condition is statically known.""" 882 if isinstance(expression, exp.Case): 883 this = expression.this 884 for case in expression.args["ifs"]: 885 cond = case.this 886 if this: 887 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 888 cond = cond.replace(this.pop().eq(cond)) 889 890 if always_true(cond): 891 return case.args["true"] 892 893 if always_false(cond): 894 case.pop() 895 if not expression.args["ifs"]: 896 return expression.args.get("default") or exp.null() 897 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 898 if always_true(expression.this): 899 return expression.args["true"] 900 if always_false(expression.this): 901 return expression.args.get("false") or exp.null() 902 903 return expression 904 905 906def simplify_startswith(expression: exp.Expression) -> exp.Expression: 907 """ 908 Reduces a prefix check to either TRUE or FALSE if both the string and the 909 prefix are statically known. 910 911 Example: 912 >>> from sqlglot import parse_one 913 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 914 'TRUE' 915 """ 916 if ( 917 isinstance(expression, exp.StartsWith) 918 and expression.this.is_string 919 and expression.expression.is_string 920 ): 921 return exp.convert(expression.name.startswith(expression.expression.name)) 922 923 return expression 924 925 926DateRange = t.Tuple[datetime.date, datetime.date] 927 928 929def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 930 """ 931 Get the date range for a DATE_TRUNC equality comparison: 932 933 Example: 934 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 935 Returns: 936 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 937 """ 938 floor = date_floor(date, unit, dialect) 939 940 if date != floor: 941 # This will always be False, except for NULL values. 942 return None 943 944 return floor, floor + interval(unit) 945 946 947def _datetrunc_eq_expression( 948 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 949) -> exp.Expression: 950 """Get the logical expression for a date range""" 951 return exp.and_( 952 left >= date_literal(drange[0], target_type), 953 left < date_literal(drange[1], target_type), 954 copy=False, 955 ) 956 957 958def _datetrunc_eq( 959 left: exp.Expression, 960 date: datetime.date, 961 unit: str, 962 dialect: Dialect, 963 target_type: t.Optional[exp.DataType], 964) -> t.Optional[exp.Expression]: 965 drange = _datetrunc_range(date, unit, dialect) 966 if not drange: 967 return None 968 969 return _datetrunc_eq_expression(left, drange, target_type) 970 971 972def _datetrunc_neq( 973 left: exp.Expression, 974 date: datetime.date, 975 unit: str, 976 dialect: Dialect, 977 target_type: t.Optional[exp.DataType], 978) -> t.Optional[exp.Expression]: 979 drange = _datetrunc_range(date, unit, dialect) 980 if not drange: 981 return None 982 983 return exp.and_( 984 left < date_literal(drange[0], target_type), 985 left >= date_literal(drange[1], target_type), 986 copy=False, 987 ) 988 989 990DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 991 exp.LT: lambda l, dt, u, d, t: l 992 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 993 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 994 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 995 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 996 exp.EQ: _datetrunc_eq, 997 exp.NEQ: _datetrunc_neq, 998} 999DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 1000DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 1001 1002 1003def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 1004 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 1005 1006 1007@catch(ModuleNotFoundError, UnsupportedUnit) 1008def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1009 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1010 comparison = expression.__class__ 1011 1012 if isinstance(expression, DATETRUNCS): 1013 this = expression.this 1014 trunc_type = extract_type(this) 1015 date = extract_date(this) 1016 if date and expression.unit: 1017 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1018 elif comparison not in DATETRUNC_COMPARISONS: 1019 return expression 1020 1021 if isinstance(expression, exp.Binary): 1022 l, r = expression.left, expression.right 1023 1024 if not _is_datetrunc_predicate(l, r): 1025 return expression 1026 1027 l = t.cast(exp.DateTrunc, l) 1028 trunc_arg = l.this 1029 unit = l.unit.name.lower() 1030 date = extract_date(r) 1031 1032 if not date: 1033 return expression 1034 1035 return ( 1036 DATETRUNC_BINARY_COMPARISONS[comparison]( 1037 trunc_arg, date, unit, dialect, extract_type(trunc_arg, r) 1038 ) 1039 or expression 1040 ) 1041 1042 if isinstance(expression, exp.In): 1043 l = expression.this 1044 rs = expression.expressions 1045 1046 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1047 l = t.cast(exp.DateTrunc, l) 1048 unit = l.unit.name.lower() 1049 1050 ranges = [] 1051 for r in rs: 1052 date = extract_date(r) 1053 if not date: 1054 return expression 1055 drange = _datetrunc_range(date, unit, dialect) 1056 if drange: 1057 ranges.append(drange) 1058 1059 if not ranges: 1060 return expression 1061 1062 ranges = merge_ranges(ranges) 1063 target_type = extract_type(l, *rs) 1064 1065 return exp.or_( 1066 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1067 ) 1068 1069 return expression 1070 1071 1072def sort_comparison(expression: exp.Expression) -> exp.Expression: 1073 if expression.__class__ in COMPLEMENT_COMPARISONS: 1074 l, r = expression.this, expression.expression 1075 l_column = isinstance(l, exp.Column) 1076 r_column = isinstance(r, exp.Column) 1077 l_const = _is_constant(l) 1078 r_const = _is_constant(r) 1079 1080 if (l_column and not r_column) or (r_const and not l_const): 1081 return expression 1082 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1083 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1084 this=r, expression=l 1085 ) 1086 return expression 1087 1088 1089# CROSS joins result in an empty table if the right table is empty. 1090# So we can only simplify certain types of joins to CROSS. 1091# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1092JOINS = { 1093 ("", ""), 1094 ("", "INNER"), 1095 ("RIGHT", ""), 1096 ("RIGHT", "OUTER"), 1097} 1098 1099 1100def remove_where_true(expression): 1101 for where in expression.find_all(exp.Where): 1102 if always_true(where.this): 1103 where.pop() 1104 for join in expression.find_all(exp.Join): 1105 if ( 1106 always_true(join.args.get("on")) 1107 and not join.args.get("using") 1108 and not join.args.get("method") 1109 and (join.side, join.kind) in JOINS 1110 ): 1111 join.args["on"].pop() 1112 join.set("side", None) 1113 join.set("kind", "CROSS") 1114 1115 1116def always_true(expression): 1117 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 1118 expression, exp.Literal 1119 ) 1120 1121 1122def always_false(expression): 1123 return is_false(expression) or is_null(expression) 1124 1125 1126def is_complement(a, b): 1127 return isinstance(b, exp.Not) and b.this == a 1128 1129 1130def is_false(a: exp.Expression) -> bool: 1131 return type(a) is exp.Boolean and not a.this 1132 1133 1134def is_null(a: exp.Expression) -> bool: 1135 return type(a) is exp.Null 1136 1137 1138def eval_boolean(expression, a, b): 1139 if isinstance(expression, (exp.EQ, exp.Is)): 1140 return boolean_literal(a == b) 1141 if isinstance(expression, exp.NEQ): 1142 return boolean_literal(a != b) 1143 if isinstance(expression, exp.GT): 1144 return boolean_literal(a > b) 1145 if isinstance(expression, exp.GTE): 1146 return boolean_literal(a >= b) 1147 if isinstance(expression, exp.LT): 1148 return boolean_literal(a < b) 1149 if isinstance(expression, exp.LTE): 1150 return boolean_literal(a <= b) 1151 return None 1152 1153 1154def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1155 if isinstance(value, datetime.datetime): 1156 return value.date() 1157 if isinstance(value, datetime.date): 1158 return value 1159 try: 1160 return datetime.datetime.fromisoformat(value).date() 1161 except ValueError: 1162 return None 1163 1164 1165def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1166 if isinstance(value, datetime.datetime): 1167 return value 1168 if isinstance(value, datetime.date): 1169 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1170 try: 1171 return datetime.datetime.fromisoformat(value) 1172 except ValueError: 1173 return None 1174 1175 1176def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1177 if not value: 1178 return None 1179 if to.is_type(exp.DataType.Type.DATE): 1180 return cast_as_date(value) 1181 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1182 return cast_as_datetime(value) 1183 return None 1184 1185 1186def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1187 if isinstance(cast, exp.Cast): 1188 to = cast.to 1189 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1190 to = exp.DataType.build(exp.DataType.Type.DATE) 1191 else: 1192 return None 1193 1194 if isinstance(cast.this, exp.Literal): 1195 value: t.Any = cast.this.name 1196 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1197 value = extract_date(cast.this) 1198 else: 1199 return None 1200 return cast_value(value, to) 1201 1202 1203def _is_date_literal(expression: exp.Expression) -> bool: 1204 return extract_date(expression) is not None 1205 1206 1207def extract_interval(expression): 1208 try: 1209 n = int(expression.name) 1210 unit = expression.text("unit").lower() 1211 return interval(unit, n) 1212 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1213 return None 1214 1215 1216def extract_type(*expressions): 1217 target_type = None 1218 for expression in expressions: 1219 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1220 if target_type: 1221 break 1222 1223 return target_type 1224 1225 1226def date_literal(date, target_type=None): 1227 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1228 target_type = ( 1229 exp.DataType.Type.DATETIME 1230 if isinstance(date, datetime.datetime) 1231 else exp.DataType.Type.DATE 1232 ) 1233 1234 return exp.cast(exp.Literal.string(date), target_type) 1235 1236 1237def interval(unit: str, n: int = 1): 1238 from dateutil.relativedelta import relativedelta 1239 1240 if unit == "year": 1241 return relativedelta(years=1 * n) 1242 if unit == "quarter": 1243 return relativedelta(months=3 * n) 1244 if unit == "month": 1245 return relativedelta(months=1 * n) 1246 if unit == "week": 1247 return relativedelta(weeks=1 * n) 1248 if unit == "day": 1249 return relativedelta(days=1 * n) 1250 if unit == "hour": 1251 return relativedelta(hours=1 * n) 1252 if unit == "minute": 1253 return relativedelta(minutes=1 * n) 1254 if unit == "second": 1255 return relativedelta(seconds=1 * n) 1256 1257 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1258 1259 1260def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1261 if unit == "year": 1262 return d.replace(month=1, day=1) 1263 if unit == "quarter": 1264 if d.month <= 3: 1265 return d.replace(month=1, day=1) 1266 elif d.month <= 6: 1267 return d.replace(month=4, day=1) 1268 elif d.month <= 9: 1269 return d.replace(month=7, day=1) 1270 else: 1271 return d.replace(month=10, day=1) 1272 if unit == "month": 1273 return d.replace(month=d.month, day=1) 1274 if unit == "week": 1275 # Assuming week starts on Monday (0) and ends on Sunday (6) 1276 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1277 if unit == "day": 1278 return d 1279 1280 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1281 1282 1283def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1284 floor = date_floor(d, unit, dialect) 1285 1286 if floor == d: 1287 return d 1288 1289 return floor + interval(unit) 1290 1291 1292def boolean_literal(condition): 1293 return exp.true() if condition else exp.false() 1294 1295 1296def _flat_simplify(expression, simplifier, root=True): 1297 if root or not expression.same_parent: 1298 operands = [] 1299 queue = deque(expression.flatten(unnest=False)) 1300 size = len(queue) 1301 1302 while queue: 1303 a = queue.popleft() 1304 1305 for b in queue: 1306 result = simplifier(expression, a, b) 1307 1308 if result and result is not expression: 1309 queue.remove(b) 1310 queue.appendleft(result) 1311 break 1312 else: 1313 operands.append(a) 1314 1315 if len(operands) < size: 1316 return functools.reduce( 1317 lambda a, b: expression.__class__(this=a, expression=b), operands 1318 ) 1319 return expression 1320 1321 1322def gen(expression: t.Any) -> str: 1323 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1324 1325 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1326 generator is expensive so we have a bare minimum sql generator here. 1327 """ 1328 return Gen().gen(expression) 1329 1330 1331class Gen: 1332 def __init__(self): 1333 self.stack = [] 1334 self.sqls = [] 1335 1336 def gen(self, expression: exp.Expression) -> str: 1337 self.stack = [expression] 1338 self.sqls.clear() 1339 1340 while self.stack: 1341 node = self.stack.pop() 1342 1343 if isinstance(node, exp.Expression): 1344 exp_handler_name = f"{node.key}_sql" 1345 1346 if hasattr(self, exp_handler_name): 1347 getattr(self, exp_handler_name)(node) 1348 elif isinstance(node, exp.Func): 1349 self._function(node) 1350 else: 1351 key = node.key.upper() 1352 self.stack.append(f"{key} " if self._args(node) else key) 1353 elif type(node) is list: 1354 for n in reversed(node): 1355 if n is not None: 1356 self.stack.extend((n, ",")) 1357 if node: 1358 self.stack.pop() 1359 else: 1360 if node is not None: 1361 self.sqls.append(str(node)) 1362 1363 return "".join(self.sqls) 1364 1365 def add_sql(self, e: exp.Add) -> None: 1366 self._binary(e, " + ") 1367 1368 def alias_sql(self, e: exp.Alias) -> None: 1369 self.stack.extend( 1370 ( 1371 e.args.get("alias"), 1372 " AS ", 1373 e.args.get("this"), 1374 ) 1375 ) 1376 1377 def and_sql(self, e: exp.And) -> None: 1378 self._binary(e, " AND ") 1379 1380 def anonymous_sql(self, e: exp.Anonymous) -> None: 1381 this = e.this 1382 if isinstance(this, str): 1383 name = this.upper() 1384 elif isinstance(this, exp.Identifier): 1385 name = this.this 1386 name = f'"{name}"' if this.quoted else name.upper() 1387 else: 1388 raise ValueError( 1389 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1390 ) 1391 1392 self.stack.extend( 1393 ( 1394 ")", 1395 e.expressions, 1396 "(", 1397 name, 1398 ) 1399 ) 1400 1401 def between_sql(self, e: exp.Between) -> None: 1402 self.stack.extend( 1403 ( 1404 e.args.get("high"), 1405 " AND ", 1406 e.args.get("low"), 1407 " BETWEEN ", 1408 e.this, 1409 ) 1410 ) 1411 1412 def boolean_sql(self, e: exp.Boolean) -> None: 1413 self.stack.append("TRUE" if e.this else "FALSE") 1414 1415 def bracket_sql(self, e: exp.Bracket) -> None: 1416 self.stack.extend( 1417 ( 1418 "]", 1419 e.expressions, 1420 "[", 1421 e.this, 1422 ) 1423 ) 1424 1425 def column_sql(self, e: exp.Column) -> None: 1426 for p in reversed(e.parts): 1427 self.stack.extend((p, ".")) 1428 self.stack.pop() 1429 1430 def datatype_sql(self, e: exp.DataType) -> None: 1431 self._args(e, 1) 1432 self.stack.append(f"{e.this.name} ") 1433 1434 def div_sql(self, e: exp.Div) -> None: 1435 self._binary(e, " / ") 1436 1437 def dot_sql(self, e: exp.Dot) -> None: 1438 self._binary(e, ".") 1439 1440 def eq_sql(self, e: exp.EQ) -> None: 1441 self._binary(e, " = ") 1442 1443 def from_sql(self, e: exp.From) -> None: 1444 self.stack.extend((e.this, "FROM ")) 1445 1446 def gt_sql(self, e: exp.GT) -> None: 1447 self._binary(e, " > ") 1448 1449 def gte_sql(self, e: exp.GTE) -> None: 1450 self._binary(e, " >= ") 1451 1452 def identifier_sql(self, e: exp.Identifier) -> None: 1453 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1454 1455 def ilike_sql(self, e: exp.ILike) -> None: 1456 self._binary(e, " ILIKE ") 1457 1458 def in_sql(self, e: exp.In) -> None: 1459 self.stack.append(")") 1460 self._args(e, 1) 1461 self.stack.extend( 1462 ( 1463 "(", 1464 " IN ", 1465 e.this, 1466 ) 1467 ) 1468 1469 def intdiv_sql(self, e: exp.IntDiv) -> None: 1470 self._binary(e, " DIV ") 1471 1472 def is_sql(self, e: exp.Is) -> None: 1473 self._binary(e, " IS ") 1474 1475 def like_sql(self, e: exp.Like) -> None: 1476 self._binary(e, " Like ") 1477 1478 def literal_sql(self, e: exp.Literal) -> None: 1479 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1480 1481 def lt_sql(self, e: exp.LT) -> None: 1482 self._binary(e, " < ") 1483 1484 def lte_sql(self, e: exp.LTE) -> None: 1485 self._binary(e, " <= ") 1486 1487 def mod_sql(self, e: exp.Mod) -> None: 1488 self._binary(e, " % ") 1489 1490 def mul_sql(self, e: exp.Mul) -> None: 1491 self._binary(e, " * ") 1492 1493 def neg_sql(self, e: exp.Neg) -> None: 1494 self._unary(e, "-") 1495 1496 def neq_sql(self, e: exp.NEQ) -> None: 1497 self._binary(e, " <> ") 1498 1499 def not_sql(self, e: exp.Not) -> None: 1500 self._unary(e, "NOT ") 1501 1502 def null_sql(self, e: exp.Null) -> None: 1503 self.stack.append("NULL") 1504 1505 def or_sql(self, e: exp.Or) -> None: 1506 self._binary(e, " OR ") 1507 1508 def paren_sql(self, e: exp.Paren) -> None: 1509 self.stack.extend( 1510 ( 1511 ")", 1512 e.this, 1513 "(", 1514 ) 1515 ) 1516 1517 def sub_sql(self, e: exp.Sub) -> None: 1518 self._binary(e, " - ") 1519 1520 def subquery_sql(self, e: exp.Subquery) -> None: 1521 self._args(e, 2) 1522 alias = e.args.get("alias") 1523 if alias: 1524 self.stack.append(alias) 1525 self.stack.extend((")", e.this, "(")) 1526 1527 def table_sql(self, e: exp.Table) -> None: 1528 self._args(e, 4) 1529 alias = e.args.get("alias") 1530 if alias: 1531 self.stack.append(alias) 1532 for p in reversed(e.parts): 1533 self.stack.extend((p, ".")) 1534 self.stack.pop() 1535 1536 def tablealias_sql(self, e: exp.TableAlias) -> None: 1537 columns = e.columns 1538 1539 if columns: 1540 self.stack.extend((")", columns, "(")) 1541 1542 self.stack.extend((e.this, " AS ")) 1543 1544 def var_sql(self, e: exp.Var) -> None: 1545 self.stack.append(e.this) 1546 1547 def _binary(self, e: exp.Binary, op: str) -> None: 1548 self.stack.extend((e.expression, op, e.this)) 1549 1550 def _unary(self, e: exp.Unary, op: str) -> None: 1551 self.stack.extend((e.this, op)) 1552 1553 def _function(self, e: exp.Func) -> None: 1554 self.stack.extend( 1555 ( 1556 ")", 1557 list(e.args.values()), 1558 "(", 1559 e.sql_name(), 1560 ) 1561 ) 1562 1563 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1564 kvs = [] 1565 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1566 1567 for k in arg_types or arg_types: 1568 v = node.args.get(k) 1569 1570 if v is not None: 1571 kvs.append([f":{k}", v]) 1572 if kvs: 1573 self.stack.append(kvs) 1574 return True 1575 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
41def simplify( 42 expression: exp.Expression, 43 constant_propagation: bool = False, 44 dialect: DialectType = None, 45 max_depth: t.Optional[int] = None, 46): 47 """ 48 Rewrite sqlglot AST to simplify expressions. 49 50 Example: 51 >>> import sqlglot 52 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 53 >>> simplify(expression).sql() 54 'TRUE' 55 56 Args: 57 expression: expression to simplify 58 constant_propagation: whether the constant propagation rule should be used 59 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 60 Returns: 61 sqlglot.Expression: simplified expression 62 """ 63 64 dialect = Dialect.get_or_raise(dialect) 65 66 def _simplify(expression, root=True): 67 if ( 68 max_depth 69 and isinstance(expression, exp.Connector) 70 and not isinstance(expression.parent, exp.Connector) 71 ): 72 depth = connector_depth(expression) 73 if depth > max_depth: 74 logger.info( 75 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 76 ) 77 return expression 78 79 if expression.meta.get(FINAL): 80 return expression 81 82 # group by expressions cannot be simplified, for example 83 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 84 # the projection must exactly match the group by key 85 group = expression.args.get("group") 86 87 if group and hasattr(expression, "selects"): 88 groups = set(group.expressions) 89 group.meta[FINAL] = True 90 91 for e in expression.selects: 92 for node in e.walk(): 93 if node in groups: 94 e.meta[FINAL] = True 95 break 96 97 having = expression.args.get("having") 98 if having: 99 for node in having.walk(): 100 if node in groups: 101 having.meta[FINAL] = True 102 break 103 104 # Pre-order transformations 105 node = expression 106 node = rewrite_between(node) 107 node = uniq_sort(node, root) 108 node = absorb_and_eliminate(node, root) 109 node = simplify_concat(node) 110 node = simplify_conditionals(node) 111 112 if constant_propagation: 113 node = propagate_constants(node, root) 114 115 exp.replace_children(node, lambda e: _simplify(e, False)) 116 117 # Post-order transformations 118 node = simplify_not(node) 119 node = flatten(node) 120 node = simplify_connectors(node, root) 121 node = remove_complements(node, root) 122 node = simplify_coalesce(node) 123 node.parent = expression.parent 124 node = simplify_literals(node, root) 125 node = simplify_equality(node) 126 node = simplify_parens(node) 127 node = simplify_datetrunc(node, dialect) 128 node = sort_comparison(node) 129 node = simplify_startswith(node) 130 131 if root: 132 expression.replace(node) 133 return node 134 135 expression = while_changing(expression, _simplify) 136 remove_where_true(expression) 137 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- max_depth: Chains of Connectors (AND, OR, etc) exceeding
max_depth
will be skipped
Returns:
sqlglot.Expression: simplified expression
140def connector_depth(expression: exp.Expression) -> int: 141 """ 142 Determine the maximum depth of a tree of Connectors. 143 144 For example: 145 >>> from sqlglot import parse_one 146 >>> connector_depth(parse_one("a AND b AND c AND d")) 147 3 148 """ 149 stack = deque([(expression, 0)]) 150 max_depth = 0 151 152 while stack: 153 expression, depth = stack.pop() 154 155 if not isinstance(expression, exp.Connector): 156 continue 157 158 depth += 1 159 max_depth = max(depth, max_depth) 160 161 stack.append((expression.left, depth)) 162 stack.append((expression.right, depth)) 163 164 return max_depth
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one >>> connector_depth(parse_one("a AND b AND c AND d")) 3
167def catch(*exceptions): 168 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 169 170 def decorator(func): 171 def wrapped(expression, *args, **kwargs): 172 try: 173 return func(expression, *args, **kwargs) 174 except exceptions: 175 return expression 176 177 return wrapped 178 179 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
182def rewrite_between(expression: exp.Expression) -> exp.Expression: 183 """Rewrite x between y and z to x >= y AND x <= z. 184 185 This is done because comparison simplification is only done on lt/lte/gt/gte. 186 """ 187 if isinstance(expression, exp.Between): 188 negate = isinstance(expression.parent, exp.Not) 189 190 expression = exp.and_( 191 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 192 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 193 copy=False, 194 ) 195 196 if negate: 197 expression = exp.paren(expression, copy=False) 198 199 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
212def simplify_not(expression): 213 """ 214 Demorgan's Law 215 NOT (x OR y) -> NOT x AND NOT y 216 NOT (x AND y) -> NOT x OR NOT y 217 """ 218 if isinstance(expression, exp.Not): 219 this = expression.this 220 if is_null(this): 221 return exp.null() 222 if this.__class__ in COMPLEMENT_COMPARISONS: 223 return COMPLEMENT_COMPARISONS[this.__class__]( 224 this=this.this, expression=this.expression 225 ) 226 if isinstance(this, exp.Paren): 227 condition = this.unnest() 228 if isinstance(condition, exp.And): 229 return exp.paren( 230 exp.or_( 231 exp.not_(condition.left, copy=False), 232 exp.not_(condition.right, copy=False), 233 copy=False, 234 ) 235 ) 236 if isinstance(condition, exp.Or): 237 return exp.paren( 238 exp.and_( 239 exp.not_(condition.left, copy=False), 240 exp.not_(condition.right, copy=False), 241 copy=False, 242 ) 243 ) 244 if is_null(condition): 245 return exp.null() 246 if always_true(this): 247 return exp.false() 248 if is_false(this): 249 return exp.true() 250 if isinstance(this, exp.Not): 251 # double negation 252 # NOT NOT x -> x 253 return this.this 254 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
257def flatten(expression): 258 """ 259 A AND (B AND C) -> A AND B AND C 260 A OR (B OR C) -> A OR B OR C 261 """ 262 if isinstance(expression, exp.Connector): 263 for node in expression.args.values(): 264 child = node.unnest() 265 if isinstance(child, expression.__class__): 266 node.replace(child) 267 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
270def simplify_connectors(expression, root=True): 271 def _simplify_connectors(expression, left, right): 272 if left == right: 273 if isinstance(expression, exp.Xor): 274 return exp.false() 275 return left 276 if isinstance(expression, exp.And): 277 if is_false(left) or is_false(right): 278 return exp.false() 279 if is_null(left) or is_null(right): 280 return exp.null() 281 if always_true(left) and always_true(right): 282 return exp.true() 283 if always_true(left): 284 return right 285 if always_true(right): 286 return left 287 return _simplify_comparison(expression, left, right) 288 elif isinstance(expression, exp.Or): 289 if always_true(left) or always_true(right): 290 return exp.true() 291 if is_false(left) and is_false(right): 292 return exp.false() 293 if ( 294 (is_null(left) and is_null(right)) 295 or (is_null(left) and is_false(right)) 296 or (is_false(left) and is_null(right)) 297 ): 298 return exp.null() 299 if is_false(left): 300 return right 301 if is_false(right): 302 return left 303 return _simplify_comparison(expression, left, right, or_=True) 304 305 if isinstance(expression, exp.Connector): 306 return _flat_simplify(expression, _simplify_connectors, root) 307 return expression
394def remove_complements(expression, root=True): 395 """ 396 Removing complements. 397 398 A AND NOT A -> FALSE 399 A OR NOT A -> TRUE 400 """ 401 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 402 ops = set(expression.flatten()) 403 for op in ops: 404 if isinstance(op, exp.Not) and op.this in ops: 405 return exp.false() if isinstance(expression, exp.And) else exp.true() 406 407 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
410def uniq_sort(expression, root=True): 411 """ 412 Uniq and sort a connector. 413 414 C AND A AND B AND B -> A AND B AND C 415 """ 416 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 417 flattened = tuple(expression.flatten()) 418 419 if isinstance(expression, exp.Xor): 420 result_func = exp.xor 421 # Do not deduplicate XOR as A XOR A != A if A == True 422 deduped = None 423 arr = tuple((gen(e), e) for e in flattened) 424 else: 425 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 426 deduped = {gen(e): e for e in flattened} 427 arr = tuple(deduped.items()) 428 429 # check if the operands are already sorted, if not sort them 430 # A AND C AND B -> A AND B AND C 431 for i, (sql, e) in enumerate(arr[1:]): 432 if sql < arr[i][0]: 433 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 434 break 435 else: 436 # we didn't have to sort but maybe we need to dedup 437 if deduped and len(deduped) < len(flattened): 438 expression = result_func(*deduped.values(), copy=False) 439 440 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
443def absorb_and_eliminate(expression, root=True): 444 """ 445 absorption: 446 A AND (A OR B) -> A 447 A OR (A AND B) -> A 448 A AND (NOT A OR B) -> A AND B 449 A OR (NOT A AND B) -> A OR B 450 elimination: 451 (A AND B) OR (A AND NOT B) -> A 452 (A OR B) AND (A OR NOT B) -> A 453 """ 454 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 455 kind = exp.Or if isinstance(expression, exp.And) else exp.And 456 457 ops = tuple(expression.flatten()) 458 459 # Initialize lookup tables: 460 # Set of all operands, used to find complements for absorption. 461 op_set = set() 462 # Sub-operands, used to find subsets for absorption. 463 subops = defaultdict(list) 464 # Pairs of complements, used for elimination. 465 pairs = defaultdict(list) 466 467 # Populate the lookup tables 468 for op in ops: 469 op_set.add(op) 470 471 if not isinstance(op, kind): 472 # In cases like: A OR (A AND B) 473 # Subop will be: ^ 474 subops[op].append({op}) 475 continue 476 477 # In cases like: (A AND B) OR (A AND B AND C) 478 # Subops will be: ^ ^ 479 subset = set(op.flatten()) 480 for i in subset: 481 subops[i].append(subset) 482 483 a, b = op.unnest_operands() 484 if isinstance(a, exp.Not): 485 pairs[frozenset((a.this, b))].append((op, b)) 486 if isinstance(b, exp.Not): 487 pairs[frozenset((a, b.this))].append((op, a)) 488 489 for op in ops: 490 if not isinstance(op, kind): 491 continue 492 493 a, b = op.unnest_operands() 494 495 # Absorb 496 if isinstance(a, exp.Not) and a.this in op_set: 497 a.replace(exp.true() if kind == exp.And else exp.false()) 498 continue 499 if isinstance(b, exp.Not) and b.this in op_set: 500 b.replace(exp.true() if kind == exp.And else exp.false()) 501 continue 502 superset = set(op.flatten()) 503 if any(any(subset < superset for subset in subops[i]) for i in superset): 504 op.replace(exp.false() if kind == exp.And else exp.true()) 505 continue 506 507 # Eliminate 508 for other, complement in pairs[frozenset((a, b))]: 509 op.replace(complement) 510 other.replace(complement) 511 512 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
515def propagate_constants(expression, root=True): 516 """ 517 Propagate constants for conjunctions in DNF: 518 519 SELECT * FROM t WHERE a = b AND b = 5 becomes 520 SELECT * FROM t WHERE a = 5 AND b = 5 521 522 Reference: https://www.sqlite.org/optoverview.html 523 """ 524 525 if ( 526 isinstance(expression, exp.And) 527 and (root or not expression.same_parent) 528 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 529 ): 530 constant_mapping = {} 531 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 532 if isinstance(expr, exp.EQ): 533 l, r = expr.left, expr.right 534 535 # TODO: create a helper that can be used to detect nested literal expressions such 536 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 537 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 538 constant_mapping[l] = (id(l), r) 539 540 if constant_mapping: 541 for column in find_all_in_scope(expression, exp.Column): 542 parent = column.parent 543 column_id, constant = constant_mapping.get(column) or (None, None) 544 if ( 545 column_id is not None 546 and id(column) != column_id 547 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 548 ): 549 column.replace(constant.copy()) 550 551 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
171 def wrapped(expression, *args, **kwargs): 172 try: 173 return func(expression, *args, **kwargs) 174 except exceptions: 175 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
626def simplify_literals(expression, root=True): 627 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 628 return _flat_simplify(expression, _simplify_binary, root) 629 630 if isinstance(expression, exp.Neg): 631 this = expression.this 632 if this.is_number: 633 value = this.name 634 if value[0] == "-": 635 return exp.Literal.number(value[1:]) 636 return exp.Literal.number(f"-{value}") 637 638 if type(expression) in INVERSE_DATE_OPS: 639 return _simplify_binary(expression, expression.this, expression.interval()) or expression 640 641 return expression
742def simplify_parens(expression): 743 if not isinstance(expression, exp.Paren): 744 return expression 745 746 this = expression.this 747 parent = expression.parent 748 parent_is_predicate = isinstance(parent, exp.Predicate) 749 750 if ( 751 not isinstance(this, exp.Select) 752 and not isinstance(parent, exp.SubqueryPredicate) 753 and ( 754 not isinstance(parent, (exp.Condition, exp.Binary)) 755 or isinstance(parent, exp.Paren) 756 or ( 757 not isinstance(this, exp.Binary) 758 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 759 ) 760 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 761 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 762 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 763 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 764 ) 765 ): 766 return this 767 return expression
778def simplify_coalesce(expression): 779 # COALESCE(x) -> x 780 if ( 781 isinstance(expression, exp.Coalesce) 782 and (not expression.expressions or _is_nonnull_constant(expression.this)) 783 # COALESCE is also used as a Spark partitioning hint 784 and not isinstance(expression.parent, exp.Hint) 785 ): 786 return expression.this 787 788 if not isinstance(expression, COMPARISONS): 789 return expression 790 791 if isinstance(expression.left, exp.Coalesce): 792 coalesce = expression.left 793 other = expression.right 794 elif isinstance(expression.right, exp.Coalesce): 795 coalesce = expression.right 796 other = expression.left 797 else: 798 return expression 799 800 # This transformation is valid for non-constants, 801 # but it really only does anything if they are both constants. 802 if not _is_constant(other): 803 return expression 804 805 # Find the first constant arg 806 for arg_index, arg in enumerate(coalesce.expressions): 807 if _is_constant(arg): 808 break 809 else: 810 return expression 811 812 coalesce.set("expressions", coalesce.expressions[:arg_index]) 813 814 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 815 # since we already remove COALESCE at the top of this function. 816 coalesce = coalesce if coalesce.expressions else coalesce.this 817 818 # This expression is more complex than when we started, but it will get simplified further 819 return exp.paren( 820 exp.or_( 821 exp.and_( 822 coalesce.is_(exp.null()).not_(copy=False), 823 expression.copy(), 824 copy=False, 825 ), 826 exp.and_( 827 coalesce.is_(exp.null()), 828 type(expression)(this=arg.copy(), expression=other.copy()), 829 copy=False, 830 ), 831 copy=False, 832 ) 833 )
839def simplify_concat(expression): 840 """Reduces all groups that contain string literals by concatenating them.""" 841 if not isinstance(expression, CONCATS) or ( 842 # We can't reduce a CONCAT_WS call if we don't statically know the separator 843 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 844 ): 845 return expression 846 847 if isinstance(expression, exp.ConcatWs): 848 sep_expr, *expressions = expression.expressions 849 sep = sep_expr.name 850 concat_type = exp.ConcatWs 851 args = {} 852 else: 853 expressions = expression.expressions 854 sep = "" 855 concat_type = exp.Concat 856 args = { 857 "safe": expression.args.get("safe"), 858 "coalesce": expression.args.get("coalesce"), 859 } 860 861 new_args = [] 862 for is_string_group, group in itertools.groupby( 863 expressions or expression.flatten(), lambda e: e.is_string 864 ): 865 if is_string_group: 866 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 867 else: 868 new_args.extend(group) 869 870 if len(new_args) == 1 and new_args[0].is_string: 871 return new_args[0] 872 873 if concat_type is exp.ConcatWs: 874 new_args = [sep_expr] + new_args 875 elif isinstance(expression, exp.DPipe): 876 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 877 878 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
881def simplify_conditionals(expression): 882 """Simplifies expressions like IF, CASE if their condition is statically known.""" 883 if isinstance(expression, exp.Case): 884 this = expression.this 885 for case in expression.args["ifs"]: 886 cond = case.this 887 if this: 888 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 889 cond = cond.replace(this.pop().eq(cond)) 890 891 if always_true(cond): 892 return case.args["true"] 893 894 if always_false(cond): 895 case.pop() 896 if not expression.args["ifs"]: 897 return expression.args.get("default") or exp.null() 898 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 899 if always_true(expression.this): 900 return expression.args["true"] 901 if always_false(expression.this): 902 return expression.args.get("false") or exp.null() 903 904 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
907def simplify_startswith(expression: exp.Expression) -> exp.Expression: 908 """ 909 Reduces a prefix check to either TRUE or FALSE if both the string and the 910 prefix are statically known. 911 912 Example: 913 >>> from sqlglot import parse_one 914 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 915 'TRUE' 916 """ 917 if ( 918 isinstance(expression, exp.StartsWith) 919 and expression.this.is_string 920 and expression.expression.is_string 921 ): 922 return exp.convert(expression.name.startswith(expression.expression.name)) 923 924 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
171 def wrapped(expression, *args, **kwargs): 172 try: 173 return func(expression, *args, **kwargs) 174 except exceptions: 175 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1073def sort_comparison(expression: exp.Expression) -> exp.Expression: 1074 if expression.__class__ in COMPLEMENT_COMPARISONS: 1075 l, r = expression.this, expression.expression 1076 l_column = isinstance(l, exp.Column) 1077 r_column = isinstance(r, exp.Column) 1078 l_const = _is_constant(l) 1079 r_const = _is_constant(r) 1080 1081 if (l_column and not r_column) or (r_const and not l_const): 1082 return expression 1083 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1084 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1085 this=r, expression=l 1086 ) 1087 return expression
1101def remove_where_true(expression): 1102 for where in expression.find_all(exp.Where): 1103 if always_true(where.this): 1104 where.pop() 1105 for join in expression.find_all(exp.Join): 1106 if ( 1107 always_true(join.args.get("on")) 1108 and not join.args.get("using") 1109 and not join.args.get("method") 1110 and (join.side, join.kind) in JOINS 1111 ): 1112 join.args["on"].pop() 1113 join.set("side", None) 1114 join.set("kind", "CROSS")
1139def eval_boolean(expression, a, b): 1140 if isinstance(expression, (exp.EQ, exp.Is)): 1141 return boolean_literal(a == b) 1142 if isinstance(expression, exp.NEQ): 1143 return boolean_literal(a != b) 1144 if isinstance(expression, exp.GT): 1145 return boolean_literal(a > b) 1146 if isinstance(expression, exp.GTE): 1147 return boolean_literal(a >= b) 1148 if isinstance(expression, exp.LT): 1149 return boolean_literal(a < b) 1150 if isinstance(expression, exp.LTE): 1151 return boolean_literal(a <= b) 1152 return None
1155def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1156 if isinstance(value, datetime.datetime): 1157 return value.date() 1158 if isinstance(value, datetime.date): 1159 return value 1160 try: 1161 return datetime.datetime.fromisoformat(value).date() 1162 except ValueError: 1163 return None
1166def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1167 if isinstance(value, datetime.datetime): 1168 return value 1169 if isinstance(value, datetime.date): 1170 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1171 try: 1172 return datetime.datetime.fromisoformat(value) 1173 except ValueError: 1174 return None
1177def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1178 if not value: 1179 return None 1180 if to.is_type(exp.DataType.Type.DATE): 1181 return cast_as_date(value) 1182 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1183 return cast_as_datetime(value) 1184 return None
1187def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1188 if isinstance(cast, exp.Cast): 1189 to = cast.to 1190 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1191 to = exp.DataType.build(exp.DataType.Type.DATE) 1192 else: 1193 return None 1194 1195 if isinstance(cast.this, exp.Literal): 1196 value: t.Any = cast.this.name 1197 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1198 value = extract_date(cast.this) 1199 else: 1200 return None 1201 return cast_value(value, to)
1227def date_literal(date, target_type=None): 1228 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1229 target_type = ( 1230 exp.DataType.Type.DATETIME 1231 if isinstance(date, datetime.datetime) 1232 else exp.DataType.Type.DATE 1233 ) 1234 1235 return exp.cast(exp.Literal.string(date), target_type)
1238def interval(unit: str, n: int = 1): 1239 from dateutil.relativedelta import relativedelta 1240 1241 if unit == "year": 1242 return relativedelta(years=1 * n) 1243 if unit == "quarter": 1244 return relativedelta(months=3 * n) 1245 if unit == "month": 1246 return relativedelta(months=1 * n) 1247 if unit == "week": 1248 return relativedelta(weeks=1 * n) 1249 if unit == "day": 1250 return relativedelta(days=1 * n) 1251 if unit == "hour": 1252 return relativedelta(hours=1 * n) 1253 if unit == "minute": 1254 return relativedelta(minutes=1 * n) 1255 if unit == "second": 1256 return relativedelta(seconds=1 * n) 1257 1258 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1261def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1262 if unit == "year": 1263 return d.replace(month=1, day=1) 1264 if unit == "quarter": 1265 if d.month <= 3: 1266 return d.replace(month=1, day=1) 1267 elif d.month <= 6: 1268 return d.replace(month=4, day=1) 1269 elif d.month <= 9: 1270 return d.replace(month=7, day=1) 1271 else: 1272 return d.replace(month=10, day=1) 1273 if unit == "month": 1274 return d.replace(month=d.month, day=1) 1275 if unit == "week": 1276 # Assuming week starts on Monday (0) and ends on Sunday (6) 1277 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1278 if unit == "day": 1279 return d 1280 1281 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1323def gen(expression: t.Any) -> str: 1324 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1325 1326 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1327 generator is expensive so we have a bare minimum sql generator here. 1328 """ 1329 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1332class Gen: 1333 def __init__(self): 1334 self.stack = [] 1335 self.sqls = [] 1336 1337 def gen(self, expression: exp.Expression) -> str: 1338 self.stack = [expression] 1339 self.sqls.clear() 1340 1341 while self.stack: 1342 node = self.stack.pop() 1343 1344 if isinstance(node, exp.Expression): 1345 exp_handler_name = f"{node.key}_sql" 1346 1347 if hasattr(self, exp_handler_name): 1348 getattr(self, exp_handler_name)(node) 1349 elif isinstance(node, exp.Func): 1350 self._function(node) 1351 else: 1352 key = node.key.upper() 1353 self.stack.append(f"{key} " if self._args(node) else key) 1354 elif type(node) is list: 1355 for n in reversed(node): 1356 if n is not None: 1357 self.stack.extend((n, ",")) 1358 if node: 1359 self.stack.pop() 1360 else: 1361 if node is not None: 1362 self.sqls.append(str(node)) 1363 1364 return "".join(self.sqls) 1365 1366 def add_sql(self, e: exp.Add) -> None: 1367 self._binary(e, " + ") 1368 1369 def alias_sql(self, e: exp.Alias) -> None: 1370 self.stack.extend( 1371 ( 1372 e.args.get("alias"), 1373 " AS ", 1374 e.args.get("this"), 1375 ) 1376 ) 1377 1378 def and_sql(self, e: exp.And) -> None: 1379 self._binary(e, " AND ") 1380 1381 def anonymous_sql(self, e: exp.Anonymous) -> None: 1382 this = e.this 1383 if isinstance(this, str): 1384 name = this.upper() 1385 elif isinstance(this, exp.Identifier): 1386 name = this.this 1387 name = f'"{name}"' if this.quoted else name.upper() 1388 else: 1389 raise ValueError( 1390 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1391 ) 1392 1393 self.stack.extend( 1394 ( 1395 ")", 1396 e.expressions, 1397 "(", 1398 name, 1399 ) 1400 ) 1401 1402 def between_sql(self, e: exp.Between) -> None: 1403 self.stack.extend( 1404 ( 1405 e.args.get("high"), 1406 " AND ", 1407 e.args.get("low"), 1408 " BETWEEN ", 1409 e.this, 1410 ) 1411 ) 1412 1413 def boolean_sql(self, e: exp.Boolean) -> None: 1414 self.stack.append("TRUE" if e.this else "FALSE") 1415 1416 def bracket_sql(self, e: exp.Bracket) -> None: 1417 self.stack.extend( 1418 ( 1419 "]", 1420 e.expressions, 1421 "[", 1422 e.this, 1423 ) 1424 ) 1425 1426 def column_sql(self, e: exp.Column) -> None: 1427 for p in reversed(e.parts): 1428 self.stack.extend((p, ".")) 1429 self.stack.pop() 1430 1431 def datatype_sql(self, e: exp.DataType) -> None: 1432 self._args(e, 1) 1433 self.stack.append(f"{e.this.name} ") 1434 1435 def div_sql(self, e: exp.Div) -> None: 1436 self._binary(e, " / ") 1437 1438 def dot_sql(self, e: exp.Dot) -> None: 1439 self._binary(e, ".") 1440 1441 def eq_sql(self, e: exp.EQ) -> None: 1442 self._binary(e, " = ") 1443 1444 def from_sql(self, e: exp.From) -> None: 1445 self.stack.extend((e.this, "FROM ")) 1446 1447 def gt_sql(self, e: exp.GT) -> None: 1448 self._binary(e, " > ") 1449 1450 def gte_sql(self, e: exp.GTE) -> None: 1451 self._binary(e, " >= ") 1452 1453 def identifier_sql(self, e: exp.Identifier) -> None: 1454 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1455 1456 def ilike_sql(self, e: exp.ILike) -> None: 1457 self._binary(e, " ILIKE ") 1458 1459 def in_sql(self, e: exp.In) -> None: 1460 self.stack.append(")") 1461 self._args(e, 1) 1462 self.stack.extend( 1463 ( 1464 "(", 1465 " IN ", 1466 e.this, 1467 ) 1468 ) 1469 1470 def intdiv_sql(self, e: exp.IntDiv) -> None: 1471 self._binary(e, " DIV ") 1472 1473 def is_sql(self, e: exp.Is) -> None: 1474 self._binary(e, " IS ") 1475 1476 def like_sql(self, e: exp.Like) -> None: 1477 self._binary(e, " Like ") 1478 1479 def literal_sql(self, e: exp.Literal) -> None: 1480 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1481 1482 def lt_sql(self, e: exp.LT) -> None: 1483 self._binary(e, " < ") 1484 1485 def lte_sql(self, e: exp.LTE) -> None: 1486 self._binary(e, " <= ") 1487 1488 def mod_sql(self, e: exp.Mod) -> None: 1489 self._binary(e, " % ") 1490 1491 def mul_sql(self, e: exp.Mul) -> None: 1492 self._binary(e, " * ") 1493 1494 def neg_sql(self, e: exp.Neg) -> None: 1495 self._unary(e, "-") 1496 1497 def neq_sql(self, e: exp.NEQ) -> None: 1498 self._binary(e, " <> ") 1499 1500 def not_sql(self, e: exp.Not) -> None: 1501 self._unary(e, "NOT ") 1502 1503 def null_sql(self, e: exp.Null) -> None: 1504 self.stack.append("NULL") 1505 1506 def or_sql(self, e: exp.Or) -> None: 1507 self._binary(e, " OR ") 1508 1509 def paren_sql(self, e: exp.Paren) -> None: 1510 self.stack.extend( 1511 ( 1512 ")", 1513 e.this, 1514 "(", 1515 ) 1516 ) 1517 1518 def sub_sql(self, e: exp.Sub) -> None: 1519 self._binary(e, " - ") 1520 1521 def subquery_sql(self, e: exp.Subquery) -> None: 1522 self._args(e, 2) 1523 alias = e.args.get("alias") 1524 if alias: 1525 self.stack.append(alias) 1526 self.stack.extend((")", e.this, "(")) 1527 1528 def table_sql(self, e: exp.Table) -> None: 1529 self._args(e, 4) 1530 alias = e.args.get("alias") 1531 if alias: 1532 self.stack.append(alias) 1533 for p in reversed(e.parts): 1534 self.stack.extend((p, ".")) 1535 self.stack.pop() 1536 1537 def tablealias_sql(self, e: exp.TableAlias) -> None: 1538 columns = e.columns 1539 1540 if columns: 1541 self.stack.extend((")", columns, "(")) 1542 1543 self.stack.extend((e.this, " AS ")) 1544 1545 def var_sql(self, e: exp.Var) -> None: 1546 self.stack.append(e.this) 1547 1548 def _binary(self, e: exp.Binary, op: str) -> None: 1549 self.stack.extend((e.expression, op, e.this)) 1550 1551 def _unary(self, e: exp.Unary, op: str) -> None: 1552 self.stack.extend((e.this, op)) 1553 1554 def _function(self, e: exp.Func) -> None: 1555 self.stack.extend( 1556 ( 1557 ")", 1558 list(e.args.values()), 1559 "(", 1560 e.sql_name(), 1561 ) 1562 ) 1563 1564 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1565 kvs = [] 1566 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1567 1568 for k in arg_types or arg_types: 1569 v = node.args.get(k) 1570 1571 if v is not None: 1572 kvs.append([f":{k}", v]) 1573 if kvs: 1574 self.stack.append(kvs) 1575 return True 1576 return False
1337 def gen(self, expression: exp.Expression) -> str: 1338 self.stack = [expression] 1339 self.sqls.clear() 1340 1341 while self.stack: 1342 node = self.stack.pop() 1343 1344 if isinstance(node, exp.Expression): 1345 exp_handler_name = f"{node.key}_sql" 1346 1347 if hasattr(self, exp_handler_name): 1348 getattr(self, exp_handler_name)(node) 1349 elif isinstance(node, exp.Func): 1350 self._function(node) 1351 else: 1352 key = node.key.upper() 1353 self.stack.append(f"{key} " if self._args(node) else key) 1354 elif type(node) is list: 1355 for n in reversed(node): 1356 if n is not None: 1357 self.stack.extend((n, ",")) 1358 if node: 1359 self.stack.pop() 1360 else: 1361 if node is not None: 1362 self.sqls.append(str(node)) 1363 1364 return "".join(self.sqls)
1381 def anonymous_sql(self, e: exp.Anonymous) -> None: 1382 this = e.this 1383 if isinstance(this, str): 1384 name = this.upper() 1385 elif isinstance(this, exp.Identifier): 1386 name = this.this 1387 name = f'"{name}"' if this.quoted else name.upper() 1388 else: 1389 raise ValueError( 1390 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1391 ) 1392 1393 self.stack.extend( 1394 ( 1395 ")", 1396 e.expressions, 1397 "(", 1398 name, 1399 ) 1400 )