sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. 97 """ 98 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 99 taken = set(expression.named_selects) 100 for select in expression.selects: 101 if not select.alias_or_name: 102 alias = find_new_name(taken, "_c") 103 select.replace(exp.alias_(select, alias)) 104 taken.add(alias) 105 106 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 107 qualify_filters = expression.args["qualify"].pop().this 108 109 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 110 for expr in qualify_filters.find_all(select_candidates): 111 if isinstance(expr, exp.Window): 112 alias = find_new_name(expression.named_selects, "_w") 113 expression.select(exp.alias_(expr, alias), copy=False) 114 column = exp.column(alias) 115 116 if isinstance(expr.parent, exp.Qualify): 117 qualify_filters = column 118 else: 119 expr.replace(column) 120 elif expr.name not in expression.named_selects: 121 expression.select(expr.copy(), copy=False) 122 123 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 124 qualify_filters, copy=False 125 ) 126 127 return expression 128 129 130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 131 """ 132 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 133 other expressions. This transforms removes the precision from parameterized types in expressions. 134 """ 135 for node in expression.find_all(exp.DataType): 136 node.set( 137 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 138 ) 139 140 return expression 141 142 143def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 144 """Convert cross join unnest into lateral view explode.""" 145 if isinstance(expression, exp.Select): 146 for join in expression.args.get("joins") or []: 147 unnest = join.this 148 149 if isinstance(unnest, exp.Unnest): 150 alias = unnest.args.get("alias") 151 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 152 153 expression.args["joins"].remove(join) 154 155 for e, column in zip(unnest.expressions, alias.columns if alias else []): 156 expression.append( 157 "laterals", 158 exp.Lateral( 159 this=udtf(this=e), 160 view=True, 161 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 162 ), 163 ) 164 165 return expression 166 167 168def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 169 """Convert explode/posexplode into unnest.""" 170 171 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 172 if isinstance(expression, exp.Select): 173 from sqlglot.optimizer.scope import Scope 174 175 taken_select_names = set(expression.named_selects) 176 taken_source_names = {name for name, _ in Scope(expression).references} 177 178 def new_name(names: t.Set[str], name: str) -> str: 179 name = find_new_name(names, name) 180 names.add(name) 181 return name 182 183 arrays: t.List[exp.Condition] = [] 184 series_alias = new_name(taken_select_names, "pos") 185 series = exp.alias_( 186 exp.Unnest( 187 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 188 ), 189 new_name(taken_source_names, "_u"), 190 table=[series_alias], 191 ) 192 193 # we use list here because expression.selects is mutated inside the loop 194 for select in list(expression.selects): 195 explode = select.find(exp.Explode) 196 197 if explode: 198 pos_alias = "" 199 explode_alias = "" 200 201 if isinstance(select, exp.Alias): 202 explode_alias = select.args["alias"] 203 alias = select 204 elif isinstance(select, exp.Aliases): 205 pos_alias = select.aliases[0] 206 explode_alias = select.aliases[1] 207 alias = select.replace(exp.alias_(select.this, "", copy=False)) 208 else: 209 alias = select.replace(exp.alias_(select, "")) 210 explode = alias.find(exp.Explode) 211 assert explode 212 213 is_posexplode = isinstance(explode, exp.Posexplode) 214 explode_arg = explode.this 215 216 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 217 if isinstance(explode_arg, exp.Column): 218 taken_select_names.add(explode_arg.output_name) 219 220 unnest_source_alias = new_name(taken_source_names, "_u") 221 222 if not explode_alias: 223 explode_alias = new_name(taken_select_names, "col") 224 225 if is_posexplode: 226 pos_alias = new_name(taken_select_names, "pos") 227 228 if not pos_alias: 229 pos_alias = new_name(taken_select_names, "pos") 230 231 alias.set("alias", exp.to_identifier(explode_alias)) 232 233 series_table_alias = series.args["alias"].this 234 column = exp.If( 235 this=exp.column(series_alias, table=series_table_alias).eq( 236 exp.column(pos_alias, table=unnest_source_alias) 237 ), 238 true=exp.column(explode_alias, table=unnest_source_alias), 239 ) 240 241 explode.replace(column) 242 243 if is_posexplode: 244 expressions = expression.expressions 245 expressions.insert( 246 expressions.index(alias) + 1, 247 exp.If( 248 this=exp.column(series_alias, table=series_table_alias).eq( 249 exp.column(pos_alias, table=unnest_source_alias) 250 ), 251 true=exp.column(pos_alias, table=unnest_source_alias), 252 ).as_(pos_alias), 253 ) 254 expression.set("expressions", expressions) 255 256 if not arrays: 257 if expression.args.get("from"): 258 expression.join(series, copy=False, join_type="CROSS") 259 else: 260 expression.from_(series, copy=False) 261 262 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 263 arrays.append(size) 264 265 # trino doesn't support left join unnest with on conditions 266 # if it did, this would be much simpler 267 expression.join( 268 exp.alias_( 269 exp.Unnest( 270 expressions=[explode_arg.copy()], 271 offset=exp.to_identifier(pos_alias), 272 ), 273 unnest_source_alias, 274 table=[explode_alias], 275 ), 276 join_type="CROSS", 277 copy=False, 278 ) 279 280 if index_offset != 1: 281 size = size - 1 282 283 expression.where( 284 exp.column(series_alias, table=series_table_alias) 285 .eq(exp.column(pos_alias, table=unnest_source_alias)) 286 .or_( 287 (exp.column(series_alias, table=series_table_alias) > size).and_( 288 exp.column(pos_alias, table=unnest_source_alias).eq(size) 289 ) 290 ), 291 copy=False, 292 ) 293 294 if arrays: 295 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 296 297 if index_offset != 1: 298 end = end - (1 - index_offset) 299 series.expressions[0].set("end", end) 300 301 return expression 302 303 return _explode_to_unnest 304 305 306PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 307 308 309def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 310 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 311 if ( 312 isinstance(expression, PERCENTILES) 313 and not isinstance(expression.parent, exp.WithinGroup) 314 and expression.expression 315 ): 316 column = expression.this.pop() 317 expression.set("this", expression.expression.pop()) 318 order = exp.Order(expressions=[exp.Ordered(this=column)]) 319 expression = exp.WithinGroup(this=expression, expression=order) 320 321 return expression 322 323 324def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 325 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 326 if ( 327 isinstance(expression, exp.WithinGroup) 328 and isinstance(expression.this, PERCENTILES) 329 and isinstance(expression.expression, exp.Order) 330 ): 331 quantile = expression.this.this 332 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 333 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 334 335 return expression 336 337 338def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 339 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 340 if isinstance(expression, exp.With) and expression.recursive: 341 next_name = name_sequence("_c_") 342 343 for cte in expression.expressions: 344 if not cte.args["alias"].columns: 345 query = cte.this 346 if isinstance(query, exp.Union): 347 query = query.this 348 349 cte.args["alias"].set( 350 "columns", 351 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 352 ) 353 354 return expression 355 356 357def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 358 """Replace 'epoch' in casts by the equivalent date literal.""" 359 if ( 360 isinstance(expression, (exp.Cast, exp.TryCast)) 361 and expression.name.lower() == "epoch" 362 and expression.to.this in exp.DataType.TEMPORAL_TYPES 363 ): 364 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 365 366 return expression 367 368 369def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 370 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 371 if isinstance(expression, exp.Select): 372 for join in expression.args.get("joins") or []: 373 on = join.args.get("on") 374 if on and join.kind in ("SEMI", "ANTI"): 375 subquery = exp.select("1").from_(join.this).where(on) 376 exists = exp.Exists(this=subquery) 377 if join.kind == "ANTI": 378 exists = exists.not_(copy=False) 379 380 join.pop() 381 expression.where(exists, copy=False) 382 383 return expression 384 385 386def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 387 """ 388 Converts a query with a FULL OUTER join to a union of identical queries that 389 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 390 for queries that have a single FULL OUTER join. 391 """ 392 if isinstance(expression, exp.Select): 393 full_outer_joins = [ 394 (index, join) 395 for index, join in enumerate(expression.args.get("joins") or []) 396 if join.side == "FULL" 397 ] 398 399 if len(full_outer_joins) == 1: 400 expression_copy = expression.copy() 401 expression.set("limit", None) 402 index, full_outer_join = full_outer_joins[0] 403 full_outer_join.set("side", "left") 404 expression_copy.args["joins"][index].set("side", "right") 405 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 406 407 return exp.union(expression, expression_copy, copy=False) 408 409 return expression 410 411 412def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 413 """ 414 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 415 defined at the top-level, so for example queries like: 416 417 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 418 419 are invalid in those dialects. This transformation can be used to ensure all CTEs are 420 moved to the top level so that the final SQL code is valid from a syntax standpoint. 421 422 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 423 """ 424 top_level_with = expression.args.get("with") 425 for node in expression.find_all(exp.With): 426 if node.parent is expression: 427 continue 428 429 inner_with = node.pop() 430 if not top_level_with: 431 top_level_with = inner_with 432 expression.set("with", top_level_with) 433 else: 434 if inner_with.recursive: 435 top_level_with.set("recursive", True) 436 437 top_level_with.expressions.extend(inner_with.expressions) 438 439 return expression 440 441 442def ensure_bools(expression: exp.Expression) -> exp.Expression: 443 """Converts numeric values used in conditions into explicit boolean expressions.""" 444 from sqlglot.optimizer.canonicalize import ensure_bools 445 446 def _ensure_bool(node: exp.Expression) -> None: 447 if ( 448 node.is_number 449 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 450 or (isinstance(node, exp.Column) and not node.type) 451 ): 452 node.replace(node.neq(0)) 453 454 for node, *_ in expression.walk(): 455 ensure_bools(node, _ensure_bool) 456 457 return expression 458 459 460def unqualify_columns(expression: exp.Expression) -> exp.Expression: 461 for column in expression.find_all(exp.Column): 462 # We only wanna pop off the table, db, catalog args 463 for part in column.parts[:-1]: 464 part.pop() 465 466 return expression 467 468 469def preprocess( 470 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 471) -> t.Callable[[Generator, exp.Expression], str]: 472 """ 473 Creates a new transform by chaining a sequence of transformations and converts the resulting 474 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 475 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 476 477 Args: 478 transforms: sequence of transform functions. These will be called in order. 479 480 Returns: 481 Function that can be used as a generator transform. 482 """ 483 484 def _to_sql(self, expression: exp.Expression) -> str: 485 expression_type = type(expression) 486 487 expression = transforms[0](expression) 488 for transform in transforms[1:]: 489 expression = transform(expression) 490 491 _sql_handler = getattr(self, expression.key + "_sql", None) 492 if _sql_handler: 493 return _sql_handler(expression) 494 495 transforms_handler = self.TRANSFORMS.get(type(expression)) 496 if transforms_handler: 497 if expression_type is type(expression): 498 if isinstance(expression, exp.Func): 499 return self.function_fallback_sql(expression) 500 501 # Ensures we don't enter an infinite loop. This can happen when the original expression 502 # has the same type as the final expression and there's no _sql method available for it, 503 # because then it'd re-enter _to_sql. 504 raise ValueError( 505 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 506 ) 507 508 return transforms_handler(self, expression) 509 510 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 511 512 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. 98 """ 99 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 100 taken = set(expression.named_selects) 101 for select in expression.selects: 102 if not select.alias_or_name: 103 alias = find_new_name(taken, "_c") 104 select.replace(exp.alias_(select, alias)) 105 taken.add(alias) 106 107 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 108 qualify_filters = expression.args["qualify"].pop().this 109 110 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 111 for expr in qualify_filters.find_all(select_candidates): 112 if isinstance(expr, exp.Window): 113 alias = find_new_name(expression.named_selects, "_w") 114 expression.select(exp.alias_(expr, alias), copy=False) 115 column = exp.column(alias) 116 117 if isinstance(expr.parent, exp.Qualify): 118 qualify_filters = column 119 else: 120 expr.replace(column) 121 elif expr.name not in expression.named_selects: 122 expression.select(expr.copy(), copy=False) 123 124 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 125 qualify_filters, copy=False 126 ) 127 128 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 132 """ 133 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 134 other expressions. This transforms removes the precision from parameterized types in expressions. 135 """ 136 for node in expression.find_all(exp.DataType): 137 node.set( 138 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 139 ) 140 141 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
144def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 145 """Convert cross join unnest into lateral view explode.""" 146 if isinstance(expression, exp.Select): 147 for join in expression.args.get("joins") or []: 148 unnest = join.this 149 150 if isinstance(unnest, exp.Unnest): 151 alias = unnest.args.get("alias") 152 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 153 154 expression.args["joins"].remove(join) 155 156 for e, column in zip(unnest.expressions, alias.columns if alias else []): 157 expression.append( 158 "laterals", 159 exp.Lateral( 160 this=udtf(this=e), 161 view=True, 162 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 163 ), 164 ) 165 166 return expression
Convert cross join unnest into lateral view explode.
169def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 170 """Convert explode/posexplode into unnest.""" 171 172 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 173 if isinstance(expression, exp.Select): 174 from sqlglot.optimizer.scope import Scope 175 176 taken_select_names = set(expression.named_selects) 177 taken_source_names = {name for name, _ in Scope(expression).references} 178 179 def new_name(names: t.Set[str], name: str) -> str: 180 name = find_new_name(names, name) 181 names.add(name) 182 return name 183 184 arrays: t.List[exp.Condition] = [] 185 series_alias = new_name(taken_select_names, "pos") 186 series = exp.alias_( 187 exp.Unnest( 188 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 189 ), 190 new_name(taken_source_names, "_u"), 191 table=[series_alias], 192 ) 193 194 # we use list here because expression.selects is mutated inside the loop 195 for select in list(expression.selects): 196 explode = select.find(exp.Explode) 197 198 if explode: 199 pos_alias = "" 200 explode_alias = "" 201 202 if isinstance(select, exp.Alias): 203 explode_alias = select.args["alias"] 204 alias = select 205 elif isinstance(select, exp.Aliases): 206 pos_alias = select.aliases[0] 207 explode_alias = select.aliases[1] 208 alias = select.replace(exp.alias_(select.this, "", copy=False)) 209 else: 210 alias = select.replace(exp.alias_(select, "")) 211 explode = alias.find(exp.Explode) 212 assert explode 213 214 is_posexplode = isinstance(explode, exp.Posexplode) 215 explode_arg = explode.this 216 217 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 218 if isinstance(explode_arg, exp.Column): 219 taken_select_names.add(explode_arg.output_name) 220 221 unnest_source_alias = new_name(taken_source_names, "_u") 222 223 if not explode_alias: 224 explode_alias = new_name(taken_select_names, "col") 225 226 if is_posexplode: 227 pos_alias = new_name(taken_select_names, "pos") 228 229 if not pos_alias: 230 pos_alias = new_name(taken_select_names, "pos") 231 232 alias.set("alias", exp.to_identifier(explode_alias)) 233 234 series_table_alias = series.args["alias"].this 235 column = exp.If( 236 this=exp.column(series_alias, table=series_table_alias).eq( 237 exp.column(pos_alias, table=unnest_source_alias) 238 ), 239 true=exp.column(explode_alias, table=unnest_source_alias), 240 ) 241 242 explode.replace(column) 243 244 if is_posexplode: 245 expressions = expression.expressions 246 expressions.insert( 247 expressions.index(alias) + 1, 248 exp.If( 249 this=exp.column(series_alias, table=series_table_alias).eq( 250 exp.column(pos_alias, table=unnest_source_alias) 251 ), 252 true=exp.column(pos_alias, table=unnest_source_alias), 253 ).as_(pos_alias), 254 ) 255 expression.set("expressions", expressions) 256 257 if not arrays: 258 if expression.args.get("from"): 259 expression.join(series, copy=False, join_type="CROSS") 260 else: 261 expression.from_(series, copy=False) 262 263 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 264 arrays.append(size) 265 266 # trino doesn't support left join unnest with on conditions 267 # if it did, this would be much simpler 268 expression.join( 269 exp.alias_( 270 exp.Unnest( 271 expressions=[explode_arg.copy()], 272 offset=exp.to_identifier(pos_alias), 273 ), 274 unnest_source_alias, 275 table=[explode_alias], 276 ), 277 join_type="CROSS", 278 copy=False, 279 ) 280 281 if index_offset != 1: 282 size = size - 1 283 284 expression.where( 285 exp.column(series_alias, table=series_table_alias) 286 .eq(exp.column(pos_alias, table=unnest_source_alias)) 287 .or_( 288 (exp.column(series_alias, table=series_table_alias) > size).and_( 289 exp.column(pos_alias, table=unnest_source_alias).eq(size) 290 ) 291 ), 292 copy=False, 293 ) 294 295 if arrays: 296 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 297 298 if index_offset != 1: 299 end = end - (1 - index_offset) 300 series.expressions[0].set("end", end) 301 302 return expression 303 304 return _explode_to_unnest
Convert explode/posexplode into unnest.
310def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 311 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 312 if ( 313 isinstance(expression, PERCENTILES) 314 and not isinstance(expression.parent, exp.WithinGroup) 315 and expression.expression 316 ): 317 column = expression.this.pop() 318 expression.set("this", expression.expression.pop()) 319 order = exp.Order(expressions=[exp.Ordered(this=column)]) 320 expression = exp.WithinGroup(this=expression, expression=order) 321 322 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
325def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 326 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 327 if ( 328 isinstance(expression, exp.WithinGroup) 329 and isinstance(expression.this, PERCENTILES) 330 and isinstance(expression.expression, exp.Order) 331 ): 332 quantile = expression.this.this 333 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 334 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 335 336 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
339def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 340 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 341 if isinstance(expression, exp.With) and expression.recursive: 342 next_name = name_sequence("_c_") 343 344 for cte in expression.expressions: 345 if not cte.args["alias"].columns: 346 query = cte.this 347 if isinstance(query, exp.Union): 348 query = query.this 349 350 cte.args["alias"].set( 351 "columns", 352 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 353 ) 354 355 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
358def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 359 """Replace 'epoch' in casts by the equivalent date literal.""" 360 if ( 361 isinstance(expression, (exp.Cast, exp.TryCast)) 362 and expression.name.lower() == "epoch" 363 and expression.to.this in exp.DataType.TEMPORAL_TYPES 364 ): 365 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 366 367 return expression
Replace 'epoch' in casts by the equivalent date literal.
370def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 371 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 372 if isinstance(expression, exp.Select): 373 for join in expression.args.get("joins") or []: 374 on = join.args.get("on") 375 if on and join.kind in ("SEMI", "ANTI"): 376 subquery = exp.select("1").from_(join.this).where(on) 377 exists = exp.Exists(this=subquery) 378 if join.kind == "ANTI": 379 exists = exists.not_(copy=False) 380 381 join.pop() 382 expression.where(exists, copy=False) 383 384 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
387def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 388 """ 389 Converts a query with a FULL OUTER join to a union of identical queries that 390 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 391 for queries that have a single FULL OUTER join. 392 """ 393 if isinstance(expression, exp.Select): 394 full_outer_joins = [ 395 (index, join) 396 for index, join in enumerate(expression.args.get("joins") or []) 397 if join.side == "FULL" 398 ] 399 400 if len(full_outer_joins) == 1: 401 expression_copy = expression.copy() 402 expression.set("limit", None) 403 index, full_outer_join = full_outer_joins[0] 404 full_outer_join.set("side", "left") 405 expression_copy.args["joins"][index].set("side", "right") 406 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 407 408 return exp.union(expression, expression_copy, copy=False) 409 410 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
413def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 414 """ 415 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 416 defined at the top-level, so for example queries like: 417 418 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 419 420 are invalid in those dialects. This transformation can be used to ensure all CTEs are 421 moved to the top level so that the final SQL code is valid from a syntax standpoint. 422 423 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 424 """ 425 top_level_with = expression.args.get("with") 426 for node in expression.find_all(exp.With): 427 if node.parent is expression: 428 continue 429 430 inner_with = node.pop() 431 if not top_level_with: 432 top_level_with = inner_with 433 expression.set("with", top_level_with) 434 else: 435 if inner_with.recursive: 436 top_level_with.set("recursive", True) 437 438 top_level_with.expressions.extend(inner_with.expressions) 439 440 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
443def ensure_bools(expression: exp.Expression) -> exp.Expression: 444 """Converts numeric values used in conditions into explicit boolean expressions.""" 445 from sqlglot.optimizer.canonicalize import ensure_bools 446 447 def _ensure_bool(node: exp.Expression) -> None: 448 if ( 449 node.is_number 450 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 451 or (isinstance(node, exp.Column) and not node.type) 452 ): 453 node.replace(node.neq(0)) 454 455 for node, *_ in expression.walk(): 456 ensure_bools(node, _ensure_bool) 457 458 return expression
Converts numeric values used in conditions into explicit boolean expressions.
470def preprocess( 471 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 472) -> t.Callable[[Generator, exp.Expression], str]: 473 """ 474 Creates a new transform by chaining a sequence of transformations and converts the resulting 475 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 476 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 477 478 Args: 479 transforms: sequence of transform functions. These will be called in order. 480 481 Returns: 482 Function that can be used as a generator transform. 483 """ 484 485 def _to_sql(self, expression: exp.Expression) -> str: 486 expression_type = type(expression) 487 488 expression = transforms[0](expression) 489 for transform in transforms[1:]: 490 expression = transform(expression) 491 492 _sql_handler = getattr(self, expression.key + "_sql", None) 493 if _sql_handler: 494 return _sql_handler(expression) 495 496 transforms_handler = self.TRANSFORMS.get(type(expression)) 497 if transforms_handler: 498 if expression_type is type(expression): 499 if isinstance(expression, exp.Func): 500 return self.function_fallback_sql(expression) 501 502 # Ensures we don't enter an infinite loop. This can happen when the original expression 503 # has the same type as the final expression and there's no _sql method available for it, 504 # because then it'd re-enter _to_sql. 505 raise ValueError( 506 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 507 ) 508 509 return transforms_handler(self, expression) 510 511 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 512 513 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.