sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4 5from sqlglot import exp 6 7 8def canonicalize(expression: exp.Expression) -> exp.Expression: 9 """Converts a sql expression into a standard form. 10 11 This method relies on annotate_types because many of the 12 conversions rely on type inference. 13 14 Args: 15 expression: The expression to canonicalize. 16 """ 17 exp.replace_children(expression, canonicalize) 18 19 expression = add_text_to_concat(expression) 20 expression = coerce_type(expression) 21 expression = remove_redundant_casts(expression) 22 expression = ensure_bool_predicates(expression) 23 expression = remove_ascending_order(expression) 24 25 return expression 26 27 28def add_text_to_concat(node: exp.Expression) -> exp.Expression: 29 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 30 node = exp.Concat(expressions=[node.left, node.right]) 31 return node 32 33 34def coerce_type(node: exp.Expression) -> exp.Expression: 35 if isinstance(node, exp.Binary): 36 _coerce_date(node.left, node.right) 37 elif isinstance(node, exp.Between): 38 _coerce_date(node.this, node.args["low"]) 39 elif isinstance(node, exp.Extract): 40 if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: 41 _replace_cast(node.expression, "datetime") 42 return node 43 44 45def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 46 if ( 47 isinstance(expression, exp.Cast) 48 and expression.to.type 49 and expression.this.type 50 and expression.to.type.this == expression.this.type.this 51 ): 52 return expression.this 53 return expression 54 55 56def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: 57 if isinstance(expression, exp.Connector): 58 _replace_int_predicate(expression.left) 59 _replace_int_predicate(expression.right) 60 61 elif isinstance(expression, (exp.Where, exp.Having)): 62 _replace_int_predicate(expression.this) 63 64 return expression 65 66 67def remove_ascending_order(expression: exp.Expression) -> exp.Expression: 68 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 69 # Convert ORDER BY a ASC to ORDER BY a 70 expression.set("desc", None) 71 72 return expression 73 74 75def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: 76 for a, b in itertools.permutations([a, b]): 77 if ( 78 a.type 79 and a.type.this == exp.DataType.Type.DATE 80 and b.type 81 and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) 82 ): 83 _replace_cast(b, "date") 84 85 86def _replace_cast(node: exp.Expression, to: str) -> None: 87 data_type = exp.DataType.build(to) 88 cast = exp.Cast(this=node.copy(), to=data_type) 89 cast.type = data_type 90 node.replace(cast) 91 92 93def _replace_int_predicate(expression: exp.Expression) -> None: 94 if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 95 expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
9def canonicalize(expression: exp.Expression) -> exp.Expression: 10 """Converts a sql expression into a standard form. 11 12 This method relies on annotate_types because many of the 13 conversions rely on type inference. 14 15 Args: 16 expression: The expression to canonicalize. 17 """ 18 exp.replace_children(expression, canonicalize) 19 20 expression = add_text_to_concat(expression) 21 expression = coerce_type(expression) 22 expression = remove_redundant_casts(expression) 23 expression = ensure_bool_predicates(expression) 24 expression = remove_ascending_order(expression) 25 26 return expression
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
35def coerce_type(node: exp.Expression) -> exp.Expression: 36 if isinstance(node, exp.Binary): 37 _coerce_date(node.left, node.right) 38 elif isinstance(node, exp.Between): 39 _coerce_date(node.this, node.args["low"]) 40 elif isinstance(node, exp.Extract): 41 if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: 42 _replace_cast(node.expression, "datetime") 43 return node
def
remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
def
ensure_bool_predicates( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
57def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: 58 if isinstance(expression, exp.Connector): 59 _replace_int_predicate(expression.left) 60 _replace_int_predicate(expression.right) 61 62 elif isinstance(expression, (exp.Where, exp.Having)): 63 _replace_int_predicate(expression.this) 64 65 return expression
def
remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression: