sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot._typing import E 12 from sqlglot.generator import Generator 13 14 15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql 62 63 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression 119 120 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression 132 133 134def unalias_group(expression: exp.Expression) -> exp.Expression: 135 """ 136 Replace references to select aliases in GROUP BY clauses. 137 138 Example: 139 >>> import sqlglot 140 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 141 'SELECT a AS b FROM x GROUP BY 1' 142 143 Args: 144 expression: the expression that will be transformed. 145 146 Returns: 147 The transformed expression. 148 """ 149 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 150 aliased_selects = { 151 e.alias: i 152 for i, e in enumerate(expression.parent.expressions, start=1) 153 if isinstance(e, exp.Alias) 154 } 155 156 for group_by in expression.expressions: 157 if ( 158 isinstance(group_by, exp.Column) 159 and not group_by.table 160 and group_by.name in aliased_selects 161 ): 162 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 163 164 return expression 165 166 167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 168 """ 169 Convert SELECT DISTINCT ON statements to a subquery with a window function. 170 171 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 172 173 Args: 174 expression: the expression that will be transformed. 175 176 Returns: 177 The transformed expression. 178 """ 179 if ( 180 isinstance(expression, exp.Select) 181 and expression.args.get("distinct") 182 and expression.args["distinct"].args.get("on") 183 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 184 ): 185 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 186 outer_selects = expression.selects 187 row_number = find_new_name(expression.named_selects, "_row_number") 188 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 189 order = expression.args.get("order") 190 191 if order: 192 window.set("order", order.pop()) 193 else: 194 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 195 196 window = exp.alias_(window, row_number) 197 expression.select(window, copy=False) 198 199 return ( 200 exp.select(*outer_selects, copy=False) 201 .from_(expression.subquery("_t", copy=False), copy=False) 202 .where(exp.column(row_number).eq(1), copy=False) 203 ) 204 205 return expression 206 207 208def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 209 """ 210 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 211 212 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 213 https://docs.snowflake.com/en/sql-reference/constructs/qualify 214 215 Some dialects don't support window functions in the WHERE clause, so we need to include them as 216 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 217 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 218 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 219 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 220 corresponding expression to avoid creating invalid column references. 221 """ 222 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 223 taken = set(expression.named_selects) 224 for select in expression.selects: 225 if not select.alias_or_name: 226 alias = find_new_name(taken, "_c") 227 select.replace(exp.alias_(select, alias)) 228 taken.add(alias) 229 230 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 231 alias_or_name = select.alias_or_name 232 identifier = select.args.get("alias") or select.this 233 if isinstance(identifier, exp.Identifier): 234 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 235 return alias_or_name 236 237 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 238 qualify_filters = expression.args["qualify"].pop().this 239 expression_by_alias = { 240 select.alias: select.this 241 for select in expression.selects 242 if isinstance(select, exp.Alias) 243 } 244 245 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 246 for select_candidate in qualify_filters.find_all(select_candidates): 247 if isinstance(select_candidate, exp.Window): 248 if expression_by_alias: 249 for column in select_candidate.find_all(exp.Column): 250 expr = expression_by_alias.get(column.name) 251 if expr: 252 column.replace(expr) 253 254 alias = find_new_name(expression.named_selects, "_w") 255 expression.select(exp.alias_(select_candidate, alias), copy=False) 256 column = exp.column(alias) 257 258 if isinstance(select_candidate.parent, exp.Qualify): 259 qualify_filters = column 260 else: 261 select_candidate.replace(column) 262 elif select_candidate.name not in expression.named_selects: 263 expression.select(select_candidate.copy(), copy=False) 264 265 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 266 qualify_filters, copy=False 267 ) 268 269 return expression 270 271 272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 273 """ 274 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 275 other expressions. This transforms removes the precision from parameterized types in expressions. 276 """ 277 for node in expression.find_all(exp.DataType): 278 node.set( 279 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 280 ) 281 282 return expression 283 284 285def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 286 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 287 from sqlglot.optimizer.scope import find_all_in_scope 288 289 if isinstance(expression, exp.Select): 290 unnest_aliases = { 291 unnest.alias 292 for unnest in find_all_in_scope(expression, exp.Unnest) 293 if isinstance(unnest.parent, (exp.From, exp.Join)) 294 } 295 if unnest_aliases: 296 for column in expression.find_all(exp.Column): 297 if column.table in unnest_aliases: 298 column.set("table", None) 299 elif column.db in unnest_aliases: 300 column.set("db", None) 301 302 return expression 303 304 305def unnest_to_explode( 306 expression: exp.Expression, 307 unnest_using_arrays_zip: bool = True, 308) -> exp.Expression: 309 """Convert cross join unnest into lateral view explode.""" 310 311 def _unnest_zip_exprs( 312 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 313 ) -> t.List[exp.Expression]: 314 if has_multi_expr: 315 if not unnest_using_arrays_zip: 316 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 317 318 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 319 zip_exprs: t.List[exp.Expression] = [ 320 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 321 ] 322 u.set("expressions", zip_exprs) 323 return zip_exprs 324 return unnest_exprs 325 326 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 327 if u.args.get("offset"): 328 return exp.Posexplode 329 return exp.Inline if has_multi_expr else exp.Explode 330 331 if isinstance(expression, exp.Select): 332 from_ = expression.args.get("from") 333 334 if from_ and isinstance(from_.this, exp.Unnest): 335 unnest = from_.this 336 alias = unnest.args.get("alias") 337 exprs = unnest.expressions 338 has_multi_expr = len(exprs) > 1 339 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 340 341 unnest.replace( 342 exp.Table( 343 this=_udtf_type(unnest, has_multi_expr)( 344 this=this, 345 expressions=expressions, 346 ), 347 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 348 ) 349 ) 350 351 for join in expression.args.get("joins") or []: 352 join_expr = join.this 353 354 is_lateral = isinstance(join_expr, exp.Lateral) 355 356 unnest = join_expr.this if is_lateral else join_expr 357 358 if isinstance(unnest, exp.Unnest): 359 if is_lateral: 360 alias = join_expr.args.get("alias") 361 else: 362 alias = unnest.args.get("alias") 363 exprs = unnest.expressions 364 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 365 has_multi_expr = len(exprs) > 1 366 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 367 368 expression.args["joins"].remove(join) 369 370 alias_cols = alias.columns if alias else [] 371 for e, column in zip(exprs, alias_cols): 372 expression.append( 373 "laterals", 374 exp.Lateral( 375 this=_udtf_type(unnest, has_multi_expr)(this=e), 376 view=True, 377 alias=exp.TableAlias( 378 this=alias.this, # type: ignore 379 columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore 380 ), 381 ), 382 ) 383 384 return expression 385 386 387def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 388 """Convert explode/posexplode into unnest.""" 389 390 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 391 if isinstance(expression, exp.Select): 392 from sqlglot.optimizer.scope import Scope 393 394 taken_select_names = set(expression.named_selects) 395 taken_source_names = {name for name, _ in Scope(expression).references} 396 397 def new_name(names: t.Set[str], name: str) -> str: 398 name = find_new_name(names, name) 399 names.add(name) 400 return name 401 402 arrays: t.List[exp.Condition] = [] 403 series_alias = new_name(taken_select_names, "pos") 404 series = exp.alias_( 405 exp.Unnest( 406 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 407 ), 408 new_name(taken_source_names, "_u"), 409 table=[series_alias], 410 ) 411 412 # we use list here because expression.selects is mutated inside the loop 413 for select in list(expression.selects): 414 explode = select.find(exp.Explode) 415 416 if explode: 417 pos_alias = "" 418 explode_alias = "" 419 420 if isinstance(select, exp.Alias): 421 explode_alias = select.args["alias"] 422 alias = select 423 elif isinstance(select, exp.Aliases): 424 pos_alias = select.aliases[0] 425 explode_alias = select.aliases[1] 426 alias = select.replace(exp.alias_(select.this, "", copy=False)) 427 else: 428 alias = select.replace(exp.alias_(select, "")) 429 explode = alias.find(exp.Explode) 430 assert explode 431 432 is_posexplode = isinstance(explode, exp.Posexplode) 433 explode_arg = explode.this 434 435 if isinstance(explode, exp.ExplodeOuter): 436 bracket = explode_arg[0] 437 bracket.set("safe", True) 438 bracket.set("offset", True) 439 explode_arg = exp.func( 440 "IF", 441 exp.func( 442 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 443 ).eq(0), 444 exp.array(bracket, copy=False), 445 explode_arg, 446 ) 447 448 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 449 if isinstance(explode_arg, exp.Column): 450 taken_select_names.add(explode_arg.output_name) 451 452 unnest_source_alias = new_name(taken_source_names, "_u") 453 454 if not explode_alias: 455 explode_alias = new_name(taken_select_names, "col") 456 457 if is_posexplode: 458 pos_alias = new_name(taken_select_names, "pos") 459 460 if not pos_alias: 461 pos_alias = new_name(taken_select_names, "pos") 462 463 alias.set("alias", exp.to_identifier(explode_alias)) 464 465 series_table_alias = series.args["alias"].this 466 column = exp.If( 467 this=exp.column(series_alias, table=series_table_alias).eq( 468 exp.column(pos_alias, table=unnest_source_alias) 469 ), 470 true=exp.column(explode_alias, table=unnest_source_alias), 471 ) 472 473 explode.replace(column) 474 475 if is_posexplode: 476 expressions = expression.expressions 477 expressions.insert( 478 expressions.index(alias) + 1, 479 exp.If( 480 this=exp.column(series_alias, table=series_table_alias).eq( 481 exp.column(pos_alias, table=unnest_source_alias) 482 ), 483 true=exp.column(pos_alias, table=unnest_source_alias), 484 ).as_(pos_alias), 485 ) 486 expression.set("expressions", expressions) 487 488 if not arrays: 489 if expression.args.get("from"): 490 expression.join(series, copy=False, join_type="CROSS") 491 else: 492 expression.from_(series, copy=False) 493 494 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 495 arrays.append(size) 496 497 # trino doesn't support left join unnest with on conditions 498 # if it did, this would be much simpler 499 expression.join( 500 exp.alias_( 501 exp.Unnest( 502 expressions=[explode_arg.copy()], 503 offset=exp.to_identifier(pos_alias), 504 ), 505 unnest_source_alias, 506 table=[explode_alias], 507 ), 508 join_type="CROSS", 509 copy=False, 510 ) 511 512 if index_offset != 1: 513 size = size - 1 514 515 expression.where( 516 exp.column(series_alias, table=series_table_alias) 517 .eq(exp.column(pos_alias, table=unnest_source_alias)) 518 .or_( 519 (exp.column(series_alias, table=series_table_alias) > size).and_( 520 exp.column(pos_alias, table=unnest_source_alias).eq(size) 521 ) 522 ), 523 copy=False, 524 ) 525 526 if arrays: 527 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 528 529 if index_offset != 1: 530 end = end - (1 - index_offset) 531 series.expressions[0].set("end", end) 532 533 return expression 534 535 return _explode_to_unnest 536 537 538def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 539 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 540 if ( 541 isinstance(expression, exp.PERCENTILES) 542 and not isinstance(expression.parent, exp.WithinGroup) 543 and expression.expression 544 ): 545 column = expression.this.pop() 546 expression.set("this", expression.expression.pop()) 547 order = exp.Order(expressions=[exp.Ordered(this=column)]) 548 expression = exp.WithinGroup(this=expression, expression=order) 549 550 return expression 551 552 553def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 554 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 555 if ( 556 isinstance(expression, exp.WithinGroup) 557 and isinstance(expression.this, exp.PERCENTILES) 558 and isinstance(expression.expression, exp.Order) 559 ): 560 quantile = expression.this.this 561 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 562 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 563 564 return expression 565 566 567def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 568 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 569 if isinstance(expression, exp.With) and expression.recursive: 570 next_name = name_sequence("_c_") 571 572 for cte in expression.expressions: 573 if not cte.args["alias"].columns: 574 query = cte.this 575 if isinstance(query, exp.SetOperation): 576 query = query.this 577 578 cte.args["alias"].set( 579 "columns", 580 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 581 ) 582 583 return expression 584 585 586def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 587 """Replace 'epoch' in casts by the equivalent date literal.""" 588 if ( 589 isinstance(expression, (exp.Cast, exp.TryCast)) 590 and expression.name.lower() == "epoch" 591 and expression.to.this in exp.DataType.TEMPORAL_TYPES 592 ): 593 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 594 595 return expression 596 597 598def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 599 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 600 if isinstance(expression, exp.Select): 601 for join in expression.args.get("joins") or []: 602 on = join.args.get("on") 603 if on and join.kind in ("SEMI", "ANTI"): 604 subquery = exp.select("1").from_(join.this).where(on) 605 exists = exp.Exists(this=subquery) 606 if join.kind == "ANTI": 607 exists = exists.not_(copy=False) 608 609 join.pop() 610 expression.where(exists, copy=False) 611 612 return expression 613 614 615def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 616 """ 617 Converts a query with a FULL OUTER join to a union of identical queries that 618 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 619 for queries that have a single FULL OUTER join. 620 """ 621 if isinstance(expression, exp.Select): 622 full_outer_joins = [ 623 (index, join) 624 for index, join in enumerate(expression.args.get("joins") or []) 625 if join.side == "FULL" 626 ] 627 628 if len(full_outer_joins) == 1: 629 expression_copy = expression.copy() 630 expression.set("limit", None) 631 index, full_outer_join = full_outer_joins[0] 632 633 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 634 join_conditions = full_outer_join.args.get("on") or exp.and_( 635 *[ 636 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 637 for col in full_outer_join.args.get("using") 638 ] 639 ) 640 641 full_outer_join.set("side", "left") 642 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 643 expression_copy.args["joins"][index].set("side", "right") 644 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 645 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 646 expression.args.pop("order", None) # remove order by from LEFT side 647 648 return exp.union(expression, expression_copy, copy=False, distinct=False) 649 650 return expression 651 652 653def move_ctes_to_top_level(expression: E) -> E: 654 """ 655 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 656 defined at the top-level, so for example queries like: 657 658 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 659 660 are invalid in those dialects. This transformation can be used to ensure all CTEs are 661 moved to the top level so that the final SQL code is valid from a syntax standpoint. 662 663 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 664 """ 665 top_level_with = expression.args.get("with") 666 for inner_with in expression.find_all(exp.With): 667 if inner_with.parent is expression: 668 continue 669 670 if not top_level_with: 671 top_level_with = inner_with.pop() 672 expression.set("with", top_level_with) 673 else: 674 if inner_with.recursive: 675 top_level_with.set("recursive", True) 676 677 parent_cte = inner_with.find_ancestor(exp.CTE) 678 inner_with.pop() 679 680 if parent_cte: 681 i = top_level_with.expressions.index(parent_cte) 682 top_level_with.expressions[i:i] = inner_with.expressions 683 top_level_with.set("expressions", top_level_with.expressions) 684 else: 685 top_level_with.set( 686 "expressions", top_level_with.expressions + inner_with.expressions 687 ) 688 689 return expression 690 691 692def ensure_bools(expression: exp.Expression) -> exp.Expression: 693 """Converts numeric values used in conditions into explicit boolean expressions.""" 694 from sqlglot.optimizer.canonicalize import ensure_bools 695 696 def _ensure_bool(node: exp.Expression) -> None: 697 if ( 698 node.is_number 699 or ( 700 not isinstance(node, exp.SubqueryPredicate) 701 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 702 ) 703 or (isinstance(node, exp.Column) and not node.type) 704 ): 705 node.replace(node.neq(0)) 706 707 for node in expression.walk(): 708 ensure_bools(node, _ensure_bool) 709 710 return expression 711 712 713def unqualify_columns(expression: exp.Expression) -> exp.Expression: 714 for column in expression.find_all(exp.Column): 715 # We only wanna pop off the table, db, catalog args 716 for part in column.parts[:-1]: 717 part.pop() 718 719 return expression 720 721 722def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 723 assert isinstance(expression, exp.Create) 724 for constraint in expression.find_all(exp.UniqueColumnConstraint): 725 if constraint.parent: 726 constraint.parent.pop() 727 728 return expression 729 730 731def ctas_with_tmp_tables_to_create_tmp_view( 732 expression: exp.Expression, 733 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 734) -> exp.Expression: 735 assert isinstance(expression, exp.Create) 736 properties = expression.args.get("properties") 737 temporary = any( 738 isinstance(prop, exp.TemporaryProperty) 739 for prop in (properties.expressions if properties else []) 740 ) 741 742 # CTAS with temp tables map to CREATE TEMPORARY VIEW 743 if expression.kind == "TABLE" and temporary: 744 if expression.expression: 745 return exp.Create( 746 kind="TEMPORARY VIEW", 747 this=expression.this, 748 expression=expression.expression, 749 ) 750 return tmp_storage_provider(expression) 751 752 return expression 753 754 755def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 756 """ 757 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 758 PARTITIONED BY value is an array of column names, they are transformed into a schema. 759 The corresponding columns are removed from the create statement. 760 """ 761 assert isinstance(expression, exp.Create) 762 has_schema = isinstance(expression.this, exp.Schema) 763 is_partitionable = expression.kind in {"TABLE", "VIEW"} 764 765 if has_schema and is_partitionable: 766 prop = expression.find(exp.PartitionedByProperty) 767 if prop and prop.this and not isinstance(prop.this, exp.Schema): 768 schema = expression.this 769 columns = {v.name.upper() for v in prop.this.expressions} 770 partitions = [col for col in schema.expressions if col.name.upper() in columns] 771 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 772 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 773 expression.set("this", schema) 774 775 return expression 776 777 778def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 779 """ 780 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 781 782 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 783 """ 784 assert isinstance(expression, exp.Create) 785 prop = expression.find(exp.PartitionedByProperty) 786 if ( 787 prop 788 and prop.this 789 and isinstance(prop.this, exp.Schema) 790 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 791 ): 792 prop_this = exp.Tuple( 793 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 794 ) 795 schema = expression.this 796 for e in prop.this.expressions: 797 schema.append("expressions", e) 798 prop.set("this", prop_this) 799 800 return expression 801 802 803def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 804 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 805 if isinstance(expression, exp.Struct): 806 expression.set( 807 "expressions", 808 [ 809 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 810 for e in expression.expressions 811 ], 812 ) 813 814 return expression 815 816 817def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 818 """ 819 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 820 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 821 822 For example, 823 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 824 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 825 826 Args: 827 expression: The AST to remove join marks from. 828 829 Returns: 830 The AST with join marks removed. 831 """ 832 from sqlglot.optimizer.scope import traverse_scope 833 834 for scope in traverse_scope(expression): 835 query = scope.expression 836 837 where = query.args.get("where") 838 joins = query.args.get("joins") 839 840 if not where or not joins: 841 continue 842 843 query_from = query.args["from"] 844 845 # These keep track of the joins to be replaced 846 new_joins: t.Dict[str, exp.Join] = {} 847 old_joins = {join.alias_or_name: join for join in joins} 848 849 for column in scope.columns: 850 if not column.args.get("join_mark"): 851 continue 852 853 predicate = column.find_ancestor(exp.Predicate, exp.Select) 854 assert isinstance( 855 predicate, exp.Binary 856 ), "Columns can only be marked with (+) when involved in a binary operation" 857 858 predicate_parent = predicate.parent 859 join_predicate = predicate.pop() 860 861 left_columns = [ 862 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 863 ] 864 right_columns = [ 865 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 866 ] 867 868 assert not ( 869 left_columns and right_columns 870 ), "The (+) marker cannot appear in both sides of a binary predicate" 871 872 marked_column_tables = set() 873 for col in left_columns or right_columns: 874 table = col.table 875 assert table, f"Column {col} needs to be qualified with a table" 876 877 col.set("join_mark", False) 878 marked_column_tables.add(table) 879 880 assert ( 881 len(marked_column_tables) == 1 882 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 883 884 join_this = old_joins.get(col.table, query_from).this 885 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 886 887 # Upsert new_join into new_joins dictionary 888 new_join_alias_or_name = new_join.alias_or_name 889 existing_join = new_joins.get(new_join_alias_or_name) 890 if existing_join: 891 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 892 else: 893 new_joins[new_join_alias_or_name] = new_join 894 895 # If the parent of the target predicate is a binary node, then it now has only one child 896 if isinstance(predicate_parent, exp.Binary): 897 if predicate_parent.left is None: 898 predicate_parent.replace(predicate_parent.right) 899 else: 900 predicate_parent.replace(predicate_parent.left) 901 902 if query_from.alias_or_name in new_joins: 903 only_old_joins = old_joins.keys() - new_joins.keys() 904 assert ( 905 len(only_old_joins) >= 1 906 ), "Cannot determine which table to use in the new FROM clause" 907 908 new_from_name = list(only_old_joins)[0] 909 query.set("from", exp.From(this=old_joins[new_from_name].this)) 910 911 query.set("joins", list(new_joins.values())) 912 913 if not where.this: 914 where.pop() 915 916 return expression 917 918 919def any_to_exists(expression: exp.Expression) -> exp.Expression: 920 """ 921 Transform ANY operator to Spark's EXISTS 922 923 For example, 924 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 925 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 926 927 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 928 transformation 929 """ 930 if isinstance(expression, exp.Select): 931 for any in expression.find_all(exp.Any): 932 this = any.this 933 if isinstance(this, exp.Query): 934 continue 935 936 binop = any.parent 937 if isinstance(binop, exp.Binary): 938 lambda_arg = exp.to_identifier("x") 939 any.replace(lambda_arg) 940 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 941 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 942 943 return expression
16def preprocess( 17 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 18) -> t.Callable[[Generator, exp.Expression], str]: 19 """ 20 Creates a new transform by chaining a sequence of transformations and converts the resulting 21 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 22 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 23 24 Args: 25 transforms: sequence of transform functions. These will be called in order. 26 27 Returns: 28 Function that can be used as a generator transform. 29 """ 30 31 def _to_sql(self, expression: exp.Expression) -> str: 32 expression_type = type(expression) 33 34 try: 35 expression = transforms[0](expression) 36 for transform in transforms[1:]: 37 expression = transform(expression) 38 except UnsupportedError as unsupported_error: 39 self.unsupported(str(unsupported_error)) 40 41 _sql_handler = getattr(self, expression.key + "_sql", None) 42 if _sql_handler: 43 return _sql_handler(expression) 44 45 transforms_handler = self.TRANSFORMS.get(type(expression)) 46 if transforms_handler: 47 if expression_type is type(expression): 48 if isinstance(expression, exp.Func): 49 return self.function_fallback_sql(expression) 50 51 # Ensures we don't enter an infinite loop. This can happen when the original expression 52 # has the same type as the final expression and there's no _sql method available for it, 53 # because then it'd re-enter _to_sql. 54 raise ValueError( 55 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 56 ) 57 58 return transforms_handler(self, expression) 59 60 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 61 62 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 66 if isinstance(expression, exp.Select): 67 count = 0 68 recursive_ctes = [] 69 70 for unnest in expression.find_all(exp.Unnest): 71 if ( 72 not isinstance(unnest.parent, (exp.From, exp.Join)) 73 or len(unnest.expressions) != 1 74 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 75 ): 76 continue 77 78 generate_date_array = unnest.expressions[0] 79 start = generate_date_array.args.get("start") 80 end = generate_date_array.args.get("end") 81 step = generate_date_array.args.get("step") 82 83 if not start or not end or not isinstance(step, exp.Interval): 84 continue 85 86 alias = unnest.args.get("alias") 87 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 88 89 start = exp.cast(start, "date") 90 date_add = exp.func( 91 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 92 ) 93 cast_date_add = exp.cast(date_add, "date") 94 95 cte_name = "_generated_dates" + (f"_{count}" if count else "") 96 97 base_query = exp.select(start.as_(column_name)) 98 recursive_query = ( 99 exp.select(cast_date_add) 100 .from_(cte_name) 101 .where(cast_date_add <= exp.cast(end, "date")) 102 ) 103 cte_query = base_query.union(recursive_query, distinct=False) 104 105 generate_dates_query = exp.select(column_name).from_(cte_name) 106 unnest.replace(generate_dates_query.subquery(cte_name)) 107 108 recursive_ctes.append( 109 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 110 ) 111 count += 1 112 113 if recursive_ctes: 114 with_expression = expression.args.get("with") or exp.With() 115 with_expression.set("recursive", True) 116 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 117 expression.set("with", with_expression) 118 119 return expression
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 123 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 124 this = expression.this 125 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 126 unnest = exp.Unnest(expressions=[this]) 127 if expression.alias: 128 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 129 130 return unnest 131 132 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
135def unalias_group(expression: exp.Expression) -> exp.Expression: 136 """ 137 Replace references to select aliases in GROUP BY clauses. 138 139 Example: 140 >>> import sqlglot 141 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 142 'SELECT a AS b FROM x GROUP BY 1' 143 144 Args: 145 expression: the expression that will be transformed. 146 147 Returns: 148 The transformed expression. 149 """ 150 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 151 aliased_selects = { 152 e.alias: i 153 for i, e in enumerate(expression.parent.expressions, start=1) 154 if isinstance(e, exp.Alias) 155 } 156 157 for group_by in expression.expressions: 158 if ( 159 isinstance(group_by, exp.Column) 160 and not group_by.table 161 and group_by.name in aliased_selects 162 ): 163 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 164 165 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 169 """ 170 Convert SELECT DISTINCT ON statements to a subquery with a window function. 171 172 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 173 174 Args: 175 expression: the expression that will be transformed. 176 177 Returns: 178 The transformed expression. 179 """ 180 if ( 181 isinstance(expression, exp.Select) 182 and expression.args.get("distinct") 183 and expression.args["distinct"].args.get("on") 184 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 185 ): 186 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 187 outer_selects = expression.selects 188 row_number = find_new_name(expression.named_selects, "_row_number") 189 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 190 order = expression.args.get("order") 191 192 if order: 193 window.set("order", order.pop()) 194 else: 195 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 196 197 window = exp.alias_(window, row_number) 198 expression.select(window, copy=False) 199 200 return ( 201 exp.select(*outer_selects, copy=False) 202 .from_(expression.subquery("_t", copy=False), copy=False) 203 .where(exp.column(row_number).eq(1), copy=False) 204 ) 205 206 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
209def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 210 """ 211 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 212 213 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 214 https://docs.snowflake.com/en/sql-reference/constructs/qualify 215 216 Some dialects don't support window functions in the WHERE clause, so we need to include them as 217 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 218 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 219 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 220 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 221 corresponding expression to avoid creating invalid column references. 222 """ 223 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 224 taken = set(expression.named_selects) 225 for select in expression.selects: 226 if not select.alias_or_name: 227 alias = find_new_name(taken, "_c") 228 select.replace(exp.alias_(select, alias)) 229 taken.add(alias) 230 231 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 232 alias_or_name = select.alias_or_name 233 identifier = select.args.get("alias") or select.this 234 if isinstance(identifier, exp.Identifier): 235 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 236 return alias_or_name 237 238 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 239 qualify_filters = expression.args["qualify"].pop().this 240 expression_by_alias = { 241 select.alias: select.this 242 for select in expression.selects 243 if isinstance(select, exp.Alias) 244 } 245 246 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 247 for select_candidate in qualify_filters.find_all(select_candidates): 248 if isinstance(select_candidate, exp.Window): 249 if expression_by_alias: 250 for column in select_candidate.find_all(exp.Column): 251 expr = expression_by_alias.get(column.name) 252 if expr: 253 column.replace(expr) 254 255 alias = find_new_name(expression.named_selects, "_w") 256 expression.select(exp.alias_(select_candidate, alias), copy=False) 257 column = exp.column(alias) 258 259 if isinstance(select_candidate.parent, exp.Qualify): 260 qualify_filters = column 261 else: 262 select_candidate.replace(column) 263 elif select_candidate.name not in expression.named_selects: 264 expression.select(select_candidate.copy(), copy=False) 265 266 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 267 qualify_filters, copy=False 268 ) 269 270 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
273def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 274 """ 275 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 276 other expressions. This transforms removes the precision from parameterized types in expressions. 277 """ 278 for node in expression.find_all(exp.DataType): 279 node.set( 280 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 281 ) 282 283 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
286def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 287 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 288 from sqlglot.optimizer.scope import find_all_in_scope 289 290 if isinstance(expression, exp.Select): 291 unnest_aliases = { 292 unnest.alias 293 for unnest in find_all_in_scope(expression, exp.Unnest) 294 if isinstance(unnest.parent, (exp.From, exp.Join)) 295 } 296 if unnest_aliases: 297 for column in expression.find_all(exp.Column): 298 if column.table in unnest_aliases: 299 column.set("table", None) 300 elif column.db in unnest_aliases: 301 column.set("db", None) 302 303 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
306def unnest_to_explode( 307 expression: exp.Expression, 308 unnest_using_arrays_zip: bool = True, 309) -> exp.Expression: 310 """Convert cross join unnest into lateral view explode.""" 311 312 def _unnest_zip_exprs( 313 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 314 ) -> t.List[exp.Expression]: 315 if has_multi_expr: 316 if not unnest_using_arrays_zip: 317 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 318 319 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 320 zip_exprs: t.List[exp.Expression] = [ 321 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 322 ] 323 u.set("expressions", zip_exprs) 324 return zip_exprs 325 return unnest_exprs 326 327 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 328 if u.args.get("offset"): 329 return exp.Posexplode 330 return exp.Inline if has_multi_expr else exp.Explode 331 332 if isinstance(expression, exp.Select): 333 from_ = expression.args.get("from") 334 335 if from_ and isinstance(from_.this, exp.Unnest): 336 unnest = from_.this 337 alias = unnest.args.get("alias") 338 exprs = unnest.expressions 339 has_multi_expr = len(exprs) > 1 340 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 341 342 unnest.replace( 343 exp.Table( 344 this=_udtf_type(unnest, has_multi_expr)( 345 this=this, 346 expressions=expressions, 347 ), 348 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 349 ) 350 ) 351 352 for join in expression.args.get("joins") or []: 353 join_expr = join.this 354 355 is_lateral = isinstance(join_expr, exp.Lateral) 356 357 unnest = join_expr.this if is_lateral else join_expr 358 359 if isinstance(unnest, exp.Unnest): 360 if is_lateral: 361 alias = join_expr.args.get("alias") 362 else: 363 alias = unnest.args.get("alias") 364 exprs = unnest.expressions 365 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 366 has_multi_expr = len(exprs) > 1 367 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 368 369 expression.args["joins"].remove(join) 370 371 alias_cols = alias.columns if alias else [] 372 for e, column in zip(exprs, alias_cols): 373 expression.append( 374 "laterals", 375 exp.Lateral( 376 this=_udtf_type(unnest, has_multi_expr)(this=e), 377 view=True, 378 alias=exp.TableAlias( 379 this=alias.this, # type: ignore 380 columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore 381 ), 382 ), 383 ) 384 385 return expression
Convert cross join unnest into lateral view explode.
388def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 389 """Convert explode/posexplode into unnest.""" 390 391 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 392 if isinstance(expression, exp.Select): 393 from sqlglot.optimizer.scope import Scope 394 395 taken_select_names = set(expression.named_selects) 396 taken_source_names = {name for name, _ in Scope(expression).references} 397 398 def new_name(names: t.Set[str], name: str) -> str: 399 name = find_new_name(names, name) 400 names.add(name) 401 return name 402 403 arrays: t.List[exp.Condition] = [] 404 series_alias = new_name(taken_select_names, "pos") 405 series = exp.alias_( 406 exp.Unnest( 407 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 408 ), 409 new_name(taken_source_names, "_u"), 410 table=[series_alias], 411 ) 412 413 # we use list here because expression.selects is mutated inside the loop 414 for select in list(expression.selects): 415 explode = select.find(exp.Explode) 416 417 if explode: 418 pos_alias = "" 419 explode_alias = "" 420 421 if isinstance(select, exp.Alias): 422 explode_alias = select.args["alias"] 423 alias = select 424 elif isinstance(select, exp.Aliases): 425 pos_alias = select.aliases[0] 426 explode_alias = select.aliases[1] 427 alias = select.replace(exp.alias_(select.this, "", copy=False)) 428 else: 429 alias = select.replace(exp.alias_(select, "")) 430 explode = alias.find(exp.Explode) 431 assert explode 432 433 is_posexplode = isinstance(explode, exp.Posexplode) 434 explode_arg = explode.this 435 436 if isinstance(explode, exp.ExplodeOuter): 437 bracket = explode_arg[0] 438 bracket.set("safe", True) 439 bracket.set("offset", True) 440 explode_arg = exp.func( 441 "IF", 442 exp.func( 443 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 444 ).eq(0), 445 exp.array(bracket, copy=False), 446 explode_arg, 447 ) 448 449 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 450 if isinstance(explode_arg, exp.Column): 451 taken_select_names.add(explode_arg.output_name) 452 453 unnest_source_alias = new_name(taken_source_names, "_u") 454 455 if not explode_alias: 456 explode_alias = new_name(taken_select_names, "col") 457 458 if is_posexplode: 459 pos_alias = new_name(taken_select_names, "pos") 460 461 if not pos_alias: 462 pos_alias = new_name(taken_select_names, "pos") 463 464 alias.set("alias", exp.to_identifier(explode_alias)) 465 466 series_table_alias = series.args["alias"].this 467 column = exp.If( 468 this=exp.column(series_alias, table=series_table_alias).eq( 469 exp.column(pos_alias, table=unnest_source_alias) 470 ), 471 true=exp.column(explode_alias, table=unnest_source_alias), 472 ) 473 474 explode.replace(column) 475 476 if is_posexplode: 477 expressions = expression.expressions 478 expressions.insert( 479 expressions.index(alias) + 1, 480 exp.If( 481 this=exp.column(series_alias, table=series_table_alias).eq( 482 exp.column(pos_alias, table=unnest_source_alias) 483 ), 484 true=exp.column(pos_alias, table=unnest_source_alias), 485 ).as_(pos_alias), 486 ) 487 expression.set("expressions", expressions) 488 489 if not arrays: 490 if expression.args.get("from"): 491 expression.join(series, copy=False, join_type="CROSS") 492 else: 493 expression.from_(series, copy=False) 494 495 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 496 arrays.append(size) 497 498 # trino doesn't support left join unnest with on conditions 499 # if it did, this would be much simpler 500 expression.join( 501 exp.alias_( 502 exp.Unnest( 503 expressions=[explode_arg.copy()], 504 offset=exp.to_identifier(pos_alias), 505 ), 506 unnest_source_alias, 507 table=[explode_alias], 508 ), 509 join_type="CROSS", 510 copy=False, 511 ) 512 513 if index_offset != 1: 514 size = size - 1 515 516 expression.where( 517 exp.column(series_alias, table=series_table_alias) 518 .eq(exp.column(pos_alias, table=unnest_source_alias)) 519 .or_( 520 (exp.column(series_alias, table=series_table_alias) > size).and_( 521 exp.column(pos_alias, table=unnest_source_alias).eq(size) 522 ) 523 ), 524 copy=False, 525 ) 526 527 if arrays: 528 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 529 530 if index_offset != 1: 531 end = end - (1 - index_offset) 532 series.expressions[0].set("end", end) 533 534 return expression 535 536 return _explode_to_unnest
Convert explode/posexplode into unnest.
539def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 540 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 541 if ( 542 isinstance(expression, exp.PERCENTILES) 543 and not isinstance(expression.parent, exp.WithinGroup) 544 and expression.expression 545 ): 546 column = expression.this.pop() 547 expression.set("this", expression.expression.pop()) 548 order = exp.Order(expressions=[exp.Ordered(this=column)]) 549 expression = exp.WithinGroup(this=expression, expression=order) 550 551 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
554def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 555 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 556 if ( 557 isinstance(expression, exp.WithinGroup) 558 and isinstance(expression.this, exp.PERCENTILES) 559 and isinstance(expression.expression, exp.Order) 560 ): 561 quantile = expression.this.this 562 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 563 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 564 565 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
568def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 569 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 570 if isinstance(expression, exp.With) and expression.recursive: 571 next_name = name_sequence("_c_") 572 573 for cte in expression.expressions: 574 if not cte.args["alias"].columns: 575 query = cte.this 576 if isinstance(query, exp.SetOperation): 577 query = query.this 578 579 cte.args["alias"].set( 580 "columns", 581 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 582 ) 583 584 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
587def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 588 """Replace 'epoch' in casts by the equivalent date literal.""" 589 if ( 590 isinstance(expression, (exp.Cast, exp.TryCast)) 591 and expression.name.lower() == "epoch" 592 and expression.to.this in exp.DataType.TEMPORAL_TYPES 593 ): 594 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 595 596 return expression
Replace 'epoch' in casts by the equivalent date literal.
599def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 600 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 601 if isinstance(expression, exp.Select): 602 for join in expression.args.get("joins") or []: 603 on = join.args.get("on") 604 if on and join.kind in ("SEMI", "ANTI"): 605 subquery = exp.select("1").from_(join.this).where(on) 606 exists = exp.Exists(this=subquery) 607 if join.kind == "ANTI": 608 exists = exists.not_(copy=False) 609 610 join.pop() 611 expression.where(exists, copy=False) 612 613 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
616def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 617 """ 618 Converts a query with a FULL OUTER join to a union of identical queries that 619 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 620 for queries that have a single FULL OUTER join. 621 """ 622 if isinstance(expression, exp.Select): 623 full_outer_joins = [ 624 (index, join) 625 for index, join in enumerate(expression.args.get("joins") or []) 626 if join.side == "FULL" 627 ] 628 629 if len(full_outer_joins) == 1: 630 expression_copy = expression.copy() 631 expression.set("limit", None) 632 index, full_outer_join = full_outer_joins[0] 633 634 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 635 join_conditions = full_outer_join.args.get("on") or exp.and_( 636 *[ 637 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 638 for col in full_outer_join.args.get("using") 639 ] 640 ) 641 642 full_outer_join.set("side", "left") 643 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 644 expression_copy.args["joins"][index].set("side", "right") 645 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 646 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 647 expression.args.pop("order", None) # remove order by from LEFT side 648 649 return exp.union(expression, expression_copy, copy=False, distinct=False) 650 651 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
654def move_ctes_to_top_level(expression: E) -> E: 655 """ 656 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 657 defined at the top-level, so for example queries like: 658 659 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 660 661 are invalid in those dialects. This transformation can be used to ensure all CTEs are 662 moved to the top level so that the final SQL code is valid from a syntax standpoint. 663 664 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 665 """ 666 top_level_with = expression.args.get("with") 667 for inner_with in expression.find_all(exp.With): 668 if inner_with.parent is expression: 669 continue 670 671 if not top_level_with: 672 top_level_with = inner_with.pop() 673 expression.set("with", top_level_with) 674 else: 675 if inner_with.recursive: 676 top_level_with.set("recursive", True) 677 678 parent_cte = inner_with.find_ancestor(exp.CTE) 679 inner_with.pop() 680 681 if parent_cte: 682 i = top_level_with.expressions.index(parent_cte) 683 top_level_with.expressions[i:i] = inner_with.expressions 684 top_level_with.set("expressions", top_level_with.expressions) 685 else: 686 top_level_with.set( 687 "expressions", top_level_with.expressions + inner_with.expressions 688 ) 689 690 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
693def ensure_bools(expression: exp.Expression) -> exp.Expression: 694 """Converts numeric values used in conditions into explicit boolean expressions.""" 695 from sqlglot.optimizer.canonicalize import ensure_bools 696 697 def _ensure_bool(node: exp.Expression) -> None: 698 if ( 699 node.is_number 700 or ( 701 not isinstance(node, exp.SubqueryPredicate) 702 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 703 ) 704 or (isinstance(node, exp.Column) and not node.type) 705 ): 706 node.replace(node.neq(0)) 707 708 for node in expression.walk(): 709 ensure_bools(node, _ensure_bool) 710 711 return expression
Converts numeric values used in conditions into explicit boolean expressions.
732def ctas_with_tmp_tables_to_create_tmp_view( 733 expression: exp.Expression, 734 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 735) -> exp.Expression: 736 assert isinstance(expression, exp.Create) 737 properties = expression.args.get("properties") 738 temporary = any( 739 isinstance(prop, exp.TemporaryProperty) 740 for prop in (properties.expressions if properties else []) 741 ) 742 743 # CTAS with temp tables map to CREATE TEMPORARY VIEW 744 if expression.kind == "TABLE" and temporary: 745 if expression.expression: 746 return exp.Create( 747 kind="TEMPORARY VIEW", 748 this=expression.this, 749 expression=expression.expression, 750 ) 751 return tmp_storage_provider(expression) 752 753 return expression
756def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 757 """ 758 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 759 PARTITIONED BY value is an array of column names, they are transformed into a schema. 760 The corresponding columns are removed from the create statement. 761 """ 762 assert isinstance(expression, exp.Create) 763 has_schema = isinstance(expression.this, exp.Schema) 764 is_partitionable = expression.kind in {"TABLE", "VIEW"} 765 766 if has_schema and is_partitionable: 767 prop = expression.find(exp.PartitionedByProperty) 768 if prop and prop.this and not isinstance(prop.this, exp.Schema): 769 schema = expression.this 770 columns = {v.name.upper() for v in prop.this.expressions} 771 partitions = [col for col in schema.expressions if col.name.upper() in columns] 772 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 773 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 774 expression.set("this", schema) 775 776 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
779def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 780 """ 781 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 782 783 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 784 """ 785 assert isinstance(expression, exp.Create) 786 prop = expression.find(exp.PartitionedByProperty) 787 if ( 788 prop 789 and prop.this 790 and isinstance(prop.this, exp.Schema) 791 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 792 ): 793 prop_this = exp.Tuple( 794 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 795 ) 796 schema = expression.this 797 for e in prop.this.expressions: 798 schema.append("expressions", e) 799 prop.set("this", prop_this) 800 801 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
804def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 805 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 806 if isinstance(expression, exp.Struct): 807 expression.set( 808 "expressions", 809 [ 810 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 811 for e in expression.expressions 812 ], 813 ) 814 815 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
818def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 819 """ 820 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 821 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 822 823 For example, 824 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 825 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 826 827 Args: 828 expression: The AST to remove join marks from. 829 830 Returns: 831 The AST with join marks removed. 832 """ 833 from sqlglot.optimizer.scope import traverse_scope 834 835 for scope in traverse_scope(expression): 836 query = scope.expression 837 838 where = query.args.get("where") 839 joins = query.args.get("joins") 840 841 if not where or not joins: 842 continue 843 844 query_from = query.args["from"] 845 846 # These keep track of the joins to be replaced 847 new_joins: t.Dict[str, exp.Join] = {} 848 old_joins = {join.alias_or_name: join for join in joins} 849 850 for column in scope.columns: 851 if not column.args.get("join_mark"): 852 continue 853 854 predicate = column.find_ancestor(exp.Predicate, exp.Select) 855 assert isinstance( 856 predicate, exp.Binary 857 ), "Columns can only be marked with (+) when involved in a binary operation" 858 859 predicate_parent = predicate.parent 860 join_predicate = predicate.pop() 861 862 left_columns = [ 863 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 864 ] 865 right_columns = [ 866 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 867 ] 868 869 assert not ( 870 left_columns and right_columns 871 ), "The (+) marker cannot appear in both sides of a binary predicate" 872 873 marked_column_tables = set() 874 for col in left_columns or right_columns: 875 table = col.table 876 assert table, f"Column {col} needs to be qualified with a table" 877 878 col.set("join_mark", False) 879 marked_column_tables.add(table) 880 881 assert ( 882 len(marked_column_tables) == 1 883 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 884 885 join_this = old_joins.get(col.table, query_from).this 886 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 887 888 # Upsert new_join into new_joins dictionary 889 new_join_alias_or_name = new_join.alias_or_name 890 existing_join = new_joins.get(new_join_alias_or_name) 891 if existing_join: 892 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 893 else: 894 new_joins[new_join_alias_or_name] = new_join 895 896 # If the parent of the target predicate is a binary node, then it now has only one child 897 if isinstance(predicate_parent, exp.Binary): 898 if predicate_parent.left is None: 899 predicate_parent.replace(predicate_parent.right) 900 else: 901 predicate_parent.replace(predicate_parent.left) 902 903 if query_from.alias_or_name in new_joins: 904 only_old_joins = old_joins.keys() - new_joins.keys() 905 assert ( 906 len(only_old_joins) >= 1 907 ), "Cannot determine which table to use in the new FROM clause" 908 909 new_from_name = list(only_old_joins)[0] 910 query.set("from", exp.From(this=old_joins[new_from_name].this)) 911 912 query.set("joins", list(new_joins.values())) 913 914 if not where.this: 915 where.pop() 916 917 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.
920def any_to_exists(expression: exp.Expression) -> exp.Expression: 921 """ 922 Transform ANY operator to Spark's EXISTS 923 924 For example, 925 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 926 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 927 928 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 929 transformation 930 """ 931 if isinstance(expression, exp.Select): 932 for any in expression.find_all(exp.Any): 933 this = any.this 934 if isinstance(this, exp.Query): 935 continue 936 937 binop = any.parent 938 if isinstance(binop, exp.Binary): 939 lambda_arg = exp.to_identifier("x") 940 any.replace(lambda_arg) 941 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 942 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 943 944 return expression
Transform ANY operator to Spark's EXISTS
For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation