sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 alias = next_alias_name() 45 46 if ( 47 not predicate 48 or parent_select is not predicate.parent_select 49 or not parent_select.args.get("from") 50 ): 51 return 52 53 # This subquery returns a scalar and can just be converted to a cross join 54 if not isinstance(predicate, (exp.In, exp.Any)): 55 column = exp.column(select.selects[0].alias_or_name, alias) 56 57 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 58 clause_parent_select = clause.parent_select if clause else None 59 60 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 61 (not clause or clause_parent_select is not parent_select) 62 and ( 63 parent_select.args.get("group") 64 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 65 ) 66 ): 67 column = exp.Max(this=column) 68 elif not isinstance(select.parent, exp.Subquery): 69 return 70 71 _replace(select.parent, column) 72 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 73 return 74 75 if select.find(exp.Limit, exp.Offset): 76 return 77 78 if isinstance(predicate, exp.Any): 79 predicate = predicate.find_ancestor(exp.EQ) 80 81 if not predicate or parent_select is not predicate.parent_select: 82 return 83 84 column = _other_operand(predicate) 85 value = select.selects[0] 86 87 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 88 _replace(predicate, f"NOT {on.right} IS NULL") 89 90 parent_select.join( 91 select.group_by(value.this, copy=False), 92 on=on, 93 join_type="LEFT", 94 join_alias=alias, 95 copy=False, 96 ) 97 98 99def decorrelate(select, parent_select, external_columns, next_alias_name): 100 where = select.args.get("where") 101 102 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 103 return 104 105 table_alias = next_alias_name() 106 keys = [] 107 108 # for all external columns in the where statement, find the relevant predicate 109 # keys to convert it into a join 110 for column in external_columns: 111 if column.find_ancestor(exp.Where) is not where: 112 return 113 114 predicate = column.find_ancestor(exp.Predicate) 115 116 if not predicate or predicate.find_ancestor(exp.Where) is not where: 117 return 118 119 if isinstance(predicate, exp.Binary): 120 key = ( 121 predicate.right 122 if any(node is column for node, *_ in predicate.left.walk()) 123 else predicate.left 124 ) 125 else: 126 return 127 128 keys.append((key, column, predicate)) 129 130 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 131 return 132 133 is_subquery_projection = any( 134 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 135 ) 136 137 value = select.selects[0] 138 key_aliases = {} 139 group_by = [] 140 141 for key, _, predicate in keys: 142 # if we filter on the value of the subquery, it needs to be unique 143 if key == value.this: 144 key_aliases[key] = value.alias 145 group_by.append(key) 146 else: 147 if key not in key_aliases: 148 key_aliases[key] = next_alias_name() 149 # all predicates that are equalities must also be in the unique 150 # so that we don't do a many to many join 151 if isinstance(predicate, exp.EQ) and key not in group_by: 152 group_by.append(key) 153 154 parent_predicate = select.find_ancestor(exp.Predicate) 155 156 # if the value of the subquery is not an agg or a key, we need to collect it into an array 157 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 158 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 159 if not value.find(exp.AggFunc) and value.this not in group_by: 160 select.select( 161 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 162 append=False, 163 copy=False, 164 ) 165 166 # exists queries should not have any selects as it only checks if there are any rows 167 # all selects will be added by the optimizer and only used for join keys 168 if isinstance(parent_predicate, exp.Exists): 169 select.args["expressions"] = [] 170 171 for key, alias in key_aliases.items(): 172 if key in group_by: 173 # add all keys to the projections of the subquery 174 # so that we can use it as a join key 175 if isinstance(parent_predicate, exp.Exists) or key != value.this: 176 select.select(f"{key} AS {alias}", copy=False) 177 else: 178 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 179 180 alias = exp.column(value.alias, table_alias) 181 other = _other_operand(parent_predicate) 182 183 if isinstance(parent_predicate, exp.Exists): 184 alias = exp.column(list(key_aliases.values())[0], table_alias) 185 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 186 elif isinstance(parent_predicate, exp.All): 187 parent_predicate = _replace( 188 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 189 ) 190 elif isinstance(parent_predicate, exp.Any): 191 if value.this in group_by: 192 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 193 else: 194 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 195 elif isinstance(parent_predicate, exp.In): 196 if value.this in group_by: 197 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 198 else: 199 parent_predicate = _replace( 200 parent_predicate, 201 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 202 ) 203 else: 204 if is_subquery_projection: 205 alias = exp.alias_(alias, select.parent.alias) 206 207 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 208 # by transforming all counts into 0 and using that as the coalesced value 209 if value.find(exp.Count): 210 211 def remove_aggs(node): 212 if isinstance(node, exp.Count): 213 return exp.Literal.number(0) 214 elif isinstance(node, exp.AggFunc): 215 return exp.null() 216 return node 217 218 alias = exp.Coalesce( 219 this=alias, 220 expressions=[value.this.transform(remove_aggs)], 221 ) 222 223 select.parent.replace(alias) 224 225 for key, column, predicate in keys: 226 predicate.replace(exp.true()) 227 nested = exp.column(key_aliases[key], table_alias) 228 229 if is_subquery_projection: 230 key.replace(nested) 231 continue 232 233 if key in group_by: 234 key.replace(nested) 235 elif isinstance(predicate, exp.EQ): 236 parent_predicate = _replace( 237 parent_predicate, 238 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 239 ) 240 else: 241 key.replace(exp.to_identifier("_x")) 242 parent_predicate = _replace( 243 parent_predicate, 244 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 245 ) 246 247 parent_select.join( 248 select.group_by(*group_by, copy=False), 249 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 250 join_type="LEFT", 251 join_alias=table_alias, 252 copy=False, 253 ) 254 255 256def _replace(expression, condition): 257 return expression.replace(exp.condition(condition)) 258 259 260def _other_operand(expression): 261 if isinstance(expression, exp.In): 262 return expression.this 263 264 if isinstance(expression, (exp.Any, exp.All)): 265 return _other_operand(expression.parent) 266 267 if isinstance(expression, exp.Binary): 268 return ( 269 expression.right 270 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 271 else expression.left 272 ) 273 274 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 alias = next_alias_name() 46 47 if ( 48 not predicate 49 or parent_select is not predicate.parent_select 50 or not parent_select.args.get("from") 51 ): 52 return 53 54 # This subquery returns a scalar and can just be converted to a cross join 55 if not isinstance(predicate, (exp.In, exp.Any)): 56 column = exp.column(select.selects[0].alias_or_name, alias) 57 58 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 59 clause_parent_select = clause.parent_select if clause else None 60 61 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 62 (not clause or clause_parent_select is not parent_select) 63 and ( 64 parent_select.args.get("group") 65 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 66 ) 67 ): 68 column = exp.Max(this=column) 69 elif not isinstance(select.parent, exp.Subquery): 70 return 71 72 _replace(select.parent, column) 73 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 74 return 75 76 if select.find(exp.Limit, exp.Offset): 77 return 78 79 if isinstance(predicate, exp.Any): 80 predicate = predicate.find_ancestor(exp.EQ) 81 82 if not predicate or parent_select is not predicate.parent_select: 83 return 84 85 column = _other_operand(predicate) 86 value = select.selects[0] 87 88 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 89 _replace(predicate, f"NOT {on.right} IS NULL") 90 91 parent_select.join( 92 select.group_by(value.this, copy=False), 93 on=on, 94 join_type="LEFT", 95 join_alias=alias, 96 copy=False, 97 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
100def decorrelate(select, parent_select, external_columns, next_alias_name): 101 where = select.args.get("where") 102 103 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 104 return 105 106 table_alias = next_alias_name() 107 keys = [] 108 109 # for all external columns in the where statement, find the relevant predicate 110 # keys to convert it into a join 111 for column in external_columns: 112 if column.find_ancestor(exp.Where) is not where: 113 return 114 115 predicate = column.find_ancestor(exp.Predicate) 116 117 if not predicate or predicate.find_ancestor(exp.Where) is not where: 118 return 119 120 if isinstance(predicate, exp.Binary): 121 key = ( 122 predicate.right 123 if any(node is column for node, *_ in predicate.left.walk()) 124 else predicate.left 125 ) 126 else: 127 return 128 129 keys.append((key, column, predicate)) 130 131 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 132 return 133 134 is_subquery_projection = any( 135 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 136 ) 137 138 value = select.selects[0] 139 key_aliases = {} 140 group_by = [] 141 142 for key, _, predicate in keys: 143 # if we filter on the value of the subquery, it needs to be unique 144 if key == value.this: 145 key_aliases[key] = value.alias 146 group_by.append(key) 147 else: 148 if key not in key_aliases: 149 key_aliases[key] = next_alias_name() 150 # all predicates that are equalities must also be in the unique 151 # so that we don't do a many to many join 152 if isinstance(predicate, exp.EQ) and key not in group_by: 153 group_by.append(key) 154 155 parent_predicate = select.find_ancestor(exp.Predicate) 156 157 # if the value of the subquery is not an agg or a key, we need to collect it into an array 158 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 159 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 160 if not value.find(exp.AggFunc) and value.this not in group_by: 161 select.select( 162 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 163 append=False, 164 copy=False, 165 ) 166 167 # exists queries should not have any selects as it only checks if there are any rows 168 # all selects will be added by the optimizer and only used for join keys 169 if isinstance(parent_predicate, exp.Exists): 170 select.args["expressions"] = [] 171 172 for key, alias in key_aliases.items(): 173 if key in group_by: 174 # add all keys to the projections of the subquery 175 # so that we can use it as a join key 176 if isinstance(parent_predicate, exp.Exists) or key != value.this: 177 select.select(f"{key} AS {alias}", copy=False) 178 else: 179 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 180 181 alias = exp.column(value.alias, table_alias) 182 other = _other_operand(parent_predicate) 183 184 if isinstance(parent_predicate, exp.Exists): 185 alias = exp.column(list(key_aliases.values())[0], table_alias) 186 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 187 elif isinstance(parent_predicate, exp.All): 188 parent_predicate = _replace( 189 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 190 ) 191 elif isinstance(parent_predicate, exp.Any): 192 if value.this in group_by: 193 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 194 else: 195 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 196 elif isinstance(parent_predicate, exp.In): 197 if value.this in group_by: 198 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 199 else: 200 parent_predicate = _replace( 201 parent_predicate, 202 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 203 ) 204 else: 205 if is_subquery_projection: 206 alias = exp.alias_(alias, select.parent.alias) 207 208 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 209 # by transforming all counts into 0 and using that as the coalesced value 210 if value.find(exp.Count): 211 212 def remove_aggs(node): 213 if isinstance(node, exp.Count): 214 return exp.Literal.number(0) 215 elif isinstance(node, exp.AggFunc): 216 return exp.null() 217 return node 218 219 alias = exp.Coalesce( 220 this=alias, 221 expressions=[value.this.transform(remove_aggs)], 222 ) 223 224 select.parent.replace(alias) 225 226 for key, column, predicate in keys: 227 predicate.replace(exp.true()) 228 nested = exp.column(key_aliases[key], table_alias) 229 230 if is_subquery_projection: 231 key.replace(nested) 232 continue 233 234 if key in group_by: 235 key.replace(nested) 236 elif isinstance(predicate, exp.EQ): 237 parent_predicate = _replace( 238 parent_predicate, 239 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 240 ) 241 else: 242 key.replace(exp.to_identifier("_x")) 243 parent_predicate = _replace( 244 parent_predicate, 245 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 246 ) 247 248 parent_select.join( 249 select.group_by(*group_by, copy=False), 250 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 251 join_type="LEFT", 252 join_alias=table_alias, 253 copy=False, 254 )