Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
124            qualify_filters, copy=False
125        )
126
127    return expression
128
129
130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
131    """
132    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
133    other expressions. This transforms removes the precision from parameterized types in expressions.
134    """
135    for node in expression.find_all(exp.DataType):
136        node.set(
137            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
138        )
139
140    return expression
141
142
143def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
144    """Convert cross join unnest into lateral view explode."""
145    if isinstance(expression, exp.Select):
146        for join in expression.args.get("joins") or []:
147            unnest = join.this
148
149            if isinstance(unnest, exp.Unnest):
150                alias = unnest.args.get("alias")
151                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
152
153                expression.args["joins"].remove(join)
154
155                for e, column in zip(unnest.expressions, alias.columns if alias else []):
156                    expression.append(
157                        "laterals",
158                        exp.Lateral(
159                            this=udtf(this=e),
160                            view=True,
161                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
162                        ),
163                    )
164
165    return expression
166
167
168def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
169    """Convert explode/posexplode into unnest."""
170
171    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
172        if isinstance(expression, exp.Select):
173            from sqlglot.optimizer.scope import Scope
174
175            taken_select_names = set(expression.named_selects)
176            taken_source_names = {name for name, _ in Scope(expression).references}
177
178            def new_name(names: t.Set[str], name: str) -> str:
179                name = find_new_name(names, name)
180                names.add(name)
181                return name
182
183            arrays: t.List[exp.Condition] = []
184            series_alias = new_name(taken_select_names, "pos")
185            series = exp.alias_(
186                exp.Unnest(
187                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
188                ),
189                new_name(taken_source_names, "_u"),
190                table=[series_alias],
191            )
192
193            # we use list here because expression.selects is mutated inside the loop
194            for select in list(expression.selects):
195                explode = select.find(exp.Explode)
196
197                if explode:
198                    pos_alias = ""
199                    explode_alias = ""
200
201                    if isinstance(select, exp.Alias):
202                        explode_alias = select.args["alias"]
203                        alias = select
204                    elif isinstance(select, exp.Aliases):
205                        pos_alias = select.aliases[0]
206                        explode_alias = select.aliases[1]
207                        alias = select.replace(exp.alias_(select.this, "", copy=False))
208                    else:
209                        alias = select.replace(exp.alias_(select, ""))
210                        explode = alias.find(exp.Explode)
211                        assert explode
212
213                    is_posexplode = isinstance(explode, exp.Posexplode)
214                    explode_arg = explode.this
215
216                    if isinstance(explode, exp.ExplodeOuter):
217                        bracket = explode_arg[0]
218                        bracket.set("safe", True)
219                        bracket.set("offset", True)
220                        explode_arg = exp.func(
221                            "IF",
222                            exp.func(
223                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
224                            ).eq(0),
225                            exp.array(bracket, copy=False),
226                            explode_arg,
227                        )
228
229                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
230                    if isinstance(explode_arg, exp.Column):
231                        taken_select_names.add(explode_arg.output_name)
232
233                    unnest_source_alias = new_name(taken_source_names, "_u")
234
235                    if not explode_alias:
236                        explode_alias = new_name(taken_select_names, "col")
237
238                        if is_posexplode:
239                            pos_alias = new_name(taken_select_names, "pos")
240
241                    if not pos_alias:
242                        pos_alias = new_name(taken_select_names, "pos")
243
244                    alias.set("alias", exp.to_identifier(explode_alias))
245
246                    series_table_alias = series.args["alias"].this
247                    column = exp.If(
248                        this=exp.column(series_alias, table=series_table_alias).eq(
249                            exp.column(pos_alias, table=unnest_source_alias)
250                        ),
251                        true=exp.column(explode_alias, table=unnest_source_alias),
252                    )
253
254                    explode.replace(column)
255
256                    if is_posexplode:
257                        expressions = expression.expressions
258                        expressions.insert(
259                            expressions.index(alias) + 1,
260                            exp.If(
261                                this=exp.column(series_alias, table=series_table_alias).eq(
262                                    exp.column(pos_alias, table=unnest_source_alias)
263                                ),
264                                true=exp.column(pos_alias, table=unnest_source_alias),
265                            ).as_(pos_alias),
266                        )
267                        expression.set("expressions", expressions)
268
269                    if not arrays:
270                        if expression.args.get("from"):
271                            expression.join(series, copy=False, join_type="CROSS")
272                        else:
273                            expression.from_(series, copy=False)
274
275                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
276                    arrays.append(size)
277
278                    # trino doesn't support left join unnest with on conditions
279                    # if it did, this would be much simpler
280                    expression.join(
281                        exp.alias_(
282                            exp.Unnest(
283                                expressions=[explode_arg.copy()],
284                                offset=exp.to_identifier(pos_alias),
285                            ),
286                            unnest_source_alias,
287                            table=[explode_alias],
288                        ),
289                        join_type="CROSS",
290                        copy=False,
291                    )
292
293                    if index_offset != 1:
294                        size = size - 1
295
296                    expression.where(
297                        exp.column(series_alias, table=series_table_alias)
298                        .eq(exp.column(pos_alias, table=unnest_source_alias))
299                        .or_(
300                            (exp.column(series_alias, table=series_table_alias) > size).and_(
301                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
302                            )
303                        ),
304                        copy=False,
305                    )
306
307            if arrays:
308                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
309
310                if index_offset != 1:
311                    end = end - (1 - index_offset)
312                series.expressions[0].set("end", end)
313
314        return expression
315
316    return _explode_to_unnest
317
318
319PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
320
321
322def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
323    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
324    if (
325        isinstance(expression, PERCENTILES)
326        and not isinstance(expression.parent, exp.WithinGroup)
327        and expression.expression
328    ):
329        column = expression.this.pop()
330        expression.set("this", expression.expression.pop())
331        order = exp.Order(expressions=[exp.Ordered(this=column)])
332        expression = exp.WithinGroup(this=expression, expression=order)
333
334    return expression
335
336
337def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
338    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
339    if (
340        isinstance(expression, exp.WithinGroup)
341        and isinstance(expression.this, PERCENTILES)
342        and isinstance(expression.expression, exp.Order)
343    ):
344        quantile = expression.this.this
345        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
346        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
347
348    return expression
349
350
351def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
352    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
353    if isinstance(expression, exp.With) and expression.recursive:
354        next_name = name_sequence("_c_")
355
356        for cte in expression.expressions:
357            if not cte.args["alias"].columns:
358                query = cte.this
359                if isinstance(query, exp.Union):
360                    query = query.this
361
362                cte.args["alias"].set(
363                    "columns",
364                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
365                )
366
367    return expression
368
369
370def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
371    """Replace 'epoch' in casts by the equivalent date literal."""
372    if (
373        isinstance(expression, (exp.Cast, exp.TryCast))
374        and expression.name.lower() == "epoch"
375        and expression.to.this in exp.DataType.TEMPORAL_TYPES
376    ):
377        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
378
379    return expression
380
381
382def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
383    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
384    if isinstance(expression, exp.Select):
385        for join in expression.args.get("joins") or []:
386            on = join.args.get("on")
387            if on and join.kind in ("SEMI", "ANTI"):
388                subquery = exp.select("1").from_(join.this).where(on)
389                exists = exp.Exists(this=subquery)
390                if join.kind == "ANTI":
391                    exists = exists.not_(copy=False)
392
393                join.pop()
394                expression.where(exists, copy=False)
395
396    return expression
397
398
399def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
400    """
401    Converts a query with a FULL OUTER join to a union of identical queries that
402    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
403    for queries that have a single FULL OUTER join.
404    """
405    if isinstance(expression, exp.Select):
406        full_outer_joins = [
407            (index, join)
408            for index, join in enumerate(expression.args.get("joins") or [])
409            if join.side == "FULL"
410        ]
411
412        if len(full_outer_joins) == 1:
413            expression_copy = expression.copy()
414            expression.set("limit", None)
415            index, full_outer_join = full_outer_joins[0]
416            full_outer_join.set("side", "left")
417            expression_copy.args["joins"][index].set("side", "right")
418            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
419
420            return exp.union(expression, expression_copy, copy=False)
421
422    return expression
423
424
425def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
426    """
427    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
428    defined at the top-level, so for example queries like:
429
430        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
431
432    are invalid in those dialects. This transformation can be used to ensure all CTEs are
433    moved to the top level so that the final SQL code is valid from a syntax standpoint.
434
435    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
436    """
437    top_level_with = expression.args.get("with")
438    for node in expression.find_all(exp.With):
439        if node.parent is expression:
440            continue
441
442        inner_with = node.pop()
443        if not top_level_with:
444            top_level_with = inner_with
445            expression.set("with", top_level_with)
446        else:
447            if inner_with.recursive:
448                top_level_with.set("recursive", True)
449
450            top_level_with.expressions.extend(inner_with.expressions)
451
452    return expression
453
454
455def ensure_bools(expression: exp.Expression) -> exp.Expression:
456    """Converts numeric values used in conditions into explicit boolean expressions."""
457    from sqlglot.optimizer.canonicalize import ensure_bools
458
459    def _ensure_bool(node: exp.Expression) -> None:
460        if (
461            node.is_number
462            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
463            or (isinstance(node, exp.Column) and not node.type)
464        ):
465            node.replace(node.neq(0))
466
467    for node, *_ in expression.walk():
468        ensure_bools(node, _ensure_bool)
469
470    return expression
471
472
473def unqualify_columns(expression: exp.Expression) -> exp.Expression:
474    for column in expression.find_all(exp.Column):
475        # We only wanna pop off the table, db, catalog args
476        for part in column.parts[:-1]:
477            part.pop()
478
479    return expression
480
481
482def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
483    assert isinstance(expression, exp.Create)
484    for constraint in expression.find_all(exp.UniqueColumnConstraint):
485        if constraint.parent:
486            constraint.parent.pop()
487
488    return expression
489
490
491def ctas_with_tmp_tables_to_create_tmp_view(
492    expression: exp.Expression,
493    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
494) -> exp.Expression:
495    assert isinstance(expression, exp.Create)
496    properties = expression.args.get("properties")
497    temporary = any(
498        isinstance(prop, exp.TemporaryProperty)
499        for prop in (properties.expressions if properties else [])
500    )
501
502    # CTAS with temp tables map to CREATE TEMPORARY VIEW
503    if expression.kind == "TABLE" and temporary:
504        if expression.expression:
505            return exp.Create(
506                kind="TEMPORARY VIEW",
507                this=expression.this,
508                expression=expression.expression,
509            )
510        return tmp_storage_provider(expression)
511
512    return expression
513
514
515def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
516    """
517    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
518    PARTITIONED BY value is an array of column names, they are transformed into a schema.
519    The corresponding columns are removed from the create statement.
520    """
521    assert isinstance(expression, exp.Create)
522    has_schema = isinstance(expression.this, exp.Schema)
523    is_partitionable = expression.kind in {"TABLE", "VIEW"}
524
525    if has_schema and is_partitionable:
526        prop = expression.find(exp.PartitionedByProperty)
527        if prop and prop.this and not isinstance(prop.this, exp.Schema):
528            schema = expression.this
529            columns = {v.name.upper() for v in prop.this.expressions}
530            partitions = [col for col in schema.expressions if col.name.upper() in columns]
531            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
532            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
533            expression.set("this", schema)
534
535    return expression
536
537
538def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
539    """
540    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
541
542    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
543    """
544    assert isinstance(expression, exp.Create)
545    prop = expression.find(exp.PartitionedByProperty)
546    if (
547        prop
548        and prop.this
549        and isinstance(prop.this, exp.Schema)
550        and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
551    ):
552        prop_this = exp.Tuple(
553            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
554        )
555        schema = expression.this
556        for e in prop.this.expressions:
557            schema.append("expressions", e)
558        prop.set("this", prop_this)
559
560    return expression
561
562
563def preprocess(
564    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
565) -> t.Callable[[Generator, exp.Expression], str]:
566    """
567    Creates a new transform by chaining a sequence of transformations and converts the resulting
568    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
569    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
570
571    Args:
572        transforms: sequence of transform functions. These will be called in order.
573
574    Returns:
575        Function that can be used as a generator transform.
576    """
577
578    def _to_sql(self, expression: exp.Expression) -> str:
579        expression_type = type(expression)
580
581        expression = transforms[0](expression)
582        for transform in transforms[1:]:
583            expression = transform(expression)
584
585        _sql_handler = getattr(self, expression.key + "_sql", None)
586        if _sql_handler:
587            return _sql_handler(expression)
588
589        transforms_handler = self.TRANSFORMS.get(type(expression))
590        if transforms_handler:
591            if expression_type is type(expression):
592                if isinstance(expression, exp.Func):
593                    return self.function_fallback_sql(expression)
594
595                # Ensures we don't enter an infinite loop. This can happen when the original expression
596                # has the same type as the final expression and there's no _sql method available for it,
597                # because then it'd re-enter _to_sql.
598                raise ValueError(
599                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
600                )
601
602            return transforms_handler(self, expression)
603
604        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
605
606    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
125            qualify_filters, copy=False
126        )
127
128    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
132    """
133    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
134    other expressions. This transforms removes the precision from parameterized types in expressions.
135    """
136    for node in expression.find_all(exp.DataType):
137        node.set(
138            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
139        )
140
141    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
144def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
145    """Convert cross join unnest into lateral view explode."""
146    if isinstance(expression, exp.Select):
147        for join in expression.args.get("joins") or []:
148            unnest = join.this
149
150            if isinstance(unnest, exp.Unnest):
151                alias = unnest.args.get("alias")
152                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
153
154                expression.args["joins"].remove(join)
155
156                for e, column in zip(unnest.expressions, alias.columns if alias else []):
157                    expression.append(
158                        "laterals",
159                        exp.Lateral(
160                            this=udtf(this=e),
161                            view=True,
162                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
163                        ),
164                    )
165
166    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
169def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
170    """Convert explode/posexplode into unnest."""
171
172    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
173        if isinstance(expression, exp.Select):
174            from sqlglot.optimizer.scope import Scope
175
176            taken_select_names = set(expression.named_selects)
177            taken_source_names = {name for name, _ in Scope(expression).references}
178
179            def new_name(names: t.Set[str], name: str) -> str:
180                name = find_new_name(names, name)
181                names.add(name)
182                return name
183
184            arrays: t.List[exp.Condition] = []
185            series_alias = new_name(taken_select_names, "pos")
186            series = exp.alias_(
187                exp.Unnest(
188                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
189                ),
190                new_name(taken_source_names, "_u"),
191                table=[series_alias],
192            )
193
194            # we use list here because expression.selects is mutated inside the loop
195            for select in list(expression.selects):
196                explode = select.find(exp.Explode)
197
198                if explode:
199                    pos_alias = ""
200                    explode_alias = ""
201
202                    if isinstance(select, exp.Alias):
203                        explode_alias = select.args["alias"]
204                        alias = select
205                    elif isinstance(select, exp.Aliases):
206                        pos_alias = select.aliases[0]
207                        explode_alias = select.aliases[1]
208                        alias = select.replace(exp.alias_(select.this, "", copy=False))
209                    else:
210                        alias = select.replace(exp.alias_(select, ""))
211                        explode = alias.find(exp.Explode)
212                        assert explode
213
214                    is_posexplode = isinstance(explode, exp.Posexplode)
215                    explode_arg = explode.this
216
217                    if isinstance(explode, exp.ExplodeOuter):
218                        bracket = explode_arg[0]
219                        bracket.set("safe", True)
220                        bracket.set("offset", True)
221                        explode_arg = exp.func(
222                            "IF",
223                            exp.func(
224                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
225                            ).eq(0),
226                            exp.array(bracket, copy=False),
227                            explode_arg,
228                        )
229
230                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
231                    if isinstance(explode_arg, exp.Column):
232                        taken_select_names.add(explode_arg.output_name)
233
234                    unnest_source_alias = new_name(taken_source_names, "_u")
235
236                    if not explode_alias:
237                        explode_alias = new_name(taken_select_names, "col")
238
239                        if is_posexplode:
240                            pos_alias = new_name(taken_select_names, "pos")
241
242                    if not pos_alias:
243                        pos_alias = new_name(taken_select_names, "pos")
244
245                    alias.set("alias", exp.to_identifier(explode_alias))
246
247                    series_table_alias = series.args["alias"].this
248                    column = exp.If(
249                        this=exp.column(series_alias, table=series_table_alias).eq(
250                            exp.column(pos_alias, table=unnest_source_alias)
251                        ),
252                        true=exp.column(explode_alias, table=unnest_source_alias),
253                    )
254
255                    explode.replace(column)
256
257                    if is_posexplode:
258                        expressions = expression.expressions
259                        expressions.insert(
260                            expressions.index(alias) + 1,
261                            exp.If(
262                                this=exp.column(series_alias, table=series_table_alias).eq(
263                                    exp.column(pos_alias, table=unnest_source_alias)
264                                ),
265                                true=exp.column(pos_alias, table=unnest_source_alias),
266                            ).as_(pos_alias),
267                        )
268                        expression.set("expressions", expressions)
269
270                    if not arrays:
271                        if expression.args.get("from"):
272                            expression.join(series, copy=False, join_type="CROSS")
273                        else:
274                            expression.from_(series, copy=False)
275
276                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
277                    arrays.append(size)
278
279                    # trino doesn't support left join unnest with on conditions
280                    # if it did, this would be much simpler
281                    expression.join(
282                        exp.alias_(
283                            exp.Unnest(
284                                expressions=[explode_arg.copy()],
285                                offset=exp.to_identifier(pos_alias),
286                            ),
287                            unnest_source_alias,
288                            table=[explode_alias],
289                        ),
290                        join_type="CROSS",
291                        copy=False,
292                    )
293
294                    if index_offset != 1:
295                        size = size - 1
296
297                    expression.where(
298                        exp.column(series_alias, table=series_table_alias)
299                        .eq(exp.column(pos_alias, table=unnest_source_alias))
300                        .or_(
301                            (exp.column(series_alias, table=series_table_alias) > size).and_(
302                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
303                            )
304                        ),
305                        copy=False,
306                    )
307
308            if arrays:
309                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
310
311                if index_offset != 1:
312                    end = end - (1 - index_offset)
313                series.expressions[0].set("end", end)
314
315        return expression
316
317    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
323def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
324    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
325    if (
326        isinstance(expression, PERCENTILES)
327        and not isinstance(expression.parent, exp.WithinGroup)
328        and expression.expression
329    ):
330        column = expression.this.pop()
331        expression.set("this", expression.expression.pop())
332        order = exp.Order(expressions=[exp.Ordered(this=column)])
333        expression = exp.WithinGroup(this=expression, expression=order)
334
335    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
338def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
339    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
340    if (
341        isinstance(expression, exp.WithinGroup)
342        and isinstance(expression.this, PERCENTILES)
343        and isinstance(expression.expression, exp.Order)
344    ):
345        quantile = expression.this.this
346        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
347        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
348
349    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
352def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
353    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
354    if isinstance(expression, exp.With) and expression.recursive:
355        next_name = name_sequence("_c_")
356
357        for cte in expression.expressions:
358            if not cte.args["alias"].columns:
359                query = cte.this
360                if isinstance(query, exp.Union):
361                    query = query.this
362
363                cte.args["alias"].set(
364                    "columns",
365                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
366                )
367
368    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
371def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
372    """Replace 'epoch' in casts by the equivalent date literal."""
373    if (
374        isinstance(expression, (exp.Cast, exp.TryCast))
375        and expression.name.lower() == "epoch"
376        and expression.to.this in exp.DataType.TEMPORAL_TYPES
377    ):
378        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
379
380    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
383def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
384    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
385    if isinstance(expression, exp.Select):
386        for join in expression.args.get("joins") or []:
387            on = join.args.get("on")
388            if on and join.kind in ("SEMI", "ANTI"):
389                subquery = exp.select("1").from_(join.this).where(on)
390                exists = exp.Exists(this=subquery)
391                if join.kind == "ANTI":
392                    exists = exists.not_(copy=False)
393
394                join.pop()
395                expression.where(exists, copy=False)
396
397    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
400def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
401    """
402    Converts a query with a FULL OUTER join to a union of identical queries that
403    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
404    for queries that have a single FULL OUTER join.
405    """
406    if isinstance(expression, exp.Select):
407        full_outer_joins = [
408            (index, join)
409            for index, join in enumerate(expression.args.get("joins") or [])
410            if join.side == "FULL"
411        ]
412
413        if len(full_outer_joins) == 1:
414            expression_copy = expression.copy()
415            expression.set("limit", None)
416            index, full_outer_join = full_outer_joins[0]
417            full_outer_join.set("side", "left")
418            expression_copy.args["joins"][index].set("side", "right")
419            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
420
421            return exp.union(expression, expression_copy, copy=False)
422
423    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
426def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
427    """
428    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
429    defined at the top-level, so for example queries like:
430
431        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
432
433    are invalid in those dialects. This transformation can be used to ensure all CTEs are
434    moved to the top level so that the final SQL code is valid from a syntax standpoint.
435
436    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
437    """
438    top_level_with = expression.args.get("with")
439    for node in expression.find_all(exp.With):
440        if node.parent is expression:
441            continue
442
443        inner_with = node.pop()
444        if not top_level_with:
445            top_level_with = inner_with
446            expression.set("with", top_level_with)
447        else:
448            if inner_with.recursive:
449                top_level_with.set("recursive", True)
450
451            top_level_with.expressions.extend(inner_with.expressions)
452
453    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
456def ensure_bools(expression: exp.Expression) -> exp.Expression:
457    """Converts numeric values used in conditions into explicit boolean expressions."""
458    from sqlglot.optimizer.canonicalize import ensure_bools
459
460    def _ensure_bool(node: exp.Expression) -> None:
461        if (
462            node.is_number
463            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
464            or (isinstance(node, exp.Column) and not node.type)
465        ):
466            node.replace(node.neq(0))
467
468    for node, *_ in expression.walk():
469        ensure_bools(node, _ensure_bool)
470
471    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
474def unqualify_columns(expression: exp.Expression) -> exp.Expression:
475    for column in expression.find_all(exp.Column):
476        # We only wanna pop off the table, db, catalog args
477        for part in column.parts[:-1]:
478            part.pop()
479
480    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
483def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
484    assert isinstance(expression, exp.Create)
485    for constraint in expression.find_all(exp.UniqueColumnConstraint):
486        if constraint.parent:
487            constraint.parent.pop()
488
489    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
492def ctas_with_tmp_tables_to_create_tmp_view(
493    expression: exp.Expression,
494    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
495) -> exp.Expression:
496    assert isinstance(expression, exp.Create)
497    properties = expression.args.get("properties")
498    temporary = any(
499        isinstance(prop, exp.TemporaryProperty)
500        for prop in (properties.expressions if properties else [])
501    )
502
503    # CTAS with temp tables map to CREATE TEMPORARY VIEW
504    if expression.kind == "TABLE" and temporary:
505        if expression.expression:
506            return exp.Create(
507                kind="TEMPORARY VIEW",
508                this=expression.this,
509                expression=expression.expression,
510            )
511        return tmp_storage_provider(expression)
512
513    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
516def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
517    """
518    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
519    PARTITIONED BY value is an array of column names, they are transformed into a schema.
520    The corresponding columns are removed from the create statement.
521    """
522    assert isinstance(expression, exp.Create)
523    has_schema = isinstance(expression.this, exp.Schema)
524    is_partitionable = expression.kind in {"TABLE", "VIEW"}
525
526    if has_schema and is_partitionable:
527        prop = expression.find(exp.PartitionedByProperty)
528        if prop and prop.this and not isinstance(prop.this, exp.Schema):
529            schema = expression.this
530            columns = {v.name.upper() for v in prop.this.expressions}
531            partitions = [col for col in schema.expressions if col.name.upper() in columns]
532            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
533            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
534            expression.set("this", schema)
535
536    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
539def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
540    """
541    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
542
543    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
544    """
545    assert isinstance(expression, exp.Create)
546    prop = expression.find(exp.PartitionedByProperty)
547    if (
548        prop
549        and prop.this
550        and isinstance(prop.this, exp.Schema)
551        and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
552    ):
553        prop_this = exp.Tuple(
554            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
555        )
556        schema = expression.this
557        for e in prop.this.expressions:
558            schema.append("expressions", e)
559        prop.set("this", prop_this)
560
561    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
564def preprocess(
565    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
566) -> t.Callable[[Generator, exp.Expression], str]:
567    """
568    Creates a new transform by chaining a sequence of transformations and converts the resulting
569    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
570    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
571
572    Args:
573        transforms: sequence of transform functions. These will be called in order.
574
575    Returns:
576        Function that can be used as a generator transform.
577    """
578
579    def _to_sql(self, expression: exp.Expression) -> str:
580        expression_type = type(expression)
581
582        expression = transforms[0](expression)
583        for transform in transforms[1:]:
584            expression = transform(expression)
585
586        _sql_handler = getattr(self, expression.key + "_sql", None)
587        if _sql_handler:
588            return _sql_handler(expression)
589
590        transforms_handler = self.TRANSFORMS.get(type(expression))
591        if transforms_handler:
592            if expression_type is type(expression):
593                if isinstance(expression, exp.Func):
594                    return self.function_fallback_sql(expression)
595
596                # Ensures we don't enter an infinite loop. This can happen when the original expression
597                # has the same type as the final expression and there's no _sql method available for it,
598                # because then it'd re-enter _to_sql.
599                raise ValueError(
600                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
601                )
602
603            return transforms_handler(self, expression)
604
605        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
606
607    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.