sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime 9from sqlglot.optimizer.annotate_types import TypeAnnotator 10 11 12def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 13 """Converts a sql expression into a standard form. 14 15 This method relies on annotate_types because many of the 16 conversions rely on type inference. 17 18 Args: 19 expression: The expression to canonicalize. 20 """ 21 22 dialect = Dialect.get_or_raise(dialect) 23 24 def _canonicalize(expression: exp.Expression) -> exp.Expression: 25 expression = add_text_to_concat(expression) 26 expression = replace_date_funcs(expression, dialect=dialect) 27 expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 28 expression = remove_redundant_casts(expression) 29 expression = ensure_bools(expression, _replace_int_predicate) 30 expression = remove_ascending_order(expression) 31 return expression 32 33 return exp.replace_tree(expression, _canonicalize) 34 35 36def add_text_to_concat(node: exp.Expression) -> exp.Expression: 37 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 38 node = exp.Concat(expressions=[node.left, node.right]) 39 return node 40 41 42def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: 43 if ( 44 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 45 and not node.expressions 46 and not node.args.get("zone") 47 and node.this.is_string 48 and is_iso_date(node.this.name) 49 ): 50 return exp.cast(node.this, to=exp.DataType.Type.DATE) 51 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 52 if not node.type: 53 from sqlglot.optimizer.annotate_types import annotate_types 54 55 node = annotate_types(node, dialect=dialect) 56 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 57 58 return node 59 60 61COERCIBLE_DATE_OPS = ( 62 exp.Add, 63 exp.Sub, 64 exp.EQ, 65 exp.NEQ, 66 exp.GT, 67 exp.GTE, 68 exp.LT, 69 exp.LTE, 70 exp.NullSafeEQ, 71 exp.NullSafeNEQ, 72) 73 74 75def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression: 76 if isinstance(node, COERCIBLE_DATE_OPS): 77 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 78 elif isinstance(node, exp.Between): 79 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 80 elif isinstance(node, exp.Extract) and not node.expression.type.is_type( 81 *exp.DataType.TEMPORAL_TYPES 82 ): 83 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 84 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 85 _coerce_timeunit_arg(node.this, node.unit) 86 elif isinstance(node, exp.DateDiff): 87 _coerce_datediff_args(node) 88 89 return node 90 91 92def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 93 if ( 94 isinstance(expression, exp.Cast) 95 and expression.this.type 96 and expression.to == expression.this.type 97 ): 98 return expression.this 99 100 if ( 101 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 102 and expression.this.type 103 and expression.this.type.this == exp.DataType.Type.DATE 104 and not expression.this.type.expressions 105 ): 106 return expression.this 107 108 return expression 109 110 111def ensure_bools( 112 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 113) -> exp.Expression: 114 if isinstance(expression, exp.Connector): 115 replace_func(expression.left) 116 replace_func(expression.right) 117 elif isinstance(expression, exp.Not): 118 replace_func(expression.this) 119 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 120 elif isinstance(expression, exp.If) and not ( 121 isinstance(expression.parent, exp.Case) and expression.parent.this 122 ): 123 replace_func(expression.this) 124 elif isinstance(expression, (exp.Where, exp.Having)): 125 replace_func(expression.this) 126 127 return expression 128 129 130def remove_ascending_order(expression: exp.Expression) -> exp.Expression: 131 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 132 # Convert ORDER BY a ASC to ORDER BY a 133 expression.set("desc", None) 134 135 return expression 136 137 138def _coerce_date( 139 a: exp.Expression, 140 b: exp.Expression, 141 promote_to_inferred_datetime_type: bool, 142) -> None: 143 for a, b in itertools.permutations([a, b]): 144 if isinstance(b, exp.Interval): 145 a = _coerce_timeunit_arg(a, b.unit) 146 147 a_type = a.type 148 if ( 149 not a_type 150 or a_type.this not in exp.DataType.TEMPORAL_TYPES 151 or not b.type 152 or b.type.this not in exp.DataType.TEXT_TYPES 153 ): 154 continue 155 156 if promote_to_inferred_datetime_type: 157 if b.is_string: 158 date_text = b.name 159 if is_iso_date(date_text): 160 b_type = exp.DataType.Type.DATE 161 elif is_iso_datetime(date_text): 162 b_type = exp.DataType.Type.DATETIME 163 else: 164 b_type = a_type.this 165 else: 166 # If b is not a datetime string, we conservatively promote it to a DATETIME, 167 # in order to ensure there are no surprising truncations due to downcasting 168 b_type = exp.DataType.Type.DATETIME 169 170 target_type = ( 171 b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type 172 ) 173 else: 174 target_type = a_type 175 176 if target_type != a_type: 177 _replace_cast(a, target_type) 178 179 _replace_cast(b, target_type) 180 181 182def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: 183 if not arg.type: 184 return arg 185 186 if arg.type.this in exp.DataType.TEXT_TYPES: 187 date_text = arg.name 188 is_iso_date_ = is_iso_date(date_text) 189 190 if is_iso_date_ and is_date_unit(unit): 191 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) 192 193 # An ISO date is also an ISO datetime, but not vice versa 194 if is_iso_date_ or is_iso_datetime(date_text): 195 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 196 197 elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): 198 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 199 200 return arg 201 202 203def _coerce_datediff_args(node: exp.DateDiff) -> None: 204 for e in (node.this, node.expression): 205 if e.type.this not in exp.DataType.TEMPORAL_TYPES: 206 e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) 207 208 209def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None: 210 node.replace(exp.cast(node.copy(), to=to)) 211 212 213# this was originally designed for presto, there is a similar transform for tsql 214# this is different in that it only operates on int types, this is because 215# presto has a boolean type whereas tsql doesn't (people use bits) 216# with y as (select true as x) select x = 0 FROM y -- illegal presto query 217def _replace_int_predicate(expression: exp.Expression) -> None: 218 if isinstance(expression, exp.Coalesce): 219 for child in expression.iter_expressions(): 220 _replace_int_predicate(child) 221 elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 222 expression.replace(expression.neq(0))
def
canonicalize( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
13def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 14 """Converts a sql expression into a standard form. 15 16 This method relies on annotate_types because many of the 17 conversions rely on type inference. 18 19 Args: 20 expression: The expression to canonicalize. 21 """ 22 23 dialect = Dialect.get_or_raise(dialect) 24 25 def _canonicalize(expression: exp.Expression) -> exp.Expression: 26 expression = add_text_to_concat(expression) 27 expression = replace_date_funcs(expression, dialect=dialect) 28 expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 29 expression = remove_redundant_casts(expression) 30 expression = ensure_bools(expression, _replace_int_predicate) 31 expression = remove_ascending_order(expression) 32 return expression 33 34 return exp.replace_tree(expression, _canonicalize)
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
def
replace_date_funcs( node: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
43def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: 44 if ( 45 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 46 and not node.expressions 47 and not node.args.get("zone") 48 and node.this.is_string 49 and is_iso_date(node.this.name) 50 ): 51 return exp.cast(node.this, to=exp.DataType.Type.DATE) 52 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 53 if not node.type: 54 from sqlglot.optimizer.annotate_types import annotate_types 55 56 node = annotate_types(node, dialect=dialect) 57 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 58 59 return node
COERCIBLE_DATE_OPS =
(<class 'sqlglot.expressions.Add'>, <class 'sqlglot.expressions.Sub'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.NullSafeEQ'>, <class 'sqlglot.expressions.NullSafeNEQ'>)
def
coerce_type( node: sqlglot.expressions.Expression, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.Expression:
76def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression: 77 if isinstance(node, COERCIBLE_DATE_OPS): 78 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 79 elif isinstance(node, exp.Between): 80 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 81 elif isinstance(node, exp.Extract) and not node.expression.type.is_type( 82 *exp.DataType.TEMPORAL_TYPES 83 ): 84 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 85 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 86 _coerce_timeunit_arg(node.this, node.unit) 87 elif isinstance(node, exp.DateDiff): 88 _coerce_datediff_args(node) 89 90 return node
def
remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
93def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 94 if ( 95 isinstance(expression, exp.Cast) 96 and expression.this.type 97 and expression.to == expression.this.type 98 ): 99 return expression.this 100 101 if ( 102 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 103 and expression.this.type 104 and expression.this.type.this == exp.DataType.Type.DATE 105 and not expression.this.type.expressions 106 ): 107 return expression.this 108 109 return expression
def
ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
112def ensure_bools( 113 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 114) -> exp.Expression: 115 if isinstance(expression, exp.Connector): 116 replace_func(expression.left) 117 replace_func(expression.right) 118 elif isinstance(expression, exp.Not): 119 replace_func(expression.this) 120 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 121 elif isinstance(expression, exp.If) and not ( 122 isinstance(expression.parent, exp.Case) and expression.parent.this 123 ): 124 replace_func(expression.this) 125 elif isinstance(expression, (exp.Where, exp.Having)): 126 replace_func(expression.this) 127 128 return expression
def
remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression: