Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
124            qualify_filters, copy=False
125        )
126
127    return expression
128
129
130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
131    """
132    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
133    other expressions. This transforms removes the precision from parameterized types in expressions.
134    """
135    for node in expression.find_all(exp.DataType):
136        node.set(
137            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
138        )
139
140    return expression
141
142
143def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
144    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
145    if isinstance(expression, exp.Select):
146        for join in expression.args.get("joins") or []:
147            unnest = join.this
148
149            if isinstance(unnest, exp.Unnest):
150                alias = unnest.args.get("alias")
151                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
152
153                expression.args["joins"].remove(join)
154
155                for e, column in zip(unnest.expressions, alias.columns if alias else []):
156                    expression.append(
157                        "laterals",
158                        exp.Lateral(
159                            this=udtf(this=e),
160                            view=True,
161                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
162                        ),
163                    )
164
165    return expression
166
167
168def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
169    """Convert explode/posexplode into unnest (used in hive -> presto)."""
170
171    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
172        if isinstance(expression, exp.Select):
173            from sqlglot.optimizer.scope import Scope
174
175            taken_select_names = set(expression.named_selects)
176            taken_source_names = {name for name, _ in Scope(expression).references}
177
178            def new_name(names: t.Set[str], name: str) -> str:
179                name = find_new_name(names, name)
180                names.add(name)
181                return name
182
183            arrays: t.List[exp.Condition] = []
184            series_alias = new_name(taken_select_names, "pos")
185            series = exp.alias_(
186                exp.Unnest(
187                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
188                ),
189                new_name(taken_source_names, "_u"),
190                table=[series_alias],
191            )
192
193            # we use list here because expression.selects is mutated inside the loop
194            for select in list(expression.selects):
195                explode = select.find(exp.Explode)
196
197                if explode:
198                    pos_alias = ""
199                    explode_alias = ""
200
201                    if isinstance(select, exp.Alias):
202                        explode_alias = select.alias
203                        alias = select
204                    elif isinstance(select, exp.Aliases):
205                        pos_alias = select.aliases[0].name
206                        explode_alias = select.aliases[1].name
207                        alias = select.replace(exp.alias_(select.this, "", copy=False))
208                    else:
209                        alias = select.replace(exp.alias_(select, ""))
210                        explode = alias.find(exp.Explode)
211                        assert explode
212
213                    is_posexplode = isinstance(explode, exp.Posexplode)
214                    explode_arg = explode.this
215
216                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
217                    if isinstance(explode_arg, exp.Column):
218                        taken_select_names.add(explode_arg.output_name)
219
220                    unnest_source_alias = new_name(taken_source_names, "_u")
221
222                    if not explode_alias:
223                        explode_alias = new_name(taken_select_names, "col")
224
225                        if is_posexplode:
226                            pos_alias = new_name(taken_select_names, "pos")
227
228                    if not pos_alias:
229                        pos_alias = new_name(taken_select_names, "pos")
230
231                    alias.set("alias", exp.to_identifier(explode_alias))
232
233                    column = exp.If(
234                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
235                        true=exp.column(explode_alias),
236                    )
237
238                    explode.replace(column)
239
240                    if is_posexplode:
241                        expressions = expression.expressions
242                        expressions.insert(
243                            expressions.index(alias) + 1,
244                            exp.If(
245                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
246                                true=exp.column(pos_alias),
247                            ).as_(pos_alias),
248                        )
249                        expression.set("expressions", expressions)
250
251                    if not arrays:
252                        if expression.args.get("from"):
253                            expression.join(series, copy=False)
254                        else:
255                            expression.from_(series, copy=False)
256
257                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
258                    arrays.append(size)
259
260                    # trino doesn't support left join unnest with on conditions
261                    # if it did, this would be much simpler
262                    expression.join(
263                        exp.alias_(
264                            exp.Unnest(
265                                expressions=[explode_arg.copy()],
266                                offset=exp.to_identifier(pos_alias),
267                            ),
268                            unnest_source_alias,
269                            table=[explode_alias],
270                        ),
271                        join_type="CROSS",
272                        copy=False,
273                    )
274
275                    if index_offset != 1:
276                        size = size - 1
277
278                    expression.where(
279                        exp.column(series_alias)
280                        .eq(exp.column(pos_alias))
281                        .or_(
282                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
283                        ),
284                        copy=False,
285                    )
286
287            if arrays:
288                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
289
290                if index_offset != 1:
291                    end = end - (1 - index_offset)
292                series.expressions[0].set("end", end)
293
294        return expression
295
296    return _explode_to_unnest
297
298
299PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
300
301
302def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
303    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
304    if (
305        isinstance(expression, PERCENTILES)
306        and not isinstance(expression.parent, exp.WithinGroup)
307        and expression.expression
308    ):
309        column = expression.this.pop()
310        expression.set("this", expression.expression.pop())
311        order = exp.Order(expressions=[exp.Ordered(this=column)])
312        expression = exp.WithinGroup(this=expression, expression=order)
313
314    return expression
315
316
317def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
318    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
319    if (
320        isinstance(expression, exp.WithinGroup)
321        and isinstance(expression.this, PERCENTILES)
322        and isinstance(expression.expression, exp.Order)
323    ):
324        quantile = expression.this.this
325        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
326        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
327
328    return expression
329
330
331def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
332    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
333    if isinstance(expression, exp.With) and expression.recursive:
334        next_name = name_sequence("_c_")
335
336        for cte in expression.expressions:
337            if not cte.args["alias"].columns:
338                query = cte.this
339                if isinstance(query, exp.Union):
340                    query = query.this
341
342                cte.args["alias"].set(
343                    "columns",
344                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
345                )
346
347    return expression
348
349
350def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
351    """Replace 'epoch' in casts by the equivalent date literal."""
352    if (
353        isinstance(expression, (exp.Cast, exp.TryCast))
354        and expression.name.lower() == "epoch"
355        and expression.to.this in exp.DataType.TEMPORAL_TYPES
356    ):
357        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
358
359    return expression
360
361
362def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
363    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
364    if isinstance(expression, exp.Select):
365        for join in expression.args.get("joins") or []:
366            on = join.args.get("on")
367            if on and join.kind in ("SEMI", "ANTI"):
368                subquery = exp.select("1").from_(join.this).where(on)
369                exists = exp.Exists(this=subquery)
370                if join.kind == "ANTI":
371                    exists = exists.not_(copy=False)
372
373                join.pop()
374                expression.where(exists, copy=False)
375
376    return expression
377
378
379def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
380    """
381    Converts a query with a FULL OUTER join to a union of identical queries that
382    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
383    for queries that have a single FULL OUTER join.
384    """
385    if isinstance(expression, exp.Select):
386        full_outer_joins = [
387            (index, join)
388            for index, join in enumerate(expression.args.get("joins") or [])
389            if join.side == "FULL" and join.kind == "OUTER"
390        ]
391
392        if len(full_outer_joins) == 1:
393            expression_copy = expression.copy()
394            index, full_outer_join = full_outer_joins[0]
395            full_outer_join.set("side", "left")
396            expression_copy.args["joins"][index].set("side", "right")
397
398            return exp.union(expression, expression_copy, copy=False)
399
400    return expression
401
402
403def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
404    """
405    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
406    defined at the top-level, so for example queries like:
407
408        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
409
410    are invalid in those dialects. This transformation can be used to ensure all CTEs are
411    moved to the top level so that the final SQL code is valid from a syntax standpoint.
412
413    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
414    """
415    top_level_with = expression.args.get("with")
416    for node in expression.find_all(exp.With):
417        if node.parent is expression:
418            continue
419
420        inner_with = node.pop()
421        if not top_level_with:
422            top_level_with = inner_with
423            expression.set("with", top_level_with)
424        else:
425            if inner_with.recursive:
426                top_level_with.set("recursive", True)
427
428            top_level_with.expressions.extend(inner_with.expressions)
429
430    return expression
431
432
433def preprocess(
434    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
435) -> t.Callable[[Generator, exp.Expression], str]:
436    """
437    Creates a new transform by chaining a sequence of transformations and converts the resulting
438    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
439    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
440
441    Args:
442        transforms: sequence of transform functions. These will be called in order.
443
444    Returns:
445        Function that can be used as a generator transform.
446    """
447
448    def _to_sql(self, expression: exp.Expression) -> str:
449        expression_type = type(expression)
450
451        expression = transforms[0](expression)
452        for t in transforms[1:]:
453            expression = t(expression)
454
455        _sql_handler = getattr(self, expression.key + "_sql", None)
456        if _sql_handler:
457            return _sql_handler(expression)
458
459        transforms_handler = self.TRANSFORMS.get(type(expression))
460        if transforms_handler:
461            if expression_type is type(expression):
462                if isinstance(expression, exp.Func):
463                    return self.function_fallback_sql(expression)
464
465                # Ensures we don't enter an infinite loop. This can happen when the original expression
466                # has the same type as the final expression and there's no _sql method available for it,
467                # because then it'd re-enter _to_sql.
468                raise ValueError(
469                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
470                )
471
472            return transforms_handler(self, expression)
473
474        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
475
476    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
125            qualify_filters, copy=False
126        )
127
128    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
132    """
133    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
134    other expressions. This transforms removes the precision from parameterized types in expressions.
135    """
136    for node in expression.find_all(exp.DataType):
137        node.set(
138            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
139        )
140
141    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
144def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
145    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
146    if isinstance(expression, exp.Select):
147        for join in expression.args.get("joins") or []:
148            unnest = join.this
149
150            if isinstance(unnest, exp.Unnest):
151                alias = unnest.args.get("alias")
152                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
153
154                expression.args["joins"].remove(join)
155
156                for e, column in zip(unnest.expressions, alias.columns if alias else []):
157                    expression.append(
158                        "laterals",
159                        exp.Lateral(
160                            this=udtf(this=e),
161                            view=True,
162                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
163                        ),
164                    )
165
166    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
169def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
170    """Convert explode/posexplode into unnest (used in hive -> presto)."""
171
172    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
173        if isinstance(expression, exp.Select):
174            from sqlglot.optimizer.scope import Scope
175
176            taken_select_names = set(expression.named_selects)
177            taken_source_names = {name for name, _ in Scope(expression).references}
178
179            def new_name(names: t.Set[str], name: str) -> str:
180                name = find_new_name(names, name)
181                names.add(name)
182                return name
183
184            arrays: t.List[exp.Condition] = []
185            series_alias = new_name(taken_select_names, "pos")
186            series = exp.alias_(
187                exp.Unnest(
188                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
189                ),
190                new_name(taken_source_names, "_u"),
191                table=[series_alias],
192            )
193
194            # we use list here because expression.selects is mutated inside the loop
195            for select in list(expression.selects):
196                explode = select.find(exp.Explode)
197
198                if explode:
199                    pos_alias = ""
200                    explode_alias = ""
201
202                    if isinstance(select, exp.Alias):
203                        explode_alias = select.alias
204                        alias = select
205                    elif isinstance(select, exp.Aliases):
206                        pos_alias = select.aliases[0].name
207                        explode_alias = select.aliases[1].name
208                        alias = select.replace(exp.alias_(select.this, "", copy=False))
209                    else:
210                        alias = select.replace(exp.alias_(select, ""))
211                        explode = alias.find(exp.Explode)
212                        assert explode
213
214                    is_posexplode = isinstance(explode, exp.Posexplode)
215                    explode_arg = explode.this
216
217                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
218                    if isinstance(explode_arg, exp.Column):
219                        taken_select_names.add(explode_arg.output_name)
220
221                    unnest_source_alias = new_name(taken_source_names, "_u")
222
223                    if not explode_alias:
224                        explode_alias = new_name(taken_select_names, "col")
225
226                        if is_posexplode:
227                            pos_alias = new_name(taken_select_names, "pos")
228
229                    if not pos_alias:
230                        pos_alias = new_name(taken_select_names, "pos")
231
232                    alias.set("alias", exp.to_identifier(explode_alias))
233
234                    column = exp.If(
235                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
236                        true=exp.column(explode_alias),
237                    )
238
239                    explode.replace(column)
240
241                    if is_posexplode:
242                        expressions = expression.expressions
243                        expressions.insert(
244                            expressions.index(alias) + 1,
245                            exp.If(
246                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
247                                true=exp.column(pos_alias),
248                            ).as_(pos_alias),
249                        )
250                        expression.set("expressions", expressions)
251
252                    if not arrays:
253                        if expression.args.get("from"):
254                            expression.join(series, copy=False)
255                        else:
256                            expression.from_(series, copy=False)
257
258                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
259                    arrays.append(size)
260
261                    # trino doesn't support left join unnest with on conditions
262                    # if it did, this would be much simpler
263                    expression.join(
264                        exp.alias_(
265                            exp.Unnest(
266                                expressions=[explode_arg.copy()],
267                                offset=exp.to_identifier(pos_alias),
268                            ),
269                            unnest_source_alias,
270                            table=[explode_alias],
271                        ),
272                        join_type="CROSS",
273                        copy=False,
274                    )
275
276                    if index_offset != 1:
277                        size = size - 1
278
279                    expression.where(
280                        exp.column(series_alias)
281                        .eq(exp.column(pos_alias))
282                        .or_(
283                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
284                        ),
285                        copy=False,
286                    )
287
288            if arrays:
289                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
290
291                if index_offset != 1:
292                    end = end - (1 - index_offset)
293                series.expressions[0].set("end", end)
294
295        return expression
296
297    return _explode_to_unnest

Convert explode/posexplode into unnest (used in hive -> presto).

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
303def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
304    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
305    if (
306        isinstance(expression, PERCENTILES)
307        and not isinstance(expression.parent, exp.WithinGroup)
308        and expression.expression
309    ):
310        column = expression.this.pop()
311        expression.set("this", expression.expression.pop())
312        order = exp.Order(expressions=[exp.Ordered(this=column)])
313        expression = exp.WithinGroup(this=expression, expression=order)
314
315    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
318def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
319    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
320    if (
321        isinstance(expression, exp.WithinGroup)
322        and isinstance(expression.this, PERCENTILES)
323        and isinstance(expression.expression, exp.Order)
324    ):
325        quantile = expression.this.this
326        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
327        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
328
329    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
332def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
333    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
334    if isinstance(expression, exp.With) and expression.recursive:
335        next_name = name_sequence("_c_")
336
337        for cte in expression.expressions:
338            if not cte.args["alias"].columns:
339                query = cte.this
340                if isinstance(query, exp.Union):
341                    query = query.this
342
343                cte.args["alias"].set(
344                    "columns",
345                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
346                )
347
348    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
351def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
352    """Replace 'epoch' in casts by the equivalent date literal."""
353    if (
354        isinstance(expression, (exp.Cast, exp.TryCast))
355        and expression.name.lower() == "epoch"
356        and expression.to.this in exp.DataType.TEMPORAL_TYPES
357    ):
358        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
359
360    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
363def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
364    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
365    if isinstance(expression, exp.Select):
366        for join in expression.args.get("joins") or []:
367            on = join.args.get("on")
368            if on and join.kind in ("SEMI", "ANTI"):
369                subquery = exp.select("1").from_(join.this).where(on)
370                exists = exp.Exists(this=subquery)
371                if join.kind == "ANTI":
372                    exists = exists.not_(copy=False)
373
374                join.pop()
375                expression.where(exists, copy=False)
376
377    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
380def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
381    """
382    Converts a query with a FULL OUTER join to a union of identical queries that
383    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
384    for queries that have a single FULL OUTER join.
385    """
386    if isinstance(expression, exp.Select):
387        full_outer_joins = [
388            (index, join)
389            for index, join in enumerate(expression.args.get("joins") or [])
390            if join.side == "FULL" and join.kind == "OUTER"
391        ]
392
393        if len(full_outer_joins) == 1:
394            expression_copy = expression.copy()
395            index, full_outer_join = full_outer_joins[0]
396            full_outer_join.set("side", "left")
397            expression_copy.args["joins"][index].set("side", "right")
398
399            return exp.union(expression, expression_copy, copy=False)
400
401    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
404def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
405    """
406    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
407    defined at the top-level, so for example queries like:
408
409        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
410
411    are invalid in those dialects. This transformation can be used to ensure all CTEs are
412    moved to the top level so that the final SQL code is valid from a syntax standpoint.
413
414    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
415    """
416    top_level_with = expression.args.get("with")
417    for node in expression.find_all(exp.With):
418        if node.parent is expression:
419            continue
420
421        inner_with = node.pop()
422        if not top_level_with:
423            top_level_with = inner_with
424            expression.set("with", top_level_with)
425        else:
426            if inner_with.recursive:
427                top_level_with.set("recursive", True)
428
429            top_level_with.expressions.extend(inner_with.expressions)
430
431    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
434def preprocess(
435    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
436) -> t.Callable[[Generator, exp.Expression], str]:
437    """
438    Creates a new transform by chaining a sequence of transformations and converts the resulting
439    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
440    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
441
442    Args:
443        transforms: sequence of transform functions. These will be called in order.
444
445    Returns:
446        Function that can be used as a generator transform.
447    """
448
449    def _to_sql(self, expression: exp.Expression) -> str:
450        expression_type = type(expression)
451
452        expression = transforms[0](expression)
453        for t in transforms[1:]:
454            expression = t(expression)
455
456        _sql_handler = getattr(self, expression.key + "_sql", None)
457        if _sql_handler:
458            return _sql_handler(expression)
459
460        transforms_handler = self.TRANSFORMS.get(type(expression))
461        if transforms_handler:
462            if expression_type is type(expression):
463                if isinstance(expression, exp.Func):
464                    return self.function_fallback_sql(expression)
465
466                # Ensures we don't enter an infinite loop. This can happen when the original expression
467                # has the same type as the final expression and there's no _sql method available for it,
468                # because then it'd re-enter _to_sql.
469                raise ValueError(
470                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
471                )
472
473            return transforms_handler(self, expression)
474
475        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
476
477    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.