Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
124            qualify_filters, copy=False
125        )
126
127    return expression
128
129
130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
131    """
132    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
133    other expressions. This transforms removes the precision from parameterized types in expressions.
134    """
135    for node in expression.find_all(exp.DataType):
136        node.set(
137            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
138        )
139
140    return expression
141
142
143def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
144    """Convert cross join unnest into lateral view explode."""
145    if isinstance(expression, exp.Select):
146        for join in expression.args.get("joins") or []:
147            unnest = join.this
148
149            if isinstance(unnest, exp.Unnest):
150                alias = unnest.args.get("alias")
151                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
152
153                expression.args["joins"].remove(join)
154
155                for e, column in zip(unnest.expressions, alias.columns if alias else []):
156                    expression.append(
157                        "laterals",
158                        exp.Lateral(
159                            this=udtf(this=e),
160                            view=True,
161                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
162                        ),
163                    )
164
165    return expression
166
167
168def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
169    """Convert explode/posexplode into unnest."""
170
171    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
172        if isinstance(expression, exp.Select):
173            from sqlglot.optimizer.scope import Scope
174
175            taken_select_names = set(expression.named_selects)
176            taken_source_names = {name for name, _ in Scope(expression).references}
177
178            def new_name(names: t.Set[str], name: str) -> str:
179                name = find_new_name(names, name)
180                names.add(name)
181                return name
182
183            arrays: t.List[exp.Condition] = []
184            series_alias = new_name(taken_select_names, "pos")
185            series = exp.alias_(
186                exp.Unnest(
187                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
188                ),
189                new_name(taken_source_names, "_u"),
190                table=[series_alias],
191            )
192
193            # we use list here because expression.selects is mutated inside the loop
194            for select in list(expression.selects):
195                explode = select.find(exp.Explode)
196
197                if explode:
198                    pos_alias = ""
199                    explode_alias = ""
200
201                    if isinstance(select, exp.Alias):
202                        explode_alias = select.args["alias"]
203                        alias = select
204                    elif isinstance(select, exp.Aliases):
205                        pos_alias = select.aliases[0]
206                        explode_alias = select.aliases[1]
207                        alias = select.replace(exp.alias_(select.this, "", copy=False))
208                    else:
209                        alias = select.replace(exp.alias_(select, ""))
210                        explode = alias.find(exp.Explode)
211                        assert explode
212
213                    is_posexplode = isinstance(explode, exp.Posexplode)
214                    explode_arg = explode.this
215
216                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
217                    if isinstance(explode_arg, exp.Column):
218                        taken_select_names.add(explode_arg.output_name)
219
220                    unnest_source_alias = new_name(taken_source_names, "_u")
221
222                    if not explode_alias:
223                        explode_alias = new_name(taken_select_names, "col")
224
225                        if is_posexplode:
226                            pos_alias = new_name(taken_select_names, "pos")
227
228                    if not pos_alias:
229                        pos_alias = new_name(taken_select_names, "pos")
230
231                    alias.set("alias", exp.to_identifier(explode_alias))
232
233                    series_table_alias = series.args["alias"].this
234                    column = exp.If(
235                        this=exp.column(series_alias, table=series_table_alias).eq(
236                            exp.column(pos_alias, table=unnest_source_alias)
237                        ),
238                        true=exp.column(explode_alias, table=unnest_source_alias),
239                    )
240
241                    explode.replace(column)
242
243                    if is_posexplode:
244                        expressions = expression.expressions
245                        expressions.insert(
246                            expressions.index(alias) + 1,
247                            exp.If(
248                                this=exp.column(series_alias, table=series_table_alias).eq(
249                                    exp.column(pos_alias, table=unnest_source_alias)
250                                ),
251                                true=exp.column(pos_alias, table=unnest_source_alias),
252                            ).as_(pos_alias),
253                        )
254                        expression.set("expressions", expressions)
255
256                    if not arrays:
257                        if expression.args.get("from"):
258                            expression.join(series, copy=False, join_type="CROSS")
259                        else:
260                            expression.from_(series, copy=False)
261
262                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
263                    arrays.append(size)
264
265                    # trino doesn't support left join unnest with on conditions
266                    # if it did, this would be much simpler
267                    expression.join(
268                        exp.alias_(
269                            exp.Unnest(
270                                expressions=[explode_arg.copy()],
271                                offset=exp.to_identifier(pos_alias),
272                            ),
273                            unnest_source_alias,
274                            table=[explode_alias],
275                        ),
276                        join_type="CROSS",
277                        copy=False,
278                    )
279
280                    if index_offset != 1:
281                        size = size - 1
282
283                    expression.where(
284                        exp.column(series_alias, table=series_table_alias)
285                        .eq(exp.column(pos_alias, table=unnest_source_alias))
286                        .or_(
287                            (exp.column(series_alias, table=series_table_alias) > size).and_(
288                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
289                            )
290                        ),
291                        copy=False,
292                    )
293
294            if arrays:
295                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
296
297                if index_offset != 1:
298                    end = end - (1 - index_offset)
299                series.expressions[0].set("end", end)
300
301        return expression
302
303    return _explode_to_unnest
304
305
306PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
307
308
309def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
310    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
311    if (
312        isinstance(expression, PERCENTILES)
313        and not isinstance(expression.parent, exp.WithinGroup)
314        and expression.expression
315    ):
316        column = expression.this.pop()
317        expression.set("this", expression.expression.pop())
318        order = exp.Order(expressions=[exp.Ordered(this=column)])
319        expression = exp.WithinGroup(this=expression, expression=order)
320
321    return expression
322
323
324def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
325    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
326    if (
327        isinstance(expression, exp.WithinGroup)
328        and isinstance(expression.this, PERCENTILES)
329        and isinstance(expression.expression, exp.Order)
330    ):
331        quantile = expression.this.this
332        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
333        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
334
335    return expression
336
337
338def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
339    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
340    if isinstance(expression, exp.With) and expression.recursive:
341        next_name = name_sequence("_c_")
342
343        for cte in expression.expressions:
344            if not cte.args["alias"].columns:
345                query = cte.this
346                if isinstance(query, exp.Union):
347                    query = query.this
348
349                cte.args["alias"].set(
350                    "columns",
351                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
352                )
353
354    return expression
355
356
357def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
358    """Replace 'epoch' in casts by the equivalent date literal."""
359    if (
360        isinstance(expression, (exp.Cast, exp.TryCast))
361        and expression.name.lower() == "epoch"
362        and expression.to.this in exp.DataType.TEMPORAL_TYPES
363    ):
364        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
365
366    return expression
367
368
369def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
370    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
371    if isinstance(expression, exp.Select):
372        for join in expression.args.get("joins") or []:
373            on = join.args.get("on")
374            if on and join.kind in ("SEMI", "ANTI"):
375                subquery = exp.select("1").from_(join.this).where(on)
376                exists = exp.Exists(this=subquery)
377                if join.kind == "ANTI":
378                    exists = exists.not_(copy=False)
379
380                join.pop()
381                expression.where(exists, copy=False)
382
383    return expression
384
385
386def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
387    """
388    Converts a query with a FULL OUTER join to a union of identical queries that
389    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
390    for queries that have a single FULL OUTER join.
391    """
392    if isinstance(expression, exp.Select):
393        full_outer_joins = [
394            (index, join)
395            for index, join in enumerate(expression.args.get("joins") or [])
396            if join.side == "FULL"
397        ]
398
399        if len(full_outer_joins) == 1:
400            expression_copy = expression.copy()
401            expression.set("limit", None)
402            index, full_outer_join = full_outer_joins[0]
403            full_outer_join.set("side", "left")
404            expression_copy.args["joins"][index].set("side", "right")
405            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
406
407            return exp.union(expression, expression_copy, copy=False)
408
409    return expression
410
411
412def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
413    """
414    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
415    defined at the top-level, so for example queries like:
416
417        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
418
419    are invalid in those dialects. This transformation can be used to ensure all CTEs are
420    moved to the top level so that the final SQL code is valid from a syntax standpoint.
421
422    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
423    """
424    top_level_with = expression.args.get("with")
425    for node in expression.find_all(exp.With):
426        if node.parent is expression:
427            continue
428
429        inner_with = node.pop()
430        if not top_level_with:
431            top_level_with = inner_with
432            expression.set("with", top_level_with)
433        else:
434            if inner_with.recursive:
435                top_level_with.set("recursive", True)
436
437            top_level_with.expressions.extend(inner_with.expressions)
438
439    return expression
440
441
442def ensure_bools(expression: exp.Expression) -> exp.Expression:
443    """Converts numeric values used in conditions into explicit boolean expressions."""
444    from sqlglot.optimizer.canonicalize import ensure_bools
445
446    def _ensure_bool(node: exp.Expression) -> None:
447        if (
448            node.is_number
449            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
450            or (isinstance(node, exp.Column) and not node.type)
451        ):
452            node.replace(node.neq(0))
453
454    for node, *_ in expression.walk():
455        ensure_bools(node, _ensure_bool)
456
457    return expression
458
459
460def unqualify_columns(expression: exp.Expression) -> exp.Expression:
461    for column in expression.find_all(exp.Column):
462        # We only wanna pop off the table, db, catalog args
463        for part in column.parts[:-1]:
464            part.pop()
465
466    return expression
467
468
469def preprocess(
470    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
471) -> t.Callable[[Generator, exp.Expression], str]:
472    """
473    Creates a new transform by chaining a sequence of transformations and converts the resulting
474    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
475    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
476
477    Args:
478        transforms: sequence of transform functions. These will be called in order.
479
480    Returns:
481        Function that can be used as a generator transform.
482    """
483
484    def _to_sql(self, expression: exp.Expression) -> str:
485        expression_type = type(expression)
486
487        expression = transforms[0](expression)
488        for transform in transforms[1:]:
489            expression = transform(expression)
490
491        _sql_handler = getattr(self, expression.key + "_sql", None)
492        if _sql_handler:
493            return _sql_handler(expression)
494
495        transforms_handler = self.TRANSFORMS.get(type(expression))
496        if transforms_handler:
497            if expression_type is type(expression):
498                if isinstance(expression, exp.Func):
499                    return self.function_fallback_sql(expression)
500
501                # Ensures we don't enter an infinite loop. This can happen when the original expression
502                # has the same type as the final expression and there's no _sql method available for it,
503                # because then it'd re-enter _to_sql.
504                raise ValueError(
505                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
506                )
507
508            return transforms_handler(self, expression)
509
510        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
511
512    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
125            qualify_filters, copy=False
126        )
127
128    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
132    """
133    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
134    other expressions. This transforms removes the precision from parameterized types in expressions.
135    """
136    for node in expression.find_all(exp.DataType):
137        node.set(
138            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
139        )
140
141    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
144def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
145    """Convert cross join unnest into lateral view explode."""
146    if isinstance(expression, exp.Select):
147        for join in expression.args.get("joins") or []:
148            unnest = join.this
149
150            if isinstance(unnest, exp.Unnest):
151                alias = unnest.args.get("alias")
152                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
153
154                expression.args["joins"].remove(join)
155
156                for e, column in zip(unnest.expressions, alias.columns if alias else []):
157                    expression.append(
158                        "laterals",
159                        exp.Lateral(
160                            this=udtf(this=e),
161                            view=True,
162                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
163                        ),
164                    )
165
166    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
169def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
170    """Convert explode/posexplode into unnest."""
171
172    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
173        if isinstance(expression, exp.Select):
174            from sqlglot.optimizer.scope import Scope
175
176            taken_select_names = set(expression.named_selects)
177            taken_source_names = {name for name, _ in Scope(expression).references}
178
179            def new_name(names: t.Set[str], name: str) -> str:
180                name = find_new_name(names, name)
181                names.add(name)
182                return name
183
184            arrays: t.List[exp.Condition] = []
185            series_alias = new_name(taken_select_names, "pos")
186            series = exp.alias_(
187                exp.Unnest(
188                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
189                ),
190                new_name(taken_source_names, "_u"),
191                table=[series_alias],
192            )
193
194            # we use list here because expression.selects is mutated inside the loop
195            for select in list(expression.selects):
196                explode = select.find(exp.Explode)
197
198                if explode:
199                    pos_alias = ""
200                    explode_alias = ""
201
202                    if isinstance(select, exp.Alias):
203                        explode_alias = select.args["alias"]
204                        alias = select
205                    elif isinstance(select, exp.Aliases):
206                        pos_alias = select.aliases[0]
207                        explode_alias = select.aliases[1]
208                        alias = select.replace(exp.alias_(select.this, "", copy=False))
209                    else:
210                        alias = select.replace(exp.alias_(select, ""))
211                        explode = alias.find(exp.Explode)
212                        assert explode
213
214                    is_posexplode = isinstance(explode, exp.Posexplode)
215                    explode_arg = explode.this
216
217                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
218                    if isinstance(explode_arg, exp.Column):
219                        taken_select_names.add(explode_arg.output_name)
220
221                    unnest_source_alias = new_name(taken_source_names, "_u")
222
223                    if not explode_alias:
224                        explode_alias = new_name(taken_select_names, "col")
225
226                        if is_posexplode:
227                            pos_alias = new_name(taken_select_names, "pos")
228
229                    if not pos_alias:
230                        pos_alias = new_name(taken_select_names, "pos")
231
232                    alias.set("alias", exp.to_identifier(explode_alias))
233
234                    series_table_alias = series.args["alias"].this
235                    column = exp.If(
236                        this=exp.column(series_alias, table=series_table_alias).eq(
237                            exp.column(pos_alias, table=unnest_source_alias)
238                        ),
239                        true=exp.column(explode_alias, table=unnest_source_alias),
240                    )
241
242                    explode.replace(column)
243
244                    if is_posexplode:
245                        expressions = expression.expressions
246                        expressions.insert(
247                            expressions.index(alias) + 1,
248                            exp.If(
249                                this=exp.column(series_alias, table=series_table_alias).eq(
250                                    exp.column(pos_alias, table=unnest_source_alias)
251                                ),
252                                true=exp.column(pos_alias, table=unnest_source_alias),
253                            ).as_(pos_alias),
254                        )
255                        expression.set("expressions", expressions)
256
257                    if not arrays:
258                        if expression.args.get("from"):
259                            expression.join(series, copy=False, join_type="CROSS")
260                        else:
261                            expression.from_(series, copy=False)
262
263                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
264                    arrays.append(size)
265
266                    # trino doesn't support left join unnest with on conditions
267                    # if it did, this would be much simpler
268                    expression.join(
269                        exp.alias_(
270                            exp.Unnest(
271                                expressions=[explode_arg.copy()],
272                                offset=exp.to_identifier(pos_alias),
273                            ),
274                            unnest_source_alias,
275                            table=[explode_alias],
276                        ),
277                        join_type="CROSS",
278                        copy=False,
279                    )
280
281                    if index_offset != 1:
282                        size = size - 1
283
284                    expression.where(
285                        exp.column(series_alias, table=series_table_alias)
286                        .eq(exp.column(pos_alias, table=unnest_source_alias))
287                        .or_(
288                            (exp.column(series_alias, table=series_table_alias) > size).and_(
289                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
290                            )
291                        ),
292                        copy=False,
293                    )
294
295            if arrays:
296                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
297
298                if index_offset != 1:
299                    end = end - (1 - index_offset)
300                series.expressions[0].set("end", end)
301
302        return expression
303
304    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
310def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
311    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
312    if (
313        isinstance(expression, PERCENTILES)
314        and not isinstance(expression.parent, exp.WithinGroup)
315        and expression.expression
316    ):
317        column = expression.this.pop()
318        expression.set("this", expression.expression.pop())
319        order = exp.Order(expressions=[exp.Ordered(this=column)])
320        expression = exp.WithinGroup(this=expression, expression=order)
321
322    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
325def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
326    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
327    if (
328        isinstance(expression, exp.WithinGroup)
329        and isinstance(expression.this, PERCENTILES)
330        and isinstance(expression.expression, exp.Order)
331    ):
332        quantile = expression.this.this
333        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
334        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
335
336    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
339def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
340    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
341    if isinstance(expression, exp.With) and expression.recursive:
342        next_name = name_sequence("_c_")
343
344        for cte in expression.expressions:
345            if not cte.args["alias"].columns:
346                query = cte.this
347                if isinstance(query, exp.Union):
348                    query = query.this
349
350                cte.args["alias"].set(
351                    "columns",
352                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
353                )
354
355    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
358def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
359    """Replace 'epoch' in casts by the equivalent date literal."""
360    if (
361        isinstance(expression, (exp.Cast, exp.TryCast))
362        and expression.name.lower() == "epoch"
363        and expression.to.this in exp.DataType.TEMPORAL_TYPES
364    ):
365        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
366
367    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
370def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
371    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
372    if isinstance(expression, exp.Select):
373        for join in expression.args.get("joins") or []:
374            on = join.args.get("on")
375            if on and join.kind in ("SEMI", "ANTI"):
376                subquery = exp.select("1").from_(join.this).where(on)
377                exists = exp.Exists(this=subquery)
378                if join.kind == "ANTI":
379                    exists = exists.not_(copy=False)
380
381                join.pop()
382                expression.where(exists, copy=False)
383
384    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
387def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
388    """
389    Converts a query with a FULL OUTER join to a union of identical queries that
390    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
391    for queries that have a single FULL OUTER join.
392    """
393    if isinstance(expression, exp.Select):
394        full_outer_joins = [
395            (index, join)
396            for index, join in enumerate(expression.args.get("joins") or [])
397            if join.side == "FULL"
398        ]
399
400        if len(full_outer_joins) == 1:
401            expression_copy = expression.copy()
402            expression.set("limit", None)
403            index, full_outer_join = full_outer_joins[0]
404            full_outer_join.set("side", "left")
405            expression_copy.args["joins"][index].set("side", "right")
406            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
407
408            return exp.union(expression, expression_copy, copy=False)
409
410    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
413def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
414    """
415    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
416    defined at the top-level, so for example queries like:
417
418        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
419
420    are invalid in those dialects. This transformation can be used to ensure all CTEs are
421    moved to the top level so that the final SQL code is valid from a syntax standpoint.
422
423    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
424    """
425    top_level_with = expression.args.get("with")
426    for node in expression.find_all(exp.With):
427        if node.parent is expression:
428            continue
429
430        inner_with = node.pop()
431        if not top_level_with:
432            top_level_with = inner_with
433            expression.set("with", top_level_with)
434        else:
435            if inner_with.recursive:
436                top_level_with.set("recursive", True)
437
438            top_level_with.expressions.extend(inner_with.expressions)
439
440    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
443def ensure_bools(expression: exp.Expression) -> exp.Expression:
444    """Converts numeric values used in conditions into explicit boolean expressions."""
445    from sqlglot.optimizer.canonicalize import ensure_bools
446
447    def _ensure_bool(node: exp.Expression) -> None:
448        if (
449            node.is_number
450            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
451            or (isinstance(node, exp.Column) and not node.type)
452        ):
453            node.replace(node.neq(0))
454
455    for node, *_ in expression.walk():
456        ensure_bools(node, _ensure_bool)
457
458    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
461def unqualify_columns(expression: exp.Expression) -> exp.Expression:
462    for column in expression.find_all(exp.Column):
463        # We only wanna pop off the table, db, catalog args
464        for part in column.parts[:-1]:
465            part.pop()
466
467    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
470def preprocess(
471    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
472) -> t.Callable[[Generator, exp.Expression], str]:
473    """
474    Creates a new transform by chaining a sequence of transformations and converts the resulting
475    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
476    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
477
478    Args:
479        transforms: sequence of transform functions. These will be called in order.
480
481    Returns:
482        Function that can be used as a generator transform.
483    """
484
485    def _to_sql(self, expression: exp.Expression) -> str:
486        expression_type = type(expression)
487
488        expression = transforms[0](expression)
489        for transform in transforms[1:]:
490            expression = transform(expression)
491
492        _sql_handler = getattr(self, expression.key + "_sql", None)
493        if _sql_handler:
494            return _sql_handler(expression)
495
496        transforms_handler = self.TRANSFORMS.get(type(expression))
497        if transforms_handler:
498            if expression_type is type(expression):
499                if isinstance(expression, exp.Func):
500                    return self.function_fallback_sql(expression)
501
502                # Ensures we don't enter an infinite loop. This can happen when the original expression
503                # has the same type as the final expression and there's no _sql method available for it,
504                # because then it'd re-enter _to_sql.
505                raise ValueError(
506                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
507                )
508
509            return transforms_handler(self, expression)
510
511        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
512
513    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.