Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.errors import UnsupportedError
  7from sqlglot.helper import find_new_name, name_sequence
  8
  9
 10if t.TYPE_CHECKING:
 11    from sqlglot._typing import E
 12    from sqlglot.generator import Generator
 13
 14
 15def preprocess(
 16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 17) -> t.Callable[[Generator, exp.Expression], str]:
 18    """
 19    Creates a new transform by chaining a sequence of transformations and converts the resulting
 20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 22
 23    Args:
 24        transforms: sequence of transform functions. These will be called in order.
 25
 26    Returns:
 27        Function that can be used as a generator transform.
 28    """
 29
 30    def _to_sql(self, expression: exp.Expression) -> str:
 31        expression_type = type(expression)
 32
 33        try:
 34            expression = transforms[0](expression)
 35            for transform in transforms[1:]:
 36                expression = transform(expression)
 37        except UnsupportedError as unsupported_error:
 38            self.unsupported(str(unsupported_error))
 39
 40        _sql_handler = getattr(self, expression.key + "_sql", None)
 41        if _sql_handler:
 42            return _sql_handler(expression)
 43
 44        transforms_handler = self.TRANSFORMS.get(type(expression))
 45        if transforms_handler:
 46            if expression_type is type(expression):
 47                if isinstance(expression, exp.Func):
 48                    return self.function_fallback_sql(expression)
 49
 50                # Ensures we don't enter an infinite loop. This can happen when the original expression
 51                # has the same type as the final expression and there's no _sql method available for it,
 52                # because then it'd re-enter _to_sql.
 53                raise ValueError(
 54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 55                )
 56
 57            return transforms_handler(self, expression)
 58
 59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 60
 61    return _to_sql
 62
 63
 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 65    if isinstance(expression, exp.Select):
 66        count = 0
 67        recursive_ctes = []
 68
 69        for unnest in expression.find_all(exp.Unnest):
 70            if (
 71                not isinstance(unnest.parent, (exp.From, exp.Join))
 72                or len(unnest.expressions) != 1
 73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 74            ):
 75                continue
 76
 77            generate_date_array = unnest.expressions[0]
 78            start = generate_date_array.args.get("start")
 79            end = generate_date_array.args.get("end")
 80            step = generate_date_array.args.get("step")
 81
 82            if not start or not end or not isinstance(step, exp.Interval):
 83                continue
 84
 85            alias = unnest.args.get("alias")
 86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 87
 88            start = exp.cast(start, "date")
 89            date_add = exp.func(
 90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 91            )
 92            cast_date_add = exp.cast(date_add, "date")
 93
 94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 95
 96            base_query = exp.select(start.as_(column_name))
 97            recursive_query = (
 98                exp.select(cast_date_add)
 99                .from_(cte_name)
100                .where(cast_date_add <= exp.cast(end, "date"))
101            )
102            cte_query = base_query.union(recursive_query, distinct=False)
103
104            generate_dates_query = exp.select(column_name).from_(cte_name)
105            unnest.replace(generate_dates_query.subquery(cte_name))
106
107            recursive_ctes.append(
108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109            )
110            count += 1
111
112        if recursive_ctes:
113            with_expression = expression.args.get("with") or exp.With()
114            with_expression.set("recursive", True)
115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116            expression.set("with", with_expression)
117
118    return expression
119
120
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
123    this = expression.this
124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125        unnest = exp.Unnest(expressions=[this])
126        if expression.alias:
127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
129        return unnest
130
131    return expression
132
133
134def unalias_group(expression: exp.Expression) -> exp.Expression:
135    """
136    Replace references to select aliases in GROUP BY clauses.
137
138    Example:
139        >>> import sqlglot
140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141        'SELECT a AS b FROM x GROUP BY 1'
142
143    Args:
144        expression: the expression that will be transformed.
145
146    Returns:
147        The transformed expression.
148    """
149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150        aliased_selects = {
151            e.alias: i
152            for i, e in enumerate(expression.parent.expressions, start=1)
153            if isinstance(e, exp.Alias)
154        }
155
156        for group_by in expression.expressions:
157            if (
158                isinstance(group_by, exp.Column)
159                and not group_by.table
160                and group_by.name in aliased_selects
161            ):
162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
164    return expression
165
166
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168    """
169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
173    Args:
174        expression: the expression that will be transformed.
175
176    Returns:
177        The transformed expression.
178    """
179    if (
180        isinstance(expression, exp.Select)
181        and expression.args.get("distinct")
182        and expression.args["distinct"].args.get("on")
183        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
184    ):
185        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
186        outer_selects = expression.selects
187        row_number = find_new_name(expression.named_selects, "_row_number")
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189        order = expression.args.get("order")
190
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number)
197        expression.select(window, copy=False)
198
199        return (
200            exp.select(*outer_selects, copy=False)
201            .from_(expression.subquery("_t", copy=False), copy=False)
202            .where(exp.column(row_number).eq(1), copy=False)
203        )
204
205    return expression
206
207
208def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
209    """
210    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
211
212    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
213    https://docs.snowflake.com/en/sql-reference/constructs/qualify
214
215    Some dialects don't support window functions in the WHERE clause, so we need to include them as
216    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
217    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
218    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
219    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
220    corresponding expression to avoid creating invalid column references.
221    """
222    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
223        taken = set(expression.named_selects)
224        for select in expression.selects:
225            if not select.alias_or_name:
226                alias = find_new_name(taken, "_c")
227                select.replace(exp.alias_(select, alias))
228                taken.add(alias)
229
230        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
231            alias_or_name = select.alias_or_name
232            identifier = select.args.get("alias") or select.this
233            if isinstance(identifier, exp.Identifier):
234                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
235            return alias_or_name
236
237        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
238        qualify_filters = expression.args["qualify"].pop().this
239        expression_by_alias = {
240            select.alias: select.this
241            for select in expression.selects
242            if isinstance(select, exp.Alias)
243        }
244
245        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
246        for select_candidate in qualify_filters.find_all(select_candidates):
247            if isinstance(select_candidate, exp.Window):
248                if expression_by_alias:
249                    for column in select_candidate.find_all(exp.Column):
250                        expr = expression_by_alias.get(column.name)
251                        if expr:
252                            column.replace(expr)
253
254                alias = find_new_name(expression.named_selects, "_w")
255                expression.select(exp.alias_(select_candidate, alias), copy=False)
256                column = exp.column(alias)
257
258                if isinstance(select_candidate.parent, exp.Qualify):
259                    qualify_filters = column
260                else:
261                    select_candidate.replace(column)
262            elif select_candidate.name not in expression.named_selects:
263                expression.select(select_candidate.copy(), copy=False)
264
265        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
266            qualify_filters, copy=False
267        )
268
269    return expression
270
271
272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
273    """
274    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
275    other expressions. This transforms removes the precision from parameterized types in expressions.
276    """
277    for node in expression.find_all(exp.DataType):
278        node.set(
279            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
280        )
281
282    return expression
283
284
285def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
286    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
287    from sqlglot.optimizer.scope import find_all_in_scope
288
289    if isinstance(expression, exp.Select):
290        unnest_aliases = {
291            unnest.alias
292            for unnest in find_all_in_scope(expression, exp.Unnest)
293            if isinstance(unnest.parent, (exp.From, exp.Join))
294        }
295        if unnest_aliases:
296            for column in expression.find_all(exp.Column):
297                if column.table in unnest_aliases:
298                    column.set("table", None)
299                elif column.db in unnest_aliases:
300                    column.set("db", None)
301
302    return expression
303
304
305def unnest_to_explode(
306    expression: exp.Expression,
307    unnest_using_arrays_zip: bool = True,
308) -> exp.Expression:
309    """Convert cross join unnest into lateral view explode."""
310
311    def _unnest_zip_exprs(
312        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
313    ) -> t.List[exp.Expression]:
314        if has_multi_expr:
315            if not unnest_using_arrays_zip:
316                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
317
318            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
319            zip_exprs: t.List[exp.Expression] = [
320                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
321            ]
322            u.set("expressions", zip_exprs)
323            return zip_exprs
324        return unnest_exprs
325
326    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
327        if u.args.get("offset"):
328            return exp.Posexplode
329        return exp.Inline if has_multi_expr else exp.Explode
330
331    if isinstance(expression, exp.Select):
332        from_ = expression.args.get("from")
333
334        if from_ and isinstance(from_.this, exp.Unnest):
335            unnest = from_.this
336            alias = unnest.args.get("alias")
337            exprs = unnest.expressions
338            has_multi_expr = len(exprs) > 1
339            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
340
341            unnest.replace(
342                exp.Table(
343                    this=_udtf_type(unnest, has_multi_expr)(
344                        this=this,
345                        expressions=expressions,
346                    ),
347                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
348                )
349            )
350
351        for join in expression.args.get("joins") or []:
352            join_expr = join.this
353
354            is_lateral = isinstance(join_expr, exp.Lateral)
355
356            unnest = join_expr.this if is_lateral else join_expr
357
358            if isinstance(unnest, exp.Unnest):
359                if is_lateral:
360                    alias = join_expr.args.get("alias")
361                else:
362                    alias = unnest.args.get("alias")
363                exprs = unnest.expressions
364                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
365                has_multi_expr = len(exprs) > 1
366                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
367
368                expression.args["joins"].remove(join)
369
370                alias_cols = alias.columns if alias else []
371                for e, column in zip(exprs, alias_cols):
372                    expression.append(
373                        "laterals",
374                        exp.Lateral(
375                            this=_udtf_type(unnest, has_multi_expr)(this=e),
376                            view=True,
377                            alias=exp.TableAlias(
378                                this=alias.this,  # type: ignore
379                                columns=alias_cols if unnest_using_arrays_zip else [column],  # type: ignore
380                            ),
381                        ),
382                    )
383
384    return expression
385
386
387def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
388    """Convert explode/posexplode into unnest."""
389
390    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
391        if isinstance(expression, exp.Select):
392            from sqlglot.optimizer.scope import Scope
393
394            taken_select_names = set(expression.named_selects)
395            taken_source_names = {name for name, _ in Scope(expression).references}
396
397            def new_name(names: t.Set[str], name: str) -> str:
398                name = find_new_name(names, name)
399                names.add(name)
400                return name
401
402            arrays: t.List[exp.Condition] = []
403            series_alias = new_name(taken_select_names, "pos")
404            series = exp.alias_(
405                exp.Unnest(
406                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
407                ),
408                new_name(taken_source_names, "_u"),
409                table=[series_alias],
410            )
411
412            # we use list here because expression.selects is mutated inside the loop
413            for select in list(expression.selects):
414                explode = select.find(exp.Explode)
415
416                if explode:
417                    pos_alias = ""
418                    explode_alias = ""
419
420                    if isinstance(select, exp.Alias):
421                        explode_alias = select.args["alias"]
422                        alias = select
423                    elif isinstance(select, exp.Aliases):
424                        pos_alias = select.aliases[0]
425                        explode_alias = select.aliases[1]
426                        alias = select.replace(exp.alias_(select.this, "", copy=False))
427                    else:
428                        alias = select.replace(exp.alias_(select, ""))
429                        explode = alias.find(exp.Explode)
430                        assert explode
431
432                    is_posexplode = isinstance(explode, exp.Posexplode)
433                    explode_arg = explode.this
434
435                    if isinstance(explode, exp.ExplodeOuter):
436                        bracket = explode_arg[0]
437                        bracket.set("safe", True)
438                        bracket.set("offset", True)
439                        explode_arg = exp.func(
440                            "IF",
441                            exp.func(
442                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
443                            ).eq(0),
444                            exp.array(bracket, copy=False),
445                            explode_arg,
446                        )
447
448                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
449                    if isinstance(explode_arg, exp.Column):
450                        taken_select_names.add(explode_arg.output_name)
451
452                    unnest_source_alias = new_name(taken_source_names, "_u")
453
454                    if not explode_alias:
455                        explode_alias = new_name(taken_select_names, "col")
456
457                        if is_posexplode:
458                            pos_alias = new_name(taken_select_names, "pos")
459
460                    if not pos_alias:
461                        pos_alias = new_name(taken_select_names, "pos")
462
463                    alias.set("alias", exp.to_identifier(explode_alias))
464
465                    series_table_alias = series.args["alias"].this
466                    column = exp.If(
467                        this=exp.column(series_alias, table=series_table_alias).eq(
468                            exp.column(pos_alias, table=unnest_source_alias)
469                        ),
470                        true=exp.column(explode_alias, table=unnest_source_alias),
471                    )
472
473                    explode.replace(column)
474
475                    if is_posexplode:
476                        expressions = expression.expressions
477                        expressions.insert(
478                            expressions.index(alias) + 1,
479                            exp.If(
480                                this=exp.column(series_alias, table=series_table_alias).eq(
481                                    exp.column(pos_alias, table=unnest_source_alias)
482                                ),
483                                true=exp.column(pos_alias, table=unnest_source_alias),
484                            ).as_(pos_alias),
485                        )
486                        expression.set("expressions", expressions)
487
488                    if not arrays:
489                        if expression.args.get("from"):
490                            expression.join(series, copy=False, join_type="CROSS")
491                        else:
492                            expression.from_(series, copy=False)
493
494                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
495                    arrays.append(size)
496
497                    # trino doesn't support left join unnest with on conditions
498                    # if it did, this would be much simpler
499                    expression.join(
500                        exp.alias_(
501                            exp.Unnest(
502                                expressions=[explode_arg.copy()],
503                                offset=exp.to_identifier(pos_alias),
504                            ),
505                            unnest_source_alias,
506                            table=[explode_alias],
507                        ),
508                        join_type="CROSS",
509                        copy=False,
510                    )
511
512                    if index_offset != 1:
513                        size = size - 1
514
515                    expression.where(
516                        exp.column(series_alias, table=series_table_alias)
517                        .eq(exp.column(pos_alias, table=unnest_source_alias))
518                        .or_(
519                            (exp.column(series_alias, table=series_table_alias) > size).and_(
520                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
521                            )
522                        ),
523                        copy=False,
524                    )
525
526            if arrays:
527                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
528
529                if index_offset != 1:
530                    end = end - (1 - index_offset)
531                series.expressions[0].set("end", end)
532
533        return expression
534
535    return _explode_to_unnest
536
537
538def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
539    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
540    if (
541        isinstance(expression, exp.PERCENTILES)
542        and not isinstance(expression.parent, exp.WithinGroup)
543        and expression.expression
544    ):
545        column = expression.this.pop()
546        expression.set("this", expression.expression.pop())
547        order = exp.Order(expressions=[exp.Ordered(this=column)])
548        expression = exp.WithinGroup(this=expression, expression=order)
549
550    return expression
551
552
553def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
554    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
555    if (
556        isinstance(expression, exp.WithinGroup)
557        and isinstance(expression.this, exp.PERCENTILES)
558        and isinstance(expression.expression, exp.Order)
559    ):
560        quantile = expression.this.this
561        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
562        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
563
564    return expression
565
566
567def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
568    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
569    if isinstance(expression, exp.With) and expression.recursive:
570        next_name = name_sequence("_c_")
571
572        for cte in expression.expressions:
573            if not cte.args["alias"].columns:
574                query = cte.this
575                if isinstance(query, exp.SetOperation):
576                    query = query.this
577
578                cte.args["alias"].set(
579                    "columns",
580                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
581                )
582
583    return expression
584
585
586def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
587    """Replace 'epoch' in casts by the equivalent date literal."""
588    if (
589        isinstance(expression, (exp.Cast, exp.TryCast))
590        and expression.name.lower() == "epoch"
591        and expression.to.this in exp.DataType.TEMPORAL_TYPES
592    ):
593        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
594
595    return expression
596
597
598def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
599    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
600    if isinstance(expression, exp.Select):
601        for join in expression.args.get("joins") or []:
602            on = join.args.get("on")
603            if on and join.kind in ("SEMI", "ANTI"):
604                subquery = exp.select("1").from_(join.this).where(on)
605                exists = exp.Exists(this=subquery)
606                if join.kind == "ANTI":
607                    exists = exists.not_(copy=False)
608
609                join.pop()
610                expression.where(exists, copy=False)
611
612    return expression
613
614
615def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
616    """
617    Converts a query with a FULL OUTER join to a union of identical queries that
618    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
619    for queries that have a single FULL OUTER join.
620    """
621    if isinstance(expression, exp.Select):
622        full_outer_joins = [
623            (index, join)
624            for index, join in enumerate(expression.args.get("joins") or [])
625            if join.side == "FULL"
626        ]
627
628        if len(full_outer_joins) == 1:
629            expression_copy = expression.copy()
630            expression.set("limit", None)
631            index, full_outer_join = full_outer_joins[0]
632
633            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
634            join_conditions = full_outer_join.args.get("on") or exp.and_(
635                *[
636                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
637                    for col in full_outer_join.args.get("using")
638                ]
639            )
640
641            full_outer_join.set("side", "left")
642            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
643            expression_copy.args["joins"][index].set("side", "right")
644            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
645            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
646            expression.args.pop("order", None)  # remove order by from LEFT side
647
648            return exp.union(expression, expression_copy, copy=False, distinct=False)
649
650    return expression
651
652
653def move_ctes_to_top_level(expression: E) -> E:
654    """
655    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
656    defined at the top-level, so for example queries like:
657
658        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
659
660    are invalid in those dialects. This transformation can be used to ensure all CTEs are
661    moved to the top level so that the final SQL code is valid from a syntax standpoint.
662
663    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
664    """
665    top_level_with = expression.args.get("with")
666    for inner_with in expression.find_all(exp.With):
667        if inner_with.parent is expression:
668            continue
669
670        if not top_level_with:
671            top_level_with = inner_with.pop()
672            expression.set("with", top_level_with)
673        else:
674            if inner_with.recursive:
675                top_level_with.set("recursive", True)
676
677            parent_cte = inner_with.find_ancestor(exp.CTE)
678            inner_with.pop()
679
680            if parent_cte:
681                i = top_level_with.expressions.index(parent_cte)
682                top_level_with.expressions[i:i] = inner_with.expressions
683                top_level_with.set("expressions", top_level_with.expressions)
684            else:
685                top_level_with.set(
686                    "expressions", top_level_with.expressions + inner_with.expressions
687                )
688
689    return expression
690
691
692def ensure_bools(expression: exp.Expression) -> exp.Expression:
693    """Converts numeric values used in conditions into explicit boolean expressions."""
694    from sqlglot.optimizer.canonicalize import ensure_bools
695
696    def _ensure_bool(node: exp.Expression) -> None:
697        if (
698            node.is_number
699            or (
700                not isinstance(node, exp.SubqueryPredicate)
701                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
702            )
703            or (isinstance(node, exp.Column) and not node.type)
704        ):
705            node.replace(node.neq(0))
706
707    for node in expression.walk():
708        ensure_bools(node, _ensure_bool)
709
710    return expression
711
712
713def unqualify_columns(expression: exp.Expression) -> exp.Expression:
714    for column in expression.find_all(exp.Column):
715        # We only wanna pop off the table, db, catalog args
716        for part in column.parts[:-1]:
717            part.pop()
718
719    return expression
720
721
722def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
723    assert isinstance(expression, exp.Create)
724    for constraint in expression.find_all(exp.UniqueColumnConstraint):
725        if constraint.parent:
726            constraint.parent.pop()
727
728    return expression
729
730
731def ctas_with_tmp_tables_to_create_tmp_view(
732    expression: exp.Expression,
733    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
734) -> exp.Expression:
735    assert isinstance(expression, exp.Create)
736    properties = expression.args.get("properties")
737    temporary = any(
738        isinstance(prop, exp.TemporaryProperty)
739        for prop in (properties.expressions if properties else [])
740    )
741
742    # CTAS with temp tables map to CREATE TEMPORARY VIEW
743    if expression.kind == "TABLE" and temporary:
744        if expression.expression:
745            return exp.Create(
746                kind="TEMPORARY VIEW",
747                this=expression.this,
748                expression=expression.expression,
749            )
750        return tmp_storage_provider(expression)
751
752    return expression
753
754
755def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
756    """
757    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
758    PARTITIONED BY value is an array of column names, they are transformed into a schema.
759    The corresponding columns are removed from the create statement.
760    """
761    assert isinstance(expression, exp.Create)
762    has_schema = isinstance(expression.this, exp.Schema)
763    is_partitionable = expression.kind in {"TABLE", "VIEW"}
764
765    if has_schema and is_partitionable:
766        prop = expression.find(exp.PartitionedByProperty)
767        if prop and prop.this and not isinstance(prop.this, exp.Schema):
768            schema = expression.this
769            columns = {v.name.upper() for v in prop.this.expressions}
770            partitions = [col for col in schema.expressions if col.name.upper() in columns]
771            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
772            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
773            expression.set("this", schema)
774
775    return expression
776
777
778def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
779    """
780    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
781
782    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
783    """
784    assert isinstance(expression, exp.Create)
785    prop = expression.find(exp.PartitionedByProperty)
786    if (
787        prop
788        and prop.this
789        and isinstance(prop.this, exp.Schema)
790        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
791    ):
792        prop_this = exp.Tuple(
793            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
794        )
795        schema = expression.this
796        for e in prop.this.expressions:
797            schema.append("expressions", e)
798        prop.set("this", prop_this)
799
800    return expression
801
802
803def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
804    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
805    if isinstance(expression, exp.Struct):
806        expression.set(
807            "expressions",
808            [
809                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
810                for e in expression.expressions
811            ],
812        )
813
814    return expression
815
816
817def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
818    """
819    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
820    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
821
822    For example,
823        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
824        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
825
826    Args:
827        expression: The AST to remove join marks from.
828
829    Returns:
830       The AST with join marks removed.
831    """
832    from sqlglot.optimizer.scope import traverse_scope
833
834    for scope in traverse_scope(expression):
835        query = scope.expression
836
837        where = query.args.get("where")
838        joins = query.args.get("joins")
839
840        if not where or not joins:
841            continue
842
843        query_from = query.args["from"]
844
845        # These keep track of the joins to be replaced
846        new_joins: t.Dict[str, exp.Join] = {}
847        old_joins = {join.alias_or_name: join for join in joins}
848
849        for column in scope.columns:
850            if not column.args.get("join_mark"):
851                continue
852
853            predicate = column.find_ancestor(exp.Predicate, exp.Select)
854            assert isinstance(
855                predicate, exp.Binary
856            ), "Columns can only be marked with (+) when involved in a binary operation"
857
858            predicate_parent = predicate.parent
859            join_predicate = predicate.pop()
860
861            left_columns = [
862                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
863            ]
864            right_columns = [
865                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
866            ]
867
868            assert not (
869                left_columns and right_columns
870            ), "The (+) marker cannot appear in both sides of a binary predicate"
871
872            marked_column_tables = set()
873            for col in left_columns or right_columns:
874                table = col.table
875                assert table, f"Column {col} needs to be qualified with a table"
876
877                col.set("join_mark", False)
878                marked_column_tables.add(table)
879
880            assert (
881                len(marked_column_tables) == 1
882            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
883
884            join_this = old_joins.get(col.table, query_from).this
885            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
886
887            # Upsert new_join into new_joins dictionary
888            new_join_alias_or_name = new_join.alias_or_name
889            existing_join = new_joins.get(new_join_alias_or_name)
890            if existing_join:
891                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
892            else:
893                new_joins[new_join_alias_or_name] = new_join
894
895            # If the parent of the target predicate is a binary node, then it now has only one child
896            if isinstance(predicate_parent, exp.Binary):
897                if predicate_parent.left is None:
898                    predicate_parent.replace(predicate_parent.right)
899                else:
900                    predicate_parent.replace(predicate_parent.left)
901
902        if query_from.alias_or_name in new_joins:
903            only_old_joins = old_joins.keys() - new_joins.keys()
904            assert (
905                len(only_old_joins) >= 1
906            ), "Cannot determine which table to use in the new FROM clause"
907
908            new_from_name = list(only_old_joins)[0]
909            query.set("from", exp.From(this=old_joins[new_from_name].this))
910
911        query.set("joins", list(new_joins.values()))
912
913        if not where.this:
914            where.pop()
915
916    return expression
917
918
919def any_to_exists(expression: exp.Expression) -> exp.Expression:
920    """
921    Transform ANY operator to Spark's EXISTS
922
923    For example,
924        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
925        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
926
927    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
928    transformation
929    """
930    if isinstance(expression, exp.Select):
931        for any in expression.find_all(exp.Any):
932            this = any.this
933            if isinstance(this, exp.Query):
934                continue
935
936            binop = any.parent
937            if isinstance(binop, exp.Binary):
938                lambda_arg = exp.to_identifier("x")
939                any.replace(lambda_arg)
940                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
941                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
942
943    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and expression.args["distinct"].args.get("on")
184        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
185    ):
186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
187        outer_selects = expression.selects
188        row_number = find_new_name(expression.named_selects, "_row_number")
189        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
190        order = expression.args.get("order")
191
192        if order:
193            window.set("order", order.pop())
194        else:
195            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
196
197        window = exp.alias_(window, row_number)
198        expression.select(window, copy=False)
199
200        return (
201            exp.select(*outer_selects, copy=False)
202            .from_(expression.subquery("_t", copy=False), copy=False)
203            .where(exp.column(row_number).eq(1), copy=False)
204        )
205
206    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
209def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
210    """
211    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
212
213    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
214    https://docs.snowflake.com/en/sql-reference/constructs/qualify
215
216    Some dialects don't support window functions in the WHERE clause, so we need to include them as
217    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
218    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
219    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
220    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
221    corresponding expression to avoid creating invalid column references.
222    """
223    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
224        taken = set(expression.named_selects)
225        for select in expression.selects:
226            if not select.alias_or_name:
227                alias = find_new_name(taken, "_c")
228                select.replace(exp.alias_(select, alias))
229                taken.add(alias)
230
231        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
232            alias_or_name = select.alias_or_name
233            identifier = select.args.get("alias") or select.this
234            if isinstance(identifier, exp.Identifier):
235                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
236            return alias_or_name
237
238        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
239        qualify_filters = expression.args["qualify"].pop().this
240        expression_by_alias = {
241            select.alias: select.this
242            for select in expression.selects
243            if isinstance(select, exp.Alias)
244        }
245
246        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
247        for select_candidate in qualify_filters.find_all(select_candidates):
248            if isinstance(select_candidate, exp.Window):
249                if expression_by_alias:
250                    for column in select_candidate.find_all(exp.Column):
251                        expr = expression_by_alias.get(column.name)
252                        if expr:
253                            column.replace(expr)
254
255                alias = find_new_name(expression.named_selects, "_w")
256                expression.select(exp.alias_(select_candidate, alias), copy=False)
257                column = exp.column(alias)
258
259                if isinstance(select_candidate.parent, exp.Qualify):
260                    qualify_filters = column
261                else:
262                    select_candidate.replace(column)
263            elif select_candidate.name not in expression.named_selects:
264                expression.select(select_candidate.copy(), copy=False)
265
266        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
267            qualify_filters, copy=False
268        )
269
270    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
273def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
274    """
275    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
276    other expressions. This transforms removes the precision from parameterized types in expressions.
277    """
278    for node in expression.find_all(exp.DataType):
279        node.set(
280            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
281        )
282
283    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
286def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
287    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
288    from sqlglot.optimizer.scope import find_all_in_scope
289
290    if isinstance(expression, exp.Select):
291        unnest_aliases = {
292            unnest.alias
293            for unnest in find_all_in_scope(expression, exp.Unnest)
294            if isinstance(unnest.parent, (exp.From, exp.Join))
295        }
296        if unnest_aliases:
297            for column in expression.find_all(exp.Column):
298                if column.table in unnest_aliases:
299                    column.set("table", None)
300                elif column.db in unnest_aliases:
301                    column.set("db", None)
302
303    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
306def unnest_to_explode(
307    expression: exp.Expression,
308    unnest_using_arrays_zip: bool = True,
309) -> exp.Expression:
310    """Convert cross join unnest into lateral view explode."""
311
312    def _unnest_zip_exprs(
313        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
314    ) -> t.List[exp.Expression]:
315        if has_multi_expr:
316            if not unnest_using_arrays_zip:
317                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
318
319            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
320            zip_exprs: t.List[exp.Expression] = [
321                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
322            ]
323            u.set("expressions", zip_exprs)
324            return zip_exprs
325        return unnest_exprs
326
327    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
328        if u.args.get("offset"):
329            return exp.Posexplode
330        return exp.Inline if has_multi_expr else exp.Explode
331
332    if isinstance(expression, exp.Select):
333        from_ = expression.args.get("from")
334
335        if from_ and isinstance(from_.this, exp.Unnest):
336            unnest = from_.this
337            alias = unnest.args.get("alias")
338            exprs = unnest.expressions
339            has_multi_expr = len(exprs) > 1
340            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
341
342            unnest.replace(
343                exp.Table(
344                    this=_udtf_type(unnest, has_multi_expr)(
345                        this=this,
346                        expressions=expressions,
347                    ),
348                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
349                )
350            )
351
352        for join in expression.args.get("joins") or []:
353            join_expr = join.this
354
355            is_lateral = isinstance(join_expr, exp.Lateral)
356
357            unnest = join_expr.this if is_lateral else join_expr
358
359            if isinstance(unnest, exp.Unnest):
360                if is_lateral:
361                    alias = join_expr.args.get("alias")
362                else:
363                    alias = unnest.args.get("alias")
364                exprs = unnest.expressions
365                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
366                has_multi_expr = len(exprs) > 1
367                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
368
369                expression.args["joins"].remove(join)
370
371                alias_cols = alias.columns if alias else []
372                for e, column in zip(exprs, alias_cols):
373                    expression.append(
374                        "laterals",
375                        exp.Lateral(
376                            this=_udtf_type(unnest, has_multi_expr)(this=e),
377                            view=True,
378                            alias=exp.TableAlias(
379                                this=alias.this,  # type: ignore
380                                columns=alias_cols if unnest_using_arrays_zip else [column],  # type: ignore
381                            ),
382                        ),
383                    )
384
385    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
388def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
389    """Convert explode/posexplode into unnest."""
390
391    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
392        if isinstance(expression, exp.Select):
393            from sqlglot.optimizer.scope import Scope
394
395            taken_select_names = set(expression.named_selects)
396            taken_source_names = {name for name, _ in Scope(expression).references}
397
398            def new_name(names: t.Set[str], name: str) -> str:
399                name = find_new_name(names, name)
400                names.add(name)
401                return name
402
403            arrays: t.List[exp.Condition] = []
404            series_alias = new_name(taken_select_names, "pos")
405            series = exp.alias_(
406                exp.Unnest(
407                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
408                ),
409                new_name(taken_source_names, "_u"),
410                table=[series_alias],
411            )
412
413            # we use list here because expression.selects is mutated inside the loop
414            for select in list(expression.selects):
415                explode = select.find(exp.Explode)
416
417                if explode:
418                    pos_alias = ""
419                    explode_alias = ""
420
421                    if isinstance(select, exp.Alias):
422                        explode_alias = select.args["alias"]
423                        alias = select
424                    elif isinstance(select, exp.Aliases):
425                        pos_alias = select.aliases[0]
426                        explode_alias = select.aliases[1]
427                        alias = select.replace(exp.alias_(select.this, "", copy=False))
428                    else:
429                        alias = select.replace(exp.alias_(select, ""))
430                        explode = alias.find(exp.Explode)
431                        assert explode
432
433                    is_posexplode = isinstance(explode, exp.Posexplode)
434                    explode_arg = explode.this
435
436                    if isinstance(explode, exp.ExplodeOuter):
437                        bracket = explode_arg[0]
438                        bracket.set("safe", True)
439                        bracket.set("offset", True)
440                        explode_arg = exp.func(
441                            "IF",
442                            exp.func(
443                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
444                            ).eq(0),
445                            exp.array(bracket, copy=False),
446                            explode_arg,
447                        )
448
449                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
450                    if isinstance(explode_arg, exp.Column):
451                        taken_select_names.add(explode_arg.output_name)
452
453                    unnest_source_alias = new_name(taken_source_names, "_u")
454
455                    if not explode_alias:
456                        explode_alias = new_name(taken_select_names, "col")
457
458                        if is_posexplode:
459                            pos_alias = new_name(taken_select_names, "pos")
460
461                    if not pos_alias:
462                        pos_alias = new_name(taken_select_names, "pos")
463
464                    alias.set("alias", exp.to_identifier(explode_alias))
465
466                    series_table_alias = series.args["alias"].this
467                    column = exp.If(
468                        this=exp.column(series_alias, table=series_table_alias).eq(
469                            exp.column(pos_alias, table=unnest_source_alias)
470                        ),
471                        true=exp.column(explode_alias, table=unnest_source_alias),
472                    )
473
474                    explode.replace(column)
475
476                    if is_posexplode:
477                        expressions = expression.expressions
478                        expressions.insert(
479                            expressions.index(alias) + 1,
480                            exp.If(
481                                this=exp.column(series_alias, table=series_table_alias).eq(
482                                    exp.column(pos_alias, table=unnest_source_alias)
483                                ),
484                                true=exp.column(pos_alias, table=unnest_source_alias),
485                            ).as_(pos_alias),
486                        )
487                        expression.set("expressions", expressions)
488
489                    if not arrays:
490                        if expression.args.get("from"):
491                            expression.join(series, copy=False, join_type="CROSS")
492                        else:
493                            expression.from_(series, copy=False)
494
495                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
496                    arrays.append(size)
497
498                    # trino doesn't support left join unnest with on conditions
499                    # if it did, this would be much simpler
500                    expression.join(
501                        exp.alias_(
502                            exp.Unnest(
503                                expressions=[explode_arg.copy()],
504                                offset=exp.to_identifier(pos_alias),
505                            ),
506                            unnest_source_alias,
507                            table=[explode_alias],
508                        ),
509                        join_type="CROSS",
510                        copy=False,
511                    )
512
513                    if index_offset != 1:
514                        size = size - 1
515
516                    expression.where(
517                        exp.column(series_alias, table=series_table_alias)
518                        .eq(exp.column(pos_alias, table=unnest_source_alias))
519                        .or_(
520                            (exp.column(series_alias, table=series_table_alias) > size).and_(
521                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
522                            )
523                        ),
524                        copy=False,
525                    )
526
527            if arrays:
528                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
529
530                if index_offset != 1:
531                    end = end - (1 - index_offset)
532                series.expressions[0].set("end", end)
533
534        return expression
535
536    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
539def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
540    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
541    if (
542        isinstance(expression, exp.PERCENTILES)
543        and not isinstance(expression.parent, exp.WithinGroup)
544        and expression.expression
545    ):
546        column = expression.this.pop()
547        expression.set("this", expression.expression.pop())
548        order = exp.Order(expressions=[exp.Ordered(this=column)])
549        expression = exp.WithinGroup(this=expression, expression=order)
550
551    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
554def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
555    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
556    if (
557        isinstance(expression, exp.WithinGroup)
558        and isinstance(expression.this, exp.PERCENTILES)
559        and isinstance(expression.expression, exp.Order)
560    ):
561        quantile = expression.this.this
562        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
563        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
564
565    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
568def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
569    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
570    if isinstance(expression, exp.With) and expression.recursive:
571        next_name = name_sequence("_c_")
572
573        for cte in expression.expressions:
574            if not cte.args["alias"].columns:
575                query = cte.this
576                if isinstance(query, exp.SetOperation):
577                    query = query.this
578
579                cte.args["alias"].set(
580                    "columns",
581                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
582                )
583
584    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
587def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
588    """Replace 'epoch' in casts by the equivalent date literal."""
589    if (
590        isinstance(expression, (exp.Cast, exp.TryCast))
591        and expression.name.lower() == "epoch"
592        and expression.to.this in exp.DataType.TEMPORAL_TYPES
593    ):
594        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
595
596    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
599def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
600    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
601    if isinstance(expression, exp.Select):
602        for join in expression.args.get("joins") or []:
603            on = join.args.get("on")
604            if on and join.kind in ("SEMI", "ANTI"):
605                subquery = exp.select("1").from_(join.this).where(on)
606                exists = exp.Exists(this=subquery)
607                if join.kind == "ANTI":
608                    exists = exists.not_(copy=False)
609
610                join.pop()
611                expression.where(exists, copy=False)
612
613    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
616def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
617    """
618    Converts a query with a FULL OUTER join to a union of identical queries that
619    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
620    for queries that have a single FULL OUTER join.
621    """
622    if isinstance(expression, exp.Select):
623        full_outer_joins = [
624            (index, join)
625            for index, join in enumerate(expression.args.get("joins") or [])
626            if join.side == "FULL"
627        ]
628
629        if len(full_outer_joins) == 1:
630            expression_copy = expression.copy()
631            expression.set("limit", None)
632            index, full_outer_join = full_outer_joins[0]
633
634            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
635            join_conditions = full_outer_join.args.get("on") or exp.and_(
636                *[
637                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
638                    for col in full_outer_join.args.get("using")
639                ]
640            )
641
642            full_outer_join.set("side", "left")
643            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
644            expression_copy.args["joins"][index].set("side", "right")
645            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
646            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
647            expression.args.pop("order", None)  # remove order by from LEFT side
648
649            return exp.union(expression, expression_copy, copy=False, distinct=False)
650
651    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
654def move_ctes_to_top_level(expression: E) -> E:
655    """
656    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
657    defined at the top-level, so for example queries like:
658
659        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
660
661    are invalid in those dialects. This transformation can be used to ensure all CTEs are
662    moved to the top level so that the final SQL code is valid from a syntax standpoint.
663
664    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
665    """
666    top_level_with = expression.args.get("with")
667    for inner_with in expression.find_all(exp.With):
668        if inner_with.parent is expression:
669            continue
670
671        if not top_level_with:
672            top_level_with = inner_with.pop()
673            expression.set("with", top_level_with)
674        else:
675            if inner_with.recursive:
676                top_level_with.set("recursive", True)
677
678            parent_cte = inner_with.find_ancestor(exp.CTE)
679            inner_with.pop()
680
681            if parent_cte:
682                i = top_level_with.expressions.index(parent_cte)
683                top_level_with.expressions[i:i] = inner_with.expressions
684                top_level_with.set("expressions", top_level_with.expressions)
685            else:
686                top_level_with.set(
687                    "expressions", top_level_with.expressions + inner_with.expressions
688                )
689
690    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
693def ensure_bools(expression: exp.Expression) -> exp.Expression:
694    """Converts numeric values used in conditions into explicit boolean expressions."""
695    from sqlglot.optimizer.canonicalize import ensure_bools
696
697    def _ensure_bool(node: exp.Expression) -> None:
698        if (
699            node.is_number
700            or (
701                not isinstance(node, exp.SubqueryPredicate)
702                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
703            )
704            or (isinstance(node, exp.Column) and not node.type)
705        ):
706            node.replace(node.neq(0))
707
708    for node in expression.walk():
709        ensure_bools(node, _ensure_bool)
710
711    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
714def unqualify_columns(expression: exp.Expression) -> exp.Expression:
715    for column in expression.find_all(exp.Column):
716        # We only wanna pop off the table, db, catalog args
717        for part in column.parts[:-1]:
718            part.pop()
719
720    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
723def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
724    assert isinstance(expression, exp.Create)
725    for constraint in expression.find_all(exp.UniqueColumnConstraint):
726        if constraint.parent:
727            constraint.parent.pop()
728
729    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
732def ctas_with_tmp_tables_to_create_tmp_view(
733    expression: exp.Expression,
734    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
735) -> exp.Expression:
736    assert isinstance(expression, exp.Create)
737    properties = expression.args.get("properties")
738    temporary = any(
739        isinstance(prop, exp.TemporaryProperty)
740        for prop in (properties.expressions if properties else [])
741    )
742
743    # CTAS with temp tables map to CREATE TEMPORARY VIEW
744    if expression.kind == "TABLE" and temporary:
745        if expression.expression:
746            return exp.Create(
747                kind="TEMPORARY VIEW",
748                this=expression.this,
749                expression=expression.expression,
750            )
751        return tmp_storage_provider(expression)
752
753    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
756def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
757    """
758    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
759    PARTITIONED BY value is an array of column names, they are transformed into a schema.
760    The corresponding columns are removed from the create statement.
761    """
762    assert isinstance(expression, exp.Create)
763    has_schema = isinstance(expression.this, exp.Schema)
764    is_partitionable = expression.kind in {"TABLE", "VIEW"}
765
766    if has_schema and is_partitionable:
767        prop = expression.find(exp.PartitionedByProperty)
768        if prop and prop.this and not isinstance(prop.this, exp.Schema):
769            schema = expression.this
770            columns = {v.name.upper() for v in prop.this.expressions}
771            partitions = [col for col in schema.expressions if col.name.upper() in columns]
772            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
773            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
774            expression.set("this", schema)
775
776    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
779def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
780    """
781    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
782
783    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
784    """
785    assert isinstance(expression, exp.Create)
786    prop = expression.find(exp.PartitionedByProperty)
787    if (
788        prop
789        and prop.this
790        and isinstance(prop.this, exp.Schema)
791        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
792    ):
793        prop_this = exp.Tuple(
794            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
795        )
796        schema = expression.this
797        for e in prop.this.expressions:
798            schema.append("expressions", e)
799        prop.set("this", prop_this)
800
801    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
804def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
805    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
806    if isinstance(expression, exp.Struct):
807        expression.set(
808            "expressions",
809            [
810                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
811                for e in expression.expressions
812            ],
813        )
814
815    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
818def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
819    """
820    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
821    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
822
823    For example,
824        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
825        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
826
827    Args:
828        expression: The AST to remove join marks from.
829
830    Returns:
831       The AST with join marks removed.
832    """
833    from sqlglot.optimizer.scope import traverse_scope
834
835    for scope in traverse_scope(expression):
836        query = scope.expression
837
838        where = query.args.get("where")
839        joins = query.args.get("joins")
840
841        if not where or not joins:
842            continue
843
844        query_from = query.args["from"]
845
846        # These keep track of the joins to be replaced
847        new_joins: t.Dict[str, exp.Join] = {}
848        old_joins = {join.alias_or_name: join for join in joins}
849
850        for column in scope.columns:
851            if not column.args.get("join_mark"):
852                continue
853
854            predicate = column.find_ancestor(exp.Predicate, exp.Select)
855            assert isinstance(
856                predicate, exp.Binary
857            ), "Columns can only be marked with (+) when involved in a binary operation"
858
859            predicate_parent = predicate.parent
860            join_predicate = predicate.pop()
861
862            left_columns = [
863                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
864            ]
865            right_columns = [
866                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
867            ]
868
869            assert not (
870                left_columns and right_columns
871            ), "The (+) marker cannot appear in both sides of a binary predicate"
872
873            marked_column_tables = set()
874            for col in left_columns or right_columns:
875                table = col.table
876                assert table, f"Column {col} needs to be qualified with a table"
877
878                col.set("join_mark", False)
879                marked_column_tables.add(table)
880
881            assert (
882                len(marked_column_tables) == 1
883            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
884
885            join_this = old_joins.get(col.table, query_from).this
886            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
887
888            # Upsert new_join into new_joins dictionary
889            new_join_alias_or_name = new_join.alias_or_name
890            existing_join = new_joins.get(new_join_alias_or_name)
891            if existing_join:
892                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
893            else:
894                new_joins[new_join_alias_or_name] = new_join
895
896            # If the parent of the target predicate is a binary node, then it now has only one child
897            if isinstance(predicate_parent, exp.Binary):
898                if predicate_parent.left is None:
899                    predicate_parent.replace(predicate_parent.right)
900                else:
901                    predicate_parent.replace(predicate_parent.left)
902
903        if query_from.alias_or_name in new_joins:
904            only_old_joins = old_joins.keys() - new_joins.keys()
905            assert (
906                len(only_old_joins) >= 1
907            ), "Cannot determine which table to use in the new FROM clause"
908
909            new_from_name = list(only_old_joins)[0]
910            query.set("from", exp.From(this=old_joins[new_from_name].this))
911
912        query.set("joins", list(new_joins.values()))
913
914        if not where.this:
915            where.pop()
916
917    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
920def any_to_exists(expression: exp.Expression) -> exp.Expression:
921    """
922    Transform ANY operator to Spark's EXISTS
923
924    For example,
925        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
926        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
927
928    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
929    transformation
930    """
931    if isinstance(expression, exp.Select):
932        for any in expression.find_all(exp.Any):
933            this = any.this
934            if isinstance(this, exp.Query):
935                continue
936
937            binop = any.parent
938            if isinstance(binop, exp.Binary):
939                lambda_arg = exp.to_identifier("x")
940                any.replace(lambda_arg)
941                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
942                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
943
944    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation