Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.errors import UnsupportedError
  7from sqlglot.helper import find_new_name, name_sequence
  8
  9
 10if t.TYPE_CHECKING:
 11    from sqlglot._typing import E
 12    from sqlglot.generator import Generator
 13
 14
 15def preprocess(
 16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 17) -> t.Callable[[Generator, exp.Expression], str]:
 18    """
 19    Creates a new transform by chaining a sequence of transformations and converts the resulting
 20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 22
 23    Args:
 24        transforms: sequence of transform functions. These will be called in order.
 25
 26    Returns:
 27        Function that can be used as a generator transform.
 28    """
 29
 30    def _to_sql(self, expression: exp.Expression) -> str:
 31        expression_type = type(expression)
 32
 33        try:
 34            expression = transforms[0](expression)
 35            for transform in transforms[1:]:
 36                expression = transform(expression)
 37        except UnsupportedError as unsupported_error:
 38            self.unsupported(str(unsupported_error))
 39
 40        _sql_handler = getattr(self, expression.key + "_sql", None)
 41        if _sql_handler:
 42            return _sql_handler(expression)
 43
 44        transforms_handler = self.TRANSFORMS.get(type(expression))
 45        if transforms_handler:
 46            if expression_type is type(expression):
 47                if isinstance(expression, exp.Func):
 48                    return self.function_fallback_sql(expression)
 49
 50                # Ensures we don't enter an infinite loop. This can happen when the original expression
 51                # has the same type as the final expression and there's no _sql method available for it,
 52                # because then it'd re-enter _to_sql.
 53                raise ValueError(
 54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 55                )
 56
 57            return transforms_handler(self, expression)
 58
 59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 60
 61    return _to_sql
 62
 63
 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 65    if isinstance(expression, exp.Select):
 66        count = 0
 67        recursive_ctes = []
 68
 69        for unnest in expression.find_all(exp.Unnest):
 70            if (
 71                not isinstance(unnest.parent, (exp.From, exp.Join))
 72                or len(unnest.expressions) != 1
 73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 74            ):
 75                continue
 76
 77            generate_date_array = unnest.expressions[0]
 78            start = generate_date_array.args.get("start")
 79            end = generate_date_array.args.get("end")
 80            step = generate_date_array.args.get("step")
 81
 82            if not start or not end or not isinstance(step, exp.Interval):
 83                continue
 84
 85            alias = unnest.args.get("alias")
 86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 87
 88            start = exp.cast(start, "date")
 89            date_add = exp.func(
 90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 91            )
 92            cast_date_add = exp.cast(date_add, "date")
 93
 94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 95
 96            base_query = exp.select(start.as_(column_name))
 97            recursive_query = (
 98                exp.select(cast_date_add)
 99                .from_(cte_name)
100                .where(cast_date_add <= exp.cast(end, "date"))
101            )
102            cte_query = base_query.union(recursive_query, distinct=False)
103
104            generate_dates_query = exp.select(column_name).from_(cte_name)
105            unnest.replace(generate_dates_query.subquery(cte_name))
106
107            recursive_ctes.append(
108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109            )
110            count += 1
111
112        if recursive_ctes:
113            with_expression = expression.args.get("with") or exp.With()
114            with_expression.set("recursive", True)
115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116            expression.set("with", with_expression)
117
118    return expression
119
120
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
123    this = expression.this
124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125        unnest = exp.Unnest(expressions=[this])
126        if expression.alias:
127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
129        return unnest
130
131    return expression
132
133
134def unalias_group(expression: exp.Expression) -> exp.Expression:
135    """
136    Replace references to select aliases in GROUP BY clauses.
137
138    Example:
139        >>> import sqlglot
140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141        'SELECT a AS b FROM x GROUP BY 1'
142
143    Args:
144        expression: the expression that will be transformed.
145
146    Returns:
147        The transformed expression.
148    """
149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150        aliased_selects = {
151            e.alias: i
152            for i, e in enumerate(expression.parent.expressions, start=1)
153            if isinstance(e, exp.Alias)
154        }
155
156        for group_by in expression.expressions:
157            if (
158                isinstance(group_by, exp.Column)
159                and not group_by.table
160                and group_by.name in aliased_selects
161            ):
162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
164    return expression
165
166
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168    """
169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
173    Args:
174        expression: the expression that will be transformed.
175
176    Returns:
177        The transformed expression.
178    """
179    if (
180        isinstance(expression, exp.Select)
181        and expression.args.get("distinct")
182        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
183    ):
184        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
185
186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
187        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
188
189        order = expression.args.get("order")
190        if order:
191            window.set("order", order.pop())
192        else:
193            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
194
195        window = exp.alias_(window, row_number_window_alias)
196        expression.select(window, copy=False)
197
198        # We add aliases to the projections so that we can safely reference them in the outer query
199        new_selects = []
200        taken_names = {row_number_window_alias}
201        for select in expression.selects[:-1]:
202            if select.is_star:
203                new_selects = [exp.Star()]
204                break
205
206            if not isinstance(select, exp.Alias):
207                alias = find_new_name(taken_names, select.output_name or "_col")
208                select = select.replace(exp.alias_(select, alias))
209
210            taken_names.add(select.output_name)
211            new_selects.append(select.args["alias"])
212
213        return (
214            exp.select(*new_selects, copy=False)
215            .from_(expression.subquery("_t", copy=False), copy=False)
216            .where(exp.column(row_number_window_alias).eq(1), copy=False)
217        )
218
219    return expression
220
221
222def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
223    """
224    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
225
226    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
227    https://docs.snowflake.com/en/sql-reference/constructs/qualify
228
229    Some dialects don't support window functions in the WHERE clause, so we need to include them as
230    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
231    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
232    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
233    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
234    corresponding expression to avoid creating invalid column references.
235    """
236    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
237        taken = set(expression.named_selects)
238        for select in expression.selects:
239            if not select.alias_or_name:
240                alias = find_new_name(taken, "_c")
241                select.replace(exp.alias_(select, alias))
242                taken.add(alias)
243
244        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
245            alias_or_name = select.alias_or_name
246            identifier = select.args.get("alias") or select.this
247            if isinstance(identifier, exp.Identifier):
248                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
249            return alias_or_name
250
251        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
252        qualify_filters = expression.args["qualify"].pop().this
253        expression_by_alias = {
254            select.alias: select.this
255            for select in expression.selects
256            if isinstance(select, exp.Alias)
257        }
258
259        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
260        for select_candidate in qualify_filters.find_all(select_candidates):
261            if isinstance(select_candidate, exp.Window):
262                if expression_by_alias:
263                    for column in select_candidate.find_all(exp.Column):
264                        expr = expression_by_alias.get(column.name)
265                        if expr:
266                            column.replace(expr)
267
268                alias = find_new_name(expression.named_selects, "_w")
269                expression.select(exp.alias_(select_candidate, alias), copy=False)
270                column = exp.column(alias)
271
272                if isinstance(select_candidate.parent, exp.Qualify):
273                    qualify_filters = column
274                else:
275                    select_candidate.replace(column)
276            elif select_candidate.name not in expression.named_selects:
277                expression.select(select_candidate.copy(), copy=False)
278
279        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
280            qualify_filters, copy=False
281        )
282
283    return expression
284
285
286def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
287    """
288    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
289    other expressions. This transforms removes the precision from parameterized types in expressions.
290    """
291    for node in expression.find_all(exp.DataType):
292        node.set(
293            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
294        )
295
296    return expression
297
298
299def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
300    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
301    from sqlglot.optimizer.scope import find_all_in_scope
302
303    if isinstance(expression, exp.Select):
304        unnest_aliases = {
305            unnest.alias
306            for unnest in find_all_in_scope(expression, exp.Unnest)
307            if isinstance(unnest.parent, (exp.From, exp.Join))
308        }
309        if unnest_aliases:
310            for column in expression.find_all(exp.Column):
311                if column.table in unnest_aliases:
312                    column.set("table", None)
313                elif column.db in unnest_aliases:
314                    column.set("db", None)
315
316    return expression
317
318
319def unnest_to_explode(
320    expression: exp.Expression,
321    unnest_using_arrays_zip: bool = True,
322) -> exp.Expression:
323    """Convert cross join unnest into lateral view explode."""
324
325    def _unnest_zip_exprs(
326        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
327    ) -> t.List[exp.Expression]:
328        if has_multi_expr:
329            if not unnest_using_arrays_zip:
330                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
331
332            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
333            zip_exprs: t.List[exp.Expression] = [
334                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
335            ]
336            u.set("expressions", zip_exprs)
337            return zip_exprs
338        return unnest_exprs
339
340    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
341        if u.args.get("offset"):
342            return exp.Posexplode
343        return exp.Inline if has_multi_expr else exp.Explode
344
345    if isinstance(expression, exp.Select):
346        from_ = expression.args.get("from")
347
348        if from_ and isinstance(from_.this, exp.Unnest):
349            unnest = from_.this
350            alias = unnest.args.get("alias")
351            exprs = unnest.expressions
352            has_multi_expr = len(exprs) > 1
353            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
354
355            unnest.replace(
356                exp.Table(
357                    this=_udtf_type(unnest, has_multi_expr)(
358                        this=this,
359                        expressions=expressions,
360                    ),
361                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
362                )
363            )
364
365        joins = expression.args.get("joins") or []
366        for join in list(joins):
367            join_expr = join.this
368
369            is_lateral = isinstance(join_expr, exp.Lateral)
370
371            unnest = join_expr.this if is_lateral else join_expr
372
373            if isinstance(unnest, exp.Unnest):
374                if is_lateral:
375                    alias = join_expr.args.get("alias")
376                else:
377                    alias = unnest.args.get("alias")
378                exprs = unnest.expressions
379                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
380                has_multi_expr = len(exprs) > 1
381                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
382
383                joins.remove(join)
384
385                alias_cols = alias.columns if alias else []
386
387                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
388                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
389                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
390
391                if not has_multi_expr and len(alias_cols) not in (1, 2):
392                    raise UnsupportedError(
393                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
394                    )
395
396                for e, column in zip(exprs, alias_cols):
397                    expression.append(
398                        "laterals",
399                        exp.Lateral(
400                            this=_udtf_type(unnest, has_multi_expr)(this=e),
401                            view=True,
402                            alias=exp.TableAlias(
403                                this=alias.this,  # type: ignore
404                                columns=alias_cols,
405                            ),
406                        ),
407                    )
408
409    return expression
410
411
412def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
413    """Convert explode/posexplode into unnest."""
414
415    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
416        if isinstance(expression, exp.Select):
417            from sqlglot.optimizer.scope import Scope
418
419            taken_select_names = set(expression.named_selects)
420            taken_source_names = {name for name, _ in Scope(expression).references}
421
422            def new_name(names: t.Set[str], name: str) -> str:
423                name = find_new_name(names, name)
424                names.add(name)
425                return name
426
427            arrays: t.List[exp.Condition] = []
428            series_alias = new_name(taken_select_names, "pos")
429            series = exp.alias_(
430                exp.Unnest(
431                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
432                ),
433                new_name(taken_source_names, "_u"),
434                table=[series_alias],
435            )
436
437            # we use list here because expression.selects is mutated inside the loop
438            for select in list(expression.selects):
439                explode = select.find(exp.Explode)
440
441                if explode:
442                    pos_alias = ""
443                    explode_alias = ""
444
445                    if isinstance(select, exp.Alias):
446                        explode_alias = select.args["alias"]
447                        alias = select
448                    elif isinstance(select, exp.Aliases):
449                        pos_alias = select.aliases[0]
450                        explode_alias = select.aliases[1]
451                        alias = select.replace(exp.alias_(select.this, "", copy=False))
452                    else:
453                        alias = select.replace(exp.alias_(select, ""))
454                        explode = alias.find(exp.Explode)
455                        assert explode
456
457                    is_posexplode = isinstance(explode, exp.Posexplode)
458                    explode_arg = explode.this
459
460                    if isinstance(explode, exp.ExplodeOuter):
461                        bracket = explode_arg[0]
462                        bracket.set("safe", True)
463                        bracket.set("offset", True)
464                        explode_arg = exp.func(
465                            "IF",
466                            exp.func(
467                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
468                            ).eq(0),
469                            exp.array(bracket, copy=False),
470                            explode_arg,
471                        )
472
473                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
474                    if isinstance(explode_arg, exp.Column):
475                        taken_select_names.add(explode_arg.output_name)
476
477                    unnest_source_alias = new_name(taken_source_names, "_u")
478
479                    if not explode_alias:
480                        explode_alias = new_name(taken_select_names, "col")
481
482                        if is_posexplode:
483                            pos_alias = new_name(taken_select_names, "pos")
484
485                    if not pos_alias:
486                        pos_alias = new_name(taken_select_names, "pos")
487
488                    alias.set("alias", exp.to_identifier(explode_alias))
489
490                    series_table_alias = series.args["alias"].this
491                    column = exp.If(
492                        this=exp.column(series_alias, table=series_table_alias).eq(
493                            exp.column(pos_alias, table=unnest_source_alias)
494                        ),
495                        true=exp.column(explode_alias, table=unnest_source_alias),
496                    )
497
498                    explode.replace(column)
499
500                    if is_posexplode:
501                        expressions = expression.expressions
502                        expressions.insert(
503                            expressions.index(alias) + 1,
504                            exp.If(
505                                this=exp.column(series_alias, table=series_table_alias).eq(
506                                    exp.column(pos_alias, table=unnest_source_alias)
507                                ),
508                                true=exp.column(pos_alias, table=unnest_source_alias),
509                            ).as_(pos_alias),
510                        )
511                        expression.set("expressions", expressions)
512
513                    if not arrays:
514                        if expression.args.get("from"):
515                            expression.join(series, copy=False, join_type="CROSS")
516                        else:
517                            expression.from_(series, copy=False)
518
519                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
520                    arrays.append(size)
521
522                    # trino doesn't support left join unnest with on conditions
523                    # if it did, this would be much simpler
524                    expression.join(
525                        exp.alias_(
526                            exp.Unnest(
527                                expressions=[explode_arg.copy()],
528                                offset=exp.to_identifier(pos_alias),
529                            ),
530                            unnest_source_alias,
531                            table=[explode_alias],
532                        ),
533                        join_type="CROSS",
534                        copy=False,
535                    )
536
537                    if index_offset != 1:
538                        size = size - 1
539
540                    expression.where(
541                        exp.column(series_alias, table=series_table_alias)
542                        .eq(exp.column(pos_alias, table=unnest_source_alias))
543                        .or_(
544                            (exp.column(series_alias, table=series_table_alias) > size).and_(
545                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
546                            )
547                        ),
548                        copy=False,
549                    )
550
551            if arrays:
552                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
553
554                if index_offset != 1:
555                    end = end - (1 - index_offset)
556                series.expressions[0].set("end", end)
557
558        return expression
559
560    return _explode_to_unnest
561
562
563def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
564    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
565    if (
566        isinstance(expression, exp.PERCENTILES)
567        and not isinstance(expression.parent, exp.WithinGroup)
568        and expression.expression
569    ):
570        column = expression.this.pop()
571        expression.set("this", expression.expression.pop())
572        order = exp.Order(expressions=[exp.Ordered(this=column)])
573        expression = exp.WithinGroup(this=expression, expression=order)
574
575    return expression
576
577
578def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
579    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
580    if (
581        isinstance(expression, exp.WithinGroup)
582        and isinstance(expression.this, exp.PERCENTILES)
583        and isinstance(expression.expression, exp.Order)
584    ):
585        quantile = expression.this.this
586        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
587        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
588
589    return expression
590
591
592def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
593    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
594    if isinstance(expression, exp.With) and expression.recursive:
595        next_name = name_sequence("_c_")
596
597        for cte in expression.expressions:
598            if not cte.args["alias"].columns:
599                query = cte.this
600                if isinstance(query, exp.SetOperation):
601                    query = query.this
602
603                cte.args["alias"].set(
604                    "columns",
605                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
606                )
607
608    return expression
609
610
611def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
612    """Replace 'epoch' in casts by the equivalent date literal."""
613    if (
614        isinstance(expression, (exp.Cast, exp.TryCast))
615        and expression.name.lower() == "epoch"
616        and expression.to.this in exp.DataType.TEMPORAL_TYPES
617    ):
618        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
619
620    return expression
621
622
623def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
624    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
625    if isinstance(expression, exp.Select):
626        for join in expression.args.get("joins") or []:
627            on = join.args.get("on")
628            if on and join.kind in ("SEMI", "ANTI"):
629                subquery = exp.select("1").from_(join.this).where(on)
630                exists = exp.Exists(this=subquery)
631                if join.kind == "ANTI":
632                    exists = exists.not_(copy=False)
633
634                join.pop()
635                expression.where(exists, copy=False)
636
637    return expression
638
639
640def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
641    """
642    Converts a query with a FULL OUTER join to a union of identical queries that
643    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
644    for queries that have a single FULL OUTER join.
645    """
646    if isinstance(expression, exp.Select):
647        full_outer_joins = [
648            (index, join)
649            for index, join in enumerate(expression.args.get("joins") or [])
650            if join.side == "FULL"
651        ]
652
653        if len(full_outer_joins) == 1:
654            expression_copy = expression.copy()
655            expression.set("limit", None)
656            index, full_outer_join = full_outer_joins[0]
657
658            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
659            join_conditions = full_outer_join.args.get("on") or exp.and_(
660                *[
661                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
662                    for col in full_outer_join.args.get("using")
663                ]
664            )
665
666            full_outer_join.set("side", "left")
667            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
668            expression_copy.args["joins"][index].set("side", "right")
669            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
670            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
671            expression.args.pop("order", None)  # remove order by from LEFT side
672
673            return exp.union(expression, expression_copy, copy=False, distinct=False)
674
675    return expression
676
677
678def move_ctes_to_top_level(expression: E) -> E:
679    """
680    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
681    defined at the top-level, so for example queries like:
682
683        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
684
685    are invalid in those dialects. This transformation can be used to ensure all CTEs are
686    moved to the top level so that the final SQL code is valid from a syntax standpoint.
687
688    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
689    """
690    top_level_with = expression.args.get("with")
691    for inner_with in expression.find_all(exp.With):
692        if inner_with.parent is expression:
693            continue
694
695        if not top_level_with:
696            top_level_with = inner_with.pop()
697            expression.set("with", top_level_with)
698        else:
699            if inner_with.recursive:
700                top_level_with.set("recursive", True)
701
702            parent_cte = inner_with.find_ancestor(exp.CTE)
703            inner_with.pop()
704
705            if parent_cte:
706                i = top_level_with.expressions.index(parent_cte)
707                top_level_with.expressions[i:i] = inner_with.expressions
708                top_level_with.set("expressions", top_level_with.expressions)
709            else:
710                top_level_with.set(
711                    "expressions", top_level_with.expressions + inner_with.expressions
712                )
713
714    return expression
715
716
717def ensure_bools(expression: exp.Expression) -> exp.Expression:
718    """Converts numeric values used in conditions into explicit boolean expressions."""
719    from sqlglot.optimizer.canonicalize import ensure_bools
720
721    def _ensure_bool(node: exp.Expression) -> None:
722        if (
723            node.is_number
724            or (
725                not isinstance(node, exp.SubqueryPredicate)
726                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
727            )
728            or (isinstance(node, exp.Column) and not node.type)
729        ):
730            node.replace(node.neq(0))
731
732    for node in expression.walk():
733        ensure_bools(node, _ensure_bool)
734
735    return expression
736
737
738def unqualify_columns(expression: exp.Expression) -> exp.Expression:
739    for column in expression.find_all(exp.Column):
740        # We only wanna pop off the table, db, catalog args
741        for part in column.parts[:-1]:
742            part.pop()
743
744    return expression
745
746
747def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
748    assert isinstance(expression, exp.Create)
749    for constraint in expression.find_all(exp.UniqueColumnConstraint):
750        if constraint.parent:
751            constraint.parent.pop()
752
753    return expression
754
755
756def ctas_with_tmp_tables_to_create_tmp_view(
757    expression: exp.Expression,
758    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
759) -> exp.Expression:
760    assert isinstance(expression, exp.Create)
761    properties = expression.args.get("properties")
762    temporary = any(
763        isinstance(prop, exp.TemporaryProperty)
764        for prop in (properties.expressions if properties else [])
765    )
766
767    # CTAS with temp tables map to CREATE TEMPORARY VIEW
768    if expression.kind == "TABLE" and temporary:
769        if expression.expression:
770            return exp.Create(
771                kind="TEMPORARY VIEW",
772                this=expression.this,
773                expression=expression.expression,
774            )
775        return tmp_storage_provider(expression)
776
777    return expression
778
779
780def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
781    """
782    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
783    PARTITIONED BY value is an array of column names, they are transformed into a schema.
784    The corresponding columns are removed from the create statement.
785    """
786    assert isinstance(expression, exp.Create)
787    has_schema = isinstance(expression.this, exp.Schema)
788    is_partitionable = expression.kind in {"TABLE", "VIEW"}
789
790    if has_schema and is_partitionable:
791        prop = expression.find(exp.PartitionedByProperty)
792        if prop and prop.this and not isinstance(prop.this, exp.Schema):
793            schema = expression.this
794            columns = {v.name.upper() for v in prop.this.expressions}
795            partitions = [col for col in schema.expressions if col.name.upper() in columns]
796            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
797            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
798            expression.set("this", schema)
799
800    return expression
801
802
803def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
804    """
805    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
806
807    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
808    """
809    assert isinstance(expression, exp.Create)
810    prop = expression.find(exp.PartitionedByProperty)
811    if (
812        prop
813        and prop.this
814        and isinstance(prop.this, exp.Schema)
815        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
816    ):
817        prop_this = exp.Tuple(
818            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
819        )
820        schema = expression.this
821        for e in prop.this.expressions:
822            schema.append("expressions", e)
823        prop.set("this", prop_this)
824
825    return expression
826
827
828def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
829    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
830    if isinstance(expression, exp.Struct):
831        expression.set(
832            "expressions",
833            [
834                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
835                for e in expression.expressions
836            ],
837        )
838
839    return expression
840
841
842def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
843    """
844    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
845    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
846
847    For example,
848        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
849        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
850
851    Args:
852        expression: The AST to remove join marks from.
853
854    Returns:
855       The AST with join marks removed.
856    """
857    from sqlglot.optimizer.scope import traverse_scope
858
859    for scope in traverse_scope(expression):
860        query = scope.expression
861
862        where = query.args.get("where")
863        joins = query.args.get("joins")
864
865        if not where or not joins:
866            continue
867
868        query_from = query.args["from"]
869
870        # These keep track of the joins to be replaced
871        new_joins: t.Dict[str, exp.Join] = {}
872        old_joins = {join.alias_or_name: join for join in joins}
873
874        for column in scope.columns:
875            if not column.args.get("join_mark"):
876                continue
877
878            predicate = column.find_ancestor(exp.Predicate, exp.Select)
879            assert isinstance(
880                predicate, exp.Binary
881            ), "Columns can only be marked with (+) when involved in a binary operation"
882
883            predicate_parent = predicate.parent
884            join_predicate = predicate.pop()
885
886            left_columns = [
887                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
888            ]
889            right_columns = [
890                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
891            ]
892
893            assert not (
894                left_columns and right_columns
895            ), "The (+) marker cannot appear in both sides of a binary predicate"
896
897            marked_column_tables = set()
898            for col in left_columns or right_columns:
899                table = col.table
900                assert table, f"Column {col} needs to be qualified with a table"
901
902                col.set("join_mark", False)
903                marked_column_tables.add(table)
904
905            assert (
906                len(marked_column_tables) == 1
907            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
908
909            join_this = old_joins.get(col.table, query_from).this
910            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
911
912            # Upsert new_join into new_joins dictionary
913            new_join_alias_or_name = new_join.alias_or_name
914            existing_join = new_joins.get(new_join_alias_or_name)
915            if existing_join:
916                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
917            else:
918                new_joins[new_join_alias_or_name] = new_join
919
920            # If the parent of the target predicate is a binary node, then it now has only one child
921            if isinstance(predicate_parent, exp.Binary):
922                if predicate_parent.left is None:
923                    predicate_parent.replace(predicate_parent.right)
924                else:
925                    predicate_parent.replace(predicate_parent.left)
926
927        if query_from.alias_or_name in new_joins:
928            only_old_joins = old_joins.keys() - new_joins.keys()
929            assert (
930                len(only_old_joins) >= 1
931            ), "Cannot determine which table to use in the new FROM clause"
932
933            new_from_name = list(only_old_joins)[0]
934            query.set("from", exp.From(this=old_joins[new_from_name].this))
935
936        query.set("joins", list(new_joins.values()))
937
938        if not where.this:
939            where.pop()
940
941    return expression
942
943
944def any_to_exists(expression: exp.Expression) -> exp.Expression:
945    """
946    Transform ANY operator to Spark's EXISTS
947
948    For example,
949        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
950        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
951
952    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
953    transformation
954    """
955    if isinstance(expression, exp.Select):
956        for any in expression.find_all(exp.Any):
957            this = any.this
958            if isinstance(this, exp.Query):
959                continue
960
961            binop = any.parent
962            if isinstance(binop, exp.Binary):
963                lambda_arg = exp.to_identifier("x")
964                any.replace(lambda_arg)
965                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
966                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
967
968    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
184    ):
185        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
186
187        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189
190        order = expression.args.get("order")
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number_window_alias)
197        expression.select(window, copy=False)
198
199        # We add aliases to the projections so that we can safely reference them in the outer query
200        new_selects = []
201        taken_names = {row_number_window_alias}
202        for select in expression.selects[:-1]:
203            if select.is_star:
204                new_selects = [exp.Star()]
205                break
206
207            if not isinstance(select, exp.Alias):
208                alias = find_new_name(taken_names, select.output_name or "_col")
209                select = select.replace(exp.alias_(select, alias))
210
211            taken_names.add(select.output_name)
212            new_selects.append(select.args["alias"])
213
214        return (
215            exp.select(*new_selects, copy=False)
216            .from_(expression.subquery("_t", copy=False), copy=False)
217            .where(exp.column(row_number_window_alias).eq(1), copy=False)
218        )
219
220    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
223def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
224    """
225    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
226
227    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
228    https://docs.snowflake.com/en/sql-reference/constructs/qualify
229
230    Some dialects don't support window functions in the WHERE clause, so we need to include them as
231    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
232    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
233    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
234    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
235    corresponding expression to avoid creating invalid column references.
236    """
237    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
238        taken = set(expression.named_selects)
239        for select in expression.selects:
240            if not select.alias_or_name:
241                alias = find_new_name(taken, "_c")
242                select.replace(exp.alias_(select, alias))
243                taken.add(alias)
244
245        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
246            alias_or_name = select.alias_or_name
247            identifier = select.args.get("alias") or select.this
248            if isinstance(identifier, exp.Identifier):
249                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
250            return alias_or_name
251
252        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
253        qualify_filters = expression.args["qualify"].pop().this
254        expression_by_alias = {
255            select.alias: select.this
256            for select in expression.selects
257            if isinstance(select, exp.Alias)
258        }
259
260        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
261        for select_candidate in qualify_filters.find_all(select_candidates):
262            if isinstance(select_candidate, exp.Window):
263                if expression_by_alias:
264                    for column in select_candidate.find_all(exp.Column):
265                        expr = expression_by_alias.get(column.name)
266                        if expr:
267                            column.replace(expr)
268
269                alias = find_new_name(expression.named_selects, "_w")
270                expression.select(exp.alias_(select_candidate, alias), copy=False)
271                column = exp.column(alias)
272
273                if isinstance(select_candidate.parent, exp.Qualify):
274                    qualify_filters = column
275                else:
276                    select_candidate.replace(column)
277            elif select_candidate.name not in expression.named_selects:
278                expression.select(select_candidate.copy(), copy=False)
279
280        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
281            qualify_filters, copy=False
282        )
283
284    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
287def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
288    """
289    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
290    other expressions. This transforms removes the precision from parameterized types in expressions.
291    """
292    for node in expression.find_all(exp.DataType):
293        node.set(
294            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
295        )
296
297    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
300def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
301    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
302    from sqlglot.optimizer.scope import find_all_in_scope
303
304    if isinstance(expression, exp.Select):
305        unnest_aliases = {
306            unnest.alias
307            for unnest in find_all_in_scope(expression, exp.Unnest)
308            if isinstance(unnest.parent, (exp.From, exp.Join))
309        }
310        if unnest_aliases:
311            for column in expression.find_all(exp.Column):
312                if column.table in unnest_aliases:
313                    column.set("table", None)
314                elif column.db in unnest_aliases:
315                    column.set("db", None)
316
317    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
320def unnest_to_explode(
321    expression: exp.Expression,
322    unnest_using_arrays_zip: bool = True,
323) -> exp.Expression:
324    """Convert cross join unnest into lateral view explode."""
325
326    def _unnest_zip_exprs(
327        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
328    ) -> t.List[exp.Expression]:
329        if has_multi_expr:
330            if not unnest_using_arrays_zip:
331                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
332
333            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
334            zip_exprs: t.List[exp.Expression] = [
335                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
336            ]
337            u.set("expressions", zip_exprs)
338            return zip_exprs
339        return unnest_exprs
340
341    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
342        if u.args.get("offset"):
343            return exp.Posexplode
344        return exp.Inline if has_multi_expr else exp.Explode
345
346    if isinstance(expression, exp.Select):
347        from_ = expression.args.get("from")
348
349        if from_ and isinstance(from_.this, exp.Unnest):
350            unnest = from_.this
351            alias = unnest.args.get("alias")
352            exprs = unnest.expressions
353            has_multi_expr = len(exprs) > 1
354            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
355
356            unnest.replace(
357                exp.Table(
358                    this=_udtf_type(unnest, has_multi_expr)(
359                        this=this,
360                        expressions=expressions,
361                    ),
362                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
363                )
364            )
365
366        joins = expression.args.get("joins") or []
367        for join in list(joins):
368            join_expr = join.this
369
370            is_lateral = isinstance(join_expr, exp.Lateral)
371
372            unnest = join_expr.this if is_lateral else join_expr
373
374            if isinstance(unnest, exp.Unnest):
375                if is_lateral:
376                    alias = join_expr.args.get("alias")
377                else:
378                    alias = unnest.args.get("alias")
379                exprs = unnest.expressions
380                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
381                has_multi_expr = len(exprs) > 1
382                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
383
384                joins.remove(join)
385
386                alias_cols = alias.columns if alias else []
387
388                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
389                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
390                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
391
392                if not has_multi_expr and len(alias_cols) not in (1, 2):
393                    raise UnsupportedError(
394                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
395                    )
396
397                for e, column in zip(exprs, alias_cols):
398                    expression.append(
399                        "laterals",
400                        exp.Lateral(
401                            this=_udtf_type(unnest, has_multi_expr)(this=e),
402                            view=True,
403                            alias=exp.TableAlias(
404                                this=alias.this,  # type: ignore
405                                columns=alias_cols,
406                            ),
407                        ),
408                    )
409
410    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
413def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
414    """Convert explode/posexplode into unnest."""
415
416    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
417        if isinstance(expression, exp.Select):
418            from sqlglot.optimizer.scope import Scope
419
420            taken_select_names = set(expression.named_selects)
421            taken_source_names = {name for name, _ in Scope(expression).references}
422
423            def new_name(names: t.Set[str], name: str) -> str:
424                name = find_new_name(names, name)
425                names.add(name)
426                return name
427
428            arrays: t.List[exp.Condition] = []
429            series_alias = new_name(taken_select_names, "pos")
430            series = exp.alias_(
431                exp.Unnest(
432                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
433                ),
434                new_name(taken_source_names, "_u"),
435                table=[series_alias],
436            )
437
438            # we use list here because expression.selects is mutated inside the loop
439            for select in list(expression.selects):
440                explode = select.find(exp.Explode)
441
442                if explode:
443                    pos_alias = ""
444                    explode_alias = ""
445
446                    if isinstance(select, exp.Alias):
447                        explode_alias = select.args["alias"]
448                        alias = select
449                    elif isinstance(select, exp.Aliases):
450                        pos_alias = select.aliases[0]
451                        explode_alias = select.aliases[1]
452                        alias = select.replace(exp.alias_(select.this, "", copy=False))
453                    else:
454                        alias = select.replace(exp.alias_(select, ""))
455                        explode = alias.find(exp.Explode)
456                        assert explode
457
458                    is_posexplode = isinstance(explode, exp.Posexplode)
459                    explode_arg = explode.this
460
461                    if isinstance(explode, exp.ExplodeOuter):
462                        bracket = explode_arg[0]
463                        bracket.set("safe", True)
464                        bracket.set("offset", True)
465                        explode_arg = exp.func(
466                            "IF",
467                            exp.func(
468                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
469                            ).eq(0),
470                            exp.array(bracket, copy=False),
471                            explode_arg,
472                        )
473
474                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
475                    if isinstance(explode_arg, exp.Column):
476                        taken_select_names.add(explode_arg.output_name)
477
478                    unnest_source_alias = new_name(taken_source_names, "_u")
479
480                    if not explode_alias:
481                        explode_alias = new_name(taken_select_names, "col")
482
483                        if is_posexplode:
484                            pos_alias = new_name(taken_select_names, "pos")
485
486                    if not pos_alias:
487                        pos_alias = new_name(taken_select_names, "pos")
488
489                    alias.set("alias", exp.to_identifier(explode_alias))
490
491                    series_table_alias = series.args["alias"].this
492                    column = exp.If(
493                        this=exp.column(series_alias, table=series_table_alias).eq(
494                            exp.column(pos_alias, table=unnest_source_alias)
495                        ),
496                        true=exp.column(explode_alias, table=unnest_source_alias),
497                    )
498
499                    explode.replace(column)
500
501                    if is_posexplode:
502                        expressions = expression.expressions
503                        expressions.insert(
504                            expressions.index(alias) + 1,
505                            exp.If(
506                                this=exp.column(series_alias, table=series_table_alias).eq(
507                                    exp.column(pos_alias, table=unnest_source_alias)
508                                ),
509                                true=exp.column(pos_alias, table=unnest_source_alias),
510                            ).as_(pos_alias),
511                        )
512                        expression.set("expressions", expressions)
513
514                    if not arrays:
515                        if expression.args.get("from"):
516                            expression.join(series, copy=False, join_type="CROSS")
517                        else:
518                            expression.from_(series, copy=False)
519
520                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
521                    arrays.append(size)
522
523                    # trino doesn't support left join unnest with on conditions
524                    # if it did, this would be much simpler
525                    expression.join(
526                        exp.alias_(
527                            exp.Unnest(
528                                expressions=[explode_arg.copy()],
529                                offset=exp.to_identifier(pos_alias),
530                            ),
531                            unnest_source_alias,
532                            table=[explode_alias],
533                        ),
534                        join_type="CROSS",
535                        copy=False,
536                    )
537
538                    if index_offset != 1:
539                        size = size - 1
540
541                    expression.where(
542                        exp.column(series_alias, table=series_table_alias)
543                        .eq(exp.column(pos_alias, table=unnest_source_alias))
544                        .or_(
545                            (exp.column(series_alias, table=series_table_alias) > size).and_(
546                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
547                            )
548                        ),
549                        copy=False,
550                    )
551
552            if arrays:
553                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
554
555                if index_offset != 1:
556                    end = end - (1 - index_offset)
557                series.expressions[0].set("end", end)
558
559        return expression
560
561    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
564def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
565    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
566    if (
567        isinstance(expression, exp.PERCENTILES)
568        and not isinstance(expression.parent, exp.WithinGroup)
569        and expression.expression
570    ):
571        column = expression.this.pop()
572        expression.set("this", expression.expression.pop())
573        order = exp.Order(expressions=[exp.Ordered(this=column)])
574        expression = exp.WithinGroup(this=expression, expression=order)
575
576    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
579def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
580    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
581    if (
582        isinstance(expression, exp.WithinGroup)
583        and isinstance(expression.this, exp.PERCENTILES)
584        and isinstance(expression.expression, exp.Order)
585    ):
586        quantile = expression.this.this
587        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
588        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
589
590    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
593def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
594    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
595    if isinstance(expression, exp.With) and expression.recursive:
596        next_name = name_sequence("_c_")
597
598        for cte in expression.expressions:
599            if not cte.args["alias"].columns:
600                query = cte.this
601                if isinstance(query, exp.SetOperation):
602                    query = query.this
603
604                cte.args["alias"].set(
605                    "columns",
606                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
607                )
608
609    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
612def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
613    """Replace 'epoch' in casts by the equivalent date literal."""
614    if (
615        isinstance(expression, (exp.Cast, exp.TryCast))
616        and expression.name.lower() == "epoch"
617        and expression.to.this in exp.DataType.TEMPORAL_TYPES
618    ):
619        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
620
621    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
624def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
625    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
626    if isinstance(expression, exp.Select):
627        for join in expression.args.get("joins") or []:
628            on = join.args.get("on")
629            if on and join.kind in ("SEMI", "ANTI"):
630                subquery = exp.select("1").from_(join.this).where(on)
631                exists = exp.Exists(this=subquery)
632                if join.kind == "ANTI":
633                    exists = exists.not_(copy=False)
634
635                join.pop()
636                expression.where(exists, copy=False)
637
638    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
641def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
642    """
643    Converts a query with a FULL OUTER join to a union of identical queries that
644    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
645    for queries that have a single FULL OUTER join.
646    """
647    if isinstance(expression, exp.Select):
648        full_outer_joins = [
649            (index, join)
650            for index, join in enumerate(expression.args.get("joins") or [])
651            if join.side == "FULL"
652        ]
653
654        if len(full_outer_joins) == 1:
655            expression_copy = expression.copy()
656            expression.set("limit", None)
657            index, full_outer_join = full_outer_joins[0]
658
659            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
660            join_conditions = full_outer_join.args.get("on") or exp.and_(
661                *[
662                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
663                    for col in full_outer_join.args.get("using")
664                ]
665            )
666
667            full_outer_join.set("side", "left")
668            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
669            expression_copy.args["joins"][index].set("side", "right")
670            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
671            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
672            expression.args.pop("order", None)  # remove order by from LEFT side
673
674            return exp.union(expression, expression_copy, copy=False, distinct=False)
675
676    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
679def move_ctes_to_top_level(expression: E) -> E:
680    """
681    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
682    defined at the top-level, so for example queries like:
683
684        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
685
686    are invalid in those dialects. This transformation can be used to ensure all CTEs are
687    moved to the top level so that the final SQL code is valid from a syntax standpoint.
688
689    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
690    """
691    top_level_with = expression.args.get("with")
692    for inner_with in expression.find_all(exp.With):
693        if inner_with.parent is expression:
694            continue
695
696        if not top_level_with:
697            top_level_with = inner_with.pop()
698            expression.set("with", top_level_with)
699        else:
700            if inner_with.recursive:
701                top_level_with.set("recursive", True)
702
703            parent_cte = inner_with.find_ancestor(exp.CTE)
704            inner_with.pop()
705
706            if parent_cte:
707                i = top_level_with.expressions.index(parent_cte)
708                top_level_with.expressions[i:i] = inner_with.expressions
709                top_level_with.set("expressions", top_level_with.expressions)
710            else:
711                top_level_with.set(
712                    "expressions", top_level_with.expressions + inner_with.expressions
713                )
714
715    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
718def ensure_bools(expression: exp.Expression) -> exp.Expression:
719    """Converts numeric values used in conditions into explicit boolean expressions."""
720    from sqlglot.optimizer.canonicalize import ensure_bools
721
722    def _ensure_bool(node: exp.Expression) -> None:
723        if (
724            node.is_number
725            or (
726                not isinstance(node, exp.SubqueryPredicate)
727                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
728            )
729            or (isinstance(node, exp.Column) and not node.type)
730        ):
731            node.replace(node.neq(0))
732
733    for node in expression.walk():
734        ensure_bools(node, _ensure_bool)
735
736    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
739def unqualify_columns(expression: exp.Expression) -> exp.Expression:
740    for column in expression.find_all(exp.Column):
741        # We only wanna pop off the table, db, catalog args
742        for part in column.parts[:-1]:
743            part.pop()
744
745    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
748def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
749    assert isinstance(expression, exp.Create)
750    for constraint in expression.find_all(exp.UniqueColumnConstraint):
751        if constraint.parent:
752            constraint.parent.pop()
753
754    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
757def ctas_with_tmp_tables_to_create_tmp_view(
758    expression: exp.Expression,
759    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
760) -> exp.Expression:
761    assert isinstance(expression, exp.Create)
762    properties = expression.args.get("properties")
763    temporary = any(
764        isinstance(prop, exp.TemporaryProperty)
765        for prop in (properties.expressions if properties else [])
766    )
767
768    # CTAS with temp tables map to CREATE TEMPORARY VIEW
769    if expression.kind == "TABLE" and temporary:
770        if expression.expression:
771            return exp.Create(
772                kind="TEMPORARY VIEW",
773                this=expression.this,
774                expression=expression.expression,
775            )
776        return tmp_storage_provider(expression)
777
778    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
781def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
782    """
783    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
784    PARTITIONED BY value is an array of column names, they are transformed into a schema.
785    The corresponding columns are removed from the create statement.
786    """
787    assert isinstance(expression, exp.Create)
788    has_schema = isinstance(expression.this, exp.Schema)
789    is_partitionable = expression.kind in {"TABLE", "VIEW"}
790
791    if has_schema and is_partitionable:
792        prop = expression.find(exp.PartitionedByProperty)
793        if prop and prop.this and not isinstance(prop.this, exp.Schema):
794            schema = expression.this
795            columns = {v.name.upper() for v in prop.this.expressions}
796            partitions = [col for col in schema.expressions if col.name.upper() in columns]
797            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
798            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
799            expression.set("this", schema)
800
801    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
804def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
805    """
806    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
807
808    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
809    """
810    assert isinstance(expression, exp.Create)
811    prop = expression.find(exp.PartitionedByProperty)
812    if (
813        prop
814        and prop.this
815        and isinstance(prop.this, exp.Schema)
816        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
817    ):
818        prop_this = exp.Tuple(
819            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
820        )
821        schema = expression.this
822        for e in prop.this.expressions:
823            schema.append("expressions", e)
824        prop.set("this", prop_this)
825
826    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
829def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
830    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
831    if isinstance(expression, exp.Struct):
832        expression.set(
833            "expressions",
834            [
835                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
836                for e in expression.expressions
837            ],
838        )
839
840    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
843def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
844    """
845    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
846    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
847
848    For example,
849        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
850        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
851
852    Args:
853        expression: The AST to remove join marks from.
854
855    Returns:
856       The AST with join marks removed.
857    """
858    from sqlglot.optimizer.scope import traverse_scope
859
860    for scope in traverse_scope(expression):
861        query = scope.expression
862
863        where = query.args.get("where")
864        joins = query.args.get("joins")
865
866        if not where or not joins:
867            continue
868
869        query_from = query.args["from"]
870
871        # These keep track of the joins to be replaced
872        new_joins: t.Dict[str, exp.Join] = {}
873        old_joins = {join.alias_or_name: join for join in joins}
874
875        for column in scope.columns:
876            if not column.args.get("join_mark"):
877                continue
878
879            predicate = column.find_ancestor(exp.Predicate, exp.Select)
880            assert isinstance(
881                predicate, exp.Binary
882            ), "Columns can only be marked with (+) when involved in a binary operation"
883
884            predicate_parent = predicate.parent
885            join_predicate = predicate.pop()
886
887            left_columns = [
888                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
889            ]
890            right_columns = [
891                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
892            ]
893
894            assert not (
895                left_columns and right_columns
896            ), "The (+) marker cannot appear in both sides of a binary predicate"
897
898            marked_column_tables = set()
899            for col in left_columns or right_columns:
900                table = col.table
901                assert table, f"Column {col} needs to be qualified with a table"
902
903                col.set("join_mark", False)
904                marked_column_tables.add(table)
905
906            assert (
907                len(marked_column_tables) == 1
908            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
909
910            join_this = old_joins.get(col.table, query_from).this
911            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
912
913            # Upsert new_join into new_joins dictionary
914            new_join_alias_or_name = new_join.alias_or_name
915            existing_join = new_joins.get(new_join_alias_or_name)
916            if existing_join:
917                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
918            else:
919                new_joins[new_join_alias_or_name] = new_join
920
921            # If the parent of the target predicate is a binary node, then it now has only one child
922            if isinstance(predicate_parent, exp.Binary):
923                if predicate_parent.left is None:
924                    predicate_parent.replace(predicate_parent.right)
925                else:
926                    predicate_parent.replace(predicate_parent.left)
927
928        if query_from.alias_or_name in new_joins:
929            only_old_joins = old_joins.keys() - new_joins.keys()
930            assert (
931                len(only_old_joins) >= 1
932            ), "Cannot determine which table to use in the new FROM clause"
933
934            new_from_name = list(only_old_joins)[0]
935            query.set("from", exp.From(this=old_joins[new_from_name].this))
936
937        query.set("joins", list(new_joins.values()))
938
939        if not where.this:
940            where.pop()
941
942    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
945def any_to_exists(expression: exp.Expression) -> exp.Expression:
946    """
947    Transform ANY operator to Spark's EXISTS
948
949    For example,
950        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
951        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
952
953    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
954    transformation
955    """
956    if isinstance(expression, exp.Select):
957        for any in expression.find_all(exp.Any):
958            this = any.this
959            if isinstance(this, exp.Query):
960                continue
961
962            binop = any.parent
963            if isinstance(binop, exp.Binary):
964                lambda_arg = exp.to_identifier("x")
965                any.replace(lambda_arg)
966                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
967                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
968
969    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation