Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop().copy())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects)
 79            .from_(expression.subquery("_t"))
 80            .where(exp.column(row_number).eq(1))
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
124
125    return expression
126
127
128def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
129    """
130    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
131    other expressions. This transforms removes the precision from parameterized types in expressions.
132    """
133    for node in expression.find_all(exp.DataType):
134        node.set(
135            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
136        )
137
138    return expression
139
140
141def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
142    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
143    if isinstance(expression, exp.Select):
144        for join in expression.args.get("joins") or []:
145            unnest = join.this
146
147            if isinstance(unnest, exp.Unnest):
148                alias = unnest.args.get("alias")
149                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
150
151                expression.args["joins"].remove(join)
152
153                for e, column in zip(unnest.expressions, alias.columns if alias else []):
154                    expression.append(
155                        "laterals",
156                        exp.Lateral(
157                            this=udtf(this=e),
158                            view=True,
159                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
160                        ),
161                    )
162
163    return expression
164
165
166def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
167    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
168        """Convert explode/posexplode into unnest (used in hive -> presto)."""
169        if isinstance(expression, exp.Select):
170            from sqlglot.optimizer.scope import Scope
171
172            taken_select_names = set(expression.named_selects)
173            taken_source_names = {name for name, _ in Scope(expression).references}
174
175            def new_name(names: t.Set[str], name: str) -> str:
176                name = find_new_name(names, name)
177                names.add(name)
178                return name
179
180            arrays: t.List[exp.Condition] = []
181            series_alias = new_name(taken_select_names, "pos")
182            series = exp.alias_(
183                exp.Unnest(
184                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
185                ),
186                new_name(taken_source_names, "_u"),
187                table=[series_alias],
188            )
189
190            # we use list here because expression.selects is mutated inside the loop
191            for select in list(expression.selects):
192                to_replace = select
193                pos_alias = ""
194                explode_alias = ""
195
196                if isinstance(select, exp.Alias):
197                    explode_alias = select.alias
198                    select = select.this
199                elif isinstance(select, exp.Aliases):
200                    pos_alias = select.aliases[0].name
201                    explode_alias = select.aliases[1].name
202                    select = select.this
203
204                if isinstance(select, (exp.Explode, exp.Posexplode)):
205                    is_posexplode = isinstance(select, exp.Posexplode)
206                    explode_arg = select.this
207
208                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
209                    if isinstance(explode_arg, exp.Column):
210                        taken_select_names.add(explode_arg.output_name)
211
212                    unnest_source_alias = new_name(taken_source_names, "_u")
213
214                    if not explode_alias:
215                        explode_alias = new_name(taken_select_names, "col")
216
217                        if is_posexplode:
218                            pos_alias = new_name(taken_select_names, "pos")
219
220                    if not pos_alias:
221                        pos_alias = new_name(taken_select_names, "pos")
222
223                    column = exp.If(
224                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
225                        true=exp.column(explode_alias),
226                    ).as_(explode_alias)
227
228                    if is_posexplode:
229                        expressions = expression.expressions
230                        index = expressions.index(to_replace)
231                        expressions.pop(index)
232                        expressions.insert(index, column)
233                        expressions.insert(
234                            index + 1,
235                            exp.If(
236                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
237                                true=exp.column(pos_alias),
238                            ).as_(pos_alias),
239                        )
240                        expression.set("expressions", expressions)
241                    else:
242                        to_replace.replace(column)
243
244                    if not arrays:
245                        if expression.args.get("from"):
246                            expression.join(series, copy=False)
247                        else:
248                            expression.from_(series, copy=False)
249
250                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
251                    arrays.append(size)
252
253                    # trino doesn't support left join unnest with on conditions
254                    # if it did, this would be much simpler
255                    expression.join(
256                        exp.alias_(
257                            exp.Unnest(
258                                expressions=[explode_arg.copy()],
259                                offset=exp.to_identifier(pos_alias),
260                            ),
261                            unnest_source_alias,
262                            table=[explode_alias],
263                        ),
264                        join_type="CROSS",
265                        copy=False,
266                    )
267
268                    if index_offset != 1:
269                        size = size - 1
270
271                    expression.where(
272                        exp.column(series_alias)
273                        .eq(exp.column(pos_alias))
274                        .or_(
275                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
276                        ),
277                        copy=False,
278                    )
279
280            if arrays:
281                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
282
283                if index_offset != 1:
284                    end = end - (1 - index_offset)
285                series.expressions[0].set("end", end)
286
287        return expression
288
289    return _explode_to_unnest
290
291
292PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
293
294
295def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
296    if (
297        isinstance(expression, PERCENTILES)
298        and not isinstance(expression.parent, exp.WithinGroup)
299        and expression.expression
300    ):
301        column = expression.this.pop()
302        expression.set("this", expression.expression.pop())
303        order = exp.Order(expressions=[exp.Ordered(this=column)])
304        expression = exp.WithinGroup(this=expression, expression=order)
305
306    return expression
307
308
309def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
310    if (
311        isinstance(expression, exp.WithinGroup)
312        and isinstance(expression.this, PERCENTILES)
313        and isinstance(expression.expression, exp.Order)
314    ):
315        quantile = expression.this.this
316        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
317        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
318
319    return expression
320
321
322def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
323    if isinstance(expression, exp.With) and expression.recursive:
324        next_name = name_sequence("_c_")
325
326        for cte in expression.expressions:
327            if not cte.args["alias"].columns:
328                query = cte.this
329                if isinstance(query, exp.Union):
330                    query = query.this
331
332                cte.args["alias"].set(
333                    "columns",
334                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
335                )
336
337    return expression
338
339
340def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
341    if (
342        isinstance(expression, (exp.Cast, exp.TryCast))
343        and expression.name.lower() == "epoch"
344        and expression.to.this in exp.DataType.TEMPORAL_TYPES
345    ):
346        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
347
348    return expression
349
350
351def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
352    if isinstance(expression, exp.Timestamp) and not expression.expression:
353        return exp.cast(
354            expression.this,
355            to=exp.DataType.Type.TIMESTAMP,
356        )
357    return expression
358
359
360def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
361    if isinstance(expression, exp.Select):
362        for join in expression.args.get("joins") or []:
363            on = join.args.get("on")
364            if on and join.kind in ("SEMI", "ANTI"):
365                subquery = exp.select("1").from_(join.this).where(on)
366                exists = exp.Exists(this=subquery)
367                if join.kind == "ANTI":
368                    exists = exists.not_(copy=False)
369
370                join.pop()
371                expression.where(exists, copy=False)
372
373    return expression
374
375
376def preprocess(
377    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
378) -> t.Callable[[Generator, exp.Expression], str]:
379    """
380    Creates a new transform by chaining a sequence of transformations and converts the resulting
381    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
382    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
383
384    Args:
385        transforms: sequence of transform functions. These will be called in order.
386
387    Returns:
388        Function that can be used as a generator transform.
389    """
390
391    def _to_sql(self, expression: exp.Expression) -> str:
392        expression_type = type(expression)
393
394        expression = transforms[0](expression.copy())
395        for t in transforms[1:]:
396            expression = t(expression)
397
398        _sql_handler = getattr(self, expression.key + "_sql", None)
399        if _sql_handler:
400            return _sql_handler(expression)
401
402        transforms_handler = self.TRANSFORMS.get(type(expression))
403        if transforms_handler:
404            if expression_type is type(expression):
405                if isinstance(expression, exp.Func):
406                    return self.function_fallback_sql(expression)
407
408                # Ensures we don't enter an infinite loop. This can happen when the original expression
409                # has the same type as the final expression and there's no _sql method available for it,
410                # because then it'd re-enter _to_sql.
411                raise ValueError(
412                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
413                )
414
415            return transforms_handler(self, expression)
416
417        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
418
419    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop().copy())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects)
80            .from_(expression.subquery("_t"))
81            .where(exp.column(row_number).eq(1))
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
125
126    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
129def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
130    """
131    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
132    other expressions. This transforms removes the precision from parameterized types in expressions.
133    """
134    for node in expression.find_all(exp.DataType):
135        node.set(
136            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
137        )
138
139    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
142def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
143    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
144    if isinstance(expression, exp.Select):
145        for join in expression.args.get("joins") or []:
146            unnest = join.this
147
148            if isinstance(unnest, exp.Unnest):
149                alias = unnest.args.get("alias")
150                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
151
152                expression.args["joins"].remove(join)
153
154                for e, column in zip(unnest.expressions, alias.columns if alias else []):
155                    expression.append(
156                        "laterals",
157                        exp.Lateral(
158                            this=udtf(this=e),
159                            view=True,
160                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
161                        ),
162                    )
163
164    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
167def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
168    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
169        """Convert explode/posexplode into unnest (used in hive -> presto)."""
170        if isinstance(expression, exp.Select):
171            from sqlglot.optimizer.scope import Scope
172
173            taken_select_names = set(expression.named_selects)
174            taken_source_names = {name for name, _ in Scope(expression).references}
175
176            def new_name(names: t.Set[str], name: str) -> str:
177                name = find_new_name(names, name)
178                names.add(name)
179                return name
180
181            arrays: t.List[exp.Condition] = []
182            series_alias = new_name(taken_select_names, "pos")
183            series = exp.alias_(
184                exp.Unnest(
185                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
186                ),
187                new_name(taken_source_names, "_u"),
188                table=[series_alias],
189            )
190
191            # we use list here because expression.selects is mutated inside the loop
192            for select in list(expression.selects):
193                to_replace = select
194                pos_alias = ""
195                explode_alias = ""
196
197                if isinstance(select, exp.Alias):
198                    explode_alias = select.alias
199                    select = select.this
200                elif isinstance(select, exp.Aliases):
201                    pos_alias = select.aliases[0].name
202                    explode_alias = select.aliases[1].name
203                    select = select.this
204
205                if isinstance(select, (exp.Explode, exp.Posexplode)):
206                    is_posexplode = isinstance(select, exp.Posexplode)
207                    explode_arg = select.this
208
209                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
210                    if isinstance(explode_arg, exp.Column):
211                        taken_select_names.add(explode_arg.output_name)
212
213                    unnest_source_alias = new_name(taken_source_names, "_u")
214
215                    if not explode_alias:
216                        explode_alias = new_name(taken_select_names, "col")
217
218                        if is_posexplode:
219                            pos_alias = new_name(taken_select_names, "pos")
220
221                    if not pos_alias:
222                        pos_alias = new_name(taken_select_names, "pos")
223
224                    column = exp.If(
225                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
226                        true=exp.column(explode_alias),
227                    ).as_(explode_alias)
228
229                    if is_posexplode:
230                        expressions = expression.expressions
231                        index = expressions.index(to_replace)
232                        expressions.pop(index)
233                        expressions.insert(index, column)
234                        expressions.insert(
235                            index + 1,
236                            exp.If(
237                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
238                                true=exp.column(pos_alias),
239                            ).as_(pos_alias),
240                        )
241                        expression.set("expressions", expressions)
242                    else:
243                        to_replace.replace(column)
244
245                    if not arrays:
246                        if expression.args.get("from"):
247                            expression.join(series, copy=False)
248                        else:
249                            expression.from_(series, copy=False)
250
251                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
252                    arrays.append(size)
253
254                    # trino doesn't support left join unnest with on conditions
255                    # if it did, this would be much simpler
256                    expression.join(
257                        exp.alias_(
258                            exp.Unnest(
259                                expressions=[explode_arg.copy()],
260                                offset=exp.to_identifier(pos_alias),
261                            ),
262                            unnest_source_alias,
263                            table=[explode_alias],
264                        ),
265                        join_type="CROSS",
266                        copy=False,
267                    )
268
269                    if index_offset != 1:
270                        size = size - 1
271
272                    expression.where(
273                        exp.column(series_alias)
274                        .eq(exp.column(pos_alias))
275                        .or_(
276                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
277                        ),
278                        copy=False,
279                    )
280
281            if arrays:
282                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
283
284                if index_offset != 1:
285                    end = end - (1 - index_offset)
286                series.expressions[0].set("end", end)
287
288        return expression
289
290    return _explode_to_unnest
def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
296def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
297    if (
298        isinstance(expression, PERCENTILES)
299        and not isinstance(expression.parent, exp.WithinGroup)
300        and expression.expression
301    ):
302        column = expression.this.pop()
303        expression.set("this", expression.expression.pop())
304        order = exp.Order(expressions=[exp.Ordered(this=column)])
305        expression = exp.WithinGroup(this=expression, expression=order)
306
307    return expression
def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
310def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
311    if (
312        isinstance(expression, exp.WithinGroup)
313        and isinstance(expression.this, PERCENTILES)
314        and isinstance(expression.expression, exp.Order)
315    ):
316        quantile = expression.this.this
317        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
318        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
319
320    return expression
def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
323def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
324    if isinstance(expression, exp.With) and expression.recursive:
325        next_name = name_sequence("_c_")
326
327        for cte in expression.expressions:
328            if not cte.args["alias"].columns:
329                query = cte.this
330                if isinstance(query, exp.Union):
331                    query = query.this
332
333                cte.args["alias"].set(
334                    "columns",
335                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
336                )
337
338    return expression
def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
341def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
342    if (
343        isinstance(expression, (exp.Cast, exp.TryCast))
344        and expression.name.lower() == "epoch"
345        and expression.to.this in exp.DataType.TEMPORAL_TYPES
346    ):
347        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
348
349    return expression
def timestamp_to_cast( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
352def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
353    if isinstance(expression, exp.Timestamp) and not expression.expression:
354        return exp.cast(
355            expression.this,
356            to=exp.DataType.Type.TIMESTAMP,
357        )
358    return expression
def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
361def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
362    if isinstance(expression, exp.Select):
363        for join in expression.args.get("joins") or []:
364            on = join.args.get("on")
365            if on and join.kind in ("SEMI", "ANTI"):
366                subquery = exp.select("1").from_(join.this).where(on)
367                exists = exp.Exists(this=subquery)
368                if join.kind == "ANTI":
369                    exists = exists.not_(copy=False)
370
371                join.pop()
372                expression.where(exists, copy=False)
373
374    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
377def preprocess(
378    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
379) -> t.Callable[[Generator, exp.Expression], str]:
380    """
381    Creates a new transform by chaining a sequence of transformations and converts the resulting
382    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
383    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
384
385    Args:
386        transforms: sequence of transform functions. These will be called in order.
387
388    Returns:
389        Function that can be used as a generator transform.
390    """
391
392    def _to_sql(self, expression: exp.Expression) -> str:
393        expression_type = type(expression)
394
395        expression = transforms[0](expression.copy())
396        for t in transforms[1:]:
397            expression = t(expression)
398
399        _sql_handler = getattr(self, expression.key + "_sql", None)
400        if _sql_handler:
401            return _sql_handler(expression)
402
403        transforms_handler = self.TRANSFORMS.get(type(expression))
404        if transforms_handler:
405            if expression_type is type(expression):
406                if isinstance(expression, exp.Func):
407                    return self.function_fallback_sql(expression)
408
409                # Ensures we don't enter an infinite loop. This can happen when the original expression
410                # has the same type as the final expression and there's no _sql method available for it,
411                # because then it'd re-enter _to_sql.
412                raise ValueError(
413                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
414                )
415
416            return transforms_handler(self, expression)
417
418        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
419
420    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.