sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def preprocess( 13 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 14) -> t.Callable[[Generator, exp.Expression], str]: 15 """ 16 Creates a new transform by chaining a sequence of transformations and converts the resulting 17 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 18 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 19 20 Args: 21 transforms: sequence of transform functions. These will be called in order. 22 23 Returns: 24 Function that can be used as a generator transform. 25 """ 26 27 def _to_sql(self, expression: exp.Expression) -> str: 28 expression_type = type(expression) 29 30 expression = transforms[0](expression) 31 for transform in transforms[1:]: 32 expression = transform(expression) 33 34 _sql_handler = getattr(self, expression.key + "_sql", None) 35 if _sql_handler: 36 return _sql_handler(expression) 37 38 transforms_handler = self.TRANSFORMS.get(type(expression)) 39 if transforms_handler: 40 if expression_type is type(expression): 41 if isinstance(expression, exp.Func): 42 return self.function_fallback_sql(expression) 43 44 # Ensures we don't enter an infinite loop. This can happen when the original expression 45 # has the same type as the final expression and there's no _sql method available for it, 46 # because then it'd re-enter _to_sql. 47 raise ValueError( 48 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 49 ) 50 51 return transforms_handler(self, expression) 52 53 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 54 55 return _to_sql 56 57 58def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 59 if isinstance(expression, exp.Select): 60 count = 0 61 recursive_ctes = [] 62 63 for unnest in expression.find_all(exp.Unnest): 64 if ( 65 not isinstance(unnest.parent, (exp.From, exp.Join)) 66 or len(unnest.expressions) != 1 67 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 68 ): 69 continue 70 71 generate_date_array = unnest.expressions[0] 72 start = generate_date_array.args.get("start") 73 end = generate_date_array.args.get("end") 74 step = generate_date_array.args.get("step") 75 76 if not start or not end or not isinstance(step, exp.Interval): 77 continue 78 79 alias = unnest.args.get("alias") 80 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 81 82 start = exp.cast(start, "date") 83 date_add = exp.func( 84 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 85 ) 86 cast_date_add = exp.cast(date_add, "date") 87 88 cte_name = "_generated_dates" + (f"_{count}" if count else "") 89 90 base_query = exp.select(start.as_(column_name)) 91 recursive_query = ( 92 exp.select(cast_date_add) 93 .from_(cte_name) 94 .where(cast_date_add <= exp.cast(end, "date")) 95 ) 96 cte_query = base_query.union(recursive_query, distinct=False) 97 98 generate_dates_query = exp.select(column_name).from_(cte_name) 99 unnest.replace(generate_dates_query.subquery(cte_name)) 100 101 recursive_ctes.append( 102 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 103 ) 104 count += 1 105 106 if recursive_ctes: 107 with_expression = expression.args.get("with") or exp.With() 108 with_expression.set("recursive", True) 109 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 110 expression.set("with", with_expression) 111 112 return expression 113 114 115def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 116 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 117 this = expression.this 118 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 119 unnest = exp.Unnest(expressions=[this]) 120 if expression.alias: 121 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 122 123 return unnest 124 125 return expression 126 127 128def unalias_group(expression: exp.Expression) -> exp.Expression: 129 """ 130 Replace references to select aliases in GROUP BY clauses. 131 132 Example: 133 >>> import sqlglot 134 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 135 'SELECT a AS b FROM x GROUP BY 1' 136 137 Args: 138 expression: the expression that will be transformed. 139 140 Returns: 141 The transformed expression. 142 """ 143 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 144 aliased_selects = { 145 e.alias: i 146 for i, e in enumerate(expression.parent.expressions, start=1) 147 if isinstance(e, exp.Alias) 148 } 149 150 for group_by in expression.expressions: 151 if ( 152 isinstance(group_by, exp.Column) 153 and not group_by.table 154 and group_by.name in aliased_selects 155 ): 156 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 157 158 return expression 159 160 161def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 162 """ 163 Convert SELECT DISTINCT ON statements to a subquery with a window function. 164 165 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 166 167 Args: 168 expression: the expression that will be transformed. 169 170 Returns: 171 The transformed expression. 172 """ 173 if ( 174 isinstance(expression, exp.Select) 175 and expression.args.get("distinct") 176 and expression.args["distinct"].args.get("on") 177 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 178 ): 179 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 180 outer_selects = expression.selects 181 row_number = find_new_name(expression.named_selects, "_row_number") 182 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 183 order = expression.args.get("order") 184 185 if order: 186 window.set("order", order.pop()) 187 else: 188 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 189 190 window = exp.alias_(window, row_number) 191 expression.select(window, copy=False) 192 193 return ( 194 exp.select(*outer_selects, copy=False) 195 .from_(expression.subquery("_t", copy=False), copy=False) 196 .where(exp.column(row_number).eq(1), copy=False) 197 ) 198 199 return expression 200 201 202def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 203 """ 204 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 205 206 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 207 https://docs.snowflake.com/en/sql-reference/constructs/qualify 208 209 Some dialects don't support window functions in the WHERE clause, so we need to include them as 210 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 211 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 212 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 213 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 214 corresponding expression to avoid creating invalid column references. 215 """ 216 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 217 taken = set(expression.named_selects) 218 for select in expression.selects: 219 if not select.alias_or_name: 220 alias = find_new_name(taken, "_c") 221 select.replace(exp.alias_(select, alias)) 222 taken.add(alias) 223 224 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 225 alias_or_name = select.alias_or_name 226 identifier = select.args.get("alias") or select.this 227 if isinstance(identifier, exp.Identifier): 228 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 229 return alias_or_name 230 231 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 232 qualify_filters = expression.args["qualify"].pop().this 233 expression_by_alias = { 234 select.alias: select.this 235 for select in expression.selects 236 if isinstance(select, exp.Alias) 237 } 238 239 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 240 for select_candidate in qualify_filters.find_all(select_candidates): 241 if isinstance(select_candidate, exp.Window): 242 if expression_by_alias: 243 for column in select_candidate.find_all(exp.Column): 244 expr = expression_by_alias.get(column.name) 245 if expr: 246 column.replace(expr) 247 248 alias = find_new_name(expression.named_selects, "_w") 249 expression.select(exp.alias_(select_candidate, alias), copy=False) 250 column = exp.column(alias) 251 252 if isinstance(select_candidate.parent, exp.Qualify): 253 qualify_filters = column 254 else: 255 select_candidate.replace(column) 256 elif select_candidate.name not in expression.named_selects: 257 expression.select(select_candidate.copy(), copy=False) 258 259 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 260 qualify_filters, copy=False 261 ) 262 263 return expression 264 265 266def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 267 """ 268 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 269 other expressions. This transforms removes the precision from parameterized types in expressions. 270 """ 271 for node in expression.find_all(exp.DataType): 272 node.set( 273 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 274 ) 275 276 return expression 277 278 279def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 280 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 281 from sqlglot.optimizer.scope import find_all_in_scope 282 283 if isinstance(expression, exp.Select): 284 unnest_aliases = { 285 unnest.alias 286 for unnest in find_all_in_scope(expression, exp.Unnest) 287 if isinstance(unnest.parent, (exp.From, exp.Join)) 288 } 289 if unnest_aliases: 290 for column in expression.find_all(exp.Column): 291 if column.table in unnest_aliases: 292 column.set("table", None) 293 elif column.db in unnest_aliases: 294 column.set("db", None) 295 296 return expression 297 298 299def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 300 """Convert cross join unnest into lateral view explode.""" 301 if isinstance(expression, exp.Select): 302 from_ = expression.args.get("from") 303 304 if from_ and isinstance(from_.this, exp.Unnest): 305 unnest = from_.this 306 alias = unnest.args.get("alias") 307 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 308 this, *expressions = unnest.expressions 309 unnest.replace( 310 exp.Table( 311 this=udtf( 312 this=this, 313 expressions=expressions, 314 ), 315 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 316 ) 317 ) 318 319 for join in expression.args.get("joins") or []: 320 join_expr = join.this 321 322 is_lateral = isinstance(join_expr, exp.Lateral) 323 324 unnest = join_expr.this if is_lateral else join_expr 325 326 if isinstance(unnest, exp.Unnest): 327 alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias") 328 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 329 330 expression.args["joins"].remove(join) 331 332 for e, column in zip(unnest.expressions, alias.columns if alias else []): 333 expression.append( 334 "laterals", 335 exp.Lateral( 336 this=udtf(this=e), 337 view=True, 338 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 339 ), 340 ) 341 342 return expression 343 344 345def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 346 """Convert explode/posexplode into unnest.""" 347 348 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 349 if isinstance(expression, exp.Select): 350 from sqlglot.optimizer.scope import Scope 351 352 taken_select_names = set(expression.named_selects) 353 taken_source_names = {name for name, _ in Scope(expression).references} 354 355 def new_name(names: t.Set[str], name: str) -> str: 356 name = find_new_name(names, name) 357 names.add(name) 358 return name 359 360 arrays: t.List[exp.Condition] = [] 361 series_alias = new_name(taken_select_names, "pos") 362 series = exp.alias_( 363 exp.Unnest( 364 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 365 ), 366 new_name(taken_source_names, "_u"), 367 table=[series_alias], 368 ) 369 370 # we use list here because expression.selects is mutated inside the loop 371 for select in list(expression.selects): 372 explode = select.find(exp.Explode) 373 374 if explode: 375 pos_alias = "" 376 explode_alias = "" 377 378 if isinstance(select, exp.Alias): 379 explode_alias = select.args["alias"] 380 alias = select 381 elif isinstance(select, exp.Aliases): 382 pos_alias = select.aliases[0] 383 explode_alias = select.aliases[1] 384 alias = select.replace(exp.alias_(select.this, "", copy=False)) 385 else: 386 alias = select.replace(exp.alias_(select, "")) 387 explode = alias.find(exp.Explode) 388 assert explode 389 390 is_posexplode = isinstance(explode, exp.Posexplode) 391 explode_arg = explode.this 392 393 if isinstance(explode, exp.ExplodeOuter): 394 bracket = explode_arg[0] 395 bracket.set("safe", True) 396 bracket.set("offset", True) 397 explode_arg = exp.func( 398 "IF", 399 exp.func( 400 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 401 ).eq(0), 402 exp.array(bracket, copy=False), 403 explode_arg, 404 ) 405 406 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 407 if isinstance(explode_arg, exp.Column): 408 taken_select_names.add(explode_arg.output_name) 409 410 unnest_source_alias = new_name(taken_source_names, "_u") 411 412 if not explode_alias: 413 explode_alias = new_name(taken_select_names, "col") 414 415 if is_posexplode: 416 pos_alias = new_name(taken_select_names, "pos") 417 418 if not pos_alias: 419 pos_alias = new_name(taken_select_names, "pos") 420 421 alias.set("alias", exp.to_identifier(explode_alias)) 422 423 series_table_alias = series.args["alias"].this 424 column = exp.If( 425 this=exp.column(series_alias, table=series_table_alias).eq( 426 exp.column(pos_alias, table=unnest_source_alias) 427 ), 428 true=exp.column(explode_alias, table=unnest_source_alias), 429 ) 430 431 explode.replace(column) 432 433 if is_posexplode: 434 expressions = expression.expressions 435 expressions.insert( 436 expressions.index(alias) + 1, 437 exp.If( 438 this=exp.column(series_alias, table=series_table_alias).eq( 439 exp.column(pos_alias, table=unnest_source_alias) 440 ), 441 true=exp.column(pos_alias, table=unnest_source_alias), 442 ).as_(pos_alias), 443 ) 444 expression.set("expressions", expressions) 445 446 if not arrays: 447 if expression.args.get("from"): 448 expression.join(series, copy=False, join_type="CROSS") 449 else: 450 expression.from_(series, copy=False) 451 452 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 453 arrays.append(size) 454 455 # trino doesn't support left join unnest with on conditions 456 # if it did, this would be much simpler 457 expression.join( 458 exp.alias_( 459 exp.Unnest( 460 expressions=[explode_arg.copy()], 461 offset=exp.to_identifier(pos_alias), 462 ), 463 unnest_source_alias, 464 table=[explode_alias], 465 ), 466 join_type="CROSS", 467 copy=False, 468 ) 469 470 if index_offset != 1: 471 size = size - 1 472 473 expression.where( 474 exp.column(series_alias, table=series_table_alias) 475 .eq(exp.column(pos_alias, table=unnest_source_alias)) 476 .or_( 477 (exp.column(series_alias, table=series_table_alias) > size).and_( 478 exp.column(pos_alias, table=unnest_source_alias).eq(size) 479 ) 480 ), 481 copy=False, 482 ) 483 484 if arrays: 485 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 486 487 if index_offset != 1: 488 end = end - (1 - index_offset) 489 series.expressions[0].set("end", end) 490 491 return expression 492 493 return _explode_to_unnest 494 495 496def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 497 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 498 if ( 499 isinstance(expression, exp.PERCENTILES) 500 and not isinstance(expression.parent, exp.WithinGroup) 501 and expression.expression 502 ): 503 column = expression.this.pop() 504 expression.set("this", expression.expression.pop()) 505 order = exp.Order(expressions=[exp.Ordered(this=column)]) 506 expression = exp.WithinGroup(this=expression, expression=order) 507 508 return expression 509 510 511def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 512 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 513 if ( 514 isinstance(expression, exp.WithinGroup) 515 and isinstance(expression.this, exp.PERCENTILES) 516 and isinstance(expression.expression, exp.Order) 517 ): 518 quantile = expression.this.this 519 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 520 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 521 522 return expression 523 524 525def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 526 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 527 if isinstance(expression, exp.With) and expression.recursive: 528 next_name = name_sequence("_c_") 529 530 for cte in expression.expressions: 531 if not cte.args["alias"].columns: 532 query = cte.this 533 if isinstance(query, exp.SetOperation): 534 query = query.this 535 536 cte.args["alias"].set( 537 "columns", 538 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 539 ) 540 541 return expression 542 543 544def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 545 """Replace 'epoch' in casts by the equivalent date literal.""" 546 if ( 547 isinstance(expression, (exp.Cast, exp.TryCast)) 548 and expression.name.lower() == "epoch" 549 and expression.to.this in exp.DataType.TEMPORAL_TYPES 550 ): 551 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 552 553 return expression 554 555 556def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 557 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 558 if isinstance(expression, exp.Select): 559 for join in expression.args.get("joins") or []: 560 on = join.args.get("on") 561 if on and join.kind in ("SEMI", "ANTI"): 562 subquery = exp.select("1").from_(join.this).where(on) 563 exists = exp.Exists(this=subquery) 564 if join.kind == "ANTI": 565 exists = exists.not_(copy=False) 566 567 join.pop() 568 expression.where(exists, copy=False) 569 570 return expression 571 572 573def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 574 """ 575 Converts a query with a FULL OUTER join to a union of identical queries that 576 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 577 for queries that have a single FULL OUTER join. 578 """ 579 if isinstance(expression, exp.Select): 580 full_outer_joins = [ 581 (index, join) 582 for index, join in enumerate(expression.args.get("joins") or []) 583 if join.side == "FULL" 584 ] 585 586 if len(full_outer_joins) == 1: 587 expression_copy = expression.copy() 588 expression.set("limit", None) 589 index, full_outer_join = full_outer_joins[0] 590 full_outer_join.set("side", "left") 591 expression_copy.args["joins"][index].set("side", "right") 592 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 593 594 return exp.union(expression, expression_copy, copy=False) 595 596 return expression 597 598 599def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 600 """ 601 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 602 defined at the top-level, so for example queries like: 603 604 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 605 606 are invalid in those dialects. This transformation can be used to ensure all CTEs are 607 moved to the top level so that the final SQL code is valid from a syntax standpoint. 608 609 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 610 """ 611 top_level_with = expression.args.get("with") 612 for inner_with in expression.find_all(exp.With): 613 if inner_with.parent is expression: 614 continue 615 616 if not top_level_with: 617 top_level_with = inner_with.pop() 618 expression.set("with", top_level_with) 619 else: 620 if inner_with.recursive: 621 top_level_with.set("recursive", True) 622 623 parent_cte = inner_with.find_ancestor(exp.CTE) 624 inner_with.pop() 625 626 if parent_cte: 627 i = top_level_with.expressions.index(parent_cte) 628 top_level_with.expressions[i:i] = inner_with.expressions 629 top_level_with.set("expressions", top_level_with.expressions) 630 else: 631 top_level_with.set( 632 "expressions", top_level_with.expressions + inner_with.expressions 633 ) 634 635 return expression 636 637 638def ensure_bools(expression: exp.Expression) -> exp.Expression: 639 """Converts numeric values used in conditions into explicit boolean expressions.""" 640 from sqlglot.optimizer.canonicalize import ensure_bools 641 642 def _ensure_bool(node: exp.Expression) -> None: 643 if ( 644 node.is_number 645 or ( 646 not isinstance(node, exp.SubqueryPredicate) 647 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 648 ) 649 or (isinstance(node, exp.Column) and not node.type) 650 ): 651 node.replace(node.neq(0)) 652 653 for node in expression.walk(): 654 ensure_bools(node, _ensure_bool) 655 656 return expression 657 658 659def unqualify_columns(expression: exp.Expression) -> exp.Expression: 660 for column in expression.find_all(exp.Column): 661 # We only wanna pop off the table, db, catalog args 662 for part in column.parts[:-1]: 663 part.pop() 664 665 return expression 666 667 668def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 669 assert isinstance(expression, exp.Create) 670 for constraint in expression.find_all(exp.UniqueColumnConstraint): 671 if constraint.parent: 672 constraint.parent.pop() 673 674 return expression 675 676 677def ctas_with_tmp_tables_to_create_tmp_view( 678 expression: exp.Expression, 679 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 680) -> exp.Expression: 681 assert isinstance(expression, exp.Create) 682 properties = expression.args.get("properties") 683 temporary = any( 684 isinstance(prop, exp.TemporaryProperty) 685 for prop in (properties.expressions if properties else []) 686 ) 687 688 # CTAS with temp tables map to CREATE TEMPORARY VIEW 689 if expression.kind == "TABLE" and temporary: 690 if expression.expression: 691 return exp.Create( 692 kind="TEMPORARY VIEW", 693 this=expression.this, 694 expression=expression.expression, 695 ) 696 return tmp_storage_provider(expression) 697 698 return expression 699 700 701def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 702 """ 703 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 704 PARTITIONED BY value is an array of column names, they are transformed into a schema. 705 The corresponding columns are removed from the create statement. 706 """ 707 assert isinstance(expression, exp.Create) 708 has_schema = isinstance(expression.this, exp.Schema) 709 is_partitionable = expression.kind in {"TABLE", "VIEW"} 710 711 if has_schema and is_partitionable: 712 prop = expression.find(exp.PartitionedByProperty) 713 if prop and prop.this and not isinstance(prop.this, exp.Schema): 714 schema = expression.this 715 columns = {v.name.upper() for v in prop.this.expressions} 716 partitions = [col for col in schema.expressions if col.name.upper() in columns] 717 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 718 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 719 expression.set("this", schema) 720 721 return expression 722 723 724def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 725 """ 726 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 727 728 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 729 """ 730 assert isinstance(expression, exp.Create) 731 prop = expression.find(exp.PartitionedByProperty) 732 if ( 733 prop 734 and prop.this 735 and isinstance(prop.this, exp.Schema) 736 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 737 ): 738 prop_this = exp.Tuple( 739 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 740 ) 741 schema = expression.this 742 for e in prop.this.expressions: 743 schema.append("expressions", e) 744 prop.set("this", prop_this) 745 746 return expression 747 748 749def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 750 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 751 if isinstance(expression, exp.Struct): 752 expression.set( 753 "expressions", 754 [ 755 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 756 for e in expression.expressions 757 ], 758 ) 759 760 return expression 761 762 763def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 764 """ 765 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 766 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 767 768 For example, 769 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 770 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 771 772 Args: 773 expression: The AST to remove join marks from. 774 775 Returns: 776 The AST with join marks removed. 777 """ 778 from sqlglot.optimizer.scope import traverse_scope 779 780 for scope in traverse_scope(expression): 781 query = scope.expression 782 783 where = query.args.get("where") 784 joins = query.args.get("joins") 785 786 if not where or not joins: 787 continue 788 789 query_from = query.args["from"] 790 791 # These keep track of the joins to be replaced 792 new_joins: t.Dict[str, exp.Join] = {} 793 old_joins = {join.alias_or_name: join for join in joins} 794 795 for column in scope.columns: 796 if not column.args.get("join_mark"): 797 continue 798 799 predicate = column.find_ancestor(exp.Predicate, exp.Select) 800 assert isinstance( 801 predicate, exp.Binary 802 ), "Columns can only be marked with (+) when involved in a binary operation" 803 804 predicate_parent = predicate.parent 805 join_predicate = predicate.pop() 806 807 left_columns = [ 808 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 809 ] 810 right_columns = [ 811 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 812 ] 813 814 assert not ( 815 left_columns and right_columns 816 ), "The (+) marker cannot appear in both sides of a binary predicate" 817 818 marked_column_tables = set() 819 for col in left_columns or right_columns: 820 table = col.table 821 assert table, f"Column {col} needs to be qualified with a table" 822 823 col.set("join_mark", False) 824 marked_column_tables.add(table) 825 826 assert ( 827 len(marked_column_tables) == 1 828 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 829 830 join_this = old_joins.get(col.table, query_from).this 831 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 832 833 # Upsert new_join into new_joins dictionary 834 new_join_alias_or_name = new_join.alias_or_name 835 existing_join = new_joins.get(new_join_alias_or_name) 836 if existing_join: 837 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 838 else: 839 new_joins[new_join_alias_or_name] = new_join 840 841 # If the parent of the target predicate is a binary node, then it now has only one child 842 if isinstance(predicate_parent, exp.Binary): 843 if predicate_parent.left is None: 844 predicate_parent.replace(predicate_parent.right) 845 else: 846 predicate_parent.replace(predicate_parent.left) 847 848 if query_from.alias_or_name in new_joins: 849 only_old_joins = old_joins.keys() - new_joins.keys() 850 assert ( 851 len(only_old_joins) >= 1 852 ), "Cannot determine which table to use in the new FROM clause" 853 854 new_from_name = list(only_old_joins)[0] 855 query.set("from", exp.From(this=old_joins[new_from_name].this)) 856 857 query.set("joins", list(new_joins.values())) 858 859 if not where.this: 860 where.pop() 861 862 return expression
13def preprocess( 14 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 15) -> t.Callable[[Generator, exp.Expression], str]: 16 """ 17 Creates a new transform by chaining a sequence of transformations and converts the resulting 18 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 19 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 20 21 Args: 22 transforms: sequence of transform functions. These will be called in order. 23 24 Returns: 25 Function that can be used as a generator transform. 26 """ 27 28 def _to_sql(self, expression: exp.Expression) -> str: 29 expression_type = type(expression) 30 31 expression = transforms[0](expression) 32 for transform in transforms[1:]: 33 expression = transform(expression) 34 35 _sql_handler = getattr(self, expression.key + "_sql", None) 36 if _sql_handler: 37 return _sql_handler(expression) 38 39 transforms_handler = self.TRANSFORMS.get(type(expression)) 40 if transforms_handler: 41 if expression_type is type(expression): 42 if isinstance(expression, exp.Func): 43 return self.function_fallback_sql(expression) 44 45 # Ensures we don't enter an infinite loop. This can happen when the original expression 46 # has the same type as the final expression and there's no _sql method available for it, 47 # because then it'd re-enter _to_sql. 48 raise ValueError( 49 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 50 ) 51 52 return transforms_handler(self, expression) 53 54 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 55 56 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
59def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 60 if isinstance(expression, exp.Select): 61 count = 0 62 recursive_ctes = [] 63 64 for unnest in expression.find_all(exp.Unnest): 65 if ( 66 not isinstance(unnest.parent, (exp.From, exp.Join)) 67 or len(unnest.expressions) != 1 68 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 69 ): 70 continue 71 72 generate_date_array = unnest.expressions[0] 73 start = generate_date_array.args.get("start") 74 end = generate_date_array.args.get("end") 75 step = generate_date_array.args.get("step") 76 77 if not start or not end or not isinstance(step, exp.Interval): 78 continue 79 80 alias = unnest.args.get("alias") 81 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 82 83 start = exp.cast(start, "date") 84 date_add = exp.func( 85 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 86 ) 87 cast_date_add = exp.cast(date_add, "date") 88 89 cte_name = "_generated_dates" + (f"_{count}" if count else "") 90 91 base_query = exp.select(start.as_(column_name)) 92 recursive_query = ( 93 exp.select(cast_date_add) 94 .from_(cte_name) 95 .where(cast_date_add <= exp.cast(end, "date")) 96 ) 97 cte_query = base_query.union(recursive_query, distinct=False) 98 99 generate_dates_query = exp.select(column_name).from_(cte_name) 100 unnest.replace(generate_dates_query.subquery(cte_name)) 101 102 recursive_ctes.append( 103 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 104 ) 105 count += 1 106 107 if recursive_ctes: 108 with_expression = expression.args.get("with") or exp.With() 109 with_expression.set("recursive", True) 110 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 111 expression.set("with", with_expression) 112 113 return expression
116def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 117 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 118 this = expression.this 119 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 120 unnest = exp.Unnest(expressions=[this]) 121 if expression.alias: 122 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 123 124 return unnest 125 126 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
129def unalias_group(expression: exp.Expression) -> exp.Expression: 130 """ 131 Replace references to select aliases in GROUP BY clauses. 132 133 Example: 134 >>> import sqlglot 135 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 136 'SELECT a AS b FROM x GROUP BY 1' 137 138 Args: 139 expression: the expression that will be transformed. 140 141 Returns: 142 The transformed expression. 143 """ 144 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 145 aliased_selects = { 146 e.alias: i 147 for i, e in enumerate(expression.parent.expressions, start=1) 148 if isinstance(e, exp.Alias) 149 } 150 151 for group_by in expression.expressions: 152 if ( 153 isinstance(group_by, exp.Column) 154 and not group_by.table 155 and group_by.name in aliased_selects 156 ): 157 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 158 159 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
162def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 163 """ 164 Convert SELECT DISTINCT ON statements to a subquery with a window function. 165 166 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 167 168 Args: 169 expression: the expression that will be transformed. 170 171 Returns: 172 The transformed expression. 173 """ 174 if ( 175 isinstance(expression, exp.Select) 176 and expression.args.get("distinct") 177 and expression.args["distinct"].args.get("on") 178 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 179 ): 180 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 181 outer_selects = expression.selects 182 row_number = find_new_name(expression.named_selects, "_row_number") 183 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 184 order = expression.args.get("order") 185 186 if order: 187 window.set("order", order.pop()) 188 else: 189 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 190 191 window = exp.alias_(window, row_number) 192 expression.select(window, copy=False) 193 194 return ( 195 exp.select(*outer_selects, copy=False) 196 .from_(expression.subquery("_t", copy=False), copy=False) 197 .where(exp.column(row_number).eq(1), copy=False) 198 ) 199 200 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
203def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 204 """ 205 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 206 207 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 208 https://docs.snowflake.com/en/sql-reference/constructs/qualify 209 210 Some dialects don't support window functions in the WHERE clause, so we need to include them as 211 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 212 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 213 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 214 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 215 corresponding expression to avoid creating invalid column references. 216 """ 217 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 218 taken = set(expression.named_selects) 219 for select in expression.selects: 220 if not select.alias_or_name: 221 alias = find_new_name(taken, "_c") 222 select.replace(exp.alias_(select, alias)) 223 taken.add(alias) 224 225 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 226 alias_or_name = select.alias_or_name 227 identifier = select.args.get("alias") or select.this 228 if isinstance(identifier, exp.Identifier): 229 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 230 return alias_or_name 231 232 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 233 qualify_filters = expression.args["qualify"].pop().this 234 expression_by_alias = { 235 select.alias: select.this 236 for select in expression.selects 237 if isinstance(select, exp.Alias) 238 } 239 240 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 241 for select_candidate in qualify_filters.find_all(select_candidates): 242 if isinstance(select_candidate, exp.Window): 243 if expression_by_alias: 244 for column in select_candidate.find_all(exp.Column): 245 expr = expression_by_alias.get(column.name) 246 if expr: 247 column.replace(expr) 248 249 alias = find_new_name(expression.named_selects, "_w") 250 expression.select(exp.alias_(select_candidate, alias), copy=False) 251 column = exp.column(alias) 252 253 if isinstance(select_candidate.parent, exp.Qualify): 254 qualify_filters = column 255 else: 256 select_candidate.replace(column) 257 elif select_candidate.name not in expression.named_selects: 258 expression.select(select_candidate.copy(), copy=False) 259 260 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 261 qualify_filters, copy=False 262 ) 263 264 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
267def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 268 """ 269 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 270 other expressions. This transforms removes the precision from parameterized types in expressions. 271 """ 272 for node in expression.find_all(exp.DataType): 273 node.set( 274 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 275 ) 276 277 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
280def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 281 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 282 from sqlglot.optimizer.scope import find_all_in_scope 283 284 if isinstance(expression, exp.Select): 285 unnest_aliases = { 286 unnest.alias 287 for unnest in find_all_in_scope(expression, exp.Unnest) 288 if isinstance(unnest.parent, (exp.From, exp.Join)) 289 } 290 if unnest_aliases: 291 for column in expression.find_all(exp.Column): 292 if column.table in unnest_aliases: 293 column.set("table", None) 294 elif column.db in unnest_aliases: 295 column.set("db", None) 296 297 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
300def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 301 """Convert cross join unnest into lateral view explode.""" 302 if isinstance(expression, exp.Select): 303 from_ = expression.args.get("from") 304 305 if from_ and isinstance(from_.this, exp.Unnest): 306 unnest = from_.this 307 alias = unnest.args.get("alias") 308 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 309 this, *expressions = unnest.expressions 310 unnest.replace( 311 exp.Table( 312 this=udtf( 313 this=this, 314 expressions=expressions, 315 ), 316 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 317 ) 318 ) 319 320 for join in expression.args.get("joins") or []: 321 join_expr = join.this 322 323 is_lateral = isinstance(join_expr, exp.Lateral) 324 325 unnest = join_expr.this if is_lateral else join_expr 326 327 if isinstance(unnest, exp.Unnest): 328 alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias") 329 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 330 331 expression.args["joins"].remove(join) 332 333 for e, column in zip(unnest.expressions, alias.columns if alias else []): 334 expression.append( 335 "laterals", 336 exp.Lateral( 337 this=udtf(this=e), 338 view=True, 339 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 340 ), 341 ) 342 343 return expression
Convert cross join unnest into lateral view explode.
346def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 347 """Convert explode/posexplode into unnest.""" 348 349 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 350 if isinstance(expression, exp.Select): 351 from sqlglot.optimizer.scope import Scope 352 353 taken_select_names = set(expression.named_selects) 354 taken_source_names = {name for name, _ in Scope(expression).references} 355 356 def new_name(names: t.Set[str], name: str) -> str: 357 name = find_new_name(names, name) 358 names.add(name) 359 return name 360 361 arrays: t.List[exp.Condition] = [] 362 series_alias = new_name(taken_select_names, "pos") 363 series = exp.alias_( 364 exp.Unnest( 365 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 366 ), 367 new_name(taken_source_names, "_u"), 368 table=[series_alias], 369 ) 370 371 # we use list here because expression.selects is mutated inside the loop 372 for select in list(expression.selects): 373 explode = select.find(exp.Explode) 374 375 if explode: 376 pos_alias = "" 377 explode_alias = "" 378 379 if isinstance(select, exp.Alias): 380 explode_alias = select.args["alias"] 381 alias = select 382 elif isinstance(select, exp.Aliases): 383 pos_alias = select.aliases[0] 384 explode_alias = select.aliases[1] 385 alias = select.replace(exp.alias_(select.this, "", copy=False)) 386 else: 387 alias = select.replace(exp.alias_(select, "")) 388 explode = alias.find(exp.Explode) 389 assert explode 390 391 is_posexplode = isinstance(explode, exp.Posexplode) 392 explode_arg = explode.this 393 394 if isinstance(explode, exp.ExplodeOuter): 395 bracket = explode_arg[0] 396 bracket.set("safe", True) 397 bracket.set("offset", True) 398 explode_arg = exp.func( 399 "IF", 400 exp.func( 401 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 402 ).eq(0), 403 exp.array(bracket, copy=False), 404 explode_arg, 405 ) 406 407 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 408 if isinstance(explode_arg, exp.Column): 409 taken_select_names.add(explode_arg.output_name) 410 411 unnest_source_alias = new_name(taken_source_names, "_u") 412 413 if not explode_alias: 414 explode_alias = new_name(taken_select_names, "col") 415 416 if is_posexplode: 417 pos_alias = new_name(taken_select_names, "pos") 418 419 if not pos_alias: 420 pos_alias = new_name(taken_select_names, "pos") 421 422 alias.set("alias", exp.to_identifier(explode_alias)) 423 424 series_table_alias = series.args["alias"].this 425 column = exp.If( 426 this=exp.column(series_alias, table=series_table_alias).eq( 427 exp.column(pos_alias, table=unnest_source_alias) 428 ), 429 true=exp.column(explode_alias, table=unnest_source_alias), 430 ) 431 432 explode.replace(column) 433 434 if is_posexplode: 435 expressions = expression.expressions 436 expressions.insert( 437 expressions.index(alias) + 1, 438 exp.If( 439 this=exp.column(series_alias, table=series_table_alias).eq( 440 exp.column(pos_alias, table=unnest_source_alias) 441 ), 442 true=exp.column(pos_alias, table=unnest_source_alias), 443 ).as_(pos_alias), 444 ) 445 expression.set("expressions", expressions) 446 447 if not arrays: 448 if expression.args.get("from"): 449 expression.join(series, copy=False, join_type="CROSS") 450 else: 451 expression.from_(series, copy=False) 452 453 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 454 arrays.append(size) 455 456 # trino doesn't support left join unnest with on conditions 457 # if it did, this would be much simpler 458 expression.join( 459 exp.alias_( 460 exp.Unnest( 461 expressions=[explode_arg.copy()], 462 offset=exp.to_identifier(pos_alias), 463 ), 464 unnest_source_alias, 465 table=[explode_alias], 466 ), 467 join_type="CROSS", 468 copy=False, 469 ) 470 471 if index_offset != 1: 472 size = size - 1 473 474 expression.where( 475 exp.column(series_alias, table=series_table_alias) 476 .eq(exp.column(pos_alias, table=unnest_source_alias)) 477 .or_( 478 (exp.column(series_alias, table=series_table_alias) > size).and_( 479 exp.column(pos_alias, table=unnest_source_alias).eq(size) 480 ) 481 ), 482 copy=False, 483 ) 484 485 if arrays: 486 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 487 488 if index_offset != 1: 489 end = end - (1 - index_offset) 490 series.expressions[0].set("end", end) 491 492 return expression 493 494 return _explode_to_unnest
Convert explode/posexplode into unnest.
497def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 498 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 499 if ( 500 isinstance(expression, exp.PERCENTILES) 501 and not isinstance(expression.parent, exp.WithinGroup) 502 and expression.expression 503 ): 504 column = expression.this.pop() 505 expression.set("this", expression.expression.pop()) 506 order = exp.Order(expressions=[exp.Ordered(this=column)]) 507 expression = exp.WithinGroup(this=expression, expression=order) 508 509 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
512def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 513 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 514 if ( 515 isinstance(expression, exp.WithinGroup) 516 and isinstance(expression.this, exp.PERCENTILES) 517 and isinstance(expression.expression, exp.Order) 518 ): 519 quantile = expression.this.this 520 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 521 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 522 523 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
526def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 527 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 528 if isinstance(expression, exp.With) and expression.recursive: 529 next_name = name_sequence("_c_") 530 531 for cte in expression.expressions: 532 if not cte.args["alias"].columns: 533 query = cte.this 534 if isinstance(query, exp.SetOperation): 535 query = query.this 536 537 cte.args["alias"].set( 538 "columns", 539 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 540 ) 541 542 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
545def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 546 """Replace 'epoch' in casts by the equivalent date literal.""" 547 if ( 548 isinstance(expression, (exp.Cast, exp.TryCast)) 549 and expression.name.lower() == "epoch" 550 and expression.to.this in exp.DataType.TEMPORAL_TYPES 551 ): 552 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 553 554 return expression
Replace 'epoch' in casts by the equivalent date literal.
557def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 558 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 559 if isinstance(expression, exp.Select): 560 for join in expression.args.get("joins") or []: 561 on = join.args.get("on") 562 if on and join.kind in ("SEMI", "ANTI"): 563 subquery = exp.select("1").from_(join.this).where(on) 564 exists = exp.Exists(this=subquery) 565 if join.kind == "ANTI": 566 exists = exists.not_(copy=False) 567 568 join.pop() 569 expression.where(exists, copy=False) 570 571 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
574def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 575 """ 576 Converts a query with a FULL OUTER join to a union of identical queries that 577 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 578 for queries that have a single FULL OUTER join. 579 """ 580 if isinstance(expression, exp.Select): 581 full_outer_joins = [ 582 (index, join) 583 for index, join in enumerate(expression.args.get("joins") or []) 584 if join.side == "FULL" 585 ] 586 587 if len(full_outer_joins) == 1: 588 expression_copy = expression.copy() 589 expression.set("limit", None) 590 index, full_outer_join = full_outer_joins[0] 591 full_outer_join.set("side", "left") 592 expression_copy.args["joins"][index].set("side", "right") 593 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 594 595 return exp.union(expression, expression_copy, copy=False) 596 597 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
600def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 601 """ 602 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 603 defined at the top-level, so for example queries like: 604 605 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 606 607 are invalid in those dialects. This transformation can be used to ensure all CTEs are 608 moved to the top level so that the final SQL code is valid from a syntax standpoint. 609 610 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 611 """ 612 top_level_with = expression.args.get("with") 613 for inner_with in expression.find_all(exp.With): 614 if inner_with.parent is expression: 615 continue 616 617 if not top_level_with: 618 top_level_with = inner_with.pop() 619 expression.set("with", top_level_with) 620 else: 621 if inner_with.recursive: 622 top_level_with.set("recursive", True) 623 624 parent_cte = inner_with.find_ancestor(exp.CTE) 625 inner_with.pop() 626 627 if parent_cte: 628 i = top_level_with.expressions.index(parent_cte) 629 top_level_with.expressions[i:i] = inner_with.expressions 630 top_level_with.set("expressions", top_level_with.expressions) 631 else: 632 top_level_with.set( 633 "expressions", top_level_with.expressions + inner_with.expressions 634 ) 635 636 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
639def ensure_bools(expression: exp.Expression) -> exp.Expression: 640 """Converts numeric values used in conditions into explicit boolean expressions.""" 641 from sqlglot.optimizer.canonicalize import ensure_bools 642 643 def _ensure_bool(node: exp.Expression) -> None: 644 if ( 645 node.is_number 646 or ( 647 not isinstance(node, exp.SubqueryPredicate) 648 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 649 ) 650 or (isinstance(node, exp.Column) and not node.type) 651 ): 652 node.replace(node.neq(0)) 653 654 for node in expression.walk(): 655 ensure_bools(node, _ensure_bool) 656 657 return expression
Converts numeric values used in conditions into explicit boolean expressions.
678def ctas_with_tmp_tables_to_create_tmp_view( 679 expression: exp.Expression, 680 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 681) -> exp.Expression: 682 assert isinstance(expression, exp.Create) 683 properties = expression.args.get("properties") 684 temporary = any( 685 isinstance(prop, exp.TemporaryProperty) 686 for prop in (properties.expressions if properties else []) 687 ) 688 689 # CTAS with temp tables map to CREATE TEMPORARY VIEW 690 if expression.kind == "TABLE" and temporary: 691 if expression.expression: 692 return exp.Create( 693 kind="TEMPORARY VIEW", 694 this=expression.this, 695 expression=expression.expression, 696 ) 697 return tmp_storage_provider(expression) 698 699 return expression
702def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 703 """ 704 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 705 PARTITIONED BY value is an array of column names, they are transformed into a schema. 706 The corresponding columns are removed from the create statement. 707 """ 708 assert isinstance(expression, exp.Create) 709 has_schema = isinstance(expression.this, exp.Schema) 710 is_partitionable = expression.kind in {"TABLE", "VIEW"} 711 712 if has_schema and is_partitionable: 713 prop = expression.find(exp.PartitionedByProperty) 714 if prop and prop.this and not isinstance(prop.this, exp.Schema): 715 schema = expression.this 716 columns = {v.name.upper() for v in prop.this.expressions} 717 partitions = [col for col in schema.expressions if col.name.upper() in columns] 718 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 719 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 720 expression.set("this", schema) 721 722 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
725def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 726 """ 727 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 728 729 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 730 """ 731 assert isinstance(expression, exp.Create) 732 prop = expression.find(exp.PartitionedByProperty) 733 if ( 734 prop 735 and prop.this 736 and isinstance(prop.this, exp.Schema) 737 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 738 ): 739 prop_this = exp.Tuple( 740 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 741 ) 742 schema = expression.this 743 for e in prop.this.expressions: 744 schema.append("expressions", e) 745 prop.set("this", prop_this) 746 747 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
750def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 751 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 752 if isinstance(expression, exp.Struct): 753 expression.set( 754 "expressions", 755 [ 756 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 757 for e in expression.expressions 758 ], 759 ) 760 761 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
764def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 765 """ 766 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 767 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 768 769 For example, 770 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 771 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 772 773 Args: 774 expression: The AST to remove join marks from. 775 776 Returns: 777 The AST with join marks removed. 778 """ 779 from sqlglot.optimizer.scope import traverse_scope 780 781 for scope in traverse_scope(expression): 782 query = scope.expression 783 784 where = query.args.get("where") 785 joins = query.args.get("joins") 786 787 if not where or not joins: 788 continue 789 790 query_from = query.args["from"] 791 792 # These keep track of the joins to be replaced 793 new_joins: t.Dict[str, exp.Join] = {} 794 old_joins = {join.alias_or_name: join for join in joins} 795 796 for column in scope.columns: 797 if not column.args.get("join_mark"): 798 continue 799 800 predicate = column.find_ancestor(exp.Predicate, exp.Select) 801 assert isinstance( 802 predicate, exp.Binary 803 ), "Columns can only be marked with (+) when involved in a binary operation" 804 805 predicate_parent = predicate.parent 806 join_predicate = predicate.pop() 807 808 left_columns = [ 809 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 810 ] 811 right_columns = [ 812 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 813 ] 814 815 assert not ( 816 left_columns and right_columns 817 ), "The (+) marker cannot appear in both sides of a binary predicate" 818 819 marked_column_tables = set() 820 for col in left_columns or right_columns: 821 table = col.table 822 assert table, f"Column {col} needs to be qualified with a table" 823 824 col.set("join_mark", False) 825 marked_column_tables.add(table) 826 827 assert ( 828 len(marked_column_tables) == 1 829 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 830 831 join_this = old_joins.get(col.table, query_from).this 832 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 833 834 # Upsert new_join into new_joins dictionary 835 new_join_alias_or_name = new_join.alias_or_name 836 existing_join = new_joins.get(new_join_alias_or_name) 837 if existing_join: 838 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 839 else: 840 new_joins[new_join_alias_or_name] = new_join 841 842 # If the parent of the target predicate is a binary node, then it now has only one child 843 if isinstance(predicate_parent, exp.Binary): 844 if predicate_parent.left is None: 845 predicate_parent.replace(predicate_parent.right) 846 else: 847 predicate_parent.replace(predicate_parent.left) 848 849 if query_from.alias_or_name in new_joins: 850 only_old_joins = old_joins.keys() - new_joins.keys() 851 assert ( 852 len(only_old_joins) >= 1 853 ), "Cannot determine which table to use in the new FROM clause" 854 855 new_from_name = list(only_old_joins)[0] 856 query.set("from", exp.From(this=old_joins[new_from_name].this)) 857 858 query.set("joins", list(new_joins.values())) 859 860 if not where.this: 861 where.pop() 862 863 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.