Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def preprocess(
 13    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 14) -> t.Callable[[Generator, exp.Expression], str]:
 15    """
 16    Creates a new transform by chaining a sequence of transformations and converts the resulting
 17    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 18    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 19
 20    Args:
 21        transforms: sequence of transform functions. These will be called in order.
 22
 23    Returns:
 24        Function that can be used as a generator transform.
 25    """
 26
 27    def _to_sql(self, expression: exp.Expression) -> str:
 28        expression_type = type(expression)
 29
 30        expression = transforms[0](expression)
 31        for transform in transforms[1:]:
 32            expression = transform(expression)
 33
 34        _sql_handler = getattr(self, expression.key + "_sql", None)
 35        if _sql_handler:
 36            return _sql_handler(expression)
 37
 38        transforms_handler = self.TRANSFORMS.get(type(expression))
 39        if transforms_handler:
 40            if expression_type is type(expression):
 41                if isinstance(expression, exp.Func):
 42                    return self.function_fallback_sql(expression)
 43
 44                # Ensures we don't enter an infinite loop. This can happen when the original expression
 45                # has the same type as the final expression and there's no _sql method available for it,
 46                # because then it'd re-enter _to_sql.
 47                raise ValueError(
 48                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 49                )
 50
 51            return transforms_handler(self, expression)
 52
 53        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 54
 55    return _to_sql
 56
 57
 58def unalias_group(expression: exp.Expression) -> exp.Expression:
 59    """
 60    Replace references to select aliases in GROUP BY clauses.
 61
 62    Example:
 63        >>> import sqlglot
 64        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 65        'SELECT a AS b FROM x GROUP BY 1'
 66
 67    Args:
 68        expression: the expression that will be transformed.
 69
 70    Returns:
 71        The transformed expression.
 72    """
 73    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 74        aliased_selects = {
 75            e.alias: i
 76            for i, e in enumerate(expression.parent.expressions, start=1)
 77            if isinstance(e, exp.Alias)
 78        }
 79
 80        for group_by in expression.expressions:
 81            if (
 82                isinstance(group_by, exp.Column)
 83                and not group_by.table
 84                and group_by.name in aliased_selects
 85            ):
 86                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 87
 88    return expression
 89
 90
 91def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 92    """
 93    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 94
 95    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 96
 97    Args:
 98        expression: the expression that will be transformed.
 99
100    Returns:
101        The transformed expression.
102    """
103    if (
104        isinstance(expression, exp.Select)
105        and expression.args.get("distinct")
106        and expression.args["distinct"].args.get("on")
107        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
108    ):
109        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
110        outer_selects = expression.selects
111        row_number = find_new_name(expression.named_selects, "_row_number")
112        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
113        order = expression.args.get("order")
114
115        if order:
116            window.set("order", order.pop())
117        else:
118            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
119
120        window = exp.alias_(window, row_number)
121        expression.select(window, copy=False)
122
123        return (
124            exp.select(*outer_selects, copy=False)
125            .from_(expression.subquery("_t", copy=False), copy=False)
126            .where(exp.column(row_number).eq(1), copy=False)
127        )
128
129    return expression
130
131
132def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
133    """
134    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
135
136    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
137    https://docs.snowflake.com/en/sql-reference/constructs/qualify
138
139    Some dialects don't support window functions in the WHERE clause, so we need to include them as
140    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
141    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
142    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
143    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
144    corresponding expression to avoid creating invalid column references.
145    """
146    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
147        taken = set(expression.named_selects)
148        for select in expression.selects:
149            if not select.alias_or_name:
150                alias = find_new_name(taken, "_c")
151                select.replace(exp.alias_(select, alias))
152                taken.add(alias)
153
154        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
155            alias_or_name = select.alias_or_name
156            identifier = select.args.get("alias") or select.this
157            if isinstance(identifier, exp.Identifier):
158                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
159            return alias_or_name
160
161        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
162        qualify_filters = expression.args["qualify"].pop().this
163        expression_by_alias = {
164            select.alias: select.this
165            for select in expression.selects
166            if isinstance(select, exp.Alias)
167        }
168
169        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
170        for select_candidate in qualify_filters.find_all(select_candidates):
171            if isinstance(select_candidate, exp.Window):
172                if expression_by_alias:
173                    for column in select_candidate.find_all(exp.Column):
174                        expr = expression_by_alias.get(column.name)
175                        if expr:
176                            column.replace(expr)
177
178                alias = find_new_name(expression.named_selects, "_w")
179                expression.select(exp.alias_(select_candidate, alias), copy=False)
180                column = exp.column(alias)
181
182                if isinstance(select_candidate.parent, exp.Qualify):
183                    qualify_filters = column
184                else:
185                    select_candidate.replace(column)
186            elif select_candidate.name not in expression.named_selects:
187                expression.select(select_candidate.copy(), copy=False)
188
189        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
190            qualify_filters, copy=False
191        )
192
193    return expression
194
195
196def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
197    """
198    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
199    other expressions. This transforms removes the precision from parameterized types in expressions.
200    """
201    for node in expression.find_all(exp.DataType):
202        node.set(
203            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
204        )
205
206    return expression
207
208
209def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
210    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
211    from sqlglot.optimizer.scope import find_all_in_scope
212
213    if isinstance(expression, exp.Select):
214        unnest_aliases = {
215            unnest.alias
216            for unnest in find_all_in_scope(expression, exp.Unnest)
217            if isinstance(unnest.parent, (exp.From, exp.Join))
218        }
219        if unnest_aliases:
220            for column in expression.find_all(exp.Column):
221                if column.table in unnest_aliases:
222                    column.set("table", None)
223                elif column.db in unnest_aliases:
224                    column.set("db", None)
225
226    return expression
227
228
229def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
230    """Convert cross join unnest into lateral view explode."""
231    if isinstance(expression, exp.Select):
232        for join in expression.args.get("joins") or []:
233            unnest = join.this
234
235            if isinstance(unnest, exp.Unnest):
236                alias = unnest.args.get("alias")
237                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
238
239                expression.args["joins"].remove(join)
240
241                for e, column in zip(unnest.expressions, alias.columns if alias else []):
242                    expression.append(
243                        "laterals",
244                        exp.Lateral(
245                            this=udtf(this=e),
246                            view=True,
247                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
248                        ),
249                    )
250
251    return expression
252
253
254def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
255    """Convert explode/posexplode into unnest."""
256
257    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
258        if isinstance(expression, exp.Select):
259            from sqlglot.optimizer.scope import Scope
260
261            taken_select_names = set(expression.named_selects)
262            taken_source_names = {name for name, _ in Scope(expression).references}
263
264            def new_name(names: t.Set[str], name: str) -> str:
265                name = find_new_name(names, name)
266                names.add(name)
267                return name
268
269            arrays: t.List[exp.Condition] = []
270            series_alias = new_name(taken_select_names, "pos")
271            series = exp.alias_(
272                exp.Unnest(
273                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
274                ),
275                new_name(taken_source_names, "_u"),
276                table=[series_alias],
277            )
278
279            # we use list here because expression.selects is mutated inside the loop
280            for select in list(expression.selects):
281                explode = select.find(exp.Explode)
282
283                if explode:
284                    pos_alias = ""
285                    explode_alias = ""
286
287                    if isinstance(select, exp.Alias):
288                        explode_alias = select.args["alias"]
289                        alias = select
290                    elif isinstance(select, exp.Aliases):
291                        pos_alias = select.aliases[0]
292                        explode_alias = select.aliases[1]
293                        alias = select.replace(exp.alias_(select.this, "", copy=False))
294                    else:
295                        alias = select.replace(exp.alias_(select, ""))
296                        explode = alias.find(exp.Explode)
297                        assert explode
298
299                    is_posexplode = isinstance(explode, exp.Posexplode)
300                    explode_arg = explode.this
301
302                    if isinstance(explode, exp.ExplodeOuter):
303                        bracket = explode_arg[0]
304                        bracket.set("safe", True)
305                        bracket.set("offset", True)
306                        explode_arg = exp.func(
307                            "IF",
308                            exp.func(
309                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
310                            ).eq(0),
311                            exp.array(bracket, copy=False),
312                            explode_arg,
313                        )
314
315                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
316                    if isinstance(explode_arg, exp.Column):
317                        taken_select_names.add(explode_arg.output_name)
318
319                    unnest_source_alias = new_name(taken_source_names, "_u")
320
321                    if not explode_alias:
322                        explode_alias = new_name(taken_select_names, "col")
323
324                        if is_posexplode:
325                            pos_alias = new_name(taken_select_names, "pos")
326
327                    if not pos_alias:
328                        pos_alias = new_name(taken_select_names, "pos")
329
330                    alias.set("alias", exp.to_identifier(explode_alias))
331
332                    series_table_alias = series.args["alias"].this
333                    column = exp.If(
334                        this=exp.column(series_alias, table=series_table_alias).eq(
335                            exp.column(pos_alias, table=unnest_source_alias)
336                        ),
337                        true=exp.column(explode_alias, table=unnest_source_alias),
338                    )
339
340                    explode.replace(column)
341
342                    if is_posexplode:
343                        expressions = expression.expressions
344                        expressions.insert(
345                            expressions.index(alias) + 1,
346                            exp.If(
347                                this=exp.column(series_alias, table=series_table_alias).eq(
348                                    exp.column(pos_alias, table=unnest_source_alias)
349                                ),
350                                true=exp.column(pos_alias, table=unnest_source_alias),
351                            ).as_(pos_alias),
352                        )
353                        expression.set("expressions", expressions)
354
355                    if not arrays:
356                        if expression.args.get("from"):
357                            expression.join(series, copy=False, join_type="CROSS")
358                        else:
359                            expression.from_(series, copy=False)
360
361                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
362                    arrays.append(size)
363
364                    # trino doesn't support left join unnest with on conditions
365                    # if it did, this would be much simpler
366                    expression.join(
367                        exp.alias_(
368                            exp.Unnest(
369                                expressions=[explode_arg.copy()],
370                                offset=exp.to_identifier(pos_alias),
371                            ),
372                            unnest_source_alias,
373                            table=[explode_alias],
374                        ),
375                        join_type="CROSS",
376                        copy=False,
377                    )
378
379                    if index_offset != 1:
380                        size = size - 1
381
382                    expression.where(
383                        exp.column(series_alias, table=series_table_alias)
384                        .eq(exp.column(pos_alias, table=unnest_source_alias))
385                        .or_(
386                            (exp.column(series_alias, table=series_table_alias) > size).and_(
387                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
388                            )
389                        ),
390                        copy=False,
391                    )
392
393            if arrays:
394                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
395
396                if index_offset != 1:
397                    end = end - (1 - index_offset)
398                series.expressions[0].set("end", end)
399
400        return expression
401
402    return _explode_to_unnest
403
404
405def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
406    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
407    if (
408        isinstance(expression, exp.PERCENTILES)
409        and not isinstance(expression.parent, exp.WithinGroup)
410        and expression.expression
411    ):
412        column = expression.this.pop()
413        expression.set("this", expression.expression.pop())
414        order = exp.Order(expressions=[exp.Ordered(this=column)])
415        expression = exp.WithinGroup(this=expression, expression=order)
416
417    return expression
418
419
420def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
421    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
422    if (
423        isinstance(expression, exp.WithinGroup)
424        and isinstance(expression.this, exp.PERCENTILES)
425        and isinstance(expression.expression, exp.Order)
426    ):
427        quantile = expression.this.this
428        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
429        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
430
431    return expression
432
433
434def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
435    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
436    if isinstance(expression, exp.With) and expression.recursive:
437        next_name = name_sequence("_c_")
438
439        for cte in expression.expressions:
440            if not cte.args["alias"].columns:
441                query = cte.this
442                if isinstance(query, exp.SetOperation):
443                    query = query.this
444
445                cte.args["alias"].set(
446                    "columns",
447                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
448                )
449
450    return expression
451
452
453def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
454    """Replace 'epoch' in casts by the equivalent date literal."""
455    if (
456        isinstance(expression, (exp.Cast, exp.TryCast))
457        and expression.name.lower() == "epoch"
458        and expression.to.this in exp.DataType.TEMPORAL_TYPES
459    ):
460        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
461
462    return expression
463
464
465def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
466    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
467    if isinstance(expression, exp.Select):
468        for join in expression.args.get("joins") or []:
469            on = join.args.get("on")
470            if on and join.kind in ("SEMI", "ANTI"):
471                subquery = exp.select("1").from_(join.this).where(on)
472                exists = exp.Exists(this=subquery)
473                if join.kind == "ANTI":
474                    exists = exists.not_(copy=False)
475
476                join.pop()
477                expression.where(exists, copy=False)
478
479    return expression
480
481
482def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
483    """
484    Converts a query with a FULL OUTER join to a union of identical queries that
485    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
486    for queries that have a single FULL OUTER join.
487    """
488    if isinstance(expression, exp.Select):
489        full_outer_joins = [
490            (index, join)
491            for index, join in enumerate(expression.args.get("joins") or [])
492            if join.side == "FULL"
493        ]
494
495        if len(full_outer_joins) == 1:
496            expression_copy = expression.copy()
497            expression.set("limit", None)
498            index, full_outer_join = full_outer_joins[0]
499            full_outer_join.set("side", "left")
500            expression_copy.args["joins"][index].set("side", "right")
501            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
502
503            return exp.union(expression, expression_copy, copy=False)
504
505    return expression
506
507
508def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
509    """
510    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
511    defined at the top-level, so for example queries like:
512
513        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
514
515    are invalid in those dialects. This transformation can be used to ensure all CTEs are
516    moved to the top level so that the final SQL code is valid from a syntax standpoint.
517
518    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
519    """
520    top_level_with = expression.args.get("with")
521    for inner_with in expression.find_all(exp.With):
522        if inner_with.parent is expression:
523            continue
524
525        if not top_level_with:
526            top_level_with = inner_with.pop()
527            expression.set("with", top_level_with)
528        else:
529            if inner_with.recursive:
530                top_level_with.set("recursive", True)
531
532            parent_cte = inner_with.find_ancestor(exp.CTE)
533            inner_with.pop()
534
535            if parent_cte:
536                i = top_level_with.expressions.index(parent_cte)
537                top_level_with.expressions[i:i] = inner_with.expressions
538                top_level_with.set("expressions", top_level_with.expressions)
539            else:
540                top_level_with.set(
541                    "expressions", top_level_with.expressions + inner_with.expressions
542                )
543
544    return expression
545
546
547def ensure_bools(expression: exp.Expression) -> exp.Expression:
548    """Converts numeric values used in conditions into explicit boolean expressions."""
549    from sqlglot.optimizer.canonicalize import ensure_bools
550
551    def _ensure_bool(node: exp.Expression) -> None:
552        if (
553            node.is_number
554            or (
555                not isinstance(node, exp.SubqueryPredicate)
556                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
557            )
558            or (isinstance(node, exp.Column) and not node.type)
559        ):
560            node.replace(node.neq(0))
561
562    for node in expression.walk():
563        ensure_bools(node, _ensure_bool)
564
565    return expression
566
567
568def unqualify_columns(expression: exp.Expression) -> exp.Expression:
569    for column in expression.find_all(exp.Column):
570        # We only wanna pop off the table, db, catalog args
571        for part in column.parts[:-1]:
572            part.pop()
573
574    return expression
575
576
577def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
578    assert isinstance(expression, exp.Create)
579    for constraint in expression.find_all(exp.UniqueColumnConstraint):
580        if constraint.parent:
581            constraint.parent.pop()
582
583    return expression
584
585
586def ctas_with_tmp_tables_to_create_tmp_view(
587    expression: exp.Expression,
588    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
589) -> exp.Expression:
590    assert isinstance(expression, exp.Create)
591    properties = expression.args.get("properties")
592    temporary = any(
593        isinstance(prop, exp.TemporaryProperty)
594        for prop in (properties.expressions if properties else [])
595    )
596
597    # CTAS with temp tables map to CREATE TEMPORARY VIEW
598    if expression.kind == "TABLE" and temporary:
599        if expression.expression:
600            return exp.Create(
601                kind="TEMPORARY VIEW",
602                this=expression.this,
603                expression=expression.expression,
604            )
605        return tmp_storage_provider(expression)
606
607    return expression
608
609
610def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
611    """
612    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
613    PARTITIONED BY value is an array of column names, they are transformed into a schema.
614    The corresponding columns are removed from the create statement.
615    """
616    assert isinstance(expression, exp.Create)
617    has_schema = isinstance(expression.this, exp.Schema)
618    is_partitionable = expression.kind in {"TABLE", "VIEW"}
619
620    if has_schema and is_partitionable:
621        prop = expression.find(exp.PartitionedByProperty)
622        if prop and prop.this and not isinstance(prop.this, exp.Schema):
623            schema = expression.this
624            columns = {v.name.upper() for v in prop.this.expressions}
625            partitions = [col for col in schema.expressions if col.name.upper() in columns]
626            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
627            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
628            expression.set("this", schema)
629
630    return expression
631
632
633def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
634    """
635    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
636
637    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
638    """
639    assert isinstance(expression, exp.Create)
640    prop = expression.find(exp.PartitionedByProperty)
641    if (
642        prop
643        and prop.this
644        and isinstance(prop.this, exp.Schema)
645        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
646    ):
647        prop_this = exp.Tuple(
648            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
649        )
650        schema = expression.this
651        for e in prop.this.expressions:
652            schema.append("expressions", e)
653        prop.set("this", prop_this)
654
655    return expression
656
657
658def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
659    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
660    if isinstance(expression, exp.Struct):
661        expression.set(
662            "expressions",
663            [
664                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
665                for e in expression.expressions
666            ],
667        )
668
669    return expression
670
671
672def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
673    """
674    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
675    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
676
677    For example,
678        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
679        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
680
681    Args:
682        expression: The AST to remove join marks from.
683
684    Returns:
685       The AST with join marks removed.
686    """
687    from sqlglot.optimizer.scope import traverse_scope
688
689    for scope in traverse_scope(expression):
690        query = scope.expression
691
692        where = query.args.get("where")
693        joins = query.args.get("joins")
694
695        if not where or not joins:
696            continue
697
698        query_from = query.args["from"]
699
700        # These keep track of the joins to be replaced
701        new_joins: t.Dict[str, exp.Join] = {}
702        old_joins = {join.alias_or_name: join for join in joins}
703
704        for column in scope.columns:
705            if not column.args.get("join_mark"):
706                continue
707
708            predicate = column.find_ancestor(exp.Predicate, exp.Select)
709            assert isinstance(
710                predicate, exp.Binary
711            ), "Columns can only be marked with (+) when involved in a binary operation"
712
713            predicate_parent = predicate.parent
714            join_predicate = predicate.pop()
715
716            left_columns = [
717                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
718            ]
719            right_columns = [
720                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
721            ]
722
723            assert not (
724                left_columns and right_columns
725            ), "The (+) marker cannot appear in both sides of a binary predicate"
726
727            marked_column_tables = set()
728            for col in left_columns or right_columns:
729                table = col.table
730                assert table, f"Column {col} needs to be qualified with a table"
731
732                col.set("join_mark", False)
733                marked_column_tables.add(table)
734
735            assert (
736                len(marked_column_tables) == 1
737            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
738
739            join_this = old_joins.get(col.table, query_from).this
740            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
741
742            # Upsert new_join into new_joins dictionary
743            new_join_alias_or_name = new_join.alias_or_name
744            existing_join = new_joins.get(new_join_alias_or_name)
745            if existing_join:
746                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
747            else:
748                new_joins[new_join_alias_or_name] = new_join
749
750            # If the parent of the target predicate is a binary node, then it now has only one child
751            if isinstance(predicate_parent, exp.Binary):
752                if predicate_parent.left is None:
753                    predicate_parent.replace(predicate_parent.right)
754                else:
755                    predicate_parent.replace(predicate_parent.left)
756
757        if query_from.alias_or_name in new_joins:
758            only_old_joins = old_joins.keys() - new_joins.keys()
759            assert (
760                len(only_old_joins) >= 1
761            ), "Cannot determine which table to use in the new FROM clause"
762
763            new_from_name = list(only_old_joins)[0]
764            query.set("from", exp.From(this=old_joins[new_from_name].this))
765
766        query.set("joins", list(new_joins.values()))
767
768        if not where.this:
769            where.pop()
770
771    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
13def preprocess(
14    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
15) -> t.Callable[[Generator, exp.Expression], str]:
16    """
17    Creates a new transform by chaining a sequence of transformations and converts the resulting
18    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
19    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
20
21    Args:
22        transforms: sequence of transform functions. These will be called in order.
23
24    Returns:
25        Function that can be used as a generator transform.
26    """
27
28    def _to_sql(self, expression: exp.Expression) -> str:
29        expression_type = type(expression)
30
31        expression = transforms[0](expression)
32        for transform in transforms[1:]:
33            expression = transform(expression)
34
35        _sql_handler = getattr(self, expression.key + "_sql", None)
36        if _sql_handler:
37            return _sql_handler(expression)
38
39        transforms_handler = self.TRANSFORMS.get(type(expression))
40        if transforms_handler:
41            if expression_type is type(expression):
42                if isinstance(expression, exp.Func):
43                    return self.function_fallback_sql(expression)
44
45                # Ensures we don't enter an infinite loop. This can happen when the original expression
46                # has the same type as the final expression and there's no _sql method available for it,
47                # because then it'd re-enter _to_sql.
48                raise ValueError(
49                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
50                )
51
52            return transforms_handler(self, expression)
53
54        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
55
56    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
59def unalias_group(expression: exp.Expression) -> exp.Expression:
60    """
61    Replace references to select aliases in GROUP BY clauses.
62
63    Example:
64        >>> import sqlglot
65        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
66        'SELECT a AS b FROM x GROUP BY 1'
67
68    Args:
69        expression: the expression that will be transformed.
70
71    Returns:
72        The transformed expression.
73    """
74    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
75        aliased_selects = {
76            e.alias: i
77            for i, e in enumerate(expression.parent.expressions, start=1)
78            if isinstance(e, exp.Alias)
79        }
80
81        for group_by in expression.expressions:
82            if (
83                isinstance(group_by, exp.Column)
84                and not group_by.table
85                and group_by.name in aliased_selects
86            ):
87                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
88
89    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 92def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 93    """
 94    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 95
 96    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 97
 98    Args:
 99        expression: the expression that will be transformed.
100
101    Returns:
102        The transformed expression.
103    """
104    if (
105        isinstance(expression, exp.Select)
106        and expression.args.get("distinct")
107        and expression.args["distinct"].args.get("on")
108        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
109    ):
110        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
111        outer_selects = expression.selects
112        row_number = find_new_name(expression.named_selects, "_row_number")
113        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
114        order = expression.args.get("order")
115
116        if order:
117            window.set("order", order.pop())
118        else:
119            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
120
121        window = exp.alias_(window, row_number)
122        expression.select(window, copy=False)
123
124        return (
125            exp.select(*outer_selects, copy=False)
126            .from_(expression.subquery("_t", copy=False), copy=False)
127            .where(exp.column(row_number).eq(1), copy=False)
128        )
129
130    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
133def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
134    """
135    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
136
137    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
138    https://docs.snowflake.com/en/sql-reference/constructs/qualify
139
140    Some dialects don't support window functions in the WHERE clause, so we need to include them as
141    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
142    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
143    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
144    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
145    corresponding expression to avoid creating invalid column references.
146    """
147    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
148        taken = set(expression.named_selects)
149        for select in expression.selects:
150            if not select.alias_or_name:
151                alias = find_new_name(taken, "_c")
152                select.replace(exp.alias_(select, alias))
153                taken.add(alias)
154
155        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
156            alias_or_name = select.alias_or_name
157            identifier = select.args.get("alias") or select.this
158            if isinstance(identifier, exp.Identifier):
159                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
160            return alias_or_name
161
162        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
163        qualify_filters = expression.args["qualify"].pop().this
164        expression_by_alias = {
165            select.alias: select.this
166            for select in expression.selects
167            if isinstance(select, exp.Alias)
168        }
169
170        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
171        for select_candidate in qualify_filters.find_all(select_candidates):
172            if isinstance(select_candidate, exp.Window):
173                if expression_by_alias:
174                    for column in select_candidate.find_all(exp.Column):
175                        expr = expression_by_alias.get(column.name)
176                        if expr:
177                            column.replace(expr)
178
179                alias = find_new_name(expression.named_selects, "_w")
180                expression.select(exp.alias_(select_candidate, alias), copy=False)
181                column = exp.column(alias)
182
183                if isinstance(select_candidate.parent, exp.Qualify):
184                    qualify_filters = column
185                else:
186                    select_candidate.replace(column)
187            elif select_candidate.name not in expression.named_selects:
188                expression.select(select_candidate.copy(), copy=False)
189
190        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
191            qualify_filters, copy=False
192        )
193
194    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
197def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
198    """
199    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
200    other expressions. This transforms removes the precision from parameterized types in expressions.
201    """
202    for node in expression.find_all(exp.DataType):
203        node.set(
204            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
205        )
206
207    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
210def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
211    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
212    from sqlglot.optimizer.scope import find_all_in_scope
213
214    if isinstance(expression, exp.Select):
215        unnest_aliases = {
216            unnest.alias
217            for unnest in find_all_in_scope(expression, exp.Unnest)
218            if isinstance(unnest.parent, (exp.From, exp.Join))
219        }
220        if unnest_aliases:
221            for column in expression.find_all(exp.Column):
222                if column.table in unnest_aliases:
223                    column.set("table", None)
224                elif column.db in unnest_aliases:
225                    column.set("db", None)
226
227    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
230def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
231    """Convert cross join unnest into lateral view explode."""
232    if isinstance(expression, exp.Select):
233        for join in expression.args.get("joins") or []:
234            unnest = join.this
235
236            if isinstance(unnest, exp.Unnest):
237                alias = unnest.args.get("alias")
238                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
239
240                expression.args["joins"].remove(join)
241
242                for e, column in zip(unnest.expressions, alias.columns if alias else []):
243                    expression.append(
244                        "laterals",
245                        exp.Lateral(
246                            this=udtf(this=e),
247                            view=True,
248                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
249                        ),
250                    )
251
252    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
255def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
256    """Convert explode/posexplode into unnest."""
257
258    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
259        if isinstance(expression, exp.Select):
260            from sqlglot.optimizer.scope import Scope
261
262            taken_select_names = set(expression.named_selects)
263            taken_source_names = {name for name, _ in Scope(expression).references}
264
265            def new_name(names: t.Set[str], name: str) -> str:
266                name = find_new_name(names, name)
267                names.add(name)
268                return name
269
270            arrays: t.List[exp.Condition] = []
271            series_alias = new_name(taken_select_names, "pos")
272            series = exp.alias_(
273                exp.Unnest(
274                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
275                ),
276                new_name(taken_source_names, "_u"),
277                table=[series_alias],
278            )
279
280            # we use list here because expression.selects is mutated inside the loop
281            for select in list(expression.selects):
282                explode = select.find(exp.Explode)
283
284                if explode:
285                    pos_alias = ""
286                    explode_alias = ""
287
288                    if isinstance(select, exp.Alias):
289                        explode_alias = select.args["alias"]
290                        alias = select
291                    elif isinstance(select, exp.Aliases):
292                        pos_alias = select.aliases[0]
293                        explode_alias = select.aliases[1]
294                        alias = select.replace(exp.alias_(select.this, "", copy=False))
295                    else:
296                        alias = select.replace(exp.alias_(select, ""))
297                        explode = alias.find(exp.Explode)
298                        assert explode
299
300                    is_posexplode = isinstance(explode, exp.Posexplode)
301                    explode_arg = explode.this
302
303                    if isinstance(explode, exp.ExplodeOuter):
304                        bracket = explode_arg[0]
305                        bracket.set("safe", True)
306                        bracket.set("offset", True)
307                        explode_arg = exp.func(
308                            "IF",
309                            exp.func(
310                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
311                            ).eq(0),
312                            exp.array(bracket, copy=False),
313                            explode_arg,
314                        )
315
316                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
317                    if isinstance(explode_arg, exp.Column):
318                        taken_select_names.add(explode_arg.output_name)
319
320                    unnest_source_alias = new_name(taken_source_names, "_u")
321
322                    if not explode_alias:
323                        explode_alias = new_name(taken_select_names, "col")
324
325                        if is_posexplode:
326                            pos_alias = new_name(taken_select_names, "pos")
327
328                    if not pos_alias:
329                        pos_alias = new_name(taken_select_names, "pos")
330
331                    alias.set("alias", exp.to_identifier(explode_alias))
332
333                    series_table_alias = series.args["alias"].this
334                    column = exp.If(
335                        this=exp.column(series_alias, table=series_table_alias).eq(
336                            exp.column(pos_alias, table=unnest_source_alias)
337                        ),
338                        true=exp.column(explode_alias, table=unnest_source_alias),
339                    )
340
341                    explode.replace(column)
342
343                    if is_posexplode:
344                        expressions = expression.expressions
345                        expressions.insert(
346                            expressions.index(alias) + 1,
347                            exp.If(
348                                this=exp.column(series_alias, table=series_table_alias).eq(
349                                    exp.column(pos_alias, table=unnest_source_alias)
350                                ),
351                                true=exp.column(pos_alias, table=unnest_source_alias),
352                            ).as_(pos_alias),
353                        )
354                        expression.set("expressions", expressions)
355
356                    if not arrays:
357                        if expression.args.get("from"):
358                            expression.join(series, copy=False, join_type="CROSS")
359                        else:
360                            expression.from_(series, copy=False)
361
362                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
363                    arrays.append(size)
364
365                    # trino doesn't support left join unnest with on conditions
366                    # if it did, this would be much simpler
367                    expression.join(
368                        exp.alias_(
369                            exp.Unnest(
370                                expressions=[explode_arg.copy()],
371                                offset=exp.to_identifier(pos_alias),
372                            ),
373                            unnest_source_alias,
374                            table=[explode_alias],
375                        ),
376                        join_type="CROSS",
377                        copy=False,
378                    )
379
380                    if index_offset != 1:
381                        size = size - 1
382
383                    expression.where(
384                        exp.column(series_alias, table=series_table_alias)
385                        .eq(exp.column(pos_alias, table=unnest_source_alias))
386                        .or_(
387                            (exp.column(series_alias, table=series_table_alias) > size).and_(
388                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
389                            )
390                        ),
391                        copy=False,
392                    )
393
394            if arrays:
395                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
396
397                if index_offset != 1:
398                    end = end - (1 - index_offset)
399                series.expressions[0].set("end", end)
400
401        return expression
402
403    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
406def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
407    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
408    if (
409        isinstance(expression, exp.PERCENTILES)
410        and not isinstance(expression.parent, exp.WithinGroup)
411        and expression.expression
412    ):
413        column = expression.this.pop()
414        expression.set("this", expression.expression.pop())
415        order = exp.Order(expressions=[exp.Ordered(this=column)])
416        expression = exp.WithinGroup(this=expression, expression=order)
417
418    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
421def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
422    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
423    if (
424        isinstance(expression, exp.WithinGroup)
425        and isinstance(expression.this, exp.PERCENTILES)
426        and isinstance(expression.expression, exp.Order)
427    ):
428        quantile = expression.this.this
429        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
430        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
431
432    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
435def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
436    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
437    if isinstance(expression, exp.With) and expression.recursive:
438        next_name = name_sequence("_c_")
439
440        for cte in expression.expressions:
441            if not cte.args["alias"].columns:
442                query = cte.this
443                if isinstance(query, exp.SetOperation):
444                    query = query.this
445
446                cte.args["alias"].set(
447                    "columns",
448                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
449                )
450
451    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
454def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
455    """Replace 'epoch' in casts by the equivalent date literal."""
456    if (
457        isinstance(expression, (exp.Cast, exp.TryCast))
458        and expression.name.lower() == "epoch"
459        and expression.to.this in exp.DataType.TEMPORAL_TYPES
460    ):
461        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
462
463    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
466def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
467    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
468    if isinstance(expression, exp.Select):
469        for join in expression.args.get("joins") or []:
470            on = join.args.get("on")
471            if on and join.kind in ("SEMI", "ANTI"):
472                subquery = exp.select("1").from_(join.this).where(on)
473                exists = exp.Exists(this=subquery)
474                if join.kind == "ANTI":
475                    exists = exists.not_(copy=False)
476
477                join.pop()
478                expression.where(exists, copy=False)
479
480    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
483def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
484    """
485    Converts a query with a FULL OUTER join to a union of identical queries that
486    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
487    for queries that have a single FULL OUTER join.
488    """
489    if isinstance(expression, exp.Select):
490        full_outer_joins = [
491            (index, join)
492            for index, join in enumerate(expression.args.get("joins") or [])
493            if join.side == "FULL"
494        ]
495
496        if len(full_outer_joins) == 1:
497            expression_copy = expression.copy()
498            expression.set("limit", None)
499            index, full_outer_join = full_outer_joins[0]
500            full_outer_join.set("side", "left")
501            expression_copy.args["joins"][index].set("side", "right")
502            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
503
504            return exp.union(expression, expression_copy, copy=False)
505
506    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
509def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
510    """
511    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
512    defined at the top-level, so for example queries like:
513
514        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
515
516    are invalid in those dialects. This transformation can be used to ensure all CTEs are
517    moved to the top level so that the final SQL code is valid from a syntax standpoint.
518
519    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
520    """
521    top_level_with = expression.args.get("with")
522    for inner_with in expression.find_all(exp.With):
523        if inner_with.parent is expression:
524            continue
525
526        if not top_level_with:
527            top_level_with = inner_with.pop()
528            expression.set("with", top_level_with)
529        else:
530            if inner_with.recursive:
531                top_level_with.set("recursive", True)
532
533            parent_cte = inner_with.find_ancestor(exp.CTE)
534            inner_with.pop()
535
536            if parent_cte:
537                i = top_level_with.expressions.index(parent_cte)
538                top_level_with.expressions[i:i] = inner_with.expressions
539                top_level_with.set("expressions", top_level_with.expressions)
540            else:
541                top_level_with.set(
542                    "expressions", top_level_with.expressions + inner_with.expressions
543                )
544
545    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
548def ensure_bools(expression: exp.Expression) -> exp.Expression:
549    """Converts numeric values used in conditions into explicit boolean expressions."""
550    from sqlglot.optimizer.canonicalize import ensure_bools
551
552    def _ensure_bool(node: exp.Expression) -> None:
553        if (
554            node.is_number
555            or (
556                not isinstance(node, exp.SubqueryPredicate)
557                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
558            )
559            or (isinstance(node, exp.Column) and not node.type)
560        ):
561            node.replace(node.neq(0))
562
563    for node in expression.walk():
564        ensure_bools(node, _ensure_bool)
565
566    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
569def unqualify_columns(expression: exp.Expression) -> exp.Expression:
570    for column in expression.find_all(exp.Column):
571        # We only wanna pop off the table, db, catalog args
572        for part in column.parts[:-1]:
573            part.pop()
574
575    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
578def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
579    assert isinstance(expression, exp.Create)
580    for constraint in expression.find_all(exp.UniqueColumnConstraint):
581        if constraint.parent:
582            constraint.parent.pop()
583
584    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
587def ctas_with_tmp_tables_to_create_tmp_view(
588    expression: exp.Expression,
589    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
590) -> exp.Expression:
591    assert isinstance(expression, exp.Create)
592    properties = expression.args.get("properties")
593    temporary = any(
594        isinstance(prop, exp.TemporaryProperty)
595        for prop in (properties.expressions if properties else [])
596    )
597
598    # CTAS with temp tables map to CREATE TEMPORARY VIEW
599    if expression.kind == "TABLE" and temporary:
600        if expression.expression:
601            return exp.Create(
602                kind="TEMPORARY VIEW",
603                this=expression.this,
604                expression=expression.expression,
605            )
606        return tmp_storage_provider(expression)
607
608    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
611def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
612    """
613    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
614    PARTITIONED BY value is an array of column names, they are transformed into a schema.
615    The corresponding columns are removed from the create statement.
616    """
617    assert isinstance(expression, exp.Create)
618    has_schema = isinstance(expression.this, exp.Schema)
619    is_partitionable = expression.kind in {"TABLE", "VIEW"}
620
621    if has_schema and is_partitionable:
622        prop = expression.find(exp.PartitionedByProperty)
623        if prop and prop.this and not isinstance(prop.this, exp.Schema):
624            schema = expression.this
625            columns = {v.name.upper() for v in prop.this.expressions}
626            partitions = [col for col in schema.expressions if col.name.upper() in columns]
627            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
628            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
629            expression.set("this", schema)
630
631    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
634def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
635    """
636    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
637
638    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
639    """
640    assert isinstance(expression, exp.Create)
641    prop = expression.find(exp.PartitionedByProperty)
642    if (
643        prop
644        and prop.this
645        and isinstance(prop.this, exp.Schema)
646        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
647    ):
648        prop_this = exp.Tuple(
649            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
650        )
651        schema = expression.this
652        for e in prop.this.expressions:
653            schema.append("expressions", e)
654        prop.set("this", prop_this)
655
656    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
659def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
660    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
661    if isinstance(expression, exp.Struct):
662        expression.set(
663            "expressions",
664            [
665                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
666                for e in expression.expressions
667            ],
668        )
669
670    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
673def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
674    """
675    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
676    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
677
678    For example,
679        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
680        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
681
682    Args:
683        expression: The AST to remove join marks from.
684
685    Returns:
686       The AST with join marks removed.
687    """
688    from sqlglot.optimizer.scope import traverse_scope
689
690    for scope in traverse_scope(expression):
691        query = scope.expression
692
693        where = query.args.get("where")
694        joins = query.args.get("joins")
695
696        if not where or not joins:
697            continue
698
699        query_from = query.args["from"]
700
701        # These keep track of the joins to be replaced
702        new_joins: t.Dict[str, exp.Join] = {}
703        old_joins = {join.alias_or_name: join for join in joins}
704
705        for column in scope.columns:
706            if not column.args.get("join_mark"):
707                continue
708
709            predicate = column.find_ancestor(exp.Predicate, exp.Select)
710            assert isinstance(
711                predicate, exp.Binary
712            ), "Columns can only be marked with (+) when involved in a binary operation"
713
714            predicate_parent = predicate.parent
715            join_predicate = predicate.pop()
716
717            left_columns = [
718                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
719            ]
720            right_columns = [
721                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
722            ]
723
724            assert not (
725                left_columns and right_columns
726            ), "The (+) marker cannot appear in both sides of a binary predicate"
727
728            marked_column_tables = set()
729            for col in left_columns or right_columns:
730                table = col.table
731                assert table, f"Column {col} needs to be qualified with a table"
732
733                col.set("join_mark", False)
734                marked_column_tables.add(table)
735
736            assert (
737                len(marked_column_tables) == 1
738            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
739
740            join_this = old_joins.get(col.table, query_from).this
741            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
742
743            # Upsert new_join into new_joins dictionary
744            new_join_alias_or_name = new_join.alias_or_name
745            existing_join = new_joins.get(new_join_alias_or_name)
746            if existing_join:
747                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
748            else:
749                new_joins[new_join_alias_or_name] = new_join
750
751            # If the parent of the target predicate is a binary node, then it now has only one child
752            if isinstance(predicate_parent, exp.Binary):
753                if predicate_parent.left is None:
754                    predicate_parent.replace(predicate_parent.right)
755                else:
756                    predicate_parent.replace(predicate_parent.left)
757
758        if query_from.alias_or_name in new_joins:
759            only_old_joins = old_joins.keys() - new_joins.keys()
760            assert (
761                len(only_old_joins) >= 1
762            ), "Cannot determine which table to use in the new FROM clause"
763
764            new_from_name = list(only_old_joins)[0]
765            query.set("from", exp.From(this=old_joins[new_from_name].this))
766
767        query.set("joins", list(new_joins.values()))
768
769        if not where.this:
770            where.pop()
771
772    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.