Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1from sqlglot import exp
  2from sqlglot.helper import name_sequence
  3from sqlglot.optimizer.scope import ScopeType, traverse_scope
  4
  5
  6def unnest_subqueries(expression):
  7    """
  8    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
  9
 10    Convert scalar subqueries into cross joins.
 11    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 16        >>> unnest_subqueries(expression).sql()
 17        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to unnest
 21    Returns:
 22        sqlglot.Expression: unnested expression
 23    """
 24    next_alias_name = name_sequence("_u_")
 25
 26    for scope in traverse_scope(expression):
 27        select = scope.expression
 28        parent = select.parent_select
 29        if not parent:
 30            continue
 31        if scope.external_columns:
 32            decorrelate(select, parent, scope.external_columns, next_alias_name)
 33        elif scope.scope_type == ScopeType.SUBQUERY:
 34            unnest(select, parent, next_alias_name)
 35
 36    return expression
 37
 38
 39def unnest(select, parent_select, next_alias_name):
 40    if len(select.selects) > 1:
 41        return
 42
 43    predicate = select.find_ancestor(exp.Condition)
 44    alias = next_alias_name()
 45
 46    if (
 47        not predicate
 48        or parent_select is not predicate.parent_select
 49        or not parent_select.args.get("from")
 50    ):
 51        return
 52
 53    # This subquery returns a scalar and can just be converted to a cross join
 54    if not isinstance(predicate, (exp.In, exp.Any)):
 55        column = exp.column(select.selects[0].alias_or_name, alias)
 56
 57        clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 58        clause_parent_select = clause.parent_select if clause else None
 59
 60        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 61            (not clause or clause_parent_select is not parent_select)
 62            and (
 63                parent_select.args.get("group")
 64                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
 65            )
 66        ):
 67            column = exp.Max(this=column)
 68        elif not isinstance(select.parent, exp.Subquery):
 69            return
 70
 71        _replace(select.parent, column)
 72        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
 73        return
 74
 75    if select.find(exp.Limit, exp.Offset):
 76        return
 77
 78    if isinstance(predicate, exp.Any):
 79        predicate = predicate.find_ancestor(exp.EQ)
 80
 81        if not predicate or parent_select is not predicate.parent_select:
 82            return
 83
 84    column = _other_operand(predicate)
 85    value = select.selects[0]
 86
 87    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
 88    _replace(predicate, f"NOT {on.right} IS NULL")
 89
 90    parent_select.join(
 91        select.group_by(value.this, copy=False),
 92        on=on,
 93        join_type="LEFT",
 94        join_alias=alias,
 95        copy=False,
 96    )
 97
 98
 99def decorrelate(select, parent_select, external_columns, next_alias_name):
100    where = select.args.get("where")
101
102    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
103        return
104
105    table_alias = next_alias_name()
106    keys = []
107
108    # for all external columns in the where statement, find the relevant predicate
109    # keys to convert it into a join
110    for column in external_columns:
111        if column.find_ancestor(exp.Where) is not where:
112            return
113
114        predicate = column.find_ancestor(exp.Predicate)
115
116        if not predicate or predicate.find_ancestor(exp.Where) is not where:
117            return
118
119        if isinstance(predicate, exp.Binary):
120            key = (
121                predicate.right
122                if any(node is column for node, *_ in predicate.left.walk())
123                else predicate.left
124            )
125        else:
126            return
127
128        keys.append((key, column, predicate))
129
130    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
131        return
132
133    is_subquery_projection = any(
134        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
135    )
136
137    value = select.selects[0]
138    key_aliases = {}
139    group_by = []
140
141    for key, _, predicate in keys:
142        # if we filter on the value of the subquery, it needs to be unique
143        if key == value.this:
144            key_aliases[key] = value.alias
145            group_by.append(key)
146        else:
147            if key not in key_aliases:
148                key_aliases[key] = next_alias_name()
149            # all predicates that are equalities must also be in the unique
150            # so that we don't do a many to many join
151            if isinstance(predicate, exp.EQ) and key not in group_by:
152                group_by.append(key)
153
154    parent_predicate = select.find_ancestor(exp.Predicate)
155
156    # if the value of the subquery is not an agg or a key, we need to collect it into an array
157    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
158    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
159    if not value.find(exp.AggFunc) and value.this not in group_by:
160        select.select(
161            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
162            append=False,
163            copy=False,
164        )
165
166    # exists queries should not have any selects as it only checks if there are any rows
167    # all selects will be added by the optimizer and only used for join keys
168    if isinstance(parent_predicate, exp.Exists):
169        select.args["expressions"] = []
170
171    for key, alias in key_aliases.items():
172        if key in group_by:
173            # add all keys to the projections of the subquery
174            # so that we can use it as a join key
175            if isinstance(parent_predicate, exp.Exists) or key != value.this:
176                select.select(f"{key} AS {alias}", copy=False)
177        else:
178            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
179
180    alias = exp.column(value.alias, table_alias)
181    other = _other_operand(parent_predicate)
182
183    if isinstance(parent_predicate, exp.Exists):
184        alias = exp.column(list(key_aliases.values())[0], table_alias)
185        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
186    elif isinstance(parent_predicate, exp.All):
187        parent_predicate = _replace(
188            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
189        )
190    elif isinstance(parent_predicate, exp.Any):
191        if value.this in group_by:
192            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
193        else:
194            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
195    elif isinstance(parent_predicate, exp.In):
196        if value.this in group_by:
197            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
198        else:
199            parent_predicate = _replace(
200                parent_predicate,
201                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
202            )
203    else:
204        if is_subquery_projection:
205            alias = exp.alias_(alias, select.parent.alias)
206
207        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
208        # by transforming all counts into 0 and using that as the coalesced value
209        if value.find(exp.Count):
210
211            def remove_aggs(node):
212                if isinstance(node, exp.Count):
213                    return exp.Literal.number(0)
214                elif isinstance(node, exp.AggFunc):
215                    return exp.null()
216                return node
217
218            alias = exp.Coalesce(
219                this=alias,
220                expressions=[value.this.transform(remove_aggs)],
221            )
222
223        select.parent.replace(alias)
224
225    for key, column, predicate in keys:
226        predicate.replace(exp.true())
227        nested = exp.column(key_aliases[key], table_alias)
228
229        if is_subquery_projection:
230            key.replace(nested)
231            continue
232
233        if key in group_by:
234            key.replace(nested)
235        elif isinstance(predicate, exp.EQ):
236            parent_predicate = _replace(
237                parent_predicate,
238                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
239            )
240        else:
241            key.replace(exp.to_identifier("_x"))
242            parent_predicate = _replace(
243                parent_predicate,
244                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
245            )
246
247    parent_select.join(
248        select.group_by(*group_by, copy=False),
249        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
250        join_type="LEFT",
251        join_alias=table_alias,
252        copy=False,
253    )
254
255
256def _replace(expression, condition):
257    return expression.replace(exp.condition(condition))
258
259
260def _other_operand(expression):
261    if isinstance(expression, exp.In):
262        return expression.this
263
264    if isinstance(expression, (exp.Any, exp.All)):
265        return _other_operand(expression.parent)
266
267    if isinstance(expression, exp.Binary):
268        return (
269            expression.right
270            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
271            else expression.left
272        )
273
274    return None
def unnest_subqueries(expression):
 7def unnest_subqueries(expression):
 8    """
 9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
10
11    Convert scalar subqueries into cross joins.
12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
17        >>> unnest_subqueries(expression).sql()
18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
19
20    Args:
21        expression (sqlglot.Expression): expression to unnest
22    Returns:
23        sqlglot.Expression: unnested expression
24    """
25    next_alias_name = name_sequence("_u_")
26
27    for scope in traverse_scope(expression):
28        select = scope.expression
29        parent = select.parent_select
30        if not parent:
31            continue
32        if scope.external_columns:
33            decorrelate(select, parent, scope.external_columns, next_alias_name)
34        elif scope.scope_type == ScopeType.SUBQUERY:
35            unnest(select, parent, next_alias_name)
36
37    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name):
41    if len(select.selects) > 1:
42        return
43
44    predicate = select.find_ancestor(exp.Condition)
45    alias = next_alias_name()
46
47    if (
48        not predicate
49        or parent_select is not predicate.parent_select
50        or not parent_select.args.get("from")
51    ):
52        return
53
54    # This subquery returns a scalar and can just be converted to a cross join
55    if not isinstance(predicate, (exp.In, exp.Any)):
56        column = exp.column(select.selects[0].alias_or_name, alias)
57
58        clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
59        clause_parent_select = clause.parent_select if clause else None
60
61        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
62            (not clause or clause_parent_select is not parent_select)
63            and (
64                parent_select.args.get("group")
65                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
66            )
67        ):
68            column = exp.Max(this=column)
69        elif not isinstance(select.parent, exp.Subquery):
70            return
71
72        _replace(select.parent, column)
73        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
74        return
75
76    if select.find(exp.Limit, exp.Offset):
77        return
78
79    if isinstance(predicate, exp.Any):
80        predicate = predicate.find_ancestor(exp.EQ)
81
82        if not predicate or parent_select is not predicate.parent_select:
83            return
84
85    column = _other_operand(predicate)
86    value = select.selects[0]
87
88    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
89    _replace(predicate, f"NOT {on.right} IS NULL")
90
91    parent_select.join(
92        select.group_by(value.this, copy=False),
93        on=on,
94        join_type="LEFT",
95        join_alias=alias,
96        copy=False,
97    )
def decorrelate(select, parent_select, external_columns, next_alias_name):
100def decorrelate(select, parent_select, external_columns, next_alias_name):
101    where = select.args.get("where")
102
103    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
104        return
105
106    table_alias = next_alias_name()
107    keys = []
108
109    # for all external columns in the where statement, find the relevant predicate
110    # keys to convert it into a join
111    for column in external_columns:
112        if column.find_ancestor(exp.Where) is not where:
113            return
114
115        predicate = column.find_ancestor(exp.Predicate)
116
117        if not predicate or predicate.find_ancestor(exp.Where) is not where:
118            return
119
120        if isinstance(predicate, exp.Binary):
121            key = (
122                predicate.right
123                if any(node is column for node, *_ in predicate.left.walk())
124                else predicate.left
125            )
126        else:
127            return
128
129        keys.append((key, column, predicate))
130
131    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
132        return
133
134    is_subquery_projection = any(
135        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
136    )
137
138    value = select.selects[0]
139    key_aliases = {}
140    group_by = []
141
142    for key, _, predicate in keys:
143        # if we filter on the value of the subquery, it needs to be unique
144        if key == value.this:
145            key_aliases[key] = value.alias
146            group_by.append(key)
147        else:
148            if key not in key_aliases:
149                key_aliases[key] = next_alias_name()
150            # all predicates that are equalities must also be in the unique
151            # so that we don't do a many to many join
152            if isinstance(predicate, exp.EQ) and key not in group_by:
153                group_by.append(key)
154
155    parent_predicate = select.find_ancestor(exp.Predicate)
156
157    # if the value of the subquery is not an agg or a key, we need to collect it into an array
158    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
159    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
160    if not value.find(exp.AggFunc) and value.this not in group_by:
161        select.select(
162            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
163            append=False,
164            copy=False,
165        )
166
167    # exists queries should not have any selects as it only checks if there are any rows
168    # all selects will be added by the optimizer and only used for join keys
169    if isinstance(parent_predicate, exp.Exists):
170        select.args["expressions"] = []
171
172    for key, alias in key_aliases.items():
173        if key in group_by:
174            # add all keys to the projections of the subquery
175            # so that we can use it as a join key
176            if isinstance(parent_predicate, exp.Exists) or key != value.this:
177                select.select(f"{key} AS {alias}", copy=False)
178        else:
179            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
180
181    alias = exp.column(value.alias, table_alias)
182    other = _other_operand(parent_predicate)
183
184    if isinstance(parent_predicate, exp.Exists):
185        alias = exp.column(list(key_aliases.values())[0], table_alias)
186        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
187    elif isinstance(parent_predicate, exp.All):
188        parent_predicate = _replace(
189            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
190        )
191    elif isinstance(parent_predicate, exp.Any):
192        if value.this in group_by:
193            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
194        else:
195            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
196    elif isinstance(parent_predicate, exp.In):
197        if value.this in group_by:
198            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
199        else:
200            parent_predicate = _replace(
201                parent_predicate,
202                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
203            )
204    else:
205        if is_subquery_projection:
206            alias = exp.alias_(alias, select.parent.alias)
207
208        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
209        # by transforming all counts into 0 and using that as the coalesced value
210        if value.find(exp.Count):
211
212            def remove_aggs(node):
213                if isinstance(node, exp.Count):
214                    return exp.Literal.number(0)
215                elif isinstance(node, exp.AggFunc):
216                    return exp.null()
217                return node
218
219            alias = exp.Coalesce(
220                this=alias,
221                expressions=[value.this.transform(remove_aggs)],
222            )
223
224        select.parent.replace(alias)
225
226    for key, column, predicate in keys:
227        predicate.replace(exp.true())
228        nested = exp.column(key_aliases[key], table_alias)
229
230        if is_subquery_projection:
231            key.replace(nested)
232            continue
233
234        if key in group_by:
235            key.replace(nested)
236        elif isinstance(predicate, exp.EQ):
237            parent_predicate = _replace(
238                parent_predicate,
239                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
240            )
241        else:
242            key.replace(exp.to_identifier("_x"))
243            parent_predicate = _replace(
244                parent_predicate,
245                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
246            )
247
248    parent_select.join(
249        select.group_by(*group_by, copy=False),
250        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
251        join_type="LEFT",
252        join_alias=table_alias,
253        copy=False,
254    )