sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4from collections import deque 5from decimal import Decimal 6 7from sqlglot import exp 8from sqlglot.generator import cached_generator 9from sqlglot.helper import first, while_changing 10 11 12def simplify(expression): 13 """ 14 Rewrite sqlglot AST to simplify expressions. 15 16 Example: 17 >>> import sqlglot 18 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 19 >>> simplify(expression).sql() 20 'TRUE' 21 22 Args: 23 expression (sqlglot.Expression): expression to simplify 24 Returns: 25 sqlglot.Expression: simplified expression 26 """ 27 28 generate = cached_generator() 29 30 def _simplify(expression, root=True): 31 if expression.meta.get("final"): 32 return expression 33 node = expression 34 node = rewrite_between(node) 35 node = uniq_sort(node, generate, root) 36 node = absorb_and_eliminate(node, root) 37 exp.replace_children(node, lambda e: _simplify(e, False)) 38 node = simplify_not(node) 39 node = flatten(node) 40 node = simplify_connectors(node, root) 41 node = remove_compliments(node, root) 42 node.parent = expression.parent 43 node = simplify_literals(node, root) 44 node = simplify_parens(node) 45 if root: 46 expression.replace(node) 47 return node 48 49 expression = while_changing(expression, _simplify) 50 remove_where_true(expression) 51 return expression 52 53 54def rewrite_between(expression: exp.Expression) -> exp.Expression: 55 """Rewrite x between y and z to x >= y AND x <= z. 56 57 This is done because comparison simplification is only done on lt/lte/gt/gte. 58 """ 59 if isinstance(expression, exp.Between): 60 return exp.and_( 61 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 62 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 63 copy=False, 64 ) 65 return expression 66 67 68def simplify_not(expression): 69 """ 70 Demorgan's Law 71 NOT (x OR y) -> NOT x AND NOT y 72 NOT (x AND y) -> NOT x OR NOT y 73 """ 74 if isinstance(expression, exp.Not): 75 if is_null(expression.this): 76 return exp.null() 77 if isinstance(expression.this, exp.Paren): 78 condition = expression.this.unnest() 79 if isinstance(condition, exp.And): 80 return exp.or_( 81 exp.not_(condition.left, copy=False), 82 exp.not_(condition.right, copy=False), 83 copy=False, 84 ) 85 if isinstance(condition, exp.Or): 86 return exp.and_( 87 exp.not_(condition.left, copy=False), 88 exp.not_(condition.right, copy=False), 89 copy=False, 90 ) 91 if is_null(condition): 92 return exp.null() 93 if always_true(expression.this): 94 return exp.false() 95 if is_false(expression.this): 96 return exp.true() 97 if isinstance(expression.this, exp.Not): 98 # double negation 99 # NOT NOT x -> x 100 return expression.this.this 101 return expression 102 103 104def flatten(expression): 105 """ 106 A AND (B AND C) -> A AND B AND C 107 A OR (B OR C) -> A OR B OR C 108 """ 109 if isinstance(expression, exp.Connector): 110 for node in expression.args.values(): 111 child = node.unnest() 112 if isinstance(child, expression.__class__): 113 node.replace(child) 114 return expression 115 116 117def simplify_connectors(expression, root=True): 118 def _simplify_connectors(expression, left, right): 119 if left == right: 120 return left 121 if isinstance(expression, exp.And): 122 if is_false(left) or is_false(right): 123 return exp.false() 124 if is_null(left) or is_null(right): 125 return exp.null() 126 if always_true(left) and always_true(right): 127 return exp.true() 128 if always_true(left): 129 return right 130 if always_true(right): 131 return left 132 return _simplify_comparison(expression, left, right) 133 elif isinstance(expression, exp.Or): 134 if always_true(left) or always_true(right): 135 return exp.true() 136 if is_false(left) and is_false(right): 137 return exp.false() 138 if ( 139 (is_null(left) and is_null(right)) 140 or (is_null(left) and is_false(right)) 141 or (is_false(left) and is_null(right)) 142 ): 143 return exp.null() 144 if is_false(left): 145 return right 146 if is_false(right): 147 return left 148 return _simplify_comparison(expression, left, right, or_=True) 149 150 if isinstance(expression, exp.Connector): 151 return _flat_simplify(expression, _simplify_connectors, root) 152 return expression 153 154 155LT_LTE = (exp.LT, exp.LTE) 156GT_GTE = (exp.GT, exp.GTE) 157 158COMPARISONS = ( 159 *LT_LTE, 160 *GT_GTE, 161 exp.EQ, 162 exp.NEQ, 163) 164 165INVERSE_COMPARISONS = { 166 exp.LT: exp.GT, 167 exp.GT: exp.LT, 168 exp.LTE: exp.GTE, 169 exp.GTE: exp.LTE, 170} 171 172 173def _simplify_comparison(expression, left, right, or_=False): 174 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 175 ll, lr = left.args.values() 176 rl, rr = right.args.values() 177 178 largs = {ll, lr} 179 rargs = {rl, rr} 180 181 matching = largs & rargs 182 columns = {m for m in matching if isinstance(m, exp.Column)} 183 184 if matching and columns: 185 try: 186 l = first(largs - columns) 187 r = first(rargs - columns) 188 except StopIteration: 189 return expression 190 191 # make sure the comparison is always of the form x > 1 instead of 1 < x 192 if left.__class__ in INVERSE_COMPARISONS and l == ll: 193 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 194 if right.__class__ in INVERSE_COMPARISONS and r == rl: 195 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 196 197 if l.is_number and r.is_number: 198 l = float(l.name) 199 r = float(r.name) 200 elif l.is_string and r.is_string: 201 l = l.name 202 r = r.name 203 else: 204 return None 205 206 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 207 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 208 return left if (av > bv if or_ else av <= bv) else right 209 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 210 return left if (av < bv if or_ else av >= bv) else right 211 212 # we can't ever shortcut to true because the column could be null 213 if not or_: 214 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 215 if av <= bv: 216 return exp.false() 217 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 218 if av >= bv: 219 return exp.false() 220 elif isinstance(a, exp.EQ): 221 if isinstance(b, exp.LT): 222 return exp.false() if av >= bv else a 223 if isinstance(b, exp.LTE): 224 return exp.false() if av > bv else a 225 if isinstance(b, exp.GT): 226 return exp.false() if av <= bv else a 227 if isinstance(b, exp.GTE): 228 return exp.false() if av < bv else a 229 if isinstance(b, exp.NEQ): 230 return exp.false() if av == bv else a 231 return None 232 233 234def remove_compliments(expression, root=True): 235 """ 236 Removing compliments. 237 238 A AND NOT A -> FALSE 239 A OR NOT A -> TRUE 240 """ 241 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 242 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 243 244 for a, b in itertools.permutations(expression.flatten(), 2): 245 if is_complement(a, b): 246 return compliment 247 return expression 248 249 250def uniq_sort(expression, generate, root=True): 251 """ 252 Uniq and sort a connector. 253 254 C AND A AND B AND B -> A AND B AND C 255 """ 256 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 257 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 258 flattened = tuple(expression.flatten()) 259 deduped = {generate(e): e for e in flattened} 260 arr = tuple(deduped.items()) 261 262 # check if the operands are already sorted, if not sort them 263 # A AND C AND B -> A AND B AND C 264 for i, (sql, e) in enumerate(arr[1:]): 265 if sql < arr[i][0]: 266 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 267 break 268 else: 269 # we didn't have to sort but maybe we need to dedup 270 if len(deduped) < len(flattened): 271 expression = result_func(*deduped.values(), copy=False) 272 273 return expression 274 275 276def absorb_and_eliminate(expression, root=True): 277 """ 278 absorption: 279 A AND (A OR B) -> A 280 A OR (A AND B) -> A 281 A AND (NOT A OR B) -> A AND B 282 A OR (NOT A AND B) -> A OR B 283 elimination: 284 (A AND B) OR (A AND NOT B) -> A 285 (A OR B) AND (A OR NOT B) -> A 286 """ 287 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 288 kind = exp.Or if isinstance(expression, exp.And) else exp.And 289 290 for a, b in itertools.permutations(expression.flatten(), 2): 291 if isinstance(a, kind): 292 aa, ab = a.unnest_operands() 293 294 # absorb 295 if is_complement(b, aa): 296 aa.replace(exp.true() if kind == exp.And else exp.false()) 297 elif is_complement(b, ab): 298 ab.replace(exp.true() if kind == exp.And else exp.false()) 299 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 300 a.replace(exp.false() if kind == exp.And else exp.true()) 301 elif isinstance(b, kind): 302 # eliminate 303 rhs = b.unnest_operands() 304 ba, bb = rhs 305 306 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 307 a.replace(aa) 308 b.replace(aa) 309 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 310 a.replace(ab) 311 b.replace(ab) 312 313 return expression 314 315 316def simplify_literals(expression, root=True): 317 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 318 return _flat_simplify(expression, _simplify_binary, root) 319 elif isinstance(expression, exp.Neg): 320 this = expression.this 321 if this.is_number: 322 value = this.name 323 if value[0] == "-": 324 return exp.Literal.number(value[1:]) 325 return exp.Literal.number(f"-{value}") 326 327 return expression 328 329 330def _simplify_binary(expression, a, b): 331 if isinstance(expression, exp.Is): 332 if isinstance(b, exp.Not): 333 c = b.this 334 not_ = True 335 else: 336 c = b 337 not_ = False 338 339 if is_null(c): 340 if isinstance(a, exp.Literal): 341 return exp.true() if not_ else exp.false() 342 if is_null(a): 343 return exp.false() if not_ else exp.true() 344 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 345 return None 346 elif is_null(a) or is_null(b): 347 return exp.null() 348 349 if a.is_number and b.is_number: 350 a = int(a.name) if a.is_int else Decimal(a.name) 351 b = int(b.name) if b.is_int else Decimal(b.name) 352 353 if isinstance(expression, exp.Add): 354 return exp.Literal.number(a + b) 355 if isinstance(expression, exp.Sub): 356 return exp.Literal.number(a - b) 357 if isinstance(expression, exp.Mul): 358 return exp.Literal.number(a * b) 359 if isinstance(expression, exp.Div): 360 # engines have differing int div behavior so intdiv is not safe 361 if isinstance(a, int) and isinstance(b, int): 362 return None 363 return exp.Literal.number(a / b) 364 365 boolean = eval_boolean(expression, a, b) 366 367 if boolean: 368 return boolean 369 elif a.is_string and b.is_string: 370 boolean = eval_boolean(expression, a.this, b.this) 371 372 if boolean: 373 return boolean 374 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 375 a, b = extract_date(a), extract_interval(b) 376 if a and b: 377 if isinstance(expression, exp.Add): 378 return date_literal(a + b) 379 if isinstance(expression, exp.Sub): 380 return date_literal(a - b) 381 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 382 a, b = extract_interval(a), extract_date(b) 383 # you cannot subtract a date from an interval 384 if a and b and isinstance(expression, exp.Add): 385 return date_literal(a + b) 386 387 return None 388 389 390def simplify_parens(expression): 391 if not isinstance(expression, exp.Paren): 392 return expression 393 394 this = expression.this 395 parent = expression.parent 396 397 if not isinstance(this, exp.Select) and ( 398 not isinstance(parent, (exp.Condition, exp.Binary)) 399 or isinstance(this, exp.Predicate) 400 or not isinstance(this, exp.Binary) 401 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 402 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 403 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 404 ): 405 return expression.this 406 return expression 407 408 409def remove_where_true(expression): 410 for where in expression.find_all(exp.Where): 411 if always_true(where.this): 412 where.parent.set("where", None) 413 for join in expression.find_all(exp.Join): 414 if ( 415 always_true(join.args.get("on")) 416 and not join.args.get("using") 417 and not join.args.get("method") 418 ): 419 join.set("on", None) 420 join.set("side", None) 421 join.set("kind", "CROSS") 422 423 424def always_true(expression): 425 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 426 expression, exp.Literal 427 ) 428 429 430def is_complement(a, b): 431 return isinstance(b, exp.Not) and b.this == a 432 433 434def is_false(a: exp.Expression) -> bool: 435 return type(a) is exp.Boolean and not a.this 436 437 438def is_null(a: exp.Expression) -> bool: 439 return type(a) is exp.Null 440 441 442def eval_boolean(expression, a, b): 443 if isinstance(expression, (exp.EQ, exp.Is)): 444 return boolean_literal(a == b) 445 if isinstance(expression, exp.NEQ): 446 return boolean_literal(a != b) 447 if isinstance(expression, exp.GT): 448 return boolean_literal(a > b) 449 if isinstance(expression, exp.GTE): 450 return boolean_literal(a >= b) 451 if isinstance(expression, exp.LT): 452 return boolean_literal(a < b) 453 if isinstance(expression, exp.LTE): 454 return boolean_literal(a <= b) 455 return None 456 457 458def extract_date(cast): 459 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 460 # so in that case we can't extract the date. 461 try: 462 if cast.args["to"].this == exp.DataType.Type.DATE: 463 return datetime.date.fromisoformat(cast.name) 464 if cast.args["to"].this == exp.DataType.Type.DATETIME: 465 return datetime.datetime.fromisoformat(cast.name) 466 except ValueError: 467 return None 468 469 470def extract_interval(interval): 471 try: 472 from dateutil.relativedelta import relativedelta # type: ignore 473 except ModuleNotFoundError: 474 return None 475 476 n = int(interval.name) 477 unit = interval.text("unit").lower() 478 479 if unit == "year": 480 return relativedelta(years=n) 481 if unit == "month": 482 return relativedelta(months=n) 483 if unit == "week": 484 return relativedelta(weeks=n) 485 if unit == "day": 486 return relativedelta(days=n) 487 return None 488 489 490def date_literal(date): 491 return exp.cast( 492 exp.Literal.string(date), 493 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 494 ) 495 496 497def boolean_literal(condition): 498 return exp.true() if condition else exp.false() 499 500 501def _flat_simplify(expression, simplifier, root=True): 502 if root or not expression.same_parent: 503 operands = [] 504 queue = deque(expression.flatten(unnest=False)) 505 size = len(queue) 506 507 while queue: 508 a = queue.popleft() 509 510 for b in queue: 511 result = simplifier(expression, a, b) 512 513 if result: 514 queue.remove(b) 515 queue.appendleft(result) 516 break 517 else: 518 operands.append(a) 519 520 if len(operands) < size: 521 return functools.reduce( 522 lambda a, b: expression.__class__(this=a, expression=b), operands 523 ) 524 return expression
def
simplify(expression):
13def simplify(expression): 14 """ 15 Rewrite sqlglot AST to simplify expressions. 16 17 Example: 18 >>> import sqlglot 19 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 20 >>> simplify(expression).sql() 21 'TRUE' 22 23 Args: 24 expression (sqlglot.Expression): expression to simplify 25 Returns: 26 sqlglot.Expression: simplified expression 27 """ 28 29 generate = cached_generator() 30 31 def _simplify(expression, root=True): 32 if expression.meta.get("final"): 33 return expression 34 node = expression 35 node = rewrite_between(node) 36 node = uniq_sort(node, generate, root) 37 node = absorb_and_eliminate(node, root) 38 exp.replace_children(node, lambda e: _simplify(e, False)) 39 node = simplify_not(node) 40 node = flatten(node) 41 node = simplify_connectors(node, root) 42 node = remove_compliments(node, root) 43 node.parent = expression.parent 44 node = simplify_literals(node, root) 45 node = simplify_parens(node) 46 if root: 47 expression.replace(node) 48 return node 49 50 expression = while_changing(expression, _simplify) 51 remove_where_true(expression) 52 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
55def rewrite_between(expression: exp.Expression) -> exp.Expression: 56 """Rewrite x between y and z to x >= y AND x <= z. 57 58 This is done because comparison simplification is only done on lt/lte/gt/gte. 59 """ 60 if isinstance(expression, exp.Between): 61 return exp.and_( 62 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 63 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 64 copy=False, 65 ) 66 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
def
simplify_not(expression):
69def simplify_not(expression): 70 """ 71 Demorgan's Law 72 NOT (x OR y) -> NOT x AND NOT y 73 NOT (x AND y) -> NOT x OR NOT y 74 """ 75 if isinstance(expression, exp.Not): 76 if is_null(expression.this): 77 return exp.null() 78 if isinstance(expression.this, exp.Paren): 79 condition = expression.this.unnest() 80 if isinstance(condition, exp.And): 81 return exp.or_( 82 exp.not_(condition.left, copy=False), 83 exp.not_(condition.right, copy=False), 84 copy=False, 85 ) 86 if isinstance(condition, exp.Or): 87 return exp.and_( 88 exp.not_(condition.left, copy=False), 89 exp.not_(condition.right, copy=False), 90 copy=False, 91 ) 92 if is_null(condition): 93 return exp.null() 94 if always_true(expression.this): 95 return exp.false() 96 if is_false(expression.this): 97 return exp.true() 98 if isinstance(expression.this, exp.Not): 99 # double negation 100 # NOT NOT x -> x 101 return expression.this.this 102 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
def
flatten(expression):
105def flatten(expression): 106 """ 107 A AND (B AND C) -> A AND B AND C 108 A OR (B OR C) -> A OR B OR C 109 """ 110 if isinstance(expression, exp.Connector): 111 for node in expression.args.values(): 112 child = node.unnest() 113 if isinstance(child, expression.__class__): 114 node.replace(child) 115 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
def
simplify_connectors(expression, root=True):
118def simplify_connectors(expression, root=True): 119 def _simplify_connectors(expression, left, right): 120 if left == right: 121 return left 122 if isinstance(expression, exp.And): 123 if is_false(left) or is_false(right): 124 return exp.false() 125 if is_null(left) or is_null(right): 126 return exp.null() 127 if always_true(left) and always_true(right): 128 return exp.true() 129 if always_true(left): 130 return right 131 if always_true(right): 132 return left 133 return _simplify_comparison(expression, left, right) 134 elif isinstance(expression, exp.Or): 135 if always_true(left) or always_true(right): 136 return exp.true() 137 if is_false(left) and is_false(right): 138 return exp.false() 139 if ( 140 (is_null(left) and is_null(right)) 141 or (is_null(left) and is_false(right)) 142 or (is_false(left) and is_null(right)) 143 ): 144 return exp.null() 145 if is_false(left): 146 return right 147 if is_false(right): 148 return left 149 return _simplify_comparison(expression, left, right, or_=True) 150 151 if isinstance(expression, exp.Connector): 152 return _flat_simplify(expression, _simplify_connectors, root) 153 return expression
LT_LTE =
(<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE =
(<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
COMPARISONS =
(<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>)
INVERSE_COMPARISONS =
{<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def
remove_compliments(expression, root=True):
235def remove_compliments(expression, root=True): 236 """ 237 Removing compliments. 238 239 A AND NOT A -> FALSE 240 A OR NOT A -> TRUE 241 """ 242 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 243 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 244 245 for a, b in itertools.permutations(expression.flatten(), 2): 246 if is_complement(a, b): 247 return compliment 248 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
def
uniq_sort(expression, generate, root=True):
251def uniq_sort(expression, generate, root=True): 252 """ 253 Uniq and sort a connector. 254 255 C AND A AND B AND B -> A AND B AND C 256 """ 257 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 258 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 259 flattened = tuple(expression.flatten()) 260 deduped = {generate(e): e for e in flattened} 261 arr = tuple(deduped.items()) 262 263 # check if the operands are already sorted, if not sort them 264 # A AND C AND B -> A AND B AND C 265 for i, (sql, e) in enumerate(arr[1:]): 266 if sql < arr[i][0]: 267 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 268 break 269 else: 270 # we didn't have to sort but maybe we need to dedup 271 if len(deduped) < len(flattened): 272 expression = result_func(*deduped.values(), copy=False) 273 274 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
def
absorb_and_eliminate(expression, root=True):
277def absorb_and_eliminate(expression, root=True): 278 """ 279 absorption: 280 A AND (A OR B) -> A 281 A OR (A AND B) -> A 282 A AND (NOT A OR B) -> A AND B 283 A OR (NOT A AND B) -> A OR B 284 elimination: 285 (A AND B) OR (A AND NOT B) -> A 286 (A OR B) AND (A OR NOT B) -> A 287 """ 288 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 289 kind = exp.Or if isinstance(expression, exp.And) else exp.And 290 291 for a, b in itertools.permutations(expression.flatten(), 2): 292 if isinstance(a, kind): 293 aa, ab = a.unnest_operands() 294 295 # absorb 296 if is_complement(b, aa): 297 aa.replace(exp.true() if kind == exp.And else exp.false()) 298 elif is_complement(b, ab): 299 ab.replace(exp.true() if kind == exp.And else exp.false()) 300 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 301 a.replace(exp.false() if kind == exp.And else exp.true()) 302 elif isinstance(b, kind): 303 # eliminate 304 rhs = b.unnest_operands() 305 ba, bb = rhs 306 307 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 308 a.replace(aa) 309 b.replace(aa) 310 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 311 a.replace(ab) 312 b.replace(ab) 313 314 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
def
simplify_literals(expression, root=True):
317def simplify_literals(expression, root=True): 318 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 319 return _flat_simplify(expression, _simplify_binary, root) 320 elif isinstance(expression, exp.Neg): 321 this = expression.this 322 if this.is_number: 323 value = this.name 324 if value[0] == "-": 325 return exp.Literal.number(value[1:]) 326 return exp.Literal.number(f"-{value}") 327 328 return expression
def
simplify_parens(expression):
391def simplify_parens(expression): 392 if not isinstance(expression, exp.Paren): 393 return expression 394 395 this = expression.this 396 parent = expression.parent 397 398 if not isinstance(this, exp.Select) and ( 399 not isinstance(parent, (exp.Condition, exp.Binary)) 400 or isinstance(this, exp.Predicate) 401 or not isinstance(this, exp.Binary) 402 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 403 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 404 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 405 ): 406 return expression.this 407 return expression
def
remove_where_true(expression):
410def remove_where_true(expression): 411 for where in expression.find_all(exp.Where): 412 if always_true(where.this): 413 where.parent.set("where", None) 414 for join in expression.find_all(exp.Join): 415 if ( 416 always_true(join.args.get("on")) 417 and not join.args.get("using") 418 and not join.args.get("method") 419 ): 420 join.set("on", None) 421 join.set("side", None) 422 join.set("kind", "CROSS")
def
always_true(expression):
def
is_complement(a, b):
def
eval_boolean(expression, a, b):
443def eval_boolean(expression, a, b): 444 if isinstance(expression, (exp.EQ, exp.Is)): 445 return boolean_literal(a == b) 446 if isinstance(expression, exp.NEQ): 447 return boolean_literal(a != b) 448 if isinstance(expression, exp.GT): 449 return boolean_literal(a > b) 450 if isinstance(expression, exp.GTE): 451 return boolean_literal(a >= b) 452 if isinstance(expression, exp.LT): 453 return boolean_literal(a < b) 454 if isinstance(expression, exp.LTE): 455 return boolean_literal(a <= b) 456 return None
def
extract_date(cast):
459def extract_date(cast): 460 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 461 # so in that case we can't extract the date. 462 try: 463 if cast.args["to"].this == exp.DataType.Type.DATE: 464 return datetime.date.fromisoformat(cast.name) 465 if cast.args["to"].this == exp.DataType.Type.DATETIME: 466 return datetime.datetime.fromisoformat(cast.name) 467 except ValueError: 468 return None
def
extract_interval(interval):
471def extract_interval(interval): 472 try: 473 from dateutil.relativedelta import relativedelta # type: ignore 474 except ModuleNotFoundError: 475 return None 476 477 n = int(interval.name) 478 unit = interval.text("unit").lower() 479 480 if unit == "year": 481 return relativedelta(years=n) 482 if unit == "month": 483 return relativedelta(months=n) 484 if unit == "week": 485 return relativedelta(weeks=n) 486 if unit == "day": 487 return relativedelta(days=n) 488 return None
def
date_literal(date):
def
boolean_literal(condition):