Edit on GitHub

sqlglot.optimizer.canonicalize

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
  8
  9
 10def canonicalize(expression: exp.Expression) -> exp.Expression:
 11    """Converts a sql expression into a standard form.
 12
 13    This method relies on annotate_types because many of the
 14    conversions rely on type inference.
 15
 16    Args:
 17        expression: The expression to canonicalize.
 18    """
 19
 20    def _canonicalize(expression: exp.Expression) -> exp.Expression:
 21        expression = add_text_to_concat(expression)
 22        expression = replace_date_funcs(expression)
 23        expression = coerce_type(expression)
 24        expression = remove_redundant_casts(expression)
 25        expression = ensure_bools(expression, _replace_int_predicate)
 26        expression = remove_ascending_order(expression)
 27        return expression
 28
 29    return exp.replace_tree(expression, _canonicalize)
 30
 31
 32def add_text_to_concat(node: exp.Expression) -> exp.Expression:
 33    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
 34        node = exp.Concat(expressions=[node.left, node.right])
 35    return node
 36
 37
 38def replace_date_funcs(node: exp.Expression) -> exp.Expression:
 39    if (
 40        isinstance(node, (exp.Date, exp.TsOrDsToDate))
 41        and not node.expressions
 42        and not node.args.get("zone")
 43        and node.this.is_string
 44        and is_iso_date(node.this.name)
 45    ):
 46        return exp.cast(node.this, to=exp.DataType.Type.DATE)
 47    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
 48        if not node.type:
 49            from sqlglot.optimizer.annotate_types import annotate_types
 50
 51            node = annotate_types(node)
 52        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
 53
 54    return node
 55
 56
 57COERCIBLE_DATE_OPS = (
 58    exp.Add,
 59    exp.Sub,
 60    exp.EQ,
 61    exp.NEQ,
 62    exp.GT,
 63    exp.GTE,
 64    exp.LT,
 65    exp.LTE,
 66    exp.NullSafeEQ,
 67    exp.NullSafeNEQ,
 68)
 69
 70
 71def coerce_type(node: exp.Expression) -> exp.Expression:
 72    if isinstance(node, COERCIBLE_DATE_OPS):
 73        _coerce_date(node.left, node.right)
 74    elif isinstance(node, exp.Between):
 75        _coerce_date(node.this, node.args["low"])
 76    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
 77        *exp.DataType.TEMPORAL_TYPES
 78    ):
 79        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
 80    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
 81        _coerce_timeunit_arg(node.this, node.unit)
 82    elif isinstance(node, exp.DateDiff):
 83        _coerce_datediff_args(node)
 84
 85    return node
 86
 87
 88def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 89    if (
 90        isinstance(expression, exp.Cast)
 91        and expression.this.type
 92        and expression.to.this == expression.this.type.this
 93    ):
 94        return expression.this
 95    if (
 96        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
 97        and expression.this.type
 98        and expression.this.type.this == exp.DataType.Type.DATE
 99    ):
100        return expression.this
101    return expression
102
103
104def ensure_bools(
105    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
106) -> exp.Expression:
107    if isinstance(expression, exp.Connector):
108        replace_func(expression.left)
109        replace_func(expression.right)
110    elif isinstance(expression, exp.Not):
111        replace_func(expression.this)
112        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
113    elif isinstance(expression, exp.If) and not (
114        isinstance(expression.parent, exp.Case) and expression.parent.this
115    ):
116        replace_func(expression.this)
117    elif isinstance(expression, (exp.Where, exp.Having)):
118        replace_func(expression.this)
119
120    return expression
121
122
123def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
124    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
125        # Convert ORDER BY a ASC to ORDER BY a
126        expression.set("desc", None)
127
128    return expression
129
130
131def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
132    for a, b in itertools.permutations([a, b]):
133        if isinstance(b, exp.Interval):
134            a = _coerce_timeunit_arg(a, b.unit)
135        if (
136            a.type
137            and a.type.this in exp.DataType.TEMPORAL_TYPES
138            and b.type
139            and b.type.this in exp.DataType.TEXT_TYPES
140        ):
141            _replace_cast(b, exp.DataType.Type.DATETIME)
142
143
144def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
145    if not arg.type:
146        return arg
147
148    if arg.type.this in exp.DataType.TEXT_TYPES:
149        date_text = arg.name
150        is_iso_date_ = is_iso_date(date_text)
151
152        if is_iso_date_ and is_date_unit(unit):
153            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
154
155        # An ISO date is also an ISO datetime, but not vice versa
156        if is_iso_date_ or is_iso_datetime(date_text):
157            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
158
159    elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
160        return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
161
162    return arg
163
164
165def _coerce_datediff_args(node: exp.DateDiff) -> None:
166    for e in (node.this, node.expression):
167        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
168            e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
169
170
171def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
172    node.replace(exp.cast(node.copy(), to=to))
173
174
175# this was originally designed for presto, there is a similar transform for tsql
176# this is different in that it only operates on int types, this is because
177# presto has a boolean type whereas tsql doesn't (people use bits)
178# with y as (select true as x) select x = 0 FROM y -- illegal presto query
179def _replace_int_predicate(expression: exp.Expression) -> None:
180    if isinstance(expression, exp.Coalesce):
181        for child in expression.iter_expressions():
182            _replace_int_predicate(child)
183    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
184        expression.replace(expression.neq(0))
def canonicalize( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
11def canonicalize(expression: exp.Expression) -> exp.Expression:
12    """Converts a sql expression into a standard form.
13
14    This method relies on annotate_types because many of the
15    conversions rely on type inference.
16
17    Args:
18        expression: The expression to canonicalize.
19    """
20
21    def _canonicalize(expression: exp.Expression) -> exp.Expression:
22        expression = add_text_to_concat(expression)
23        expression = replace_date_funcs(expression)
24        expression = coerce_type(expression)
25        expression = remove_redundant_casts(expression)
26        expression = ensure_bools(expression, _replace_int_predicate)
27        expression = remove_ascending_order(expression)
28        return expression
29
30    return exp.replace_tree(expression, _canonicalize)

Converts a sql expression into a standard form.

This method relies on annotate_types because many of the conversions rely on type inference.

Arguments:
  • expression: The expression to canonicalize.
def add_text_to_concat(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
33def add_text_to_concat(node: exp.Expression) -> exp.Expression:
34    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
35        node = exp.Concat(expressions=[node.left, node.right])
36    return node
def replace_date_funcs(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
39def replace_date_funcs(node: exp.Expression) -> exp.Expression:
40    if (
41        isinstance(node, (exp.Date, exp.TsOrDsToDate))
42        and not node.expressions
43        and not node.args.get("zone")
44        and node.this.is_string
45        and is_iso_date(node.this.name)
46    ):
47        return exp.cast(node.this, to=exp.DataType.Type.DATE)
48    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
49        if not node.type:
50            from sqlglot.optimizer.annotate_types import annotate_types
51
52            node = annotate_types(node)
53        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
54
55    return node
def coerce_type(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
72def coerce_type(node: exp.Expression) -> exp.Expression:
73    if isinstance(node, COERCIBLE_DATE_OPS):
74        _coerce_date(node.left, node.right)
75    elif isinstance(node, exp.Between):
76        _coerce_date(node.this, node.args["low"])
77    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
78        *exp.DataType.TEMPORAL_TYPES
79    ):
80        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
81    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
82        _coerce_timeunit_arg(node.this, node.unit)
83    elif isinstance(node, exp.DateDiff):
84        _coerce_datediff_args(node)
85
86    return node
def remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 89def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 90    if (
 91        isinstance(expression, exp.Cast)
 92        and expression.this.type
 93        and expression.to.this == expression.this.type.this
 94    ):
 95        return expression.this
 96    if (
 97        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
 98        and expression.this.type
 99        and expression.this.type.this == exp.DataType.Type.DATE
100    ):
101        return expression.this
102    return expression
def ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
105def ensure_bools(
106    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
107) -> exp.Expression:
108    if isinstance(expression, exp.Connector):
109        replace_func(expression.left)
110        replace_func(expression.right)
111    elif isinstance(expression, exp.Not):
112        replace_func(expression.this)
113        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
114    elif isinstance(expression, exp.If) and not (
115        isinstance(expression.parent, exp.Case) and expression.parent.this
116    ):
117        replace_func(expression.this)
118    elif isinstance(expression, (exp.Where, exp.Having)):
119        replace_func(expression.this)
120
121    return expression
def remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
124def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
125    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
126        # Convert ORDER BY a ASC to ORDER BY a
127        expression.set("desc", None)
128
129    return expression