sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. 97 """ 98 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 99 taken = set(expression.named_selects) 100 for select in expression.selects: 101 if not select.alias_or_name: 102 alias = find_new_name(taken, "_c") 103 select.replace(exp.alias_(select, alias)) 104 taken.add(alias) 105 106 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 107 qualify_filters = expression.args["qualify"].pop().this 108 109 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 110 for expr in qualify_filters.find_all(select_candidates): 111 if isinstance(expr, exp.Window): 112 alias = find_new_name(expression.named_selects, "_w") 113 expression.select(exp.alias_(expr, alias), copy=False) 114 column = exp.column(alias) 115 116 if isinstance(expr.parent, exp.Qualify): 117 qualify_filters = column 118 else: 119 expr.replace(column) 120 elif expr.name not in expression.named_selects: 121 expression.select(expr.copy(), copy=False) 122 123 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 124 qualify_filters, copy=False 125 ) 126 127 return expression 128 129 130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 131 """ 132 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 133 other expressions. This transforms removes the precision from parameterized types in expressions. 134 """ 135 for node in expression.find_all(exp.DataType): 136 node.set( 137 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 138 ) 139 140 return expression 141 142 143def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 144 """Convert cross join unnest into lateral view explode.""" 145 if isinstance(expression, exp.Select): 146 for join in expression.args.get("joins") or []: 147 unnest = join.this 148 149 if isinstance(unnest, exp.Unnest): 150 alias = unnest.args.get("alias") 151 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 152 153 expression.args["joins"].remove(join) 154 155 for e, column in zip(unnest.expressions, alias.columns if alias else []): 156 expression.append( 157 "laterals", 158 exp.Lateral( 159 this=udtf(this=e), 160 view=True, 161 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 162 ), 163 ) 164 165 return expression 166 167 168def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 169 """Convert explode/posexplode into unnest.""" 170 171 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 172 if isinstance(expression, exp.Select): 173 from sqlglot.optimizer.scope import Scope 174 175 taken_select_names = set(expression.named_selects) 176 taken_source_names = {name for name, _ in Scope(expression).references} 177 178 def new_name(names: t.Set[str], name: str) -> str: 179 name = find_new_name(names, name) 180 names.add(name) 181 return name 182 183 arrays: t.List[exp.Condition] = [] 184 series_alias = new_name(taken_select_names, "pos") 185 series = exp.alias_( 186 exp.Unnest( 187 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 188 ), 189 new_name(taken_source_names, "_u"), 190 table=[series_alias], 191 ) 192 193 # we use list here because expression.selects is mutated inside the loop 194 for select in list(expression.selects): 195 explode = select.find(exp.Explode) 196 197 if explode: 198 pos_alias = "" 199 explode_alias = "" 200 201 if isinstance(select, exp.Alias): 202 explode_alias = select.args["alias"] 203 alias = select 204 elif isinstance(select, exp.Aliases): 205 pos_alias = select.aliases[0] 206 explode_alias = select.aliases[1] 207 alias = select.replace(exp.alias_(select.this, "", copy=False)) 208 else: 209 alias = select.replace(exp.alias_(select, "")) 210 explode = alias.find(exp.Explode) 211 assert explode 212 213 is_posexplode = isinstance(explode, exp.Posexplode) 214 explode_arg = explode.this 215 216 if isinstance(explode, exp.ExplodeOuter): 217 bracket = explode_arg[0] 218 bracket.set("safe", True) 219 bracket.set("offset", True) 220 explode_arg = exp.func( 221 "IF", 222 exp.func( 223 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 224 ).eq(0), 225 exp.array(bracket, copy=False), 226 explode_arg, 227 ) 228 229 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 230 if isinstance(explode_arg, exp.Column): 231 taken_select_names.add(explode_arg.output_name) 232 233 unnest_source_alias = new_name(taken_source_names, "_u") 234 235 if not explode_alias: 236 explode_alias = new_name(taken_select_names, "col") 237 238 if is_posexplode: 239 pos_alias = new_name(taken_select_names, "pos") 240 241 if not pos_alias: 242 pos_alias = new_name(taken_select_names, "pos") 243 244 alias.set("alias", exp.to_identifier(explode_alias)) 245 246 series_table_alias = series.args["alias"].this 247 column = exp.If( 248 this=exp.column(series_alias, table=series_table_alias).eq( 249 exp.column(pos_alias, table=unnest_source_alias) 250 ), 251 true=exp.column(explode_alias, table=unnest_source_alias), 252 ) 253 254 explode.replace(column) 255 256 if is_posexplode: 257 expressions = expression.expressions 258 expressions.insert( 259 expressions.index(alias) + 1, 260 exp.If( 261 this=exp.column(series_alias, table=series_table_alias).eq( 262 exp.column(pos_alias, table=unnest_source_alias) 263 ), 264 true=exp.column(pos_alias, table=unnest_source_alias), 265 ).as_(pos_alias), 266 ) 267 expression.set("expressions", expressions) 268 269 if not arrays: 270 if expression.args.get("from"): 271 expression.join(series, copy=False, join_type="CROSS") 272 else: 273 expression.from_(series, copy=False) 274 275 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 276 arrays.append(size) 277 278 # trino doesn't support left join unnest with on conditions 279 # if it did, this would be much simpler 280 expression.join( 281 exp.alias_( 282 exp.Unnest( 283 expressions=[explode_arg.copy()], 284 offset=exp.to_identifier(pos_alias), 285 ), 286 unnest_source_alias, 287 table=[explode_alias], 288 ), 289 join_type="CROSS", 290 copy=False, 291 ) 292 293 if index_offset != 1: 294 size = size - 1 295 296 expression.where( 297 exp.column(series_alias, table=series_table_alias) 298 .eq(exp.column(pos_alias, table=unnest_source_alias)) 299 .or_( 300 (exp.column(series_alias, table=series_table_alias) > size).and_( 301 exp.column(pos_alias, table=unnest_source_alias).eq(size) 302 ) 303 ), 304 copy=False, 305 ) 306 307 if arrays: 308 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 309 310 if index_offset != 1: 311 end = end - (1 - index_offset) 312 series.expressions[0].set("end", end) 313 314 return expression 315 316 return _explode_to_unnest 317 318 319PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 320 321 322def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 323 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 324 if ( 325 isinstance(expression, PERCENTILES) 326 and not isinstance(expression.parent, exp.WithinGroup) 327 and expression.expression 328 ): 329 column = expression.this.pop() 330 expression.set("this", expression.expression.pop()) 331 order = exp.Order(expressions=[exp.Ordered(this=column)]) 332 expression = exp.WithinGroup(this=expression, expression=order) 333 334 return expression 335 336 337def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 338 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 339 if ( 340 isinstance(expression, exp.WithinGroup) 341 and isinstance(expression.this, PERCENTILES) 342 and isinstance(expression.expression, exp.Order) 343 ): 344 quantile = expression.this.this 345 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 346 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 347 348 return expression 349 350 351def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 352 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 353 if isinstance(expression, exp.With) and expression.recursive: 354 next_name = name_sequence("_c_") 355 356 for cte in expression.expressions: 357 if not cte.args["alias"].columns: 358 query = cte.this 359 if isinstance(query, exp.Union): 360 query = query.this 361 362 cte.args["alias"].set( 363 "columns", 364 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 365 ) 366 367 return expression 368 369 370def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 371 """Replace 'epoch' in casts by the equivalent date literal.""" 372 if ( 373 isinstance(expression, (exp.Cast, exp.TryCast)) 374 and expression.name.lower() == "epoch" 375 and expression.to.this in exp.DataType.TEMPORAL_TYPES 376 ): 377 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 378 379 return expression 380 381 382def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 383 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 384 if isinstance(expression, exp.Select): 385 for join in expression.args.get("joins") or []: 386 on = join.args.get("on") 387 if on and join.kind in ("SEMI", "ANTI"): 388 subquery = exp.select("1").from_(join.this).where(on) 389 exists = exp.Exists(this=subquery) 390 if join.kind == "ANTI": 391 exists = exists.not_(copy=False) 392 393 join.pop() 394 expression.where(exists, copy=False) 395 396 return expression 397 398 399def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 400 """ 401 Converts a query with a FULL OUTER join to a union of identical queries that 402 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 403 for queries that have a single FULL OUTER join. 404 """ 405 if isinstance(expression, exp.Select): 406 full_outer_joins = [ 407 (index, join) 408 for index, join in enumerate(expression.args.get("joins") or []) 409 if join.side == "FULL" 410 ] 411 412 if len(full_outer_joins) == 1: 413 expression_copy = expression.copy() 414 expression.set("limit", None) 415 index, full_outer_join = full_outer_joins[0] 416 full_outer_join.set("side", "left") 417 expression_copy.args["joins"][index].set("side", "right") 418 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 419 420 return exp.union(expression, expression_copy, copy=False) 421 422 return expression 423 424 425def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 426 """ 427 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 428 defined at the top-level, so for example queries like: 429 430 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 431 432 are invalid in those dialects. This transformation can be used to ensure all CTEs are 433 moved to the top level so that the final SQL code is valid from a syntax standpoint. 434 435 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 436 """ 437 top_level_with = expression.args.get("with") 438 for node in expression.find_all(exp.With): 439 if node.parent is expression: 440 continue 441 442 inner_with = node.pop() 443 if not top_level_with: 444 top_level_with = inner_with 445 expression.set("with", top_level_with) 446 else: 447 if inner_with.recursive: 448 top_level_with.set("recursive", True) 449 450 top_level_with.expressions.extend(inner_with.expressions) 451 452 return expression 453 454 455def ensure_bools(expression: exp.Expression) -> exp.Expression: 456 """Converts numeric values used in conditions into explicit boolean expressions.""" 457 from sqlglot.optimizer.canonicalize import ensure_bools 458 459 def _ensure_bool(node: exp.Expression) -> None: 460 if ( 461 node.is_number 462 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 463 or (isinstance(node, exp.Column) and not node.type) 464 ): 465 node.replace(node.neq(0)) 466 467 for node, *_ in expression.walk(): 468 ensure_bools(node, _ensure_bool) 469 470 return expression 471 472 473def unqualify_columns(expression: exp.Expression) -> exp.Expression: 474 for column in expression.find_all(exp.Column): 475 # We only wanna pop off the table, db, catalog args 476 for part in column.parts[:-1]: 477 part.pop() 478 479 return expression 480 481 482def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 483 assert isinstance(expression, exp.Create) 484 for constraint in expression.find_all(exp.UniqueColumnConstraint): 485 if constraint.parent: 486 constraint.parent.pop() 487 488 return expression 489 490 491def ctas_with_tmp_tables_to_create_tmp_view( 492 expression: exp.Expression, 493 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 494) -> exp.Expression: 495 assert isinstance(expression, exp.Create) 496 properties = expression.args.get("properties") 497 temporary = any( 498 isinstance(prop, exp.TemporaryProperty) 499 for prop in (properties.expressions if properties else []) 500 ) 501 502 # CTAS with temp tables map to CREATE TEMPORARY VIEW 503 if expression.kind == "TABLE" and temporary: 504 if expression.expression: 505 return exp.Create( 506 kind="TEMPORARY VIEW", 507 this=expression.this, 508 expression=expression.expression, 509 ) 510 return tmp_storage_provider(expression) 511 512 return expression 513 514 515def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 516 """ 517 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 518 PARTITIONED BY value is an array of column names, they are transformed into a schema. 519 The corresponding columns are removed from the create statement. 520 """ 521 assert isinstance(expression, exp.Create) 522 has_schema = isinstance(expression.this, exp.Schema) 523 is_partitionable = expression.kind in {"TABLE", "VIEW"} 524 525 if has_schema and is_partitionable: 526 prop = expression.find(exp.PartitionedByProperty) 527 if prop and prop.this and not isinstance(prop.this, exp.Schema): 528 schema = expression.this 529 columns = {v.name.upper() for v in prop.this.expressions} 530 partitions = [col for col in schema.expressions if col.name.upper() in columns] 531 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 532 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 533 expression.set("this", schema) 534 535 return expression 536 537 538def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 539 """ 540 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 541 542 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 543 """ 544 assert isinstance(expression, exp.Create) 545 prop = expression.find(exp.PartitionedByProperty) 546 if ( 547 prop 548 and prop.this 549 and isinstance(prop.this, exp.Schema) 550 and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions) 551 ): 552 prop_this = exp.Tuple( 553 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 554 ) 555 schema = expression.this 556 for e in prop.this.expressions: 557 schema.append("expressions", e) 558 prop.set("this", prop_this) 559 560 return expression 561 562 563def preprocess( 564 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 565) -> t.Callable[[Generator, exp.Expression], str]: 566 """ 567 Creates a new transform by chaining a sequence of transformations and converts the resulting 568 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 569 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 570 571 Args: 572 transforms: sequence of transform functions. These will be called in order. 573 574 Returns: 575 Function that can be used as a generator transform. 576 """ 577 578 def _to_sql(self, expression: exp.Expression) -> str: 579 expression_type = type(expression) 580 581 expression = transforms[0](expression) 582 for transform in transforms[1:]: 583 expression = transform(expression) 584 585 _sql_handler = getattr(self, expression.key + "_sql", None) 586 if _sql_handler: 587 return _sql_handler(expression) 588 589 transforms_handler = self.TRANSFORMS.get(type(expression)) 590 if transforms_handler: 591 if expression_type is type(expression): 592 if isinstance(expression, exp.Func): 593 return self.function_fallback_sql(expression) 594 595 # Ensures we don't enter an infinite loop. This can happen when the original expression 596 # has the same type as the final expression and there's no _sql method available for it, 597 # because then it'd re-enter _to_sql. 598 raise ValueError( 599 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 600 ) 601 602 return transforms_handler(self, expression) 603 604 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 605 606 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. 98 """ 99 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 100 taken = set(expression.named_selects) 101 for select in expression.selects: 102 if not select.alias_or_name: 103 alias = find_new_name(taken, "_c") 104 select.replace(exp.alias_(select, alias)) 105 taken.add(alias) 106 107 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 108 qualify_filters = expression.args["qualify"].pop().this 109 110 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 111 for expr in qualify_filters.find_all(select_candidates): 112 if isinstance(expr, exp.Window): 113 alias = find_new_name(expression.named_selects, "_w") 114 expression.select(exp.alias_(expr, alias), copy=False) 115 column = exp.column(alias) 116 117 if isinstance(expr.parent, exp.Qualify): 118 qualify_filters = column 119 else: 120 expr.replace(column) 121 elif expr.name not in expression.named_selects: 122 expression.select(expr.copy(), copy=False) 123 124 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 125 qualify_filters, copy=False 126 ) 127 128 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 132 """ 133 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 134 other expressions. This transforms removes the precision from parameterized types in expressions. 135 """ 136 for node in expression.find_all(exp.DataType): 137 node.set( 138 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 139 ) 140 141 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
144def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 145 """Convert cross join unnest into lateral view explode.""" 146 if isinstance(expression, exp.Select): 147 for join in expression.args.get("joins") or []: 148 unnest = join.this 149 150 if isinstance(unnest, exp.Unnest): 151 alias = unnest.args.get("alias") 152 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 153 154 expression.args["joins"].remove(join) 155 156 for e, column in zip(unnest.expressions, alias.columns if alias else []): 157 expression.append( 158 "laterals", 159 exp.Lateral( 160 this=udtf(this=e), 161 view=True, 162 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 163 ), 164 ) 165 166 return expression
Convert cross join unnest into lateral view explode.
169def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 170 """Convert explode/posexplode into unnest.""" 171 172 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 173 if isinstance(expression, exp.Select): 174 from sqlglot.optimizer.scope import Scope 175 176 taken_select_names = set(expression.named_selects) 177 taken_source_names = {name for name, _ in Scope(expression).references} 178 179 def new_name(names: t.Set[str], name: str) -> str: 180 name = find_new_name(names, name) 181 names.add(name) 182 return name 183 184 arrays: t.List[exp.Condition] = [] 185 series_alias = new_name(taken_select_names, "pos") 186 series = exp.alias_( 187 exp.Unnest( 188 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 189 ), 190 new_name(taken_source_names, "_u"), 191 table=[series_alias], 192 ) 193 194 # we use list here because expression.selects is mutated inside the loop 195 for select in list(expression.selects): 196 explode = select.find(exp.Explode) 197 198 if explode: 199 pos_alias = "" 200 explode_alias = "" 201 202 if isinstance(select, exp.Alias): 203 explode_alias = select.args["alias"] 204 alias = select 205 elif isinstance(select, exp.Aliases): 206 pos_alias = select.aliases[0] 207 explode_alias = select.aliases[1] 208 alias = select.replace(exp.alias_(select.this, "", copy=False)) 209 else: 210 alias = select.replace(exp.alias_(select, "")) 211 explode = alias.find(exp.Explode) 212 assert explode 213 214 is_posexplode = isinstance(explode, exp.Posexplode) 215 explode_arg = explode.this 216 217 if isinstance(explode, exp.ExplodeOuter): 218 bracket = explode_arg[0] 219 bracket.set("safe", True) 220 bracket.set("offset", True) 221 explode_arg = exp.func( 222 "IF", 223 exp.func( 224 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 225 ).eq(0), 226 exp.array(bracket, copy=False), 227 explode_arg, 228 ) 229 230 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 231 if isinstance(explode_arg, exp.Column): 232 taken_select_names.add(explode_arg.output_name) 233 234 unnest_source_alias = new_name(taken_source_names, "_u") 235 236 if not explode_alias: 237 explode_alias = new_name(taken_select_names, "col") 238 239 if is_posexplode: 240 pos_alias = new_name(taken_select_names, "pos") 241 242 if not pos_alias: 243 pos_alias = new_name(taken_select_names, "pos") 244 245 alias.set("alias", exp.to_identifier(explode_alias)) 246 247 series_table_alias = series.args["alias"].this 248 column = exp.If( 249 this=exp.column(series_alias, table=series_table_alias).eq( 250 exp.column(pos_alias, table=unnest_source_alias) 251 ), 252 true=exp.column(explode_alias, table=unnest_source_alias), 253 ) 254 255 explode.replace(column) 256 257 if is_posexplode: 258 expressions = expression.expressions 259 expressions.insert( 260 expressions.index(alias) + 1, 261 exp.If( 262 this=exp.column(series_alias, table=series_table_alias).eq( 263 exp.column(pos_alias, table=unnest_source_alias) 264 ), 265 true=exp.column(pos_alias, table=unnest_source_alias), 266 ).as_(pos_alias), 267 ) 268 expression.set("expressions", expressions) 269 270 if not arrays: 271 if expression.args.get("from"): 272 expression.join(series, copy=False, join_type="CROSS") 273 else: 274 expression.from_(series, copy=False) 275 276 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 277 arrays.append(size) 278 279 # trino doesn't support left join unnest with on conditions 280 # if it did, this would be much simpler 281 expression.join( 282 exp.alias_( 283 exp.Unnest( 284 expressions=[explode_arg.copy()], 285 offset=exp.to_identifier(pos_alias), 286 ), 287 unnest_source_alias, 288 table=[explode_alias], 289 ), 290 join_type="CROSS", 291 copy=False, 292 ) 293 294 if index_offset != 1: 295 size = size - 1 296 297 expression.where( 298 exp.column(series_alias, table=series_table_alias) 299 .eq(exp.column(pos_alias, table=unnest_source_alias)) 300 .or_( 301 (exp.column(series_alias, table=series_table_alias) > size).and_( 302 exp.column(pos_alias, table=unnest_source_alias).eq(size) 303 ) 304 ), 305 copy=False, 306 ) 307 308 if arrays: 309 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 310 311 if index_offset != 1: 312 end = end - (1 - index_offset) 313 series.expressions[0].set("end", end) 314 315 return expression 316 317 return _explode_to_unnest
Convert explode/posexplode into unnest.
323def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 324 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 325 if ( 326 isinstance(expression, PERCENTILES) 327 and not isinstance(expression.parent, exp.WithinGroup) 328 and expression.expression 329 ): 330 column = expression.this.pop() 331 expression.set("this", expression.expression.pop()) 332 order = exp.Order(expressions=[exp.Ordered(this=column)]) 333 expression = exp.WithinGroup(this=expression, expression=order) 334 335 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
338def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 339 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 340 if ( 341 isinstance(expression, exp.WithinGroup) 342 and isinstance(expression.this, PERCENTILES) 343 and isinstance(expression.expression, exp.Order) 344 ): 345 quantile = expression.this.this 346 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 347 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 348 349 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
352def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 353 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 354 if isinstance(expression, exp.With) and expression.recursive: 355 next_name = name_sequence("_c_") 356 357 for cte in expression.expressions: 358 if not cte.args["alias"].columns: 359 query = cte.this 360 if isinstance(query, exp.Union): 361 query = query.this 362 363 cte.args["alias"].set( 364 "columns", 365 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 366 ) 367 368 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
371def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 372 """Replace 'epoch' in casts by the equivalent date literal.""" 373 if ( 374 isinstance(expression, (exp.Cast, exp.TryCast)) 375 and expression.name.lower() == "epoch" 376 and expression.to.this in exp.DataType.TEMPORAL_TYPES 377 ): 378 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 379 380 return expression
Replace 'epoch' in casts by the equivalent date literal.
383def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 384 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 385 if isinstance(expression, exp.Select): 386 for join in expression.args.get("joins") or []: 387 on = join.args.get("on") 388 if on and join.kind in ("SEMI", "ANTI"): 389 subquery = exp.select("1").from_(join.this).where(on) 390 exists = exp.Exists(this=subquery) 391 if join.kind == "ANTI": 392 exists = exists.not_(copy=False) 393 394 join.pop() 395 expression.where(exists, copy=False) 396 397 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
400def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 401 """ 402 Converts a query with a FULL OUTER join to a union of identical queries that 403 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 404 for queries that have a single FULL OUTER join. 405 """ 406 if isinstance(expression, exp.Select): 407 full_outer_joins = [ 408 (index, join) 409 for index, join in enumerate(expression.args.get("joins") or []) 410 if join.side == "FULL" 411 ] 412 413 if len(full_outer_joins) == 1: 414 expression_copy = expression.copy() 415 expression.set("limit", None) 416 index, full_outer_join = full_outer_joins[0] 417 full_outer_join.set("side", "left") 418 expression_copy.args["joins"][index].set("side", "right") 419 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 420 421 return exp.union(expression, expression_copy, copy=False) 422 423 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
426def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 427 """ 428 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 429 defined at the top-level, so for example queries like: 430 431 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 432 433 are invalid in those dialects. This transformation can be used to ensure all CTEs are 434 moved to the top level so that the final SQL code is valid from a syntax standpoint. 435 436 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 437 """ 438 top_level_with = expression.args.get("with") 439 for node in expression.find_all(exp.With): 440 if node.parent is expression: 441 continue 442 443 inner_with = node.pop() 444 if not top_level_with: 445 top_level_with = inner_with 446 expression.set("with", top_level_with) 447 else: 448 if inner_with.recursive: 449 top_level_with.set("recursive", True) 450 451 top_level_with.expressions.extend(inner_with.expressions) 452 453 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
456def ensure_bools(expression: exp.Expression) -> exp.Expression: 457 """Converts numeric values used in conditions into explicit boolean expressions.""" 458 from sqlglot.optimizer.canonicalize import ensure_bools 459 460 def _ensure_bool(node: exp.Expression) -> None: 461 if ( 462 node.is_number 463 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 464 or (isinstance(node, exp.Column) and not node.type) 465 ): 466 node.replace(node.neq(0)) 467 468 for node, *_ in expression.walk(): 469 ensure_bools(node, _ensure_bool) 470 471 return expression
Converts numeric values used in conditions into explicit boolean expressions.
492def ctas_with_tmp_tables_to_create_tmp_view( 493 expression: exp.Expression, 494 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 495) -> exp.Expression: 496 assert isinstance(expression, exp.Create) 497 properties = expression.args.get("properties") 498 temporary = any( 499 isinstance(prop, exp.TemporaryProperty) 500 for prop in (properties.expressions if properties else []) 501 ) 502 503 # CTAS with temp tables map to CREATE TEMPORARY VIEW 504 if expression.kind == "TABLE" and temporary: 505 if expression.expression: 506 return exp.Create( 507 kind="TEMPORARY VIEW", 508 this=expression.this, 509 expression=expression.expression, 510 ) 511 return tmp_storage_provider(expression) 512 513 return expression
516def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 517 """ 518 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 519 PARTITIONED BY value is an array of column names, they are transformed into a schema. 520 The corresponding columns are removed from the create statement. 521 """ 522 assert isinstance(expression, exp.Create) 523 has_schema = isinstance(expression.this, exp.Schema) 524 is_partitionable = expression.kind in {"TABLE", "VIEW"} 525 526 if has_schema and is_partitionable: 527 prop = expression.find(exp.PartitionedByProperty) 528 if prop and prop.this and not isinstance(prop.this, exp.Schema): 529 schema = expression.this 530 columns = {v.name.upper() for v in prop.this.expressions} 531 partitions = [col for col in schema.expressions if col.name.upper() in columns] 532 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 533 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 534 expression.set("this", schema) 535 536 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
539def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 540 """ 541 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 542 543 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 544 """ 545 assert isinstance(expression, exp.Create) 546 prop = expression.find(exp.PartitionedByProperty) 547 if ( 548 prop 549 and prop.this 550 and isinstance(prop.this, exp.Schema) 551 and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions) 552 ): 553 prop_this = exp.Tuple( 554 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 555 ) 556 schema = expression.this 557 for e in prop.this.expressions: 558 schema.append("expressions", e) 559 prop.set("this", prop_this) 560 561 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
564def preprocess( 565 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 566) -> t.Callable[[Generator, exp.Expression], str]: 567 """ 568 Creates a new transform by chaining a sequence of transformations and converts the resulting 569 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 570 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 571 572 Args: 573 transforms: sequence of transform functions. These will be called in order. 574 575 Returns: 576 Function that can be used as a generator transform. 577 """ 578 579 def _to_sql(self, expression: exp.Expression) -> str: 580 expression_type = type(expression) 581 582 expression = transforms[0](expression) 583 for transform in transforms[1:]: 584 expression = transform(expression) 585 586 _sql_handler = getattr(self, expression.key + "_sql", None) 587 if _sql_handler: 588 return _sql_handler(expression) 589 590 transforms_handler = self.TRANSFORMS.get(type(expression)) 591 if transforms_handler: 592 if expression_type is type(expression): 593 if isinstance(expression, exp.Func): 594 return self.function_fallback_sql(expression) 595 596 # Ensures we don't enter an infinite loop. This can happen when the original expression 597 # has the same type as the final expression and there's no _sql method available for it, 598 # because then it'd re-enter _to_sql. 599 raise ValueError( 600 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 601 ) 602 603 return transforms_handler(self, expression) 604 605 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 606 607 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.