Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop().copy())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
 78
 79    return expression
 80
 81
 82def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 83    """
 84    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 85
 86    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 87    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 88
 89    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 90    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 91    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 92    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 93    """
 94    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 95        taken = set(expression.named_selects)
 96        for select in expression.selects:
 97            if not select.alias_or_name:
 98                alias = find_new_name(taken, "_c")
 99                select.replace(exp.alias_(select, alias))
100                taken.add(alias)
101
102        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
103        qualify_filters = expression.args["qualify"].pop().this
104
105        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
106        for expr in qualify_filters.find_all(select_candidates):
107            if isinstance(expr, exp.Window):
108                alias = find_new_name(expression.named_selects, "_w")
109                expression.select(exp.alias_(expr, alias), copy=False)
110                column = exp.column(alias)
111
112                if isinstance(expr.parent, exp.Qualify):
113                    qualify_filters = column
114                else:
115                    expr.replace(column)
116            elif expr.name not in expression.named_selects:
117                expression.select(expr.copy(), copy=False)
118
119        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
120
121    return expression
122
123
124def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
125    """
126    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
127    other expressions. This transforms removes the precision from parameterized types in expressions.
128    """
129    for node in expression.find_all(exp.DataType):
130        node.set(
131            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
132        )
133
134    return expression
135
136
137def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
138    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
139    if isinstance(expression, exp.Select):
140        for join in expression.args.get("joins") or []:
141            unnest = join.this
142
143            if isinstance(unnest, exp.Unnest):
144                alias = unnest.args.get("alias")
145                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
146
147                expression.args["joins"].remove(join)
148
149                for e, column in zip(unnest.expressions, alias.columns if alias else []):
150                    expression.append(
151                        "laterals",
152                        exp.Lateral(
153                            this=udtf(this=e),
154                            view=True,
155                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
156                        ),
157                    )
158
159    return expression
160
161
162def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
163    """Convert explode/posexplode into unnest (used in hive -> presto)."""
164    if isinstance(expression, exp.Select):
165        from sqlglot.optimizer.scope import Scope
166
167        taken_select_names = set(expression.named_selects)
168        taken_source_names = {name for name, _ in Scope(expression).references}
169
170        for select in expression.selects:
171            to_replace = select
172
173            pos_alias = ""
174            explode_alias = ""
175
176            if isinstance(select, exp.Alias):
177                explode_alias = select.alias
178                select = select.this
179            elif isinstance(select, exp.Aliases):
180                pos_alias = select.aliases[0].name
181                explode_alias = select.aliases[1].name
182                select = select.this
183
184            if isinstance(select, (exp.Explode, exp.Posexplode)):
185                is_posexplode = isinstance(select, exp.Posexplode)
186
187                explode_arg = select.this
188                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
189
190                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
191                if isinstance(explode_arg, exp.Column):
192                    taken_select_names.add(explode_arg.output_name)
193
194                unnest_source_alias = find_new_name(taken_source_names, "_u")
195                taken_source_names.add(unnest_source_alias)
196
197                if not explode_alias:
198                    explode_alias = find_new_name(taken_select_names, "col")
199                    taken_select_names.add(explode_alias)
200
201                    if is_posexplode:
202                        pos_alias = find_new_name(taken_select_names, "pos")
203                        taken_select_names.add(pos_alias)
204
205                if is_posexplode:
206                    column_names = [explode_alias, pos_alias]
207                    to_replace.pop()
208                    expression.select(pos_alias, explode_alias, copy=False)
209                else:
210                    column_names = [explode_alias]
211                    to_replace.replace(exp.column(explode_alias))
212
213                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
214
215                if not expression.args.get("from"):
216                    expression.from_(unnest, copy=False)
217                else:
218                    expression.join(unnest, join_type="CROSS", copy=False)
219
220    return expression
221
222
223def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
224    if (
225        isinstance(expression, exp.WithinGroup)
226        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
227        and isinstance(expression.expression, exp.Order)
228    ):
229        quantile = expression.this.this
230        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
231        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
232
233    return expression
234
235
236def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
237    if isinstance(expression, exp.With) and expression.recursive:
238        next_name = name_sequence("_c_")
239
240        for cte in expression.expressions:
241            if not cte.args["alias"].columns:
242                query = cte.this
243                if isinstance(query, exp.Union):
244                    query = query.this
245
246                cte.args["alias"].set(
247                    "columns",
248                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
249                )
250
251    return expression
252
253
254def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
255    if (
256        isinstance(expression, (exp.Cast, exp.TryCast))
257        and expression.name.lower() == "epoch"
258        and expression.to.this in exp.DataType.TEMPORAL_TYPES
259    ):
260        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
261
262    return expression
263
264
265def preprocess(
266    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
267) -> t.Callable[[Generator, exp.Expression], str]:
268    """
269    Creates a new transform by chaining a sequence of transformations and converts the resulting
270    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
271    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
272
273    Args:
274        transforms: sequence of transform functions. These will be called in order.
275
276    Returns:
277        Function that can be used as a generator transform.
278    """
279
280    def _to_sql(self, expression: exp.Expression) -> str:
281        expression_type = type(expression)
282
283        expression = transforms[0](expression.copy())
284        for t in transforms[1:]:
285            expression = t(expression)
286
287        _sql_handler = getattr(self, expression.key + "_sql", None)
288        if _sql_handler:
289            return _sql_handler(expression)
290
291        transforms_handler = self.TRANSFORMS.get(type(expression))
292        if transforms_handler:
293            # Ensures we don't enter an infinite loop. This can happen when the original expression
294            # has the same type as the final expression and there's no _sql method available for it,
295            # because then it'd re-enter _to_sql.
296            if expression_type is type(expression):
297                raise ValueError(
298                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
299                )
300
301            return transforms_handler(self, expression)
302
303        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
304
305    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop().copy())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
79
80    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 83def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 84    """
 85    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 86
 87    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 88    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 89
 90    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 91    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 92    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 93    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 94    """
 95    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 96        taken = set(expression.named_selects)
 97        for select in expression.selects:
 98            if not select.alias_or_name:
 99                alias = find_new_name(taken, "_c")
100                select.replace(exp.alias_(select, alias))
101                taken.add(alias)
102
103        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
104        qualify_filters = expression.args["qualify"].pop().this
105
106        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
107        for expr in qualify_filters.find_all(select_candidates):
108            if isinstance(expr, exp.Window):
109                alias = find_new_name(expression.named_selects, "_w")
110                expression.select(exp.alias_(expr, alias), copy=False)
111                column = exp.column(alias)
112
113                if isinstance(expr.parent, exp.Qualify):
114                    qualify_filters = column
115                else:
116                    expr.replace(column)
117            elif expr.name not in expression.named_selects:
118                expression.select(expr.copy(), copy=False)
119
120        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
121
122    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
125def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
126    """
127    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
128    other expressions. This transforms removes the precision from parameterized types in expressions.
129    """
130    for node in expression.find_all(exp.DataType):
131        node.set(
132            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
133        )
134
135    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
138def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
139    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
140    if isinstance(expression, exp.Select):
141        for join in expression.args.get("joins") or []:
142            unnest = join.this
143
144            if isinstance(unnest, exp.Unnest):
145                alias = unnest.args.get("alias")
146                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
147
148                expression.args["joins"].remove(join)
149
150                for e, column in zip(unnest.expressions, alias.columns if alias else []):
151                    expression.append(
152                        "laterals",
153                        exp.Lateral(
154                            this=udtf(this=e),
155                            view=True,
156                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
157                        ),
158                    )
159
160    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
163def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
164    """Convert explode/posexplode into unnest (used in hive -> presto)."""
165    if isinstance(expression, exp.Select):
166        from sqlglot.optimizer.scope import Scope
167
168        taken_select_names = set(expression.named_selects)
169        taken_source_names = {name for name, _ in Scope(expression).references}
170
171        for select in expression.selects:
172            to_replace = select
173
174            pos_alias = ""
175            explode_alias = ""
176
177            if isinstance(select, exp.Alias):
178                explode_alias = select.alias
179                select = select.this
180            elif isinstance(select, exp.Aliases):
181                pos_alias = select.aliases[0].name
182                explode_alias = select.aliases[1].name
183                select = select.this
184
185            if isinstance(select, (exp.Explode, exp.Posexplode)):
186                is_posexplode = isinstance(select, exp.Posexplode)
187
188                explode_arg = select.this
189                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
190
191                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
192                if isinstance(explode_arg, exp.Column):
193                    taken_select_names.add(explode_arg.output_name)
194
195                unnest_source_alias = find_new_name(taken_source_names, "_u")
196                taken_source_names.add(unnest_source_alias)
197
198                if not explode_alias:
199                    explode_alias = find_new_name(taken_select_names, "col")
200                    taken_select_names.add(explode_alias)
201
202                    if is_posexplode:
203                        pos_alias = find_new_name(taken_select_names, "pos")
204                        taken_select_names.add(pos_alias)
205
206                if is_posexplode:
207                    column_names = [explode_alias, pos_alias]
208                    to_replace.pop()
209                    expression.select(pos_alias, explode_alias, copy=False)
210                else:
211                    column_names = [explode_alias]
212                    to_replace.replace(exp.column(explode_alias))
213
214                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
215
216                if not expression.args.get("from"):
217                    expression.from_(unnest, copy=False)
218                else:
219                    expression.join(unnest, join_type="CROSS", copy=False)
220
221    return expression

Convert explode/posexplode into unnest (used in hive -> presto).

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
224def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
225    if (
226        isinstance(expression, exp.WithinGroup)
227        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
228        and isinstance(expression.expression, exp.Order)
229    ):
230        quantile = expression.this.this
231        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
232        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
233
234    return expression
def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
237def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
238    if isinstance(expression, exp.With) and expression.recursive:
239        next_name = name_sequence("_c_")
240
241        for cte in expression.expressions:
242            if not cte.args["alias"].columns:
243                query = cte.this
244                if isinstance(query, exp.Union):
245                    query = query.this
246
247                cte.args["alias"].set(
248                    "columns",
249                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
250                )
251
252    return expression
def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
255def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
256    if (
257        isinstance(expression, (exp.Cast, exp.TryCast))
258        and expression.name.lower() == "epoch"
259        and expression.to.this in exp.DataType.TEMPORAL_TYPES
260    ):
261        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
262
263    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
266def preprocess(
267    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
268) -> t.Callable[[Generator, exp.Expression], str]:
269    """
270    Creates a new transform by chaining a sequence of transformations and converts the resulting
271    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
272    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
273
274    Args:
275        transforms: sequence of transform functions. These will be called in order.
276
277    Returns:
278        Function that can be used as a generator transform.
279    """
280
281    def _to_sql(self, expression: exp.Expression) -> str:
282        expression_type = type(expression)
283
284        expression = transforms[0](expression.copy())
285        for t in transforms[1:]:
286            expression = t(expression)
287
288        _sql_handler = getattr(self, expression.key + "_sql", None)
289        if _sql_handler:
290            return _sql_handler(expression)
291
292        transforms_handler = self.TRANSFORMS.get(type(expression))
293        if transforms_handler:
294            # Ensures we don't enter an infinite loop. This can happen when the original expression
295            # has the same type as the final expression and there's no _sql method available for it,
296            # because then it'd re-enter _to_sql.
297            if expression_type is type(expression):
298                raise ValueError(
299                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
300                )
301
302            return transforms_handler(self, expression)
303
304        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
305
306    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.