sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime 8 9 10def canonicalize(expression: exp.Expression) -> exp.Expression: 11 """Converts a sql expression into a standard form. 12 13 This method relies on annotate_types because many of the 14 conversions rely on type inference. 15 16 Args: 17 expression: The expression to canonicalize. 18 """ 19 20 def _canonicalize(expression: exp.Expression) -> exp.Expression: 21 expression = add_text_to_concat(expression) 22 expression = replace_date_funcs(expression) 23 expression = coerce_type(expression) 24 expression = remove_redundant_casts(expression) 25 expression = ensure_bools(expression, _replace_int_predicate) 26 expression = remove_ascending_order(expression) 27 return expression 28 29 return exp.replace_tree(expression, _canonicalize) 30 31 32def add_text_to_concat(node: exp.Expression) -> exp.Expression: 33 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 34 node = exp.Concat(expressions=[node.left, node.right]) 35 return node 36 37 38def replace_date_funcs(node: exp.Expression) -> exp.Expression: 39 if ( 40 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 41 and not node.expressions 42 and not node.args.get("zone") 43 and node.this.is_string 44 and is_iso_date(node.this.name) 45 ): 46 return exp.cast(node.this, to=exp.DataType.Type.DATE) 47 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 48 if not node.type: 49 from sqlglot.optimizer.annotate_types import annotate_types 50 51 node = annotate_types(node) 52 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 53 54 return node 55 56 57COERCIBLE_DATE_OPS = ( 58 exp.Add, 59 exp.Sub, 60 exp.EQ, 61 exp.NEQ, 62 exp.GT, 63 exp.GTE, 64 exp.LT, 65 exp.LTE, 66 exp.NullSafeEQ, 67 exp.NullSafeNEQ, 68) 69 70 71def coerce_type(node: exp.Expression) -> exp.Expression: 72 if isinstance(node, COERCIBLE_DATE_OPS): 73 _coerce_date(node.left, node.right) 74 elif isinstance(node, exp.Between): 75 _coerce_date(node.this, node.args["low"]) 76 elif isinstance(node, exp.Extract) and not node.expression.type.is_type( 77 *exp.DataType.TEMPORAL_TYPES 78 ): 79 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 80 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 81 _coerce_timeunit_arg(node.this, node.unit) 82 elif isinstance(node, exp.DateDiff): 83 _coerce_datediff_args(node) 84 85 return node 86 87 88def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 89 if ( 90 isinstance(expression, exp.Cast) 91 and expression.this.type 92 and expression.to.this == expression.this.type.this 93 ): 94 return expression.this 95 if ( 96 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 97 and expression.this.type 98 and expression.this.type.this == exp.DataType.Type.DATE 99 ): 100 return expression.this 101 return expression 102 103 104def ensure_bools( 105 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 106) -> exp.Expression: 107 if isinstance(expression, exp.Connector): 108 replace_func(expression.left) 109 replace_func(expression.right) 110 elif isinstance(expression, exp.Not): 111 replace_func(expression.this) 112 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 113 elif isinstance(expression, exp.If) and not ( 114 isinstance(expression.parent, exp.Case) and expression.parent.this 115 ): 116 replace_func(expression.this) 117 elif isinstance(expression, (exp.Where, exp.Having)): 118 replace_func(expression.this) 119 120 return expression 121 122 123def remove_ascending_order(expression: exp.Expression) -> exp.Expression: 124 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 125 # Convert ORDER BY a ASC to ORDER BY a 126 expression.set("desc", None) 127 128 return expression 129 130 131def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: 132 for a, b in itertools.permutations([a, b]): 133 if isinstance(b, exp.Interval): 134 a = _coerce_timeunit_arg(a, b.unit) 135 if ( 136 a.type 137 and a.type.this in exp.DataType.TEMPORAL_TYPES 138 and b.type 139 and b.type.this in exp.DataType.TEXT_TYPES 140 ): 141 _replace_cast(b, exp.DataType.Type.DATETIME) 142 143 144def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: 145 if not arg.type: 146 return arg 147 148 if arg.type.this in exp.DataType.TEXT_TYPES: 149 date_text = arg.name 150 is_iso_date_ = is_iso_date(date_text) 151 152 if is_iso_date_ and is_date_unit(unit): 153 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) 154 155 # An ISO date is also an ISO datetime, but not vice versa 156 if is_iso_date_ or is_iso_datetime(date_text): 157 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 158 159 elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): 160 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 161 162 return arg 163 164 165def _coerce_datediff_args(node: exp.DateDiff) -> None: 166 for e in (node.this, node.expression): 167 if e.type.this not in exp.DataType.TEMPORAL_TYPES: 168 e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) 169 170 171def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: 172 node.replace(exp.cast(node.copy(), to=to)) 173 174 175# this was originally designed for presto, there is a similar transform for tsql 176# this is different in that it only operates on int types, this is because 177# presto has a boolean type whereas tsql doesn't (people use bits) 178# with y as (select true as x) select x = 0 FROM y -- illegal presto query 179def _replace_int_predicate(expression: exp.Expression) -> None: 180 if isinstance(expression, exp.Coalesce): 181 for child in expression.iter_expressions(): 182 _replace_int_predicate(child) 183 elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 184 expression.replace(expression.neq(0))
11def canonicalize(expression: exp.Expression) -> exp.Expression: 12 """Converts a sql expression into a standard form. 13 14 This method relies on annotate_types because many of the 15 conversions rely on type inference. 16 17 Args: 18 expression: The expression to canonicalize. 19 """ 20 21 def _canonicalize(expression: exp.Expression) -> exp.Expression: 22 expression = add_text_to_concat(expression) 23 expression = replace_date_funcs(expression) 24 expression = coerce_type(expression) 25 expression = remove_redundant_casts(expression) 26 expression = ensure_bools(expression, _replace_int_predicate) 27 expression = remove_ascending_order(expression) 28 return expression 29 30 return exp.replace_tree(expression, _canonicalize)
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
39def replace_date_funcs(node: exp.Expression) -> exp.Expression: 40 if ( 41 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 42 and not node.expressions 43 and not node.args.get("zone") 44 and node.this.is_string 45 and is_iso_date(node.this.name) 46 ): 47 return exp.cast(node.this, to=exp.DataType.Type.DATE) 48 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 49 if not node.type: 50 from sqlglot.optimizer.annotate_types import annotate_types 51 52 node = annotate_types(node) 53 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 54 55 return node
COERCIBLE_DATE_OPS =
(<class 'sqlglot.expressions.Add'>, <class 'sqlglot.expressions.Sub'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.NullSafeEQ'>, <class 'sqlglot.expressions.NullSafeNEQ'>)
72def coerce_type(node: exp.Expression) -> exp.Expression: 73 if isinstance(node, COERCIBLE_DATE_OPS): 74 _coerce_date(node.left, node.right) 75 elif isinstance(node, exp.Between): 76 _coerce_date(node.this, node.args["low"]) 77 elif isinstance(node, exp.Extract) and not node.expression.type.is_type( 78 *exp.DataType.TEMPORAL_TYPES 79 ): 80 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 81 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 82 _coerce_timeunit_arg(node.this, node.unit) 83 elif isinstance(node, exp.DateDiff): 84 _coerce_datediff_args(node) 85 86 return node
def
remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
89def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 90 if ( 91 isinstance(expression, exp.Cast) 92 and expression.this.type 93 and expression.to.this == expression.this.type.this 94 ): 95 return expression.this 96 if ( 97 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 98 and expression.this.type 99 and expression.this.type.this == exp.DataType.Type.DATE 100 ): 101 return expression.this 102 return expression
def
ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
105def ensure_bools( 106 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 107) -> exp.Expression: 108 if isinstance(expression, exp.Connector): 109 replace_func(expression.left) 110 replace_func(expression.right) 111 elif isinstance(expression, exp.Not): 112 replace_func(expression.this) 113 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 114 elif isinstance(expression, exp.If) and not ( 115 isinstance(expression.parent, exp.Case) and expression.parent.this 116 ): 117 replace_func(expression.this) 118 elif isinstance(expression, (exp.Where, exp.Having)): 119 replace_func(expression.this) 120 121 return expression
def
remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression: