Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def preprocess(
 13    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 14) -> t.Callable[[Generator, exp.Expression], str]:
 15    """
 16    Creates a new transform by chaining a sequence of transformations and converts the resulting
 17    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 18    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 19
 20    Args:
 21        transforms: sequence of transform functions. These will be called in order.
 22
 23    Returns:
 24        Function that can be used as a generator transform.
 25    """
 26
 27    def _to_sql(self, expression: exp.Expression) -> str:
 28        expression_type = type(expression)
 29
 30        expression = transforms[0](expression)
 31        for transform in transforms[1:]:
 32            expression = transform(expression)
 33
 34        _sql_handler = getattr(self, expression.key + "_sql", None)
 35        if _sql_handler:
 36            return _sql_handler(expression)
 37
 38        transforms_handler = self.TRANSFORMS.get(type(expression))
 39        if transforms_handler:
 40            if expression_type is type(expression):
 41                if isinstance(expression, exp.Func):
 42                    return self.function_fallback_sql(expression)
 43
 44                # Ensures we don't enter an infinite loop. This can happen when the original expression
 45                # has the same type as the final expression and there's no _sql method available for it,
 46                # because then it'd re-enter _to_sql.
 47                raise ValueError(
 48                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 49                )
 50
 51            return transforms_handler(self, expression)
 52
 53        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 54
 55    return _to_sql
 56
 57
 58def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 59    if isinstance(expression, exp.Select):
 60        count = 0
 61        recursive_ctes = []
 62
 63        for unnest in expression.find_all(exp.Unnest):
 64            if (
 65                not isinstance(unnest.parent, (exp.From, exp.Join))
 66                or len(unnest.expressions) != 1
 67                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 68            ):
 69                continue
 70
 71            generate_date_array = unnest.expressions[0]
 72            start = generate_date_array.args.get("start")
 73            end = generate_date_array.args.get("end")
 74            step = generate_date_array.args.get("step")
 75
 76            if not start or not end or not isinstance(step, exp.Interval):
 77                continue
 78
 79            alias = unnest.args.get("alias")
 80            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 81
 82            start = exp.cast(start, "date")
 83            date_add = exp.func(
 84                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 85            )
 86            cast_date_add = exp.cast(date_add, "date")
 87
 88            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 89
 90            base_query = exp.select(start.as_(column_name))
 91            recursive_query = (
 92                exp.select(cast_date_add)
 93                .from_(cte_name)
 94                .where(cast_date_add <= exp.cast(end, "date"))
 95            )
 96            cte_query = base_query.union(recursive_query, distinct=False)
 97
 98            generate_dates_query = exp.select(column_name).from_(cte_name)
 99            unnest.replace(generate_dates_query.subquery(cte_name))
100
101            recursive_ctes.append(
102                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
103            )
104            count += 1
105
106        if recursive_ctes:
107            with_expression = expression.args.get("with") or exp.With()
108            with_expression.set("recursive", True)
109            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
110            expression.set("with", with_expression)
111
112    return expression
113
114
115def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
116    """Unnests GENERATE_SERIES or SEQUENCE table references."""
117    this = expression.this
118    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
119        unnest = exp.Unnest(expressions=[this])
120        if expression.alias:
121            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
122
123        return unnest
124
125    return expression
126
127
128def unalias_group(expression: exp.Expression) -> exp.Expression:
129    """
130    Replace references to select aliases in GROUP BY clauses.
131
132    Example:
133        >>> import sqlglot
134        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
135        'SELECT a AS b FROM x GROUP BY 1'
136
137    Args:
138        expression: the expression that will be transformed.
139
140    Returns:
141        The transformed expression.
142    """
143    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
144        aliased_selects = {
145            e.alias: i
146            for i, e in enumerate(expression.parent.expressions, start=1)
147            if isinstance(e, exp.Alias)
148        }
149
150        for group_by in expression.expressions:
151            if (
152                isinstance(group_by, exp.Column)
153                and not group_by.table
154                and group_by.name in aliased_selects
155            ):
156                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
157
158    return expression
159
160
161def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
162    """
163    Convert SELECT DISTINCT ON statements to a subquery with a window function.
164
165    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
166
167    Args:
168        expression: the expression that will be transformed.
169
170    Returns:
171        The transformed expression.
172    """
173    if (
174        isinstance(expression, exp.Select)
175        and expression.args.get("distinct")
176        and expression.args["distinct"].args.get("on")
177        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
178    ):
179        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
180        outer_selects = expression.selects
181        row_number = find_new_name(expression.named_selects, "_row_number")
182        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
183        order = expression.args.get("order")
184
185        if order:
186            window.set("order", order.pop())
187        else:
188            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
189
190        window = exp.alias_(window, row_number)
191        expression.select(window, copy=False)
192
193        return (
194            exp.select(*outer_selects, copy=False)
195            .from_(expression.subquery("_t", copy=False), copy=False)
196            .where(exp.column(row_number).eq(1), copy=False)
197        )
198
199    return expression
200
201
202def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
203    """
204    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
205
206    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
207    https://docs.snowflake.com/en/sql-reference/constructs/qualify
208
209    Some dialects don't support window functions in the WHERE clause, so we need to include them as
210    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
211    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
212    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
213    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
214    corresponding expression to avoid creating invalid column references.
215    """
216    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
217        taken = set(expression.named_selects)
218        for select in expression.selects:
219            if not select.alias_or_name:
220                alias = find_new_name(taken, "_c")
221                select.replace(exp.alias_(select, alias))
222                taken.add(alias)
223
224        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
225            alias_or_name = select.alias_or_name
226            identifier = select.args.get("alias") or select.this
227            if isinstance(identifier, exp.Identifier):
228                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
229            return alias_or_name
230
231        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
232        qualify_filters = expression.args["qualify"].pop().this
233        expression_by_alias = {
234            select.alias: select.this
235            for select in expression.selects
236            if isinstance(select, exp.Alias)
237        }
238
239        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
240        for select_candidate in qualify_filters.find_all(select_candidates):
241            if isinstance(select_candidate, exp.Window):
242                if expression_by_alias:
243                    for column in select_candidate.find_all(exp.Column):
244                        expr = expression_by_alias.get(column.name)
245                        if expr:
246                            column.replace(expr)
247
248                alias = find_new_name(expression.named_selects, "_w")
249                expression.select(exp.alias_(select_candidate, alias), copy=False)
250                column = exp.column(alias)
251
252                if isinstance(select_candidate.parent, exp.Qualify):
253                    qualify_filters = column
254                else:
255                    select_candidate.replace(column)
256            elif select_candidate.name not in expression.named_selects:
257                expression.select(select_candidate.copy(), copy=False)
258
259        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
260            qualify_filters, copy=False
261        )
262
263    return expression
264
265
266def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
267    """
268    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
269    other expressions. This transforms removes the precision from parameterized types in expressions.
270    """
271    for node in expression.find_all(exp.DataType):
272        node.set(
273            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
274        )
275
276    return expression
277
278
279def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
280    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
281    from sqlglot.optimizer.scope import find_all_in_scope
282
283    if isinstance(expression, exp.Select):
284        unnest_aliases = {
285            unnest.alias
286            for unnest in find_all_in_scope(expression, exp.Unnest)
287            if isinstance(unnest.parent, (exp.From, exp.Join))
288        }
289        if unnest_aliases:
290            for column in expression.find_all(exp.Column):
291                if column.table in unnest_aliases:
292                    column.set("table", None)
293                elif column.db in unnest_aliases:
294                    column.set("db", None)
295
296    return expression
297
298
299def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
300    """Convert cross join unnest into lateral view explode."""
301    if isinstance(expression, exp.Select):
302        from_ = expression.args.get("from")
303
304        if from_ and isinstance(from_.this, exp.Unnest):
305            unnest = from_.this
306            alias = unnest.args.get("alias")
307            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
308            this, *expressions = unnest.expressions
309            unnest.replace(
310                exp.Table(
311                    this=udtf(
312                        this=this,
313                        expressions=expressions,
314                    ),
315                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
316                )
317            )
318
319        for join in expression.args.get("joins") or []:
320            join_expr = join.this
321
322            is_lateral = isinstance(join_expr, exp.Lateral)
323
324            unnest = join_expr.this if is_lateral else join_expr
325
326            if isinstance(unnest, exp.Unnest):
327                alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
328                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
329
330                expression.args["joins"].remove(join)
331
332                for e, column in zip(unnest.expressions, alias.columns if alias else []):
333                    expression.append(
334                        "laterals",
335                        exp.Lateral(
336                            this=udtf(this=e),
337                            view=True,
338                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
339                        ),
340                    )
341
342    return expression
343
344
345def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
346    """Convert explode/posexplode into unnest."""
347
348    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
349        if isinstance(expression, exp.Select):
350            from sqlglot.optimizer.scope import Scope
351
352            taken_select_names = set(expression.named_selects)
353            taken_source_names = {name for name, _ in Scope(expression).references}
354
355            def new_name(names: t.Set[str], name: str) -> str:
356                name = find_new_name(names, name)
357                names.add(name)
358                return name
359
360            arrays: t.List[exp.Condition] = []
361            series_alias = new_name(taken_select_names, "pos")
362            series = exp.alias_(
363                exp.Unnest(
364                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
365                ),
366                new_name(taken_source_names, "_u"),
367                table=[series_alias],
368            )
369
370            # we use list here because expression.selects is mutated inside the loop
371            for select in list(expression.selects):
372                explode = select.find(exp.Explode)
373
374                if explode:
375                    pos_alias = ""
376                    explode_alias = ""
377
378                    if isinstance(select, exp.Alias):
379                        explode_alias = select.args["alias"]
380                        alias = select
381                    elif isinstance(select, exp.Aliases):
382                        pos_alias = select.aliases[0]
383                        explode_alias = select.aliases[1]
384                        alias = select.replace(exp.alias_(select.this, "", copy=False))
385                    else:
386                        alias = select.replace(exp.alias_(select, ""))
387                        explode = alias.find(exp.Explode)
388                        assert explode
389
390                    is_posexplode = isinstance(explode, exp.Posexplode)
391                    explode_arg = explode.this
392
393                    if isinstance(explode, exp.ExplodeOuter):
394                        bracket = explode_arg[0]
395                        bracket.set("safe", True)
396                        bracket.set("offset", True)
397                        explode_arg = exp.func(
398                            "IF",
399                            exp.func(
400                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
401                            ).eq(0),
402                            exp.array(bracket, copy=False),
403                            explode_arg,
404                        )
405
406                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
407                    if isinstance(explode_arg, exp.Column):
408                        taken_select_names.add(explode_arg.output_name)
409
410                    unnest_source_alias = new_name(taken_source_names, "_u")
411
412                    if not explode_alias:
413                        explode_alias = new_name(taken_select_names, "col")
414
415                        if is_posexplode:
416                            pos_alias = new_name(taken_select_names, "pos")
417
418                    if not pos_alias:
419                        pos_alias = new_name(taken_select_names, "pos")
420
421                    alias.set("alias", exp.to_identifier(explode_alias))
422
423                    series_table_alias = series.args["alias"].this
424                    column = exp.If(
425                        this=exp.column(series_alias, table=series_table_alias).eq(
426                            exp.column(pos_alias, table=unnest_source_alias)
427                        ),
428                        true=exp.column(explode_alias, table=unnest_source_alias),
429                    )
430
431                    explode.replace(column)
432
433                    if is_posexplode:
434                        expressions = expression.expressions
435                        expressions.insert(
436                            expressions.index(alias) + 1,
437                            exp.If(
438                                this=exp.column(series_alias, table=series_table_alias).eq(
439                                    exp.column(pos_alias, table=unnest_source_alias)
440                                ),
441                                true=exp.column(pos_alias, table=unnest_source_alias),
442                            ).as_(pos_alias),
443                        )
444                        expression.set("expressions", expressions)
445
446                    if not arrays:
447                        if expression.args.get("from"):
448                            expression.join(series, copy=False, join_type="CROSS")
449                        else:
450                            expression.from_(series, copy=False)
451
452                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
453                    arrays.append(size)
454
455                    # trino doesn't support left join unnest with on conditions
456                    # if it did, this would be much simpler
457                    expression.join(
458                        exp.alias_(
459                            exp.Unnest(
460                                expressions=[explode_arg.copy()],
461                                offset=exp.to_identifier(pos_alias),
462                            ),
463                            unnest_source_alias,
464                            table=[explode_alias],
465                        ),
466                        join_type="CROSS",
467                        copy=False,
468                    )
469
470                    if index_offset != 1:
471                        size = size - 1
472
473                    expression.where(
474                        exp.column(series_alias, table=series_table_alias)
475                        .eq(exp.column(pos_alias, table=unnest_source_alias))
476                        .or_(
477                            (exp.column(series_alias, table=series_table_alias) > size).and_(
478                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
479                            )
480                        ),
481                        copy=False,
482                    )
483
484            if arrays:
485                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
486
487                if index_offset != 1:
488                    end = end - (1 - index_offset)
489                series.expressions[0].set("end", end)
490
491        return expression
492
493    return _explode_to_unnest
494
495
496def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
497    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
498    if (
499        isinstance(expression, exp.PERCENTILES)
500        and not isinstance(expression.parent, exp.WithinGroup)
501        and expression.expression
502    ):
503        column = expression.this.pop()
504        expression.set("this", expression.expression.pop())
505        order = exp.Order(expressions=[exp.Ordered(this=column)])
506        expression = exp.WithinGroup(this=expression, expression=order)
507
508    return expression
509
510
511def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
512    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
513    if (
514        isinstance(expression, exp.WithinGroup)
515        and isinstance(expression.this, exp.PERCENTILES)
516        and isinstance(expression.expression, exp.Order)
517    ):
518        quantile = expression.this.this
519        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
520        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
521
522    return expression
523
524
525def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
526    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
527    if isinstance(expression, exp.With) and expression.recursive:
528        next_name = name_sequence("_c_")
529
530        for cte in expression.expressions:
531            if not cte.args["alias"].columns:
532                query = cte.this
533                if isinstance(query, exp.SetOperation):
534                    query = query.this
535
536                cte.args["alias"].set(
537                    "columns",
538                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
539                )
540
541    return expression
542
543
544def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
545    """Replace 'epoch' in casts by the equivalent date literal."""
546    if (
547        isinstance(expression, (exp.Cast, exp.TryCast))
548        and expression.name.lower() == "epoch"
549        and expression.to.this in exp.DataType.TEMPORAL_TYPES
550    ):
551        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
552
553    return expression
554
555
556def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
557    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
558    if isinstance(expression, exp.Select):
559        for join in expression.args.get("joins") or []:
560            on = join.args.get("on")
561            if on and join.kind in ("SEMI", "ANTI"):
562                subquery = exp.select("1").from_(join.this).where(on)
563                exists = exp.Exists(this=subquery)
564                if join.kind == "ANTI":
565                    exists = exists.not_(copy=False)
566
567                join.pop()
568                expression.where(exists, copy=False)
569
570    return expression
571
572
573def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
574    """
575    Converts a query with a FULL OUTER join to a union of identical queries that
576    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
577    for queries that have a single FULL OUTER join.
578    """
579    if isinstance(expression, exp.Select):
580        full_outer_joins = [
581            (index, join)
582            for index, join in enumerate(expression.args.get("joins") or [])
583            if join.side == "FULL"
584        ]
585
586        if len(full_outer_joins) == 1:
587            expression_copy = expression.copy()
588            expression.set("limit", None)
589            index, full_outer_join = full_outer_joins[0]
590            full_outer_join.set("side", "left")
591            expression_copy.args["joins"][index].set("side", "right")
592            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
593
594            return exp.union(expression, expression_copy, copy=False)
595
596    return expression
597
598
599def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
600    """
601    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
602    defined at the top-level, so for example queries like:
603
604        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
605
606    are invalid in those dialects. This transformation can be used to ensure all CTEs are
607    moved to the top level so that the final SQL code is valid from a syntax standpoint.
608
609    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
610    """
611    top_level_with = expression.args.get("with")
612    for inner_with in expression.find_all(exp.With):
613        if inner_with.parent is expression:
614            continue
615
616        if not top_level_with:
617            top_level_with = inner_with.pop()
618            expression.set("with", top_level_with)
619        else:
620            if inner_with.recursive:
621                top_level_with.set("recursive", True)
622
623            parent_cte = inner_with.find_ancestor(exp.CTE)
624            inner_with.pop()
625
626            if parent_cte:
627                i = top_level_with.expressions.index(parent_cte)
628                top_level_with.expressions[i:i] = inner_with.expressions
629                top_level_with.set("expressions", top_level_with.expressions)
630            else:
631                top_level_with.set(
632                    "expressions", top_level_with.expressions + inner_with.expressions
633                )
634
635    return expression
636
637
638def ensure_bools(expression: exp.Expression) -> exp.Expression:
639    """Converts numeric values used in conditions into explicit boolean expressions."""
640    from sqlglot.optimizer.canonicalize import ensure_bools
641
642    def _ensure_bool(node: exp.Expression) -> None:
643        if (
644            node.is_number
645            or (
646                not isinstance(node, exp.SubqueryPredicate)
647                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
648            )
649            or (isinstance(node, exp.Column) and not node.type)
650        ):
651            node.replace(node.neq(0))
652
653    for node in expression.walk():
654        ensure_bools(node, _ensure_bool)
655
656    return expression
657
658
659def unqualify_columns(expression: exp.Expression) -> exp.Expression:
660    for column in expression.find_all(exp.Column):
661        # We only wanna pop off the table, db, catalog args
662        for part in column.parts[:-1]:
663            part.pop()
664
665    return expression
666
667
668def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
669    assert isinstance(expression, exp.Create)
670    for constraint in expression.find_all(exp.UniqueColumnConstraint):
671        if constraint.parent:
672            constraint.parent.pop()
673
674    return expression
675
676
677def ctas_with_tmp_tables_to_create_tmp_view(
678    expression: exp.Expression,
679    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
680) -> exp.Expression:
681    assert isinstance(expression, exp.Create)
682    properties = expression.args.get("properties")
683    temporary = any(
684        isinstance(prop, exp.TemporaryProperty)
685        for prop in (properties.expressions if properties else [])
686    )
687
688    # CTAS with temp tables map to CREATE TEMPORARY VIEW
689    if expression.kind == "TABLE" and temporary:
690        if expression.expression:
691            return exp.Create(
692                kind="TEMPORARY VIEW",
693                this=expression.this,
694                expression=expression.expression,
695            )
696        return tmp_storage_provider(expression)
697
698    return expression
699
700
701def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
702    """
703    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
704    PARTITIONED BY value is an array of column names, they are transformed into a schema.
705    The corresponding columns are removed from the create statement.
706    """
707    assert isinstance(expression, exp.Create)
708    has_schema = isinstance(expression.this, exp.Schema)
709    is_partitionable = expression.kind in {"TABLE", "VIEW"}
710
711    if has_schema and is_partitionable:
712        prop = expression.find(exp.PartitionedByProperty)
713        if prop and prop.this and not isinstance(prop.this, exp.Schema):
714            schema = expression.this
715            columns = {v.name.upper() for v in prop.this.expressions}
716            partitions = [col for col in schema.expressions if col.name.upper() in columns]
717            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
718            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
719            expression.set("this", schema)
720
721    return expression
722
723
724def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
725    """
726    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
727
728    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
729    """
730    assert isinstance(expression, exp.Create)
731    prop = expression.find(exp.PartitionedByProperty)
732    if (
733        prop
734        and prop.this
735        and isinstance(prop.this, exp.Schema)
736        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
737    ):
738        prop_this = exp.Tuple(
739            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
740        )
741        schema = expression.this
742        for e in prop.this.expressions:
743            schema.append("expressions", e)
744        prop.set("this", prop_this)
745
746    return expression
747
748
749def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
750    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
751    if isinstance(expression, exp.Struct):
752        expression.set(
753            "expressions",
754            [
755                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
756                for e in expression.expressions
757            ],
758        )
759
760    return expression
761
762
763def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
764    """
765    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
766    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
767
768    For example,
769        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
770        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
771
772    Args:
773        expression: The AST to remove join marks from.
774
775    Returns:
776       The AST with join marks removed.
777    """
778    from sqlglot.optimizer.scope import traverse_scope
779
780    for scope in traverse_scope(expression):
781        query = scope.expression
782
783        where = query.args.get("where")
784        joins = query.args.get("joins")
785
786        if not where or not joins:
787            continue
788
789        query_from = query.args["from"]
790
791        # These keep track of the joins to be replaced
792        new_joins: t.Dict[str, exp.Join] = {}
793        old_joins = {join.alias_or_name: join for join in joins}
794
795        for column in scope.columns:
796            if not column.args.get("join_mark"):
797                continue
798
799            predicate = column.find_ancestor(exp.Predicate, exp.Select)
800            assert isinstance(
801                predicate, exp.Binary
802            ), "Columns can only be marked with (+) when involved in a binary operation"
803
804            predicate_parent = predicate.parent
805            join_predicate = predicate.pop()
806
807            left_columns = [
808                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
809            ]
810            right_columns = [
811                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
812            ]
813
814            assert not (
815                left_columns and right_columns
816            ), "The (+) marker cannot appear in both sides of a binary predicate"
817
818            marked_column_tables = set()
819            for col in left_columns or right_columns:
820                table = col.table
821                assert table, f"Column {col} needs to be qualified with a table"
822
823                col.set("join_mark", False)
824                marked_column_tables.add(table)
825
826            assert (
827                len(marked_column_tables) == 1
828            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
829
830            join_this = old_joins.get(col.table, query_from).this
831            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
832
833            # Upsert new_join into new_joins dictionary
834            new_join_alias_or_name = new_join.alias_or_name
835            existing_join = new_joins.get(new_join_alias_or_name)
836            if existing_join:
837                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
838            else:
839                new_joins[new_join_alias_or_name] = new_join
840
841            # If the parent of the target predicate is a binary node, then it now has only one child
842            if isinstance(predicate_parent, exp.Binary):
843                if predicate_parent.left is None:
844                    predicate_parent.replace(predicate_parent.right)
845                else:
846                    predicate_parent.replace(predicate_parent.left)
847
848        if query_from.alias_or_name in new_joins:
849            only_old_joins = old_joins.keys() - new_joins.keys()
850            assert (
851                len(only_old_joins) >= 1
852            ), "Cannot determine which table to use in the new FROM clause"
853
854            new_from_name = list(only_old_joins)[0]
855            query.set("from", exp.From(this=old_joins[new_from_name].this))
856
857        query.set("joins", list(new_joins.values()))
858
859        if not where.this:
860            where.pop()
861
862    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
13def preprocess(
14    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
15) -> t.Callable[[Generator, exp.Expression], str]:
16    """
17    Creates a new transform by chaining a sequence of transformations and converts the resulting
18    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
19    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
20
21    Args:
22        transforms: sequence of transform functions. These will be called in order.
23
24    Returns:
25        Function that can be used as a generator transform.
26    """
27
28    def _to_sql(self, expression: exp.Expression) -> str:
29        expression_type = type(expression)
30
31        expression = transforms[0](expression)
32        for transform in transforms[1:]:
33            expression = transform(expression)
34
35        _sql_handler = getattr(self, expression.key + "_sql", None)
36        if _sql_handler:
37            return _sql_handler(expression)
38
39        transforms_handler = self.TRANSFORMS.get(type(expression))
40        if transforms_handler:
41            if expression_type is type(expression):
42                if isinstance(expression, exp.Func):
43                    return self.function_fallback_sql(expression)
44
45                # Ensures we don't enter an infinite loop. This can happen when the original expression
46                # has the same type as the final expression and there's no _sql method available for it,
47                # because then it'd re-enter _to_sql.
48                raise ValueError(
49                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
50                )
51
52            return transforms_handler(self, expression)
53
54        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
55
56    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 59def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 60    if isinstance(expression, exp.Select):
 61        count = 0
 62        recursive_ctes = []
 63
 64        for unnest in expression.find_all(exp.Unnest):
 65            if (
 66                not isinstance(unnest.parent, (exp.From, exp.Join))
 67                or len(unnest.expressions) != 1
 68                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 69            ):
 70                continue
 71
 72            generate_date_array = unnest.expressions[0]
 73            start = generate_date_array.args.get("start")
 74            end = generate_date_array.args.get("end")
 75            step = generate_date_array.args.get("step")
 76
 77            if not start or not end or not isinstance(step, exp.Interval):
 78                continue
 79
 80            alias = unnest.args.get("alias")
 81            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 82
 83            start = exp.cast(start, "date")
 84            date_add = exp.func(
 85                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 86            )
 87            cast_date_add = exp.cast(date_add, "date")
 88
 89            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 90
 91            base_query = exp.select(start.as_(column_name))
 92            recursive_query = (
 93                exp.select(cast_date_add)
 94                .from_(cte_name)
 95                .where(cast_date_add <= exp.cast(end, "date"))
 96            )
 97            cte_query = base_query.union(recursive_query, distinct=False)
 98
 99            generate_dates_query = exp.select(column_name).from_(cte_name)
100            unnest.replace(generate_dates_query.subquery(cte_name))
101
102            recursive_ctes.append(
103                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
104            )
105            count += 1
106
107        if recursive_ctes:
108            with_expression = expression.args.get("with") or exp.With()
109            with_expression.set("recursive", True)
110            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
111            expression.set("with", with_expression)
112
113    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
116def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
117    """Unnests GENERATE_SERIES or SEQUENCE table references."""
118    this = expression.this
119    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
120        unnest = exp.Unnest(expressions=[this])
121        if expression.alias:
122            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
123
124        return unnest
125
126    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
129def unalias_group(expression: exp.Expression) -> exp.Expression:
130    """
131    Replace references to select aliases in GROUP BY clauses.
132
133    Example:
134        >>> import sqlglot
135        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
136        'SELECT a AS b FROM x GROUP BY 1'
137
138    Args:
139        expression: the expression that will be transformed.
140
141    Returns:
142        The transformed expression.
143    """
144    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
145        aliased_selects = {
146            e.alias: i
147            for i, e in enumerate(expression.parent.expressions, start=1)
148            if isinstance(e, exp.Alias)
149        }
150
151        for group_by in expression.expressions:
152            if (
153                isinstance(group_by, exp.Column)
154                and not group_by.table
155                and group_by.name in aliased_selects
156            ):
157                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
158
159    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
162def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
163    """
164    Convert SELECT DISTINCT ON statements to a subquery with a window function.
165
166    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
167
168    Args:
169        expression: the expression that will be transformed.
170
171    Returns:
172        The transformed expression.
173    """
174    if (
175        isinstance(expression, exp.Select)
176        and expression.args.get("distinct")
177        and expression.args["distinct"].args.get("on")
178        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
179    ):
180        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
181        outer_selects = expression.selects
182        row_number = find_new_name(expression.named_selects, "_row_number")
183        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
184        order = expression.args.get("order")
185
186        if order:
187            window.set("order", order.pop())
188        else:
189            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
190
191        window = exp.alias_(window, row_number)
192        expression.select(window, copy=False)
193
194        return (
195            exp.select(*outer_selects, copy=False)
196            .from_(expression.subquery("_t", copy=False), copy=False)
197            .where(exp.column(row_number).eq(1), copy=False)
198        )
199
200    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
203def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
204    """
205    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
206
207    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
208    https://docs.snowflake.com/en/sql-reference/constructs/qualify
209
210    Some dialects don't support window functions in the WHERE clause, so we need to include them as
211    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
212    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
213    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
214    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
215    corresponding expression to avoid creating invalid column references.
216    """
217    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
218        taken = set(expression.named_selects)
219        for select in expression.selects:
220            if not select.alias_or_name:
221                alias = find_new_name(taken, "_c")
222                select.replace(exp.alias_(select, alias))
223                taken.add(alias)
224
225        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
226            alias_or_name = select.alias_or_name
227            identifier = select.args.get("alias") or select.this
228            if isinstance(identifier, exp.Identifier):
229                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
230            return alias_or_name
231
232        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
233        qualify_filters = expression.args["qualify"].pop().this
234        expression_by_alias = {
235            select.alias: select.this
236            for select in expression.selects
237            if isinstance(select, exp.Alias)
238        }
239
240        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
241        for select_candidate in qualify_filters.find_all(select_candidates):
242            if isinstance(select_candidate, exp.Window):
243                if expression_by_alias:
244                    for column in select_candidate.find_all(exp.Column):
245                        expr = expression_by_alias.get(column.name)
246                        if expr:
247                            column.replace(expr)
248
249                alias = find_new_name(expression.named_selects, "_w")
250                expression.select(exp.alias_(select_candidate, alias), copy=False)
251                column = exp.column(alias)
252
253                if isinstance(select_candidate.parent, exp.Qualify):
254                    qualify_filters = column
255                else:
256                    select_candidate.replace(column)
257            elif select_candidate.name not in expression.named_selects:
258                expression.select(select_candidate.copy(), copy=False)
259
260        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
261            qualify_filters, copy=False
262        )
263
264    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
267def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
268    """
269    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
270    other expressions. This transforms removes the precision from parameterized types in expressions.
271    """
272    for node in expression.find_all(exp.DataType):
273        node.set(
274            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
275        )
276
277    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
280def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
281    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
282    from sqlglot.optimizer.scope import find_all_in_scope
283
284    if isinstance(expression, exp.Select):
285        unnest_aliases = {
286            unnest.alias
287            for unnest in find_all_in_scope(expression, exp.Unnest)
288            if isinstance(unnest.parent, (exp.From, exp.Join))
289        }
290        if unnest_aliases:
291            for column in expression.find_all(exp.Column):
292                if column.table in unnest_aliases:
293                    column.set("table", None)
294                elif column.db in unnest_aliases:
295                    column.set("db", None)
296
297    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
300def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
301    """Convert cross join unnest into lateral view explode."""
302    if isinstance(expression, exp.Select):
303        from_ = expression.args.get("from")
304
305        if from_ and isinstance(from_.this, exp.Unnest):
306            unnest = from_.this
307            alias = unnest.args.get("alias")
308            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
309            this, *expressions = unnest.expressions
310            unnest.replace(
311                exp.Table(
312                    this=udtf(
313                        this=this,
314                        expressions=expressions,
315                    ),
316                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
317                )
318            )
319
320        for join in expression.args.get("joins") or []:
321            join_expr = join.this
322
323            is_lateral = isinstance(join_expr, exp.Lateral)
324
325            unnest = join_expr.this if is_lateral else join_expr
326
327            if isinstance(unnest, exp.Unnest):
328                alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
329                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
330
331                expression.args["joins"].remove(join)
332
333                for e, column in zip(unnest.expressions, alias.columns if alias else []):
334                    expression.append(
335                        "laterals",
336                        exp.Lateral(
337                            this=udtf(this=e),
338                            view=True,
339                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
340                        ),
341                    )
342
343    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
346def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
347    """Convert explode/posexplode into unnest."""
348
349    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
350        if isinstance(expression, exp.Select):
351            from sqlglot.optimizer.scope import Scope
352
353            taken_select_names = set(expression.named_selects)
354            taken_source_names = {name for name, _ in Scope(expression).references}
355
356            def new_name(names: t.Set[str], name: str) -> str:
357                name = find_new_name(names, name)
358                names.add(name)
359                return name
360
361            arrays: t.List[exp.Condition] = []
362            series_alias = new_name(taken_select_names, "pos")
363            series = exp.alias_(
364                exp.Unnest(
365                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
366                ),
367                new_name(taken_source_names, "_u"),
368                table=[series_alias],
369            )
370
371            # we use list here because expression.selects is mutated inside the loop
372            for select in list(expression.selects):
373                explode = select.find(exp.Explode)
374
375                if explode:
376                    pos_alias = ""
377                    explode_alias = ""
378
379                    if isinstance(select, exp.Alias):
380                        explode_alias = select.args["alias"]
381                        alias = select
382                    elif isinstance(select, exp.Aliases):
383                        pos_alias = select.aliases[0]
384                        explode_alias = select.aliases[1]
385                        alias = select.replace(exp.alias_(select.this, "", copy=False))
386                    else:
387                        alias = select.replace(exp.alias_(select, ""))
388                        explode = alias.find(exp.Explode)
389                        assert explode
390
391                    is_posexplode = isinstance(explode, exp.Posexplode)
392                    explode_arg = explode.this
393
394                    if isinstance(explode, exp.ExplodeOuter):
395                        bracket = explode_arg[0]
396                        bracket.set("safe", True)
397                        bracket.set("offset", True)
398                        explode_arg = exp.func(
399                            "IF",
400                            exp.func(
401                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
402                            ).eq(0),
403                            exp.array(bracket, copy=False),
404                            explode_arg,
405                        )
406
407                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
408                    if isinstance(explode_arg, exp.Column):
409                        taken_select_names.add(explode_arg.output_name)
410
411                    unnest_source_alias = new_name(taken_source_names, "_u")
412
413                    if not explode_alias:
414                        explode_alias = new_name(taken_select_names, "col")
415
416                        if is_posexplode:
417                            pos_alias = new_name(taken_select_names, "pos")
418
419                    if not pos_alias:
420                        pos_alias = new_name(taken_select_names, "pos")
421
422                    alias.set("alias", exp.to_identifier(explode_alias))
423
424                    series_table_alias = series.args["alias"].this
425                    column = exp.If(
426                        this=exp.column(series_alias, table=series_table_alias).eq(
427                            exp.column(pos_alias, table=unnest_source_alias)
428                        ),
429                        true=exp.column(explode_alias, table=unnest_source_alias),
430                    )
431
432                    explode.replace(column)
433
434                    if is_posexplode:
435                        expressions = expression.expressions
436                        expressions.insert(
437                            expressions.index(alias) + 1,
438                            exp.If(
439                                this=exp.column(series_alias, table=series_table_alias).eq(
440                                    exp.column(pos_alias, table=unnest_source_alias)
441                                ),
442                                true=exp.column(pos_alias, table=unnest_source_alias),
443                            ).as_(pos_alias),
444                        )
445                        expression.set("expressions", expressions)
446
447                    if not arrays:
448                        if expression.args.get("from"):
449                            expression.join(series, copy=False, join_type="CROSS")
450                        else:
451                            expression.from_(series, copy=False)
452
453                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
454                    arrays.append(size)
455
456                    # trino doesn't support left join unnest with on conditions
457                    # if it did, this would be much simpler
458                    expression.join(
459                        exp.alias_(
460                            exp.Unnest(
461                                expressions=[explode_arg.copy()],
462                                offset=exp.to_identifier(pos_alias),
463                            ),
464                            unnest_source_alias,
465                            table=[explode_alias],
466                        ),
467                        join_type="CROSS",
468                        copy=False,
469                    )
470
471                    if index_offset != 1:
472                        size = size - 1
473
474                    expression.where(
475                        exp.column(series_alias, table=series_table_alias)
476                        .eq(exp.column(pos_alias, table=unnest_source_alias))
477                        .or_(
478                            (exp.column(series_alias, table=series_table_alias) > size).and_(
479                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
480                            )
481                        ),
482                        copy=False,
483                    )
484
485            if arrays:
486                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
487
488                if index_offset != 1:
489                    end = end - (1 - index_offset)
490                series.expressions[0].set("end", end)
491
492        return expression
493
494    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
497def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
498    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
499    if (
500        isinstance(expression, exp.PERCENTILES)
501        and not isinstance(expression.parent, exp.WithinGroup)
502        and expression.expression
503    ):
504        column = expression.this.pop()
505        expression.set("this", expression.expression.pop())
506        order = exp.Order(expressions=[exp.Ordered(this=column)])
507        expression = exp.WithinGroup(this=expression, expression=order)
508
509    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
512def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
513    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
514    if (
515        isinstance(expression, exp.WithinGroup)
516        and isinstance(expression.this, exp.PERCENTILES)
517        and isinstance(expression.expression, exp.Order)
518    ):
519        quantile = expression.this.this
520        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
521        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
522
523    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
526def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
527    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
528    if isinstance(expression, exp.With) and expression.recursive:
529        next_name = name_sequence("_c_")
530
531        for cte in expression.expressions:
532            if not cte.args["alias"].columns:
533                query = cte.this
534                if isinstance(query, exp.SetOperation):
535                    query = query.this
536
537                cte.args["alias"].set(
538                    "columns",
539                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
540                )
541
542    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
545def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
546    """Replace 'epoch' in casts by the equivalent date literal."""
547    if (
548        isinstance(expression, (exp.Cast, exp.TryCast))
549        and expression.name.lower() == "epoch"
550        and expression.to.this in exp.DataType.TEMPORAL_TYPES
551    ):
552        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
553
554    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
557def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
558    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
559    if isinstance(expression, exp.Select):
560        for join in expression.args.get("joins") or []:
561            on = join.args.get("on")
562            if on and join.kind in ("SEMI", "ANTI"):
563                subquery = exp.select("1").from_(join.this).where(on)
564                exists = exp.Exists(this=subquery)
565                if join.kind == "ANTI":
566                    exists = exists.not_(copy=False)
567
568                join.pop()
569                expression.where(exists, copy=False)
570
571    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
574def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
575    """
576    Converts a query with a FULL OUTER join to a union of identical queries that
577    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
578    for queries that have a single FULL OUTER join.
579    """
580    if isinstance(expression, exp.Select):
581        full_outer_joins = [
582            (index, join)
583            for index, join in enumerate(expression.args.get("joins") or [])
584            if join.side == "FULL"
585        ]
586
587        if len(full_outer_joins) == 1:
588            expression_copy = expression.copy()
589            expression.set("limit", None)
590            index, full_outer_join = full_outer_joins[0]
591            full_outer_join.set("side", "left")
592            expression_copy.args["joins"][index].set("side", "right")
593            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
594
595            return exp.union(expression, expression_copy, copy=False)
596
597    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
600def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
601    """
602    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
603    defined at the top-level, so for example queries like:
604
605        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
606
607    are invalid in those dialects. This transformation can be used to ensure all CTEs are
608    moved to the top level so that the final SQL code is valid from a syntax standpoint.
609
610    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
611    """
612    top_level_with = expression.args.get("with")
613    for inner_with in expression.find_all(exp.With):
614        if inner_with.parent is expression:
615            continue
616
617        if not top_level_with:
618            top_level_with = inner_with.pop()
619            expression.set("with", top_level_with)
620        else:
621            if inner_with.recursive:
622                top_level_with.set("recursive", True)
623
624            parent_cte = inner_with.find_ancestor(exp.CTE)
625            inner_with.pop()
626
627            if parent_cte:
628                i = top_level_with.expressions.index(parent_cte)
629                top_level_with.expressions[i:i] = inner_with.expressions
630                top_level_with.set("expressions", top_level_with.expressions)
631            else:
632                top_level_with.set(
633                    "expressions", top_level_with.expressions + inner_with.expressions
634                )
635
636    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
639def ensure_bools(expression: exp.Expression) -> exp.Expression:
640    """Converts numeric values used in conditions into explicit boolean expressions."""
641    from sqlglot.optimizer.canonicalize import ensure_bools
642
643    def _ensure_bool(node: exp.Expression) -> None:
644        if (
645            node.is_number
646            or (
647                not isinstance(node, exp.SubqueryPredicate)
648                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
649            )
650            or (isinstance(node, exp.Column) and not node.type)
651        ):
652            node.replace(node.neq(0))
653
654    for node in expression.walk():
655        ensure_bools(node, _ensure_bool)
656
657    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
660def unqualify_columns(expression: exp.Expression) -> exp.Expression:
661    for column in expression.find_all(exp.Column):
662        # We only wanna pop off the table, db, catalog args
663        for part in column.parts[:-1]:
664            part.pop()
665
666    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
669def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
670    assert isinstance(expression, exp.Create)
671    for constraint in expression.find_all(exp.UniqueColumnConstraint):
672        if constraint.parent:
673            constraint.parent.pop()
674
675    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
678def ctas_with_tmp_tables_to_create_tmp_view(
679    expression: exp.Expression,
680    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
681) -> exp.Expression:
682    assert isinstance(expression, exp.Create)
683    properties = expression.args.get("properties")
684    temporary = any(
685        isinstance(prop, exp.TemporaryProperty)
686        for prop in (properties.expressions if properties else [])
687    )
688
689    # CTAS with temp tables map to CREATE TEMPORARY VIEW
690    if expression.kind == "TABLE" and temporary:
691        if expression.expression:
692            return exp.Create(
693                kind="TEMPORARY VIEW",
694                this=expression.this,
695                expression=expression.expression,
696            )
697        return tmp_storage_provider(expression)
698
699    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
702def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
703    """
704    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
705    PARTITIONED BY value is an array of column names, they are transformed into a schema.
706    The corresponding columns are removed from the create statement.
707    """
708    assert isinstance(expression, exp.Create)
709    has_schema = isinstance(expression.this, exp.Schema)
710    is_partitionable = expression.kind in {"TABLE", "VIEW"}
711
712    if has_schema and is_partitionable:
713        prop = expression.find(exp.PartitionedByProperty)
714        if prop and prop.this and not isinstance(prop.this, exp.Schema):
715            schema = expression.this
716            columns = {v.name.upper() for v in prop.this.expressions}
717            partitions = [col for col in schema.expressions if col.name.upper() in columns]
718            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
719            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
720            expression.set("this", schema)
721
722    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
725def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
726    """
727    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
728
729    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
730    """
731    assert isinstance(expression, exp.Create)
732    prop = expression.find(exp.PartitionedByProperty)
733    if (
734        prop
735        and prop.this
736        and isinstance(prop.this, exp.Schema)
737        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
738    ):
739        prop_this = exp.Tuple(
740            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
741        )
742        schema = expression.this
743        for e in prop.this.expressions:
744            schema.append("expressions", e)
745        prop.set("this", prop_this)
746
747    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
750def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
751    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
752    if isinstance(expression, exp.Struct):
753        expression.set(
754            "expressions",
755            [
756                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
757                for e in expression.expressions
758            ],
759        )
760
761    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
764def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
765    """
766    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
767    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
768
769    For example,
770        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
771        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
772
773    Args:
774        expression: The AST to remove join marks from.
775
776    Returns:
777       The AST with join marks removed.
778    """
779    from sqlglot.optimizer.scope import traverse_scope
780
781    for scope in traverse_scope(expression):
782        query = scope.expression
783
784        where = query.args.get("where")
785        joins = query.args.get("joins")
786
787        if not where or not joins:
788            continue
789
790        query_from = query.args["from"]
791
792        # These keep track of the joins to be replaced
793        new_joins: t.Dict[str, exp.Join] = {}
794        old_joins = {join.alias_or_name: join for join in joins}
795
796        for column in scope.columns:
797            if not column.args.get("join_mark"):
798                continue
799
800            predicate = column.find_ancestor(exp.Predicate, exp.Select)
801            assert isinstance(
802                predicate, exp.Binary
803            ), "Columns can only be marked with (+) when involved in a binary operation"
804
805            predicate_parent = predicate.parent
806            join_predicate = predicate.pop()
807
808            left_columns = [
809                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
810            ]
811            right_columns = [
812                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
813            ]
814
815            assert not (
816                left_columns and right_columns
817            ), "The (+) marker cannot appear in both sides of a binary predicate"
818
819            marked_column_tables = set()
820            for col in left_columns or right_columns:
821                table = col.table
822                assert table, f"Column {col} needs to be qualified with a table"
823
824                col.set("join_mark", False)
825                marked_column_tables.add(table)
826
827            assert (
828                len(marked_column_tables) == 1
829            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
830
831            join_this = old_joins.get(col.table, query_from).this
832            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
833
834            # Upsert new_join into new_joins dictionary
835            new_join_alias_or_name = new_join.alias_or_name
836            existing_join = new_joins.get(new_join_alias_or_name)
837            if existing_join:
838                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
839            else:
840                new_joins[new_join_alias_or_name] = new_join
841
842            # If the parent of the target predicate is a binary node, then it now has only one child
843            if isinstance(predicate_parent, exp.Binary):
844                if predicate_parent.left is None:
845                    predicate_parent.replace(predicate_parent.right)
846                else:
847                    predicate_parent.replace(predicate_parent.left)
848
849        if query_from.alias_or_name in new_joins:
850            only_old_joins = old_joins.keys() - new_joins.keys()
851            assert (
852                len(only_old_joins) >= 1
853            ), "Cannot determine which table to use in the new FROM clause"
854
855            new_from_name = list(only_old_joins)[0]
856            query.set("from", exp.From(this=old_joins[new_from_name].this))
857
858        query.set("joins", list(new_joins.values()))
859
860        if not where.this:
861            where.pop()
862
863    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.