sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 alias = next_alias_name() 45 46 if ( 47 not predicate 48 or parent_select is not predicate.parent_select 49 or not parent_select.args.get("from") 50 ): 51 return 52 53 # This subquery returns a scalar and can just be converted to a cross join 54 if not isinstance(predicate, (exp.In, exp.Any)): 55 column = exp.column(select.selects[0].alias_or_name, alias) 56 57 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 58 clause_parent_select = clause.parent_select if clause else None 59 60 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 61 (not clause or clause_parent_select is not parent_select) 62 and ( 63 parent_select.args.get("group") 64 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 65 ) 66 ): 67 column = exp.Max(this=column) 68 69 _replace(select.parent, column) 70 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 71 return 72 73 if select.find(exp.Limit, exp.Offset): 74 return 75 76 if isinstance(predicate, exp.Any): 77 predicate = predicate.find_ancestor(exp.EQ) 78 79 if not predicate or parent_select is not predicate.parent_select: 80 return 81 82 column = _other_operand(predicate) 83 value = select.selects[0] 84 85 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 86 _replace(predicate, f"NOT {on.right} IS NULL") 87 88 parent_select.join( 89 select.group_by(value.this, copy=False), 90 on=on, 91 join_type="LEFT", 92 join_alias=alias, 93 copy=False, 94 ) 95 96 97def decorrelate(select, parent_select, external_columns, next_alias_name): 98 where = select.args.get("where") 99 100 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 101 return 102 103 table_alias = next_alias_name() 104 keys = [] 105 106 # for all external columns in the where statement, find the relevant predicate 107 # keys to convert it into a join 108 for column in external_columns: 109 if column.find_ancestor(exp.Where) is not where: 110 return 111 112 predicate = column.find_ancestor(exp.Predicate) 113 114 if not predicate or predicate.find_ancestor(exp.Where) is not where: 115 return 116 117 if isinstance(predicate, exp.Binary): 118 key = ( 119 predicate.right 120 if any(node is column for node, *_ in predicate.left.walk()) 121 else predicate.left 122 ) 123 else: 124 return 125 126 keys.append((key, column, predicate)) 127 128 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 129 return 130 131 is_subquery_projection = any( 132 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 133 ) 134 135 value = select.selects[0] 136 key_aliases = {} 137 group_by = [] 138 139 for key, _, predicate in keys: 140 # if we filter on the value of the subquery, it needs to be unique 141 if key == value.this: 142 key_aliases[key] = value.alias 143 group_by.append(key) 144 else: 145 if key not in key_aliases: 146 key_aliases[key] = next_alias_name() 147 # all predicates that are equalities must also be in the unique 148 # so that we don't do a many to many join 149 if isinstance(predicate, exp.EQ) and key not in group_by: 150 group_by.append(key) 151 152 parent_predicate = select.find_ancestor(exp.Predicate) 153 154 # if the value of the subquery is not an agg or a key, we need to collect it into an array 155 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 156 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 157 if not value.find(exp.AggFunc) and value.this not in group_by: 158 select.select( 159 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 160 append=False, 161 copy=False, 162 ) 163 164 # exists queries should not have any selects as it only checks if there are any rows 165 # all selects will be added by the optimizer and only used for join keys 166 if isinstance(parent_predicate, exp.Exists): 167 select.args["expressions"] = [] 168 169 for key, alias in key_aliases.items(): 170 if key in group_by: 171 # add all keys to the projections of the subquery 172 # so that we can use it as a join key 173 if isinstance(parent_predicate, exp.Exists) or key != value.this: 174 select.select(f"{key} AS {alias}", copy=False) 175 else: 176 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 177 178 alias = exp.column(value.alias, table_alias) 179 other = _other_operand(parent_predicate) 180 181 if isinstance(parent_predicate, exp.Exists): 182 alias = exp.column(list(key_aliases.values())[0], table_alias) 183 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 184 elif isinstance(parent_predicate, exp.All): 185 parent_predicate = _replace( 186 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 187 ) 188 elif isinstance(parent_predicate, exp.Any): 189 if value.this in group_by: 190 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 191 else: 192 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 193 elif isinstance(parent_predicate, exp.In): 194 if value.this in group_by: 195 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 196 else: 197 parent_predicate = _replace( 198 parent_predicate, 199 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 200 ) 201 else: 202 if is_subquery_projection: 203 alias = exp.alias_(alias, select.parent.alias) 204 205 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 206 # by transforming all counts into 0 and using that as the coalesced value 207 if value.find(exp.Count): 208 209 def remove_aggs(node): 210 if isinstance(node, exp.Count): 211 return exp.Literal.number(0) 212 elif isinstance(node, exp.AggFunc): 213 return exp.null() 214 return node 215 216 alias = exp.Coalesce( 217 this=alias, 218 expressions=[value.this.transform(remove_aggs)], 219 ) 220 221 select.parent.replace(alias) 222 223 for key, column, predicate in keys: 224 predicate.replace(exp.true()) 225 nested = exp.column(key_aliases[key], table_alias) 226 227 if is_subquery_projection: 228 key.replace(nested) 229 continue 230 231 if key in group_by: 232 key.replace(nested) 233 elif isinstance(predicate, exp.EQ): 234 parent_predicate = _replace( 235 parent_predicate, 236 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 237 ) 238 else: 239 key.replace(exp.to_identifier("_x")) 240 parent_predicate = _replace( 241 parent_predicate, 242 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 243 ) 244 245 parent_select.join( 246 select.group_by(*group_by, copy=False), 247 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 248 join_type="LEFT", 249 join_alias=table_alias, 250 copy=False, 251 ) 252 253 254def _replace(expression, condition): 255 return expression.replace(exp.condition(condition)) 256 257 258def _other_operand(expression): 259 if isinstance(expression, exp.In): 260 return expression.this 261 262 if isinstance(expression, (exp.Any, exp.All)): 263 return _other_operand(expression.parent) 264 265 if isinstance(expression, exp.Binary): 266 return ( 267 expression.right 268 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 269 else expression.left 270 ) 271 272 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 alias = next_alias_name() 46 47 if ( 48 not predicate 49 or parent_select is not predicate.parent_select 50 or not parent_select.args.get("from") 51 ): 52 return 53 54 # This subquery returns a scalar and can just be converted to a cross join 55 if not isinstance(predicate, (exp.In, exp.Any)): 56 column = exp.column(select.selects[0].alias_or_name, alias) 57 58 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 59 clause_parent_select = clause.parent_select if clause else None 60 61 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 62 (not clause or clause_parent_select is not parent_select) 63 and ( 64 parent_select.args.get("group") 65 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 66 ) 67 ): 68 column = exp.Max(this=column) 69 70 _replace(select.parent, column) 71 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 72 return 73 74 if select.find(exp.Limit, exp.Offset): 75 return 76 77 if isinstance(predicate, exp.Any): 78 predicate = predicate.find_ancestor(exp.EQ) 79 80 if not predicate or parent_select is not predicate.parent_select: 81 return 82 83 column = _other_operand(predicate) 84 value = select.selects[0] 85 86 on = exp.condition(f'{column} = "{alias}"."{value.alias}"') 87 _replace(predicate, f"NOT {on.right} IS NULL") 88 89 parent_select.join( 90 select.group_by(value.this, copy=False), 91 on=on, 92 join_type="LEFT", 93 join_alias=alias, 94 copy=False, 95 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
98def decorrelate(select, parent_select, external_columns, next_alias_name): 99 where = select.args.get("where") 100 101 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 102 return 103 104 table_alias = next_alias_name() 105 keys = [] 106 107 # for all external columns in the where statement, find the relevant predicate 108 # keys to convert it into a join 109 for column in external_columns: 110 if column.find_ancestor(exp.Where) is not where: 111 return 112 113 predicate = column.find_ancestor(exp.Predicate) 114 115 if not predicate or predicate.find_ancestor(exp.Where) is not where: 116 return 117 118 if isinstance(predicate, exp.Binary): 119 key = ( 120 predicate.right 121 if any(node is column for node, *_ in predicate.left.walk()) 122 else predicate.left 123 ) 124 else: 125 return 126 127 keys.append((key, column, predicate)) 128 129 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 130 return 131 132 is_subquery_projection = any( 133 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 134 ) 135 136 value = select.selects[0] 137 key_aliases = {} 138 group_by = [] 139 140 for key, _, predicate in keys: 141 # if we filter on the value of the subquery, it needs to be unique 142 if key == value.this: 143 key_aliases[key] = value.alias 144 group_by.append(key) 145 else: 146 if key not in key_aliases: 147 key_aliases[key] = next_alias_name() 148 # all predicates that are equalities must also be in the unique 149 # so that we don't do a many to many join 150 if isinstance(predicate, exp.EQ) and key not in group_by: 151 group_by.append(key) 152 153 parent_predicate = select.find_ancestor(exp.Predicate) 154 155 # if the value of the subquery is not an agg or a key, we need to collect it into an array 156 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 157 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 158 if not value.find(exp.AggFunc) and value.this not in group_by: 159 select.select( 160 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 161 append=False, 162 copy=False, 163 ) 164 165 # exists queries should not have any selects as it only checks if there are any rows 166 # all selects will be added by the optimizer and only used for join keys 167 if isinstance(parent_predicate, exp.Exists): 168 select.args["expressions"] = [] 169 170 for key, alias in key_aliases.items(): 171 if key in group_by: 172 # add all keys to the projections of the subquery 173 # so that we can use it as a join key 174 if isinstance(parent_predicate, exp.Exists) or key != value.this: 175 select.select(f"{key} AS {alias}", copy=False) 176 else: 177 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 178 179 alias = exp.column(value.alias, table_alias) 180 other = _other_operand(parent_predicate) 181 182 if isinstance(parent_predicate, exp.Exists): 183 alias = exp.column(list(key_aliases.values())[0], table_alias) 184 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 185 elif isinstance(parent_predicate, exp.All): 186 parent_predicate = _replace( 187 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 188 ) 189 elif isinstance(parent_predicate, exp.Any): 190 if value.this in group_by: 191 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 192 else: 193 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 194 elif isinstance(parent_predicate, exp.In): 195 if value.this in group_by: 196 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 197 else: 198 parent_predicate = _replace( 199 parent_predicate, 200 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 201 ) 202 else: 203 if is_subquery_projection: 204 alias = exp.alias_(alias, select.parent.alias) 205 206 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 207 # by transforming all counts into 0 and using that as the coalesced value 208 if value.find(exp.Count): 209 210 def remove_aggs(node): 211 if isinstance(node, exp.Count): 212 return exp.Literal.number(0) 213 elif isinstance(node, exp.AggFunc): 214 return exp.null() 215 return node 216 217 alias = exp.Coalesce( 218 this=alias, 219 expressions=[value.this.transform(remove_aggs)], 220 ) 221 222 select.parent.replace(alias) 223 224 for key, column, predicate in keys: 225 predicate.replace(exp.true()) 226 nested = exp.column(key_aliases[key], table_alias) 227 228 if is_subquery_projection: 229 key.replace(nested) 230 continue 231 232 if key in group_by: 233 key.replace(nested) 234 elif isinstance(predicate, exp.EQ): 235 parent_predicate = _replace( 236 parent_predicate, 237 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 238 ) 239 else: 240 key.replace(exp.to_identifier("_x")) 241 parent_predicate = _replace( 242 parent_predicate, 243 f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', 244 ) 245 246 parent_select.join( 247 select.group_by(*group_by, copy=False), 248 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 249 join_type="LEFT", 250 join_alias=table_alias, 251 copy=False, 252 )