sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot._typing import E 12 from sqlglot.generator import Generator 13 14 15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql 62 63 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression 119 120 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression 132 133 134def unalias_group(expression: exp.Expression) -> exp.Expression: 135 """ 136 Replace references to select aliases in GROUP BY clauses. 137 138 Example: 139 >>> import sqlglot 140 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 141 'SELECT a AS b FROM x GROUP BY 1' 142 143 Args: 144 expression: the expression that will be transformed. 145 146 Returns: 147 The transformed expression. 148 """ 149 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 150 aliased_selects = { 151 e.alias: i 152 for i, e in enumerate(expression.parent.expressions, start=1) 153 if isinstance(e, exp.Alias) 154 } 155 156 for group_by in expression.expressions: 157 if ( 158 isinstance(group_by, exp.Column) 159 and not group_by.table 160 and group_by.name in aliased_selects 161 ): 162 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 163 164 return expression 165 166 167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 168 """ 169 Convert SELECT DISTINCT ON statements to a subquery with a window function. 170 171 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 172 173 Args: 174 expression: the expression that will be transformed. 175 176 Returns: 177 The transformed expression. 178 """ 179 if ( 180 isinstance(expression, exp.Select) 181 and expression.args.get("distinct") 182 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 183 ): 184 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 185 186 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 187 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 188 189 order = expression.args.get("order") 190 if order: 191 window.set("order", order.pop()) 192 else: 193 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 194 195 window = exp.alias_(window, row_number_window_alias) 196 expression.select(window, copy=False) 197 198 # We add aliases to the projections so that we can safely reference them in the outer query 199 new_selects = [] 200 taken_names = {row_number_window_alias} 201 for select in expression.selects[:-1]: 202 if select.is_star: 203 new_selects = [exp.Star()] 204 break 205 206 if not isinstance(select, exp.Alias): 207 alias = find_new_name(taken_names, select.output_name or "_col") 208 select = select.replace(exp.alias_(select, alias)) 209 210 taken_names.add(select.output_name) 211 new_selects.append(select.args["alias"]) 212 213 return ( 214 exp.select(*new_selects, copy=False) 215 .from_(expression.subquery("_t", copy=False), copy=False) 216 .where(exp.column(row_number_window_alias).eq(1), copy=False) 217 ) 218 219 return expression 220 221 222def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 223 """ 224 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 225 226 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 227 https://docs.snowflake.com/en/sql-reference/constructs/qualify 228 229 Some dialects don't support window functions in the WHERE clause, so we need to include them as 230 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 231 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 232 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 233 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 234 corresponding expression to avoid creating invalid column references. 235 """ 236 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 237 taken = set(expression.named_selects) 238 for select in expression.selects: 239 if not select.alias_or_name: 240 alias = find_new_name(taken, "_c") 241 select.replace(exp.alias_(select, alias)) 242 taken.add(alias) 243 244 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 245 alias_or_name = select.alias_or_name 246 identifier = select.args.get("alias") or select.this 247 if isinstance(identifier, exp.Identifier): 248 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 249 return alias_or_name 250 251 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 252 qualify_filters = expression.args["qualify"].pop().this 253 expression_by_alias = { 254 select.alias: select.this 255 for select in expression.selects 256 if isinstance(select, exp.Alias) 257 } 258 259 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 260 for select_candidate in qualify_filters.find_all(select_candidates): 261 if isinstance(select_candidate, exp.Window): 262 if expression_by_alias: 263 for column in select_candidate.find_all(exp.Column): 264 expr = expression_by_alias.get(column.name) 265 if expr: 266 column.replace(expr) 267 268 alias = find_new_name(expression.named_selects, "_w") 269 expression.select(exp.alias_(select_candidate, alias), copy=False) 270 column = exp.column(alias) 271 272 if isinstance(select_candidate.parent, exp.Qualify): 273 qualify_filters = column 274 else: 275 select_candidate.replace(column) 276 elif select_candidate.name not in expression.named_selects: 277 expression.select(select_candidate.copy(), copy=False) 278 279 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 280 qualify_filters, copy=False 281 ) 282 283 return expression 284 285 286def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 287 """ 288 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 289 other expressions. This transforms removes the precision from parameterized types in expressions. 290 """ 291 for node in expression.find_all(exp.DataType): 292 node.set( 293 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 294 ) 295 296 return expression 297 298 299def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 300 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 301 from sqlglot.optimizer.scope import find_all_in_scope 302 303 if isinstance(expression, exp.Select): 304 unnest_aliases = { 305 unnest.alias 306 for unnest in find_all_in_scope(expression, exp.Unnest) 307 if isinstance(unnest.parent, (exp.From, exp.Join)) 308 } 309 if unnest_aliases: 310 for column in expression.find_all(exp.Column): 311 if column.table in unnest_aliases: 312 column.set("table", None) 313 elif column.db in unnest_aliases: 314 column.set("db", None) 315 316 return expression 317 318 319def unnest_to_explode( 320 expression: exp.Expression, 321 unnest_using_arrays_zip: bool = True, 322) -> exp.Expression: 323 """Convert cross join unnest into lateral view explode.""" 324 325 def _unnest_zip_exprs( 326 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 327 ) -> t.List[exp.Expression]: 328 if has_multi_expr: 329 if not unnest_using_arrays_zip: 330 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 331 332 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 333 zip_exprs: t.List[exp.Expression] = [ 334 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 335 ] 336 u.set("expressions", zip_exprs) 337 return zip_exprs 338 return unnest_exprs 339 340 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 341 if u.args.get("offset"): 342 return exp.Posexplode 343 return exp.Inline if has_multi_expr else exp.Explode 344 345 if isinstance(expression, exp.Select): 346 from_ = expression.args.get("from") 347 348 if from_ and isinstance(from_.this, exp.Unnest): 349 unnest = from_.this 350 alias = unnest.args.get("alias") 351 exprs = unnest.expressions 352 has_multi_expr = len(exprs) > 1 353 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 354 355 unnest.replace( 356 exp.Table( 357 this=_udtf_type(unnest, has_multi_expr)( 358 this=this, 359 expressions=expressions, 360 ), 361 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 362 ) 363 ) 364 365 joins = expression.args.get("joins") or [] 366 for join in list(joins): 367 join_expr = join.this 368 369 is_lateral = isinstance(join_expr, exp.Lateral) 370 371 unnest = join_expr.this if is_lateral else join_expr 372 373 if isinstance(unnest, exp.Unnest): 374 if is_lateral: 375 alias = join_expr.args.get("alias") 376 else: 377 alias = unnest.args.get("alias") 378 exprs = unnest.expressions 379 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 380 has_multi_expr = len(exprs) > 1 381 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 382 383 joins.remove(join) 384 385 alias_cols = alias.columns if alias else [] 386 387 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 388 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 389 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 390 391 if not has_multi_expr and len(alias_cols) not in (1, 2): 392 raise UnsupportedError( 393 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 394 ) 395 396 for e, column in zip(exprs, alias_cols): 397 expression.append( 398 "laterals", 399 exp.Lateral( 400 this=_udtf_type(unnest, has_multi_expr)(this=e), 401 view=True, 402 alias=exp.TableAlias( 403 this=alias.this, # type: ignore 404 columns=alias_cols, 405 ), 406 ), 407 ) 408 409 return expression 410 411 412def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 413 """Convert explode/posexplode into unnest.""" 414 415 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 416 if isinstance(expression, exp.Select): 417 from sqlglot.optimizer.scope import Scope 418 419 taken_select_names = set(expression.named_selects) 420 taken_source_names = {name for name, _ in Scope(expression).references} 421 422 def new_name(names: t.Set[str], name: str) -> str: 423 name = find_new_name(names, name) 424 names.add(name) 425 return name 426 427 arrays: t.List[exp.Condition] = [] 428 series_alias = new_name(taken_select_names, "pos") 429 series = exp.alias_( 430 exp.Unnest( 431 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 432 ), 433 new_name(taken_source_names, "_u"), 434 table=[series_alias], 435 ) 436 437 # we use list here because expression.selects is mutated inside the loop 438 for select in list(expression.selects): 439 explode = select.find(exp.Explode) 440 441 if explode: 442 pos_alias = "" 443 explode_alias = "" 444 445 if isinstance(select, exp.Alias): 446 explode_alias = select.args["alias"] 447 alias = select 448 elif isinstance(select, exp.Aliases): 449 pos_alias = select.aliases[0] 450 explode_alias = select.aliases[1] 451 alias = select.replace(exp.alias_(select.this, "", copy=False)) 452 else: 453 alias = select.replace(exp.alias_(select, "")) 454 explode = alias.find(exp.Explode) 455 assert explode 456 457 is_posexplode = isinstance(explode, exp.Posexplode) 458 explode_arg = explode.this 459 460 if isinstance(explode, exp.ExplodeOuter): 461 bracket = explode_arg[0] 462 bracket.set("safe", True) 463 bracket.set("offset", True) 464 explode_arg = exp.func( 465 "IF", 466 exp.func( 467 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 468 ).eq(0), 469 exp.array(bracket, copy=False), 470 explode_arg, 471 ) 472 473 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 474 if isinstance(explode_arg, exp.Column): 475 taken_select_names.add(explode_arg.output_name) 476 477 unnest_source_alias = new_name(taken_source_names, "_u") 478 479 if not explode_alias: 480 explode_alias = new_name(taken_select_names, "col") 481 482 if is_posexplode: 483 pos_alias = new_name(taken_select_names, "pos") 484 485 if not pos_alias: 486 pos_alias = new_name(taken_select_names, "pos") 487 488 alias.set("alias", exp.to_identifier(explode_alias)) 489 490 series_table_alias = series.args["alias"].this 491 column = exp.If( 492 this=exp.column(series_alias, table=series_table_alias).eq( 493 exp.column(pos_alias, table=unnest_source_alias) 494 ), 495 true=exp.column(explode_alias, table=unnest_source_alias), 496 ) 497 498 explode.replace(column) 499 500 if is_posexplode: 501 expressions = expression.expressions 502 expressions.insert( 503 expressions.index(alias) + 1, 504 exp.If( 505 this=exp.column(series_alias, table=series_table_alias).eq( 506 exp.column(pos_alias, table=unnest_source_alias) 507 ), 508 true=exp.column(pos_alias, table=unnest_source_alias), 509 ).as_(pos_alias), 510 ) 511 expression.set("expressions", expressions) 512 513 if not arrays: 514 if expression.args.get("from"): 515 expression.join(series, copy=False, join_type="CROSS") 516 else: 517 expression.from_(series, copy=False) 518 519 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 520 arrays.append(size) 521 522 # trino doesn't support left join unnest with on conditions 523 # if it did, this would be much simpler 524 expression.join( 525 exp.alias_( 526 exp.Unnest( 527 expressions=[explode_arg.copy()], 528 offset=exp.to_identifier(pos_alias), 529 ), 530 unnest_source_alias, 531 table=[explode_alias], 532 ), 533 join_type="CROSS", 534 copy=False, 535 ) 536 537 if index_offset != 1: 538 size = size - 1 539 540 expression.where( 541 exp.column(series_alias, table=series_table_alias) 542 .eq(exp.column(pos_alias, table=unnest_source_alias)) 543 .or_( 544 (exp.column(series_alias, table=series_table_alias) > size).and_( 545 exp.column(pos_alias, table=unnest_source_alias).eq(size) 546 ) 547 ), 548 copy=False, 549 ) 550 551 if arrays: 552 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 553 554 if index_offset != 1: 555 end = end - (1 - index_offset) 556 series.expressions[0].set("end", end) 557 558 return expression 559 560 return _explode_to_unnest 561 562 563def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 564 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 565 if ( 566 isinstance(expression, exp.PERCENTILES) 567 and not isinstance(expression.parent, exp.WithinGroup) 568 and expression.expression 569 ): 570 column = expression.this.pop() 571 expression.set("this", expression.expression.pop()) 572 order = exp.Order(expressions=[exp.Ordered(this=column)]) 573 expression = exp.WithinGroup(this=expression, expression=order) 574 575 return expression 576 577 578def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 579 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 580 if ( 581 isinstance(expression, exp.WithinGroup) 582 and isinstance(expression.this, exp.PERCENTILES) 583 and isinstance(expression.expression, exp.Order) 584 ): 585 quantile = expression.this.this 586 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 587 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 588 589 return expression 590 591 592def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 593 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 594 if isinstance(expression, exp.With) and expression.recursive: 595 next_name = name_sequence("_c_") 596 597 for cte in expression.expressions: 598 if not cte.args["alias"].columns: 599 query = cte.this 600 if isinstance(query, exp.SetOperation): 601 query = query.this 602 603 cte.args["alias"].set( 604 "columns", 605 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 606 ) 607 608 return expression 609 610 611def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 612 """Replace 'epoch' in casts by the equivalent date literal.""" 613 if ( 614 isinstance(expression, (exp.Cast, exp.TryCast)) 615 and expression.name.lower() == "epoch" 616 and expression.to.this in exp.DataType.TEMPORAL_TYPES 617 ): 618 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 619 620 return expression 621 622 623def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 624 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 625 if isinstance(expression, exp.Select): 626 for join in expression.args.get("joins") or []: 627 on = join.args.get("on") 628 if on and join.kind in ("SEMI", "ANTI"): 629 subquery = exp.select("1").from_(join.this).where(on) 630 exists = exp.Exists(this=subquery) 631 if join.kind == "ANTI": 632 exists = exists.not_(copy=False) 633 634 join.pop() 635 expression.where(exists, copy=False) 636 637 return expression 638 639 640def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 641 """ 642 Converts a query with a FULL OUTER join to a union of identical queries that 643 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 644 for queries that have a single FULL OUTER join. 645 """ 646 if isinstance(expression, exp.Select): 647 full_outer_joins = [ 648 (index, join) 649 for index, join in enumerate(expression.args.get("joins") or []) 650 if join.side == "FULL" 651 ] 652 653 if len(full_outer_joins) == 1: 654 expression_copy = expression.copy() 655 expression.set("limit", None) 656 index, full_outer_join = full_outer_joins[0] 657 658 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 659 join_conditions = full_outer_join.args.get("on") or exp.and_( 660 *[ 661 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 662 for col in full_outer_join.args.get("using") 663 ] 664 ) 665 666 full_outer_join.set("side", "left") 667 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 668 expression_copy.args["joins"][index].set("side", "right") 669 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 670 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 671 expression.args.pop("order", None) # remove order by from LEFT side 672 673 return exp.union(expression, expression_copy, copy=False, distinct=False) 674 675 return expression 676 677 678def move_ctes_to_top_level(expression: E) -> E: 679 """ 680 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 681 defined at the top-level, so for example queries like: 682 683 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 684 685 are invalid in those dialects. This transformation can be used to ensure all CTEs are 686 moved to the top level so that the final SQL code is valid from a syntax standpoint. 687 688 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 689 """ 690 top_level_with = expression.args.get("with") 691 for inner_with in expression.find_all(exp.With): 692 if inner_with.parent is expression: 693 continue 694 695 if not top_level_with: 696 top_level_with = inner_with.pop() 697 expression.set("with", top_level_with) 698 else: 699 if inner_with.recursive: 700 top_level_with.set("recursive", True) 701 702 parent_cte = inner_with.find_ancestor(exp.CTE) 703 inner_with.pop() 704 705 if parent_cte: 706 i = top_level_with.expressions.index(parent_cte) 707 top_level_with.expressions[i:i] = inner_with.expressions 708 top_level_with.set("expressions", top_level_with.expressions) 709 else: 710 top_level_with.set( 711 "expressions", top_level_with.expressions + inner_with.expressions 712 ) 713 714 return expression 715 716 717def ensure_bools(expression: exp.Expression) -> exp.Expression: 718 """Converts numeric values used in conditions into explicit boolean expressions.""" 719 from sqlglot.optimizer.canonicalize import ensure_bools 720 721 def _ensure_bool(node: exp.Expression) -> None: 722 if ( 723 node.is_number 724 or ( 725 not isinstance(node, exp.SubqueryPredicate) 726 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 727 ) 728 or (isinstance(node, exp.Column) and not node.type) 729 ): 730 node.replace(node.neq(0)) 731 732 for node in expression.walk(): 733 ensure_bools(node, _ensure_bool) 734 735 return expression 736 737 738def unqualify_columns(expression: exp.Expression) -> exp.Expression: 739 for column in expression.find_all(exp.Column): 740 # We only wanna pop off the table, db, catalog args 741 for part in column.parts[:-1]: 742 part.pop() 743 744 return expression 745 746 747def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 748 assert isinstance(expression, exp.Create) 749 for constraint in expression.find_all(exp.UniqueColumnConstraint): 750 if constraint.parent: 751 constraint.parent.pop() 752 753 return expression 754 755 756def ctas_with_tmp_tables_to_create_tmp_view( 757 expression: exp.Expression, 758 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 759) -> exp.Expression: 760 assert isinstance(expression, exp.Create) 761 properties = expression.args.get("properties") 762 temporary = any( 763 isinstance(prop, exp.TemporaryProperty) 764 for prop in (properties.expressions if properties else []) 765 ) 766 767 # CTAS with temp tables map to CREATE TEMPORARY VIEW 768 if expression.kind == "TABLE" and temporary: 769 if expression.expression: 770 return exp.Create( 771 kind="TEMPORARY VIEW", 772 this=expression.this, 773 expression=expression.expression, 774 ) 775 return tmp_storage_provider(expression) 776 777 return expression 778 779 780def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 781 """ 782 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 783 PARTITIONED BY value is an array of column names, they are transformed into a schema. 784 The corresponding columns are removed from the create statement. 785 """ 786 assert isinstance(expression, exp.Create) 787 has_schema = isinstance(expression.this, exp.Schema) 788 is_partitionable = expression.kind in {"TABLE", "VIEW"} 789 790 if has_schema and is_partitionable: 791 prop = expression.find(exp.PartitionedByProperty) 792 if prop and prop.this and not isinstance(prop.this, exp.Schema): 793 schema = expression.this 794 columns = {v.name.upper() for v in prop.this.expressions} 795 partitions = [col for col in schema.expressions if col.name.upper() in columns] 796 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 797 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 798 expression.set("this", schema) 799 800 return expression 801 802 803def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 804 """ 805 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 806 807 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 808 """ 809 assert isinstance(expression, exp.Create) 810 prop = expression.find(exp.PartitionedByProperty) 811 if ( 812 prop 813 and prop.this 814 and isinstance(prop.this, exp.Schema) 815 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 816 ): 817 prop_this = exp.Tuple( 818 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 819 ) 820 schema = expression.this 821 for e in prop.this.expressions: 822 schema.append("expressions", e) 823 prop.set("this", prop_this) 824 825 return expression 826 827 828def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 829 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 830 if isinstance(expression, exp.Struct): 831 expression.set( 832 "expressions", 833 [ 834 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 835 for e in expression.expressions 836 ], 837 ) 838 839 return expression 840 841 842def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 843 """ 844 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 845 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 846 847 For example, 848 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 849 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 850 851 Args: 852 expression: The AST to remove join marks from. 853 854 Returns: 855 The AST with join marks removed. 856 """ 857 from sqlglot.optimizer.scope import traverse_scope 858 859 for scope in traverse_scope(expression): 860 query = scope.expression 861 862 where = query.args.get("where") 863 joins = query.args.get("joins") 864 865 if not where or not joins: 866 continue 867 868 query_from = query.args["from"] 869 870 # These keep track of the joins to be replaced 871 new_joins: t.Dict[str, exp.Join] = {} 872 old_joins = {join.alias_or_name: join for join in joins} 873 874 for column in scope.columns: 875 if not column.args.get("join_mark"): 876 continue 877 878 predicate = column.find_ancestor(exp.Predicate, exp.Select) 879 assert isinstance( 880 predicate, exp.Binary 881 ), "Columns can only be marked with (+) when involved in a binary operation" 882 883 predicate_parent = predicate.parent 884 join_predicate = predicate.pop() 885 886 left_columns = [ 887 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 888 ] 889 right_columns = [ 890 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 891 ] 892 893 assert not ( 894 left_columns and right_columns 895 ), "The (+) marker cannot appear in both sides of a binary predicate" 896 897 marked_column_tables = set() 898 for col in left_columns or right_columns: 899 table = col.table 900 assert table, f"Column {col} needs to be qualified with a table" 901 902 col.set("join_mark", False) 903 marked_column_tables.add(table) 904 905 assert ( 906 len(marked_column_tables) == 1 907 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 908 909 join_this = old_joins.get(col.table, query_from).this 910 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 911 912 # Upsert new_join into new_joins dictionary 913 new_join_alias_or_name = new_join.alias_or_name 914 existing_join = new_joins.get(new_join_alias_or_name) 915 if existing_join: 916 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 917 else: 918 new_joins[new_join_alias_or_name] = new_join 919 920 # If the parent of the target predicate is a binary node, then it now has only one child 921 if isinstance(predicate_parent, exp.Binary): 922 if predicate_parent.left is None: 923 predicate_parent.replace(predicate_parent.right) 924 else: 925 predicate_parent.replace(predicate_parent.left) 926 927 if query_from.alias_or_name in new_joins: 928 only_old_joins = old_joins.keys() - new_joins.keys() 929 assert ( 930 len(only_old_joins) >= 1 931 ), "Cannot determine which table to use in the new FROM clause" 932 933 new_from_name = list(only_old_joins)[0] 934 query.set("from", exp.From(this=old_joins[new_from_name].this)) 935 936 query.set("joins", list(new_joins.values())) 937 938 if not where.this: 939 where.pop() 940 941 return expression 942 943 944def any_to_exists(expression: exp.Expression) -> exp.Expression: 945 """ 946 Transform ANY operator to Spark's EXISTS 947 948 For example, 949 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 950 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 951 952 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 953 transformation 954 """ 955 if isinstance(expression, exp.Select): 956 for any in expression.find_all(exp.Any): 957 this = any.this 958 if isinstance(this, exp.Query): 959 continue 960 961 binop = any.parent 962 if isinstance(binop, exp.Binary): 963 lambda_arg = exp.to_identifier("x") 964 any.replace(lambda_arg) 965 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 966 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 967 968 return expression
16def preprocess( 17 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 18) -> t.Callable[[Generator, exp.Expression], str]: 19 """ 20 Creates a new transform by chaining a sequence of transformations and converts the resulting 21 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 22 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 23 24 Args: 25 transforms: sequence of transform functions. These will be called in order. 26 27 Returns: 28 Function that can be used as a generator transform. 29 """ 30 31 def _to_sql(self, expression: exp.Expression) -> str: 32 expression_type = type(expression) 33 34 try: 35 expression = transforms[0](expression) 36 for transform in transforms[1:]: 37 expression = transform(expression) 38 except UnsupportedError as unsupported_error: 39 self.unsupported(str(unsupported_error)) 40 41 _sql_handler = getattr(self, expression.key + "_sql", None) 42 if _sql_handler: 43 return _sql_handler(expression) 44 45 transforms_handler = self.TRANSFORMS.get(type(expression)) 46 if transforms_handler: 47 if expression_type is type(expression): 48 if isinstance(expression, exp.Func): 49 return self.function_fallback_sql(expression) 50 51 # Ensures we don't enter an infinite loop. This can happen when the original expression 52 # has the same type as the final expression and there's no _sql method available for it, 53 # because then it'd re-enter _to_sql. 54 raise ValueError( 55 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 56 ) 57 58 return transforms_handler(self, expression) 59 60 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 61 62 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 66 if isinstance(expression, exp.Select): 67 count = 0 68 recursive_ctes = [] 69 70 for unnest in expression.find_all(exp.Unnest): 71 if ( 72 not isinstance(unnest.parent, (exp.From, exp.Join)) 73 or len(unnest.expressions) != 1 74 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 75 ): 76 continue 77 78 generate_date_array = unnest.expressions[0] 79 start = generate_date_array.args.get("start") 80 end = generate_date_array.args.get("end") 81 step = generate_date_array.args.get("step") 82 83 if not start or not end or not isinstance(step, exp.Interval): 84 continue 85 86 alias = unnest.args.get("alias") 87 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 88 89 start = exp.cast(start, "date") 90 date_add = exp.func( 91 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 92 ) 93 cast_date_add = exp.cast(date_add, "date") 94 95 cte_name = "_generated_dates" + (f"_{count}" if count else "") 96 97 base_query = exp.select(start.as_(column_name)) 98 recursive_query = ( 99 exp.select(cast_date_add) 100 .from_(cte_name) 101 .where(cast_date_add <= exp.cast(end, "date")) 102 ) 103 cte_query = base_query.union(recursive_query, distinct=False) 104 105 generate_dates_query = exp.select(column_name).from_(cte_name) 106 unnest.replace(generate_dates_query.subquery(cte_name)) 107 108 recursive_ctes.append( 109 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 110 ) 111 count += 1 112 113 if recursive_ctes: 114 with_expression = expression.args.get("with") or exp.With() 115 with_expression.set("recursive", True) 116 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 117 expression.set("with", with_expression) 118 119 return expression
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 123 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 124 this = expression.this 125 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 126 unnest = exp.Unnest(expressions=[this]) 127 if expression.alias: 128 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 129 130 return unnest 131 132 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
135def unalias_group(expression: exp.Expression) -> exp.Expression: 136 """ 137 Replace references to select aliases in GROUP BY clauses. 138 139 Example: 140 >>> import sqlglot 141 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 142 'SELECT a AS b FROM x GROUP BY 1' 143 144 Args: 145 expression: the expression that will be transformed. 146 147 Returns: 148 The transformed expression. 149 """ 150 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 151 aliased_selects = { 152 e.alias: i 153 for i, e in enumerate(expression.parent.expressions, start=1) 154 if isinstance(e, exp.Alias) 155 } 156 157 for group_by in expression.expressions: 158 if ( 159 isinstance(group_by, exp.Column) 160 and not group_by.table 161 and group_by.name in aliased_selects 162 ): 163 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 164 165 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 169 """ 170 Convert SELECT DISTINCT ON statements to a subquery with a window function. 171 172 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 173 174 Args: 175 expression: the expression that will be transformed. 176 177 Returns: 178 The transformed expression. 179 """ 180 if ( 181 isinstance(expression, exp.Select) 182 and expression.args.get("distinct") 183 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 184 ): 185 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 186 187 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 188 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 189 190 order = expression.args.get("order") 191 if order: 192 window.set("order", order.pop()) 193 else: 194 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 195 196 window = exp.alias_(window, row_number_window_alias) 197 expression.select(window, copy=False) 198 199 # We add aliases to the projections so that we can safely reference them in the outer query 200 new_selects = [] 201 taken_names = {row_number_window_alias} 202 for select in expression.selects[:-1]: 203 if select.is_star: 204 new_selects = [exp.Star()] 205 break 206 207 if not isinstance(select, exp.Alias): 208 alias = find_new_name(taken_names, select.output_name or "_col") 209 select = select.replace(exp.alias_(select, alias)) 210 211 taken_names.add(select.output_name) 212 new_selects.append(select.args["alias"]) 213 214 return ( 215 exp.select(*new_selects, copy=False) 216 .from_(expression.subquery("_t", copy=False), copy=False) 217 .where(exp.column(row_number_window_alias).eq(1), copy=False) 218 ) 219 220 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
223def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 224 """ 225 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 226 227 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 228 https://docs.snowflake.com/en/sql-reference/constructs/qualify 229 230 Some dialects don't support window functions in the WHERE clause, so we need to include them as 231 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 232 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 233 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 234 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 235 corresponding expression to avoid creating invalid column references. 236 """ 237 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 238 taken = set(expression.named_selects) 239 for select in expression.selects: 240 if not select.alias_or_name: 241 alias = find_new_name(taken, "_c") 242 select.replace(exp.alias_(select, alias)) 243 taken.add(alias) 244 245 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 246 alias_or_name = select.alias_or_name 247 identifier = select.args.get("alias") or select.this 248 if isinstance(identifier, exp.Identifier): 249 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 250 return alias_or_name 251 252 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 253 qualify_filters = expression.args["qualify"].pop().this 254 expression_by_alias = { 255 select.alias: select.this 256 for select in expression.selects 257 if isinstance(select, exp.Alias) 258 } 259 260 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 261 for select_candidate in qualify_filters.find_all(select_candidates): 262 if isinstance(select_candidate, exp.Window): 263 if expression_by_alias: 264 for column in select_candidate.find_all(exp.Column): 265 expr = expression_by_alias.get(column.name) 266 if expr: 267 column.replace(expr) 268 269 alias = find_new_name(expression.named_selects, "_w") 270 expression.select(exp.alias_(select_candidate, alias), copy=False) 271 column = exp.column(alias) 272 273 if isinstance(select_candidate.parent, exp.Qualify): 274 qualify_filters = column 275 else: 276 select_candidate.replace(column) 277 elif select_candidate.name not in expression.named_selects: 278 expression.select(select_candidate.copy(), copy=False) 279 280 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 281 qualify_filters, copy=False 282 ) 283 284 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
287def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 288 """ 289 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 290 other expressions. This transforms removes the precision from parameterized types in expressions. 291 """ 292 for node in expression.find_all(exp.DataType): 293 node.set( 294 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 295 ) 296 297 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
300def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 301 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 302 from sqlglot.optimizer.scope import find_all_in_scope 303 304 if isinstance(expression, exp.Select): 305 unnest_aliases = { 306 unnest.alias 307 for unnest in find_all_in_scope(expression, exp.Unnest) 308 if isinstance(unnest.parent, (exp.From, exp.Join)) 309 } 310 if unnest_aliases: 311 for column in expression.find_all(exp.Column): 312 if column.table in unnest_aliases: 313 column.set("table", None) 314 elif column.db in unnest_aliases: 315 column.set("db", None) 316 317 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
320def unnest_to_explode( 321 expression: exp.Expression, 322 unnest_using_arrays_zip: bool = True, 323) -> exp.Expression: 324 """Convert cross join unnest into lateral view explode.""" 325 326 def _unnest_zip_exprs( 327 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 328 ) -> t.List[exp.Expression]: 329 if has_multi_expr: 330 if not unnest_using_arrays_zip: 331 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 332 333 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 334 zip_exprs: t.List[exp.Expression] = [ 335 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 336 ] 337 u.set("expressions", zip_exprs) 338 return zip_exprs 339 return unnest_exprs 340 341 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 342 if u.args.get("offset"): 343 return exp.Posexplode 344 return exp.Inline if has_multi_expr else exp.Explode 345 346 if isinstance(expression, exp.Select): 347 from_ = expression.args.get("from") 348 349 if from_ and isinstance(from_.this, exp.Unnest): 350 unnest = from_.this 351 alias = unnest.args.get("alias") 352 exprs = unnest.expressions 353 has_multi_expr = len(exprs) > 1 354 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 355 356 unnest.replace( 357 exp.Table( 358 this=_udtf_type(unnest, has_multi_expr)( 359 this=this, 360 expressions=expressions, 361 ), 362 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 363 ) 364 ) 365 366 joins = expression.args.get("joins") or [] 367 for join in list(joins): 368 join_expr = join.this 369 370 is_lateral = isinstance(join_expr, exp.Lateral) 371 372 unnest = join_expr.this if is_lateral else join_expr 373 374 if isinstance(unnest, exp.Unnest): 375 if is_lateral: 376 alias = join_expr.args.get("alias") 377 else: 378 alias = unnest.args.get("alias") 379 exprs = unnest.expressions 380 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 381 has_multi_expr = len(exprs) > 1 382 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 383 384 joins.remove(join) 385 386 alias_cols = alias.columns if alias else [] 387 388 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 389 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 390 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 391 392 if not has_multi_expr and len(alias_cols) not in (1, 2): 393 raise UnsupportedError( 394 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 395 ) 396 397 for e, column in zip(exprs, alias_cols): 398 expression.append( 399 "laterals", 400 exp.Lateral( 401 this=_udtf_type(unnest, has_multi_expr)(this=e), 402 view=True, 403 alias=exp.TableAlias( 404 this=alias.this, # type: ignore 405 columns=alias_cols, 406 ), 407 ), 408 ) 409 410 return expression
Convert cross join unnest into lateral view explode.
413def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 414 """Convert explode/posexplode into unnest.""" 415 416 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 417 if isinstance(expression, exp.Select): 418 from sqlglot.optimizer.scope import Scope 419 420 taken_select_names = set(expression.named_selects) 421 taken_source_names = {name for name, _ in Scope(expression).references} 422 423 def new_name(names: t.Set[str], name: str) -> str: 424 name = find_new_name(names, name) 425 names.add(name) 426 return name 427 428 arrays: t.List[exp.Condition] = [] 429 series_alias = new_name(taken_select_names, "pos") 430 series = exp.alias_( 431 exp.Unnest( 432 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 433 ), 434 new_name(taken_source_names, "_u"), 435 table=[series_alias], 436 ) 437 438 # we use list here because expression.selects is mutated inside the loop 439 for select in list(expression.selects): 440 explode = select.find(exp.Explode) 441 442 if explode: 443 pos_alias = "" 444 explode_alias = "" 445 446 if isinstance(select, exp.Alias): 447 explode_alias = select.args["alias"] 448 alias = select 449 elif isinstance(select, exp.Aliases): 450 pos_alias = select.aliases[0] 451 explode_alias = select.aliases[1] 452 alias = select.replace(exp.alias_(select.this, "", copy=False)) 453 else: 454 alias = select.replace(exp.alias_(select, "")) 455 explode = alias.find(exp.Explode) 456 assert explode 457 458 is_posexplode = isinstance(explode, exp.Posexplode) 459 explode_arg = explode.this 460 461 if isinstance(explode, exp.ExplodeOuter): 462 bracket = explode_arg[0] 463 bracket.set("safe", True) 464 bracket.set("offset", True) 465 explode_arg = exp.func( 466 "IF", 467 exp.func( 468 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 469 ).eq(0), 470 exp.array(bracket, copy=False), 471 explode_arg, 472 ) 473 474 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 475 if isinstance(explode_arg, exp.Column): 476 taken_select_names.add(explode_arg.output_name) 477 478 unnest_source_alias = new_name(taken_source_names, "_u") 479 480 if not explode_alias: 481 explode_alias = new_name(taken_select_names, "col") 482 483 if is_posexplode: 484 pos_alias = new_name(taken_select_names, "pos") 485 486 if not pos_alias: 487 pos_alias = new_name(taken_select_names, "pos") 488 489 alias.set("alias", exp.to_identifier(explode_alias)) 490 491 series_table_alias = series.args["alias"].this 492 column = exp.If( 493 this=exp.column(series_alias, table=series_table_alias).eq( 494 exp.column(pos_alias, table=unnest_source_alias) 495 ), 496 true=exp.column(explode_alias, table=unnest_source_alias), 497 ) 498 499 explode.replace(column) 500 501 if is_posexplode: 502 expressions = expression.expressions 503 expressions.insert( 504 expressions.index(alias) + 1, 505 exp.If( 506 this=exp.column(series_alias, table=series_table_alias).eq( 507 exp.column(pos_alias, table=unnest_source_alias) 508 ), 509 true=exp.column(pos_alias, table=unnest_source_alias), 510 ).as_(pos_alias), 511 ) 512 expression.set("expressions", expressions) 513 514 if not arrays: 515 if expression.args.get("from"): 516 expression.join(series, copy=False, join_type="CROSS") 517 else: 518 expression.from_(series, copy=False) 519 520 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 521 arrays.append(size) 522 523 # trino doesn't support left join unnest with on conditions 524 # if it did, this would be much simpler 525 expression.join( 526 exp.alias_( 527 exp.Unnest( 528 expressions=[explode_arg.copy()], 529 offset=exp.to_identifier(pos_alias), 530 ), 531 unnest_source_alias, 532 table=[explode_alias], 533 ), 534 join_type="CROSS", 535 copy=False, 536 ) 537 538 if index_offset != 1: 539 size = size - 1 540 541 expression.where( 542 exp.column(series_alias, table=series_table_alias) 543 .eq(exp.column(pos_alias, table=unnest_source_alias)) 544 .or_( 545 (exp.column(series_alias, table=series_table_alias) > size).and_( 546 exp.column(pos_alias, table=unnest_source_alias).eq(size) 547 ) 548 ), 549 copy=False, 550 ) 551 552 if arrays: 553 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 554 555 if index_offset != 1: 556 end = end - (1 - index_offset) 557 series.expressions[0].set("end", end) 558 559 return expression 560 561 return _explode_to_unnest
Convert explode/posexplode into unnest.
564def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 565 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 566 if ( 567 isinstance(expression, exp.PERCENTILES) 568 and not isinstance(expression.parent, exp.WithinGroup) 569 and expression.expression 570 ): 571 column = expression.this.pop() 572 expression.set("this", expression.expression.pop()) 573 order = exp.Order(expressions=[exp.Ordered(this=column)]) 574 expression = exp.WithinGroup(this=expression, expression=order) 575 576 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
579def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 580 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 581 if ( 582 isinstance(expression, exp.WithinGroup) 583 and isinstance(expression.this, exp.PERCENTILES) 584 and isinstance(expression.expression, exp.Order) 585 ): 586 quantile = expression.this.this 587 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 588 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 589 590 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
593def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 594 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 595 if isinstance(expression, exp.With) and expression.recursive: 596 next_name = name_sequence("_c_") 597 598 for cte in expression.expressions: 599 if not cte.args["alias"].columns: 600 query = cte.this 601 if isinstance(query, exp.SetOperation): 602 query = query.this 603 604 cte.args["alias"].set( 605 "columns", 606 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 607 ) 608 609 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
612def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 613 """Replace 'epoch' in casts by the equivalent date literal.""" 614 if ( 615 isinstance(expression, (exp.Cast, exp.TryCast)) 616 and expression.name.lower() == "epoch" 617 and expression.to.this in exp.DataType.TEMPORAL_TYPES 618 ): 619 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 620 621 return expression
Replace 'epoch' in casts by the equivalent date literal.
624def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 625 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 626 if isinstance(expression, exp.Select): 627 for join in expression.args.get("joins") or []: 628 on = join.args.get("on") 629 if on and join.kind in ("SEMI", "ANTI"): 630 subquery = exp.select("1").from_(join.this).where(on) 631 exists = exp.Exists(this=subquery) 632 if join.kind == "ANTI": 633 exists = exists.not_(copy=False) 634 635 join.pop() 636 expression.where(exists, copy=False) 637 638 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
641def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 642 """ 643 Converts a query with a FULL OUTER join to a union of identical queries that 644 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 645 for queries that have a single FULL OUTER join. 646 """ 647 if isinstance(expression, exp.Select): 648 full_outer_joins = [ 649 (index, join) 650 for index, join in enumerate(expression.args.get("joins") or []) 651 if join.side == "FULL" 652 ] 653 654 if len(full_outer_joins) == 1: 655 expression_copy = expression.copy() 656 expression.set("limit", None) 657 index, full_outer_join = full_outer_joins[0] 658 659 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 660 join_conditions = full_outer_join.args.get("on") or exp.and_( 661 *[ 662 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 663 for col in full_outer_join.args.get("using") 664 ] 665 ) 666 667 full_outer_join.set("side", "left") 668 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 669 expression_copy.args["joins"][index].set("side", "right") 670 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 671 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 672 expression.args.pop("order", None) # remove order by from LEFT side 673 674 return exp.union(expression, expression_copy, copy=False, distinct=False) 675 676 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
679def move_ctes_to_top_level(expression: E) -> E: 680 """ 681 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 682 defined at the top-level, so for example queries like: 683 684 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 685 686 are invalid in those dialects. This transformation can be used to ensure all CTEs are 687 moved to the top level so that the final SQL code is valid from a syntax standpoint. 688 689 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 690 """ 691 top_level_with = expression.args.get("with") 692 for inner_with in expression.find_all(exp.With): 693 if inner_with.parent is expression: 694 continue 695 696 if not top_level_with: 697 top_level_with = inner_with.pop() 698 expression.set("with", top_level_with) 699 else: 700 if inner_with.recursive: 701 top_level_with.set("recursive", True) 702 703 parent_cte = inner_with.find_ancestor(exp.CTE) 704 inner_with.pop() 705 706 if parent_cte: 707 i = top_level_with.expressions.index(parent_cte) 708 top_level_with.expressions[i:i] = inner_with.expressions 709 top_level_with.set("expressions", top_level_with.expressions) 710 else: 711 top_level_with.set( 712 "expressions", top_level_with.expressions + inner_with.expressions 713 ) 714 715 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
718def ensure_bools(expression: exp.Expression) -> exp.Expression: 719 """Converts numeric values used in conditions into explicit boolean expressions.""" 720 from sqlglot.optimizer.canonicalize import ensure_bools 721 722 def _ensure_bool(node: exp.Expression) -> None: 723 if ( 724 node.is_number 725 or ( 726 not isinstance(node, exp.SubqueryPredicate) 727 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 728 ) 729 or (isinstance(node, exp.Column) and not node.type) 730 ): 731 node.replace(node.neq(0)) 732 733 for node in expression.walk(): 734 ensure_bools(node, _ensure_bool) 735 736 return expression
Converts numeric values used in conditions into explicit boolean expressions.
757def ctas_with_tmp_tables_to_create_tmp_view( 758 expression: exp.Expression, 759 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 760) -> exp.Expression: 761 assert isinstance(expression, exp.Create) 762 properties = expression.args.get("properties") 763 temporary = any( 764 isinstance(prop, exp.TemporaryProperty) 765 for prop in (properties.expressions if properties else []) 766 ) 767 768 # CTAS with temp tables map to CREATE TEMPORARY VIEW 769 if expression.kind == "TABLE" and temporary: 770 if expression.expression: 771 return exp.Create( 772 kind="TEMPORARY VIEW", 773 this=expression.this, 774 expression=expression.expression, 775 ) 776 return tmp_storage_provider(expression) 777 778 return expression
781def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 782 """ 783 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 784 PARTITIONED BY value is an array of column names, they are transformed into a schema. 785 The corresponding columns are removed from the create statement. 786 """ 787 assert isinstance(expression, exp.Create) 788 has_schema = isinstance(expression.this, exp.Schema) 789 is_partitionable = expression.kind in {"TABLE", "VIEW"} 790 791 if has_schema and is_partitionable: 792 prop = expression.find(exp.PartitionedByProperty) 793 if prop and prop.this and not isinstance(prop.this, exp.Schema): 794 schema = expression.this 795 columns = {v.name.upper() for v in prop.this.expressions} 796 partitions = [col for col in schema.expressions if col.name.upper() in columns] 797 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 798 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 799 expression.set("this", schema) 800 801 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
804def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 805 """ 806 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 807 808 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 809 """ 810 assert isinstance(expression, exp.Create) 811 prop = expression.find(exp.PartitionedByProperty) 812 if ( 813 prop 814 and prop.this 815 and isinstance(prop.this, exp.Schema) 816 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 817 ): 818 prop_this = exp.Tuple( 819 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 820 ) 821 schema = expression.this 822 for e in prop.this.expressions: 823 schema.append("expressions", e) 824 prop.set("this", prop_this) 825 826 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
829def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 830 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 831 if isinstance(expression, exp.Struct): 832 expression.set( 833 "expressions", 834 [ 835 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 836 for e in expression.expressions 837 ], 838 ) 839 840 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
843def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 844 """ 845 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 846 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 847 848 For example, 849 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 850 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 851 852 Args: 853 expression: The AST to remove join marks from. 854 855 Returns: 856 The AST with join marks removed. 857 """ 858 from sqlglot.optimizer.scope import traverse_scope 859 860 for scope in traverse_scope(expression): 861 query = scope.expression 862 863 where = query.args.get("where") 864 joins = query.args.get("joins") 865 866 if not where or not joins: 867 continue 868 869 query_from = query.args["from"] 870 871 # These keep track of the joins to be replaced 872 new_joins: t.Dict[str, exp.Join] = {} 873 old_joins = {join.alias_or_name: join for join in joins} 874 875 for column in scope.columns: 876 if not column.args.get("join_mark"): 877 continue 878 879 predicate = column.find_ancestor(exp.Predicate, exp.Select) 880 assert isinstance( 881 predicate, exp.Binary 882 ), "Columns can only be marked with (+) when involved in a binary operation" 883 884 predicate_parent = predicate.parent 885 join_predicate = predicate.pop() 886 887 left_columns = [ 888 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 889 ] 890 right_columns = [ 891 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 892 ] 893 894 assert not ( 895 left_columns and right_columns 896 ), "The (+) marker cannot appear in both sides of a binary predicate" 897 898 marked_column_tables = set() 899 for col in left_columns or right_columns: 900 table = col.table 901 assert table, f"Column {col} needs to be qualified with a table" 902 903 col.set("join_mark", False) 904 marked_column_tables.add(table) 905 906 assert ( 907 len(marked_column_tables) == 1 908 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 909 910 join_this = old_joins.get(col.table, query_from).this 911 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 912 913 # Upsert new_join into new_joins dictionary 914 new_join_alias_or_name = new_join.alias_or_name 915 existing_join = new_joins.get(new_join_alias_or_name) 916 if existing_join: 917 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 918 else: 919 new_joins[new_join_alias_or_name] = new_join 920 921 # If the parent of the target predicate is a binary node, then it now has only one child 922 if isinstance(predicate_parent, exp.Binary): 923 if predicate_parent.left is None: 924 predicate_parent.replace(predicate_parent.right) 925 else: 926 predicate_parent.replace(predicate_parent.left) 927 928 if query_from.alias_or_name in new_joins: 929 only_old_joins = old_joins.keys() - new_joins.keys() 930 assert ( 931 len(only_old_joins) >= 1 932 ), "Cannot determine which table to use in the new FROM clause" 933 934 new_from_name = list(only_old_joins)[0] 935 query.set("from", exp.From(this=old_joins[new_from_name].this)) 936 937 query.set("joins", list(new_joins.values())) 938 939 if not where.this: 940 where.pop() 941 942 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.
945def any_to_exists(expression: exp.Expression) -> exp.Expression: 946 """ 947 Transform ANY operator to Spark's EXISTS 948 949 For example, 950 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 951 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 952 953 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 954 transformation 955 """ 956 if isinstance(expression, exp.Select): 957 for any in expression.find_all(exp.Any): 958 this = any.this 959 if isinstance(this, exp.Query): 960 continue 961 962 binop = any.parent 963 if isinstance(binop, exp.Binary): 964 lambda_arg = exp.to_identifier("x") 965 any.replace(lambda_arg) 966 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 967 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 968 969 return expression
Transform ANY operator to Spark's EXISTS
For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation