sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. 97 """ 98 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 99 taken = set(expression.named_selects) 100 for select in expression.selects: 101 if not select.alias_or_name: 102 alias = find_new_name(taken, "_c") 103 select.replace(exp.alias_(select, alias)) 104 taken.add(alias) 105 106 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 107 qualify_filters = expression.args["qualify"].pop().this 108 109 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 110 for expr in qualify_filters.find_all(select_candidates): 111 if isinstance(expr, exp.Window): 112 alias = find_new_name(expression.named_selects, "_w") 113 expression.select(exp.alias_(expr, alias), copy=False) 114 column = exp.column(alias) 115 116 if isinstance(expr.parent, exp.Qualify): 117 qualify_filters = column 118 else: 119 expr.replace(column) 120 elif expr.name not in expression.named_selects: 121 expression.select(expr.copy(), copy=False) 122 123 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 124 qualify_filters, copy=False 125 ) 126 127 return expression 128 129 130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 131 """ 132 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 133 other expressions. This transforms removes the precision from parameterized types in expressions. 134 """ 135 for node in expression.find_all(exp.DataType): 136 node.set( 137 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 138 ) 139 140 return expression 141 142 143def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 144 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 145 if isinstance(expression, exp.Select): 146 for join in expression.args.get("joins") or []: 147 unnest = join.this 148 149 if isinstance(unnest, exp.Unnest): 150 alias = unnest.args.get("alias") 151 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 152 153 expression.args["joins"].remove(join) 154 155 for e, column in zip(unnest.expressions, alias.columns if alias else []): 156 expression.append( 157 "laterals", 158 exp.Lateral( 159 this=udtf(this=e), 160 view=True, 161 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 162 ), 163 ) 164 165 return expression 166 167 168def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 169 """Convert explode/posexplode into unnest (used in hive -> presto).""" 170 171 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 172 if isinstance(expression, exp.Select): 173 from sqlglot.optimizer.scope import Scope 174 175 taken_select_names = set(expression.named_selects) 176 taken_source_names = {name for name, _ in Scope(expression).references} 177 178 def new_name(names: t.Set[str], name: str) -> str: 179 name = find_new_name(names, name) 180 names.add(name) 181 return name 182 183 arrays: t.List[exp.Condition] = [] 184 series_alias = new_name(taken_select_names, "pos") 185 series = exp.alias_( 186 exp.Unnest( 187 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 188 ), 189 new_name(taken_source_names, "_u"), 190 table=[series_alias], 191 ) 192 193 # we use list here because expression.selects is mutated inside the loop 194 for select in list(expression.selects): 195 explode = select.find(exp.Explode) 196 197 if explode: 198 pos_alias = "" 199 explode_alias = "" 200 201 if isinstance(select, exp.Alias): 202 explode_alias = select.alias 203 alias = select 204 elif isinstance(select, exp.Aliases): 205 pos_alias = select.aliases[0].name 206 explode_alias = select.aliases[1].name 207 alias = select.replace(exp.alias_(select.this, "", copy=False)) 208 else: 209 alias = select.replace(exp.alias_(select, "")) 210 explode = alias.find(exp.Explode) 211 assert explode 212 213 is_posexplode = isinstance(explode, exp.Posexplode) 214 explode_arg = explode.this 215 216 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 217 if isinstance(explode_arg, exp.Column): 218 taken_select_names.add(explode_arg.output_name) 219 220 unnest_source_alias = new_name(taken_source_names, "_u") 221 222 if not explode_alias: 223 explode_alias = new_name(taken_select_names, "col") 224 225 if is_posexplode: 226 pos_alias = new_name(taken_select_names, "pos") 227 228 if not pos_alias: 229 pos_alias = new_name(taken_select_names, "pos") 230 231 alias.set("alias", exp.to_identifier(explode_alias)) 232 233 column = exp.If( 234 this=exp.column(series_alias).eq(exp.column(pos_alias)), 235 true=exp.column(explode_alias), 236 ) 237 238 explode.replace(column) 239 240 if is_posexplode: 241 expressions = expression.expressions 242 expressions.insert( 243 expressions.index(alias) + 1, 244 exp.If( 245 this=exp.column(series_alias).eq(exp.column(pos_alias)), 246 true=exp.column(pos_alias), 247 ).as_(pos_alias), 248 ) 249 expression.set("expressions", expressions) 250 251 if not arrays: 252 if expression.args.get("from"): 253 expression.join(series, copy=False) 254 else: 255 expression.from_(series, copy=False) 256 257 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 258 arrays.append(size) 259 260 # trino doesn't support left join unnest with on conditions 261 # if it did, this would be much simpler 262 expression.join( 263 exp.alias_( 264 exp.Unnest( 265 expressions=[explode_arg.copy()], 266 offset=exp.to_identifier(pos_alias), 267 ), 268 unnest_source_alias, 269 table=[explode_alias], 270 ), 271 join_type="CROSS", 272 copy=False, 273 ) 274 275 if index_offset != 1: 276 size = size - 1 277 278 expression.where( 279 exp.column(series_alias) 280 .eq(exp.column(pos_alias)) 281 .or_( 282 (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) 283 ), 284 copy=False, 285 ) 286 287 if arrays: 288 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 289 290 if index_offset != 1: 291 end = end - (1 - index_offset) 292 series.expressions[0].set("end", end) 293 294 return expression 295 296 return _explode_to_unnest 297 298 299PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 300 301 302def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 303 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 304 if ( 305 isinstance(expression, PERCENTILES) 306 and not isinstance(expression.parent, exp.WithinGroup) 307 and expression.expression 308 ): 309 column = expression.this.pop() 310 expression.set("this", expression.expression.pop()) 311 order = exp.Order(expressions=[exp.Ordered(this=column)]) 312 expression = exp.WithinGroup(this=expression, expression=order) 313 314 return expression 315 316 317def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 318 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 319 if ( 320 isinstance(expression, exp.WithinGroup) 321 and isinstance(expression.this, PERCENTILES) 322 and isinstance(expression.expression, exp.Order) 323 ): 324 quantile = expression.this.this 325 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 326 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 327 328 return expression 329 330 331def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 332 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 333 if isinstance(expression, exp.With) and expression.recursive: 334 next_name = name_sequence("_c_") 335 336 for cte in expression.expressions: 337 if not cte.args["alias"].columns: 338 query = cte.this 339 if isinstance(query, exp.Union): 340 query = query.this 341 342 cte.args["alias"].set( 343 "columns", 344 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 345 ) 346 347 return expression 348 349 350def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 351 """Replace 'epoch' in casts by the equivalent date literal.""" 352 if ( 353 isinstance(expression, (exp.Cast, exp.TryCast)) 354 and expression.name.lower() == "epoch" 355 and expression.to.this in exp.DataType.TEMPORAL_TYPES 356 ): 357 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 358 359 return expression 360 361 362def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 363 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 364 if isinstance(expression, exp.Select): 365 for join in expression.args.get("joins") or []: 366 on = join.args.get("on") 367 if on and join.kind in ("SEMI", "ANTI"): 368 subquery = exp.select("1").from_(join.this).where(on) 369 exists = exp.Exists(this=subquery) 370 if join.kind == "ANTI": 371 exists = exists.not_(copy=False) 372 373 join.pop() 374 expression.where(exists, copy=False) 375 376 return expression 377 378 379def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 380 """ 381 Converts a query with a FULL OUTER join to a union of identical queries that 382 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 383 for queries that have a single FULL OUTER join. 384 """ 385 if isinstance(expression, exp.Select): 386 full_outer_joins = [ 387 (index, join) 388 for index, join in enumerate(expression.args.get("joins") or []) 389 if join.side == "FULL" and join.kind == "OUTER" 390 ] 391 392 if len(full_outer_joins) == 1: 393 expression_copy = expression.copy() 394 index, full_outer_join = full_outer_joins[0] 395 full_outer_join.set("side", "left") 396 expression_copy.args["joins"][index].set("side", "right") 397 398 return exp.union(expression, expression_copy, copy=False) 399 400 return expression 401 402 403def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 404 """ 405 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 406 defined at the top-level, so for example queries like: 407 408 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 409 410 are invalid in those dialects. This transformation can be used to ensure all CTEs are 411 moved to the top level so that the final SQL code is valid from a syntax standpoint. 412 413 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 414 """ 415 top_level_with = expression.args.get("with") 416 for node in expression.find_all(exp.With): 417 if node.parent is expression: 418 continue 419 420 inner_with = node.pop() 421 if not top_level_with: 422 top_level_with = inner_with 423 expression.set("with", top_level_with) 424 else: 425 if inner_with.recursive: 426 top_level_with.set("recursive", True) 427 428 top_level_with.expressions.extend(inner_with.expressions) 429 430 return expression 431 432 433def preprocess( 434 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 435) -> t.Callable[[Generator, exp.Expression], str]: 436 """ 437 Creates a new transform by chaining a sequence of transformations and converts the resulting 438 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 439 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 440 441 Args: 442 transforms: sequence of transform functions. These will be called in order. 443 444 Returns: 445 Function that can be used as a generator transform. 446 """ 447 448 def _to_sql(self, expression: exp.Expression) -> str: 449 expression_type = type(expression) 450 451 expression = transforms[0](expression) 452 for t in transforms[1:]: 453 expression = t(expression) 454 455 _sql_handler = getattr(self, expression.key + "_sql", None) 456 if _sql_handler: 457 return _sql_handler(expression) 458 459 transforms_handler = self.TRANSFORMS.get(type(expression)) 460 if transforms_handler: 461 if expression_type is type(expression): 462 if isinstance(expression, exp.Func): 463 return self.function_fallback_sql(expression) 464 465 # Ensures we don't enter an infinite loop. This can happen when the original expression 466 # has the same type as the final expression and there's no _sql method available for it, 467 # because then it'd re-enter _to_sql. 468 raise ValueError( 469 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 470 ) 471 472 return transforms_handler(self, expression) 473 474 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 475 476 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. 98 """ 99 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 100 taken = set(expression.named_selects) 101 for select in expression.selects: 102 if not select.alias_or_name: 103 alias = find_new_name(taken, "_c") 104 select.replace(exp.alias_(select, alias)) 105 taken.add(alias) 106 107 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 108 qualify_filters = expression.args["qualify"].pop().this 109 110 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 111 for expr in qualify_filters.find_all(select_candidates): 112 if isinstance(expr, exp.Window): 113 alias = find_new_name(expression.named_selects, "_w") 114 expression.select(exp.alias_(expr, alias), copy=False) 115 column = exp.column(alias) 116 117 if isinstance(expr.parent, exp.Qualify): 118 qualify_filters = column 119 else: 120 expr.replace(column) 121 elif expr.name not in expression.named_selects: 122 expression.select(expr.copy(), copy=False) 123 124 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 125 qualify_filters, copy=False 126 ) 127 128 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 132 """ 133 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 134 other expressions. This transforms removes the precision from parameterized types in expressions. 135 """ 136 for node in expression.find_all(exp.DataType): 137 node.set( 138 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 139 ) 140 141 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
144def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 145 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 146 if isinstance(expression, exp.Select): 147 for join in expression.args.get("joins") or []: 148 unnest = join.this 149 150 if isinstance(unnest, exp.Unnest): 151 alias = unnest.args.get("alias") 152 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 153 154 expression.args["joins"].remove(join) 155 156 for e, column in zip(unnest.expressions, alias.columns if alias else []): 157 expression.append( 158 "laterals", 159 exp.Lateral( 160 this=udtf(this=e), 161 view=True, 162 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 163 ), 164 ) 165 166 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
169def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 170 """Convert explode/posexplode into unnest (used in hive -> presto).""" 171 172 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 173 if isinstance(expression, exp.Select): 174 from sqlglot.optimizer.scope import Scope 175 176 taken_select_names = set(expression.named_selects) 177 taken_source_names = {name for name, _ in Scope(expression).references} 178 179 def new_name(names: t.Set[str], name: str) -> str: 180 name = find_new_name(names, name) 181 names.add(name) 182 return name 183 184 arrays: t.List[exp.Condition] = [] 185 series_alias = new_name(taken_select_names, "pos") 186 series = exp.alias_( 187 exp.Unnest( 188 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 189 ), 190 new_name(taken_source_names, "_u"), 191 table=[series_alias], 192 ) 193 194 # we use list here because expression.selects is mutated inside the loop 195 for select in list(expression.selects): 196 explode = select.find(exp.Explode) 197 198 if explode: 199 pos_alias = "" 200 explode_alias = "" 201 202 if isinstance(select, exp.Alias): 203 explode_alias = select.alias 204 alias = select 205 elif isinstance(select, exp.Aliases): 206 pos_alias = select.aliases[0].name 207 explode_alias = select.aliases[1].name 208 alias = select.replace(exp.alias_(select.this, "", copy=False)) 209 else: 210 alias = select.replace(exp.alias_(select, "")) 211 explode = alias.find(exp.Explode) 212 assert explode 213 214 is_posexplode = isinstance(explode, exp.Posexplode) 215 explode_arg = explode.this 216 217 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 218 if isinstance(explode_arg, exp.Column): 219 taken_select_names.add(explode_arg.output_name) 220 221 unnest_source_alias = new_name(taken_source_names, "_u") 222 223 if not explode_alias: 224 explode_alias = new_name(taken_select_names, "col") 225 226 if is_posexplode: 227 pos_alias = new_name(taken_select_names, "pos") 228 229 if not pos_alias: 230 pos_alias = new_name(taken_select_names, "pos") 231 232 alias.set("alias", exp.to_identifier(explode_alias)) 233 234 column = exp.If( 235 this=exp.column(series_alias).eq(exp.column(pos_alias)), 236 true=exp.column(explode_alias), 237 ) 238 239 explode.replace(column) 240 241 if is_posexplode: 242 expressions = expression.expressions 243 expressions.insert( 244 expressions.index(alias) + 1, 245 exp.If( 246 this=exp.column(series_alias).eq(exp.column(pos_alias)), 247 true=exp.column(pos_alias), 248 ).as_(pos_alias), 249 ) 250 expression.set("expressions", expressions) 251 252 if not arrays: 253 if expression.args.get("from"): 254 expression.join(series, copy=False) 255 else: 256 expression.from_(series, copy=False) 257 258 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 259 arrays.append(size) 260 261 # trino doesn't support left join unnest with on conditions 262 # if it did, this would be much simpler 263 expression.join( 264 exp.alias_( 265 exp.Unnest( 266 expressions=[explode_arg.copy()], 267 offset=exp.to_identifier(pos_alias), 268 ), 269 unnest_source_alias, 270 table=[explode_alias], 271 ), 272 join_type="CROSS", 273 copy=False, 274 ) 275 276 if index_offset != 1: 277 size = size - 1 278 279 expression.where( 280 exp.column(series_alias) 281 .eq(exp.column(pos_alias)) 282 .or_( 283 (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) 284 ), 285 copy=False, 286 ) 287 288 if arrays: 289 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 290 291 if index_offset != 1: 292 end = end - (1 - index_offset) 293 series.expressions[0].set("end", end) 294 295 return expression 296 297 return _explode_to_unnest
Convert explode/posexplode into unnest (used in hive -> presto).
303def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 304 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 305 if ( 306 isinstance(expression, PERCENTILES) 307 and not isinstance(expression.parent, exp.WithinGroup) 308 and expression.expression 309 ): 310 column = expression.this.pop() 311 expression.set("this", expression.expression.pop()) 312 order = exp.Order(expressions=[exp.Ordered(this=column)]) 313 expression = exp.WithinGroup(this=expression, expression=order) 314 315 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
318def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 319 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 320 if ( 321 isinstance(expression, exp.WithinGroup) 322 and isinstance(expression.this, PERCENTILES) 323 and isinstance(expression.expression, exp.Order) 324 ): 325 quantile = expression.this.this 326 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 327 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 328 329 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
332def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 333 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 334 if isinstance(expression, exp.With) and expression.recursive: 335 next_name = name_sequence("_c_") 336 337 for cte in expression.expressions: 338 if not cte.args["alias"].columns: 339 query = cte.this 340 if isinstance(query, exp.Union): 341 query = query.this 342 343 cte.args["alias"].set( 344 "columns", 345 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 346 ) 347 348 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
351def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 352 """Replace 'epoch' in casts by the equivalent date literal.""" 353 if ( 354 isinstance(expression, (exp.Cast, exp.TryCast)) 355 and expression.name.lower() == "epoch" 356 and expression.to.this in exp.DataType.TEMPORAL_TYPES 357 ): 358 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 359 360 return expression
Replace 'epoch' in casts by the equivalent date literal.
363def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 364 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 365 if isinstance(expression, exp.Select): 366 for join in expression.args.get("joins") or []: 367 on = join.args.get("on") 368 if on and join.kind in ("SEMI", "ANTI"): 369 subquery = exp.select("1").from_(join.this).where(on) 370 exists = exp.Exists(this=subquery) 371 if join.kind == "ANTI": 372 exists = exists.not_(copy=False) 373 374 join.pop() 375 expression.where(exists, copy=False) 376 377 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
380def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 381 """ 382 Converts a query with a FULL OUTER join to a union of identical queries that 383 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 384 for queries that have a single FULL OUTER join. 385 """ 386 if isinstance(expression, exp.Select): 387 full_outer_joins = [ 388 (index, join) 389 for index, join in enumerate(expression.args.get("joins") or []) 390 if join.side == "FULL" and join.kind == "OUTER" 391 ] 392 393 if len(full_outer_joins) == 1: 394 expression_copy = expression.copy() 395 index, full_outer_join = full_outer_joins[0] 396 full_outer_join.set("side", "left") 397 expression_copy.args["joins"][index].set("side", "right") 398 399 return exp.union(expression, expression_copy, copy=False) 400 401 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
404def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 405 """ 406 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 407 defined at the top-level, so for example queries like: 408 409 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 410 411 are invalid in those dialects. This transformation can be used to ensure all CTEs are 412 moved to the top level so that the final SQL code is valid from a syntax standpoint. 413 414 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 415 """ 416 top_level_with = expression.args.get("with") 417 for node in expression.find_all(exp.With): 418 if node.parent is expression: 419 continue 420 421 inner_with = node.pop() 422 if not top_level_with: 423 top_level_with = inner_with 424 expression.set("with", top_level_with) 425 else: 426 if inner_with.recursive: 427 top_level_with.set("recursive", True) 428 429 top_level_with.expressions.extend(inner_with.expressions) 430 431 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
434def preprocess( 435 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 436) -> t.Callable[[Generator, exp.Expression], str]: 437 """ 438 Creates a new transform by chaining a sequence of transformations and converts the resulting 439 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 440 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 441 442 Args: 443 transforms: sequence of transform functions. These will be called in order. 444 445 Returns: 446 Function that can be used as a generator transform. 447 """ 448 449 def _to_sql(self, expression: exp.Expression) -> str: 450 expression_type = type(expression) 451 452 expression = transforms[0](expression) 453 for t in transforms[1:]: 454 expression = t(expression) 455 456 _sql_handler = getattr(self, expression.key + "_sql", None) 457 if _sql_handler: 458 return _sql_handler(expression) 459 460 transforms_handler = self.TRANSFORMS.get(type(expression)) 461 if transforms_handler: 462 if expression_type is type(expression): 463 if isinstance(expression, exp.Func): 464 return self.function_fallback_sql(expression) 465 466 # Ensures we don't enter an infinite loop. This can happen when the original expression 467 # has the same type as the final expression and there's no _sql method available for it, 468 # because then it'd re-enter _to_sql. 469 raise ValueError( 470 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 471 ) 472 473 return transforms_handler(self, expression) 474 475 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 476 477 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.