Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1from sqlglot import exp
  2from sqlglot.helper import name_sequence
  3from sqlglot.optimizer.scope import ScopeType, traverse_scope
  4
  5
  6def unnest_subqueries(expression):
  7    """
  8    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
  9
 10    Convert scalar subqueries into cross joins.
 11    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 16        >>> unnest_subqueries(expression).sql()
 17        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to unnest
 21    Returns:
 22        sqlglot.Expression: unnested expression
 23    """
 24    next_alias_name = name_sequence("_u_")
 25
 26    for scope in traverse_scope(expression):
 27        select = scope.expression
 28        parent = select.parent_select
 29        if not parent:
 30            continue
 31        if scope.external_columns:
 32            decorrelate(select, parent, scope.external_columns, next_alias_name)
 33        elif scope.scope_type == ScopeType.SUBQUERY:
 34            unnest(select, parent, next_alias_name)
 35
 36    return expression
 37
 38
 39def unnest(select, parent_select, next_alias_name):
 40    if len(select.selects) > 1:
 41        return
 42
 43    predicate = select.find_ancestor(exp.Condition)
 44    alias = next_alias_name()
 45
 46    if (
 47        not predicate
 48        or parent_select is not predicate.parent_select
 49        or not parent_select.args.get("from")
 50    ):
 51        return
 52
 53    # This subquery returns a scalar and can just be converted to a cross join
 54    if not isinstance(predicate, (exp.In, exp.Any)):
 55        column = exp.column(select.selects[0].alias_or_name, alias)
 56
 57        clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 58        clause_parent_select = clause.parent_select if clause else None
 59
 60        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 61            (not clause or clause_parent_select is not parent_select)
 62            and (
 63                parent_select.args.get("group")
 64                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
 65            )
 66        ):
 67            column = exp.Max(this=column)
 68
 69        _replace(select.parent, column)
 70        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
 71        return
 72
 73    if select.find(exp.Limit, exp.Offset):
 74        return
 75
 76    if isinstance(predicate, exp.Any):
 77        predicate = predicate.find_ancestor(exp.EQ)
 78
 79        if not predicate or parent_select is not predicate.parent_select:
 80            return
 81
 82    column = _other_operand(predicate)
 83    value = select.selects[0]
 84
 85    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
 86    _replace(predicate, f"NOT {on.right} IS NULL")
 87
 88    parent_select.join(
 89        select.group_by(value.this, copy=False),
 90        on=on,
 91        join_type="LEFT",
 92        join_alias=alias,
 93        copy=False,
 94    )
 95
 96
 97def decorrelate(select, parent_select, external_columns, next_alias_name):
 98    where = select.args.get("where")
 99
100    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
101        return
102
103    table_alias = next_alias_name()
104    keys = []
105
106    # for all external columns in the where statement, find the relevant predicate
107    # keys to convert it into a join
108    for column in external_columns:
109        if column.find_ancestor(exp.Where) is not where:
110            return
111
112        predicate = column.find_ancestor(exp.Predicate)
113
114        if not predicate or predicate.find_ancestor(exp.Where) is not where:
115            return
116
117        if isinstance(predicate, exp.Binary):
118            key = (
119                predicate.right
120                if any(node is column for node, *_ in predicate.left.walk())
121                else predicate.left
122            )
123        else:
124            return
125
126        keys.append((key, column, predicate))
127
128    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
129        return
130
131    is_subquery_projection = any(
132        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
133    )
134
135    value = select.selects[0]
136    key_aliases = {}
137    group_by = []
138
139    for key, _, predicate in keys:
140        # if we filter on the value of the subquery, it needs to be unique
141        if key == value.this:
142            key_aliases[key] = value.alias
143            group_by.append(key)
144        else:
145            if key not in key_aliases:
146                key_aliases[key] = next_alias_name()
147            # all predicates that are equalities must also be in the unique
148            # so that we don't do a many to many join
149            if isinstance(predicate, exp.EQ) and key not in group_by:
150                group_by.append(key)
151
152    parent_predicate = select.find_ancestor(exp.Predicate)
153
154    # if the value of the subquery is not an agg or a key, we need to collect it into an array
155    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
156    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
157    if not value.find(exp.AggFunc) and value.this not in group_by:
158        select.select(
159            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
160            append=False,
161            copy=False,
162        )
163
164    # exists queries should not have any selects as it only checks if there are any rows
165    # all selects will be added by the optimizer and only used for join keys
166    if isinstance(parent_predicate, exp.Exists):
167        select.args["expressions"] = []
168
169    for key, alias in key_aliases.items():
170        if key in group_by:
171            # add all keys to the projections of the subquery
172            # so that we can use it as a join key
173            if isinstance(parent_predicate, exp.Exists) or key != value.this:
174                select.select(f"{key} AS {alias}", copy=False)
175        else:
176            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
177
178    alias = exp.column(value.alias, table_alias)
179    other = _other_operand(parent_predicate)
180
181    if isinstance(parent_predicate, exp.Exists):
182        alias = exp.column(list(key_aliases.values())[0], table_alias)
183        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
184    elif isinstance(parent_predicate, exp.All):
185        parent_predicate = _replace(
186            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
187        )
188    elif isinstance(parent_predicate, exp.Any):
189        if value.this in group_by:
190            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
191        else:
192            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
193    elif isinstance(parent_predicate, exp.In):
194        if value.this in group_by:
195            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
196        else:
197            parent_predicate = _replace(
198                parent_predicate,
199                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
200            )
201    else:
202        if is_subquery_projection:
203            alias = exp.alias_(alias, select.parent.alias)
204
205        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
206        # by transforming all counts into 0 and using that as the coalesced value
207        if value.find(exp.Count):
208
209            def remove_aggs(node):
210                if isinstance(node, exp.Count):
211                    return exp.Literal.number(0)
212                elif isinstance(node, exp.AggFunc):
213                    return exp.null()
214                return node
215
216            alias = exp.Coalesce(
217                this=alias,
218                expressions=[value.this.transform(remove_aggs)],
219            )
220
221        select.parent.replace(alias)
222
223    for key, column, predicate in keys:
224        predicate.replace(exp.true())
225        nested = exp.column(key_aliases[key], table_alias)
226
227        if is_subquery_projection:
228            key.replace(nested)
229            continue
230
231        if key in group_by:
232            key.replace(nested)
233        elif isinstance(predicate, exp.EQ):
234            parent_predicate = _replace(
235                parent_predicate,
236                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
237            )
238        else:
239            key.replace(exp.to_identifier("_x"))
240            parent_predicate = _replace(
241                parent_predicate,
242                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
243            )
244
245    parent_select.join(
246        select.group_by(*group_by, copy=False),
247        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
248        join_type="LEFT",
249        join_alias=table_alias,
250        copy=False,
251    )
252
253
254def _replace(expression, condition):
255    return expression.replace(exp.condition(condition))
256
257
258def _other_operand(expression):
259    if isinstance(expression, exp.In):
260        return expression.this
261
262    if isinstance(expression, (exp.Any, exp.All)):
263        return _other_operand(expression.parent)
264
265    if isinstance(expression, exp.Binary):
266        return (
267            expression.right
268            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
269            else expression.left
270        )
271
272    return None
def unnest_subqueries(expression):
 7def unnest_subqueries(expression):
 8    """
 9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
10
11    Convert scalar subqueries into cross joins.
12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
17        >>> unnest_subqueries(expression).sql()
18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
19
20    Args:
21        expression (sqlglot.Expression): expression to unnest
22    Returns:
23        sqlglot.Expression: unnested expression
24    """
25    next_alias_name = name_sequence("_u_")
26
27    for scope in traverse_scope(expression):
28        select = scope.expression
29        parent = select.parent_select
30        if not parent:
31            continue
32        if scope.external_columns:
33            decorrelate(select, parent, scope.external_columns, next_alias_name)
34        elif scope.scope_type == ScopeType.SUBQUERY:
35            unnest(select, parent, next_alias_name)
36
37    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name):
41    if len(select.selects) > 1:
42        return
43
44    predicate = select.find_ancestor(exp.Condition)
45    alias = next_alias_name()
46
47    if (
48        not predicate
49        or parent_select is not predicate.parent_select
50        or not parent_select.args.get("from")
51    ):
52        return
53
54    # This subquery returns a scalar and can just be converted to a cross join
55    if not isinstance(predicate, (exp.In, exp.Any)):
56        column = exp.column(select.selects[0].alias_or_name, alias)
57
58        clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
59        clause_parent_select = clause.parent_select if clause else None
60
61        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
62            (not clause or clause_parent_select is not parent_select)
63            and (
64                parent_select.args.get("group")
65                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
66            )
67        ):
68            column = exp.Max(this=column)
69
70        _replace(select.parent, column)
71        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
72        return
73
74    if select.find(exp.Limit, exp.Offset):
75        return
76
77    if isinstance(predicate, exp.Any):
78        predicate = predicate.find_ancestor(exp.EQ)
79
80        if not predicate or parent_select is not predicate.parent_select:
81            return
82
83    column = _other_operand(predicate)
84    value = select.selects[0]
85
86    on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
87    _replace(predicate, f"NOT {on.right} IS NULL")
88
89    parent_select.join(
90        select.group_by(value.this, copy=False),
91        on=on,
92        join_type="LEFT",
93        join_alias=alias,
94        copy=False,
95    )
def decorrelate(select, parent_select, external_columns, next_alias_name):
 98def decorrelate(select, parent_select, external_columns, next_alias_name):
 99    where = select.args.get("where")
100
101    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
102        return
103
104    table_alias = next_alias_name()
105    keys = []
106
107    # for all external columns in the where statement, find the relevant predicate
108    # keys to convert it into a join
109    for column in external_columns:
110        if column.find_ancestor(exp.Where) is not where:
111            return
112
113        predicate = column.find_ancestor(exp.Predicate)
114
115        if not predicate or predicate.find_ancestor(exp.Where) is not where:
116            return
117
118        if isinstance(predicate, exp.Binary):
119            key = (
120                predicate.right
121                if any(node is column for node, *_ in predicate.left.walk())
122                else predicate.left
123            )
124        else:
125            return
126
127        keys.append((key, column, predicate))
128
129    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
130        return
131
132    is_subquery_projection = any(
133        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
134    )
135
136    value = select.selects[0]
137    key_aliases = {}
138    group_by = []
139
140    for key, _, predicate in keys:
141        # if we filter on the value of the subquery, it needs to be unique
142        if key == value.this:
143            key_aliases[key] = value.alias
144            group_by.append(key)
145        else:
146            if key not in key_aliases:
147                key_aliases[key] = next_alias_name()
148            # all predicates that are equalities must also be in the unique
149            # so that we don't do a many to many join
150            if isinstance(predicate, exp.EQ) and key not in group_by:
151                group_by.append(key)
152
153    parent_predicate = select.find_ancestor(exp.Predicate)
154
155    # if the value of the subquery is not an agg or a key, we need to collect it into an array
156    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
157    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
158    if not value.find(exp.AggFunc) and value.this not in group_by:
159        select.select(
160            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
161            append=False,
162            copy=False,
163        )
164
165    # exists queries should not have any selects as it only checks if there are any rows
166    # all selects will be added by the optimizer and only used for join keys
167    if isinstance(parent_predicate, exp.Exists):
168        select.args["expressions"] = []
169
170    for key, alias in key_aliases.items():
171        if key in group_by:
172            # add all keys to the projections of the subquery
173            # so that we can use it as a join key
174            if isinstance(parent_predicate, exp.Exists) or key != value.this:
175                select.select(f"{key} AS {alias}", copy=False)
176        else:
177            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
178
179    alias = exp.column(value.alias, table_alias)
180    other = _other_operand(parent_predicate)
181
182    if isinstance(parent_predicate, exp.Exists):
183        alias = exp.column(list(key_aliases.values())[0], table_alias)
184        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
185    elif isinstance(parent_predicate, exp.All):
186        parent_predicate = _replace(
187            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
188        )
189    elif isinstance(parent_predicate, exp.Any):
190        if value.this in group_by:
191            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
192        else:
193            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
194    elif isinstance(parent_predicate, exp.In):
195        if value.this in group_by:
196            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
197        else:
198            parent_predicate = _replace(
199                parent_predicate,
200                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
201            )
202    else:
203        if is_subquery_projection:
204            alias = exp.alias_(alias, select.parent.alias)
205
206        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
207        # by transforming all counts into 0 and using that as the coalesced value
208        if value.find(exp.Count):
209
210            def remove_aggs(node):
211                if isinstance(node, exp.Count):
212                    return exp.Literal.number(0)
213                elif isinstance(node, exp.AggFunc):
214                    return exp.null()
215                return node
216
217            alias = exp.Coalesce(
218                this=alias,
219                expressions=[value.this.transform(remove_aggs)],
220            )
221
222        select.parent.replace(alias)
223
224    for key, column, predicate in keys:
225        predicate.replace(exp.true())
226        nested = exp.column(key_aliases[key], table_alias)
227
228        if is_subquery_projection:
229            key.replace(nested)
230            continue
231
232        if key in group_by:
233            key.replace(nested)
234        elif isinstance(predicate, exp.EQ):
235            parent_predicate = _replace(
236                parent_predicate,
237                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
238            )
239        else:
240            key.replace(exp.to_identifier("_x"))
241            parent_predicate = _replace(
242                parent_predicate,
243                f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
244            )
245
246    parent_select.join(
247        select.group_by(*group_by, copy=False),
248        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
249        join_type="LEFT",
250        join_alias=table_alias,
251        copy=False,
252    )