sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects) 79 .from_(expression.subquery("_t")) 80 .where(exp.column(row_number).eq(1)) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. 97 """ 98 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 99 taken = set(expression.named_selects) 100 for select in expression.selects: 101 if not select.alias_or_name: 102 alias = find_new_name(taken, "_c") 103 select.replace(exp.alias_(select, alias)) 104 taken.add(alias) 105 106 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 107 qualify_filters = expression.args["qualify"].pop().this 108 109 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 110 for expr in qualify_filters.find_all(select_candidates): 111 if isinstance(expr, exp.Window): 112 alias = find_new_name(expression.named_selects, "_w") 113 expression.select(exp.alias_(expr, alias), copy=False) 114 column = exp.column(alias) 115 116 if isinstance(expr.parent, exp.Qualify): 117 qualify_filters = column 118 else: 119 expr.replace(column) 120 elif expr.name not in expression.named_selects: 121 expression.select(expr.copy(), copy=False) 122 123 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 124 125 return expression 126 127 128def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 129 """ 130 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 131 other expressions. This transforms removes the precision from parameterized types in expressions. 132 """ 133 for node in expression.find_all(exp.DataType): 134 node.set( 135 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 136 ) 137 138 return expression 139 140 141def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 142 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 143 if isinstance(expression, exp.Select): 144 for join in expression.args.get("joins") or []: 145 unnest = join.this 146 147 if isinstance(unnest, exp.Unnest): 148 alias = unnest.args.get("alias") 149 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 150 151 expression.args["joins"].remove(join) 152 153 for e, column in zip(unnest.expressions, alias.columns if alias else []): 154 expression.append( 155 "laterals", 156 exp.Lateral( 157 this=udtf(this=e), 158 view=True, 159 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 160 ), 161 ) 162 163 return expression 164 165 166def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 167 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 168 """Convert explode/posexplode into unnest (used in hive -> presto).""" 169 if isinstance(expression, exp.Select): 170 from sqlglot.optimizer.scope import Scope 171 172 taken_select_names = set(expression.named_selects) 173 taken_source_names = {name for name, _ in Scope(expression).references} 174 175 def new_name(names: t.Set[str], name: str) -> str: 176 name = find_new_name(names, name) 177 names.add(name) 178 return name 179 180 arrays: t.List[exp.Condition] = [] 181 series_alias = new_name(taken_select_names, "pos") 182 series = exp.alias_( 183 exp.Unnest( 184 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 185 ), 186 new_name(taken_source_names, "_u"), 187 table=[series_alias], 188 ) 189 190 # we use list here because expression.selects is mutated inside the loop 191 for select in list(expression.selects): 192 to_replace = select 193 pos_alias = "" 194 explode_alias = "" 195 196 if isinstance(select, exp.Alias): 197 explode_alias = select.alias 198 select = select.this 199 elif isinstance(select, exp.Aliases): 200 pos_alias = select.aliases[0].name 201 explode_alias = select.aliases[1].name 202 select = select.this 203 204 if isinstance(select, (exp.Explode, exp.Posexplode)): 205 is_posexplode = isinstance(select, exp.Posexplode) 206 explode_arg = select.this 207 208 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 209 if isinstance(explode_arg, exp.Column): 210 taken_select_names.add(explode_arg.output_name) 211 212 unnest_source_alias = new_name(taken_source_names, "_u") 213 214 if not explode_alias: 215 explode_alias = new_name(taken_select_names, "col") 216 217 if is_posexplode: 218 pos_alias = new_name(taken_select_names, "pos") 219 220 if not pos_alias: 221 pos_alias = new_name(taken_select_names, "pos") 222 223 column = exp.If( 224 this=exp.column(series_alias).eq(exp.column(pos_alias)), 225 true=exp.column(explode_alias), 226 ).as_(explode_alias) 227 228 if is_posexplode: 229 expressions = expression.expressions 230 index = expressions.index(to_replace) 231 expressions.pop(index) 232 expressions.insert(index, column) 233 expressions.insert( 234 index + 1, 235 exp.If( 236 this=exp.column(series_alias).eq(exp.column(pos_alias)), 237 true=exp.column(pos_alias), 238 ).as_(pos_alias), 239 ) 240 expression.set("expressions", expressions) 241 else: 242 to_replace.replace(column) 243 244 if not arrays: 245 if expression.args.get("from"): 246 expression.join(series, copy=False) 247 else: 248 expression.from_(series, copy=False) 249 250 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 251 arrays.append(size) 252 253 # trino doesn't support left join unnest with on conditions 254 # if it did, this would be much simpler 255 expression.join( 256 exp.alias_( 257 exp.Unnest( 258 expressions=[explode_arg.copy()], 259 offset=exp.to_identifier(pos_alias), 260 ), 261 unnest_source_alias, 262 table=[explode_alias], 263 ), 264 join_type="CROSS", 265 copy=False, 266 ) 267 268 if index_offset != 1: 269 size = size - 1 270 271 expression.where( 272 exp.column(series_alias) 273 .eq(exp.column(pos_alias)) 274 .or_( 275 (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) 276 ), 277 copy=False, 278 ) 279 280 if arrays: 281 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 282 283 if index_offset != 1: 284 end = end - (1 - index_offset) 285 series.expressions[0].set("end", end) 286 287 return expression 288 289 return _explode_to_unnest 290 291 292PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 293 294 295def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 296 if ( 297 isinstance(expression, PERCENTILES) 298 and not isinstance(expression.parent, exp.WithinGroup) 299 and expression.expression 300 ): 301 column = expression.this.pop() 302 expression.set("this", expression.expression.pop()) 303 order = exp.Order(expressions=[exp.Ordered(this=column)]) 304 expression = exp.WithinGroup(this=expression, expression=order) 305 306 return expression 307 308 309def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 310 if ( 311 isinstance(expression, exp.WithinGroup) 312 and isinstance(expression.this, PERCENTILES) 313 and isinstance(expression.expression, exp.Order) 314 ): 315 quantile = expression.this.this 316 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 317 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 318 319 return expression 320 321 322def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 323 if isinstance(expression, exp.With) and expression.recursive: 324 next_name = name_sequence("_c_") 325 326 for cte in expression.expressions: 327 if not cte.args["alias"].columns: 328 query = cte.this 329 if isinstance(query, exp.Union): 330 query = query.this 331 332 cte.args["alias"].set( 333 "columns", 334 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 335 ) 336 337 return expression 338 339 340def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 341 if ( 342 isinstance(expression, (exp.Cast, exp.TryCast)) 343 and expression.name.lower() == "epoch" 344 and expression.to.this in exp.DataType.TEMPORAL_TYPES 345 ): 346 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 347 348 return expression 349 350 351def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: 352 if isinstance(expression, exp.Timestamp) and not expression.expression: 353 return exp.cast( 354 expression.this, 355 to=exp.DataType.Type.TIMESTAMP, 356 ) 357 return expression 358 359 360def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 361 if isinstance(expression, exp.Select): 362 for join in expression.args.get("joins") or []: 363 on = join.args.get("on") 364 if on and join.kind in ("SEMI", "ANTI"): 365 subquery = exp.select("1").from_(join.this).where(on) 366 exists = exp.Exists(this=subquery) 367 if join.kind == "ANTI": 368 exists = exists.not_(copy=False) 369 370 join.pop() 371 expression.where(exists, copy=False) 372 373 return expression 374 375 376def preprocess( 377 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 378) -> t.Callable[[Generator, exp.Expression], str]: 379 """ 380 Creates a new transform by chaining a sequence of transformations and converts the resulting 381 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 382 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 383 384 Args: 385 transforms: sequence of transform functions. These will be called in order. 386 387 Returns: 388 Function that can be used as a generator transform. 389 """ 390 391 def _to_sql(self, expression: exp.Expression) -> str: 392 expression_type = type(expression) 393 394 expression = transforms[0](expression.copy()) 395 for t in transforms[1:]: 396 expression = t(expression) 397 398 _sql_handler = getattr(self, expression.key + "_sql", None) 399 if _sql_handler: 400 return _sql_handler(expression) 401 402 transforms_handler = self.TRANSFORMS.get(type(expression)) 403 if transforms_handler: 404 if expression_type is type(expression): 405 if isinstance(expression, exp.Func): 406 return self.function_fallback_sql(expression) 407 408 # Ensures we don't enter an infinite loop. This can happen when the original expression 409 # has the same type as the final expression and there's no _sql method available for it, 410 # because then it'd re-enter _to_sql. 411 raise ValueError( 412 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 413 ) 414 415 return transforms_handler(self, expression) 416 417 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 418 419 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects) 80 .from_(expression.subquery("_t")) 81 .where(exp.column(row_number).eq(1)) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. 98 """ 99 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 100 taken = set(expression.named_selects) 101 for select in expression.selects: 102 if not select.alias_or_name: 103 alias = find_new_name(taken, "_c") 104 select.replace(exp.alias_(select, alias)) 105 taken.add(alias) 106 107 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 108 qualify_filters = expression.args["qualify"].pop().this 109 110 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 111 for expr in qualify_filters.find_all(select_candidates): 112 if isinstance(expr, exp.Window): 113 alias = find_new_name(expression.named_selects, "_w") 114 expression.select(exp.alias_(expr, alias), copy=False) 115 column = exp.column(alias) 116 117 if isinstance(expr.parent, exp.Qualify): 118 qualify_filters = column 119 else: 120 expr.replace(column) 121 elif expr.name not in expression.named_selects: 122 expression.select(expr.copy(), copy=False) 123 124 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 125 126 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
129def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 130 """ 131 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 132 other expressions. This transforms removes the precision from parameterized types in expressions. 133 """ 134 for node in expression.find_all(exp.DataType): 135 node.set( 136 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 137 ) 138 139 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
142def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 143 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 144 if isinstance(expression, exp.Select): 145 for join in expression.args.get("joins") or []: 146 unnest = join.this 147 148 if isinstance(unnest, exp.Unnest): 149 alias = unnest.args.get("alias") 150 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 151 152 expression.args["joins"].remove(join) 153 154 for e, column in zip(unnest.expressions, alias.columns if alias else []): 155 expression.append( 156 "laterals", 157 exp.Lateral( 158 this=udtf(this=e), 159 view=True, 160 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 161 ), 162 ) 163 164 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
167def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 168 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 169 """Convert explode/posexplode into unnest (used in hive -> presto).""" 170 if isinstance(expression, exp.Select): 171 from sqlglot.optimizer.scope import Scope 172 173 taken_select_names = set(expression.named_selects) 174 taken_source_names = {name for name, _ in Scope(expression).references} 175 176 def new_name(names: t.Set[str], name: str) -> str: 177 name = find_new_name(names, name) 178 names.add(name) 179 return name 180 181 arrays: t.List[exp.Condition] = [] 182 series_alias = new_name(taken_select_names, "pos") 183 series = exp.alias_( 184 exp.Unnest( 185 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 186 ), 187 new_name(taken_source_names, "_u"), 188 table=[series_alias], 189 ) 190 191 # we use list here because expression.selects is mutated inside the loop 192 for select in list(expression.selects): 193 to_replace = select 194 pos_alias = "" 195 explode_alias = "" 196 197 if isinstance(select, exp.Alias): 198 explode_alias = select.alias 199 select = select.this 200 elif isinstance(select, exp.Aliases): 201 pos_alias = select.aliases[0].name 202 explode_alias = select.aliases[1].name 203 select = select.this 204 205 if isinstance(select, (exp.Explode, exp.Posexplode)): 206 is_posexplode = isinstance(select, exp.Posexplode) 207 explode_arg = select.this 208 209 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 210 if isinstance(explode_arg, exp.Column): 211 taken_select_names.add(explode_arg.output_name) 212 213 unnest_source_alias = new_name(taken_source_names, "_u") 214 215 if not explode_alias: 216 explode_alias = new_name(taken_select_names, "col") 217 218 if is_posexplode: 219 pos_alias = new_name(taken_select_names, "pos") 220 221 if not pos_alias: 222 pos_alias = new_name(taken_select_names, "pos") 223 224 column = exp.If( 225 this=exp.column(series_alias).eq(exp.column(pos_alias)), 226 true=exp.column(explode_alias), 227 ).as_(explode_alias) 228 229 if is_posexplode: 230 expressions = expression.expressions 231 index = expressions.index(to_replace) 232 expressions.pop(index) 233 expressions.insert(index, column) 234 expressions.insert( 235 index + 1, 236 exp.If( 237 this=exp.column(series_alias).eq(exp.column(pos_alias)), 238 true=exp.column(pos_alias), 239 ).as_(pos_alias), 240 ) 241 expression.set("expressions", expressions) 242 else: 243 to_replace.replace(column) 244 245 if not arrays: 246 if expression.args.get("from"): 247 expression.join(series, copy=False) 248 else: 249 expression.from_(series, copy=False) 250 251 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 252 arrays.append(size) 253 254 # trino doesn't support left join unnest with on conditions 255 # if it did, this would be much simpler 256 expression.join( 257 exp.alias_( 258 exp.Unnest( 259 expressions=[explode_arg.copy()], 260 offset=exp.to_identifier(pos_alias), 261 ), 262 unnest_source_alias, 263 table=[explode_alias], 264 ), 265 join_type="CROSS", 266 copy=False, 267 ) 268 269 if index_offset != 1: 270 size = size - 1 271 272 expression.where( 273 exp.column(series_alias) 274 .eq(exp.column(pos_alias)) 275 .or_( 276 (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) 277 ), 278 copy=False, 279 ) 280 281 if arrays: 282 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 283 284 if index_offset != 1: 285 end = end - (1 - index_offset) 286 series.expressions[0].set("end", end) 287 288 return expression 289 290 return _explode_to_unnest
296def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 297 if ( 298 isinstance(expression, PERCENTILES) 299 and not isinstance(expression.parent, exp.WithinGroup) 300 and expression.expression 301 ): 302 column = expression.this.pop() 303 expression.set("this", expression.expression.pop()) 304 order = exp.Order(expressions=[exp.Ordered(this=column)]) 305 expression = exp.WithinGroup(this=expression, expression=order) 306 307 return expression
310def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 311 if ( 312 isinstance(expression, exp.WithinGroup) 313 and isinstance(expression.this, PERCENTILES) 314 and isinstance(expression.expression, exp.Order) 315 ): 316 quantile = expression.this.this 317 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 318 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 319 320 return expression
323def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 324 if isinstance(expression, exp.With) and expression.recursive: 325 next_name = name_sequence("_c_") 326 327 for cte in expression.expressions: 328 if not cte.args["alias"].columns: 329 query = cte.this 330 if isinstance(query, exp.Union): 331 query = query.this 332 333 cte.args["alias"].set( 334 "columns", 335 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 336 ) 337 338 return expression
341def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 342 if ( 343 isinstance(expression, (exp.Cast, exp.TryCast)) 344 and expression.name.lower() == "epoch" 345 and expression.to.this in exp.DataType.TEMPORAL_TYPES 346 ): 347 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 348 349 return expression
361def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 362 if isinstance(expression, exp.Select): 363 for join in expression.args.get("joins") or []: 364 on = join.args.get("on") 365 if on and join.kind in ("SEMI", "ANTI"): 366 subquery = exp.select("1").from_(join.this).where(on) 367 exists = exp.Exists(this=subquery) 368 if join.kind == "ANTI": 369 exists = exists.not_(copy=False) 370 371 join.pop() 372 expression.where(exists, copy=False) 373 374 return expression
377def preprocess( 378 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 379) -> t.Callable[[Generator, exp.Expression], str]: 380 """ 381 Creates a new transform by chaining a sequence of transformations and converts the resulting 382 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 383 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 384 385 Args: 386 transforms: sequence of transform functions. These will be called in order. 387 388 Returns: 389 Function that can be used as a generator transform. 390 """ 391 392 def _to_sql(self, expression: exp.Expression) -> str: 393 expression_type = type(expression) 394 395 expression = transforms[0](expression.copy()) 396 for t in transforms[1:]: 397 expression = t(expression) 398 399 _sql_handler = getattr(self, expression.key + "_sql", None) 400 if _sql_handler: 401 return _sql_handler(expression) 402 403 transforms_handler = self.TRANSFORMS.get(type(expression)) 404 if transforms_handler: 405 if expression_type is type(expression): 406 if isinstance(expression, exp.Func): 407 return self.function_fallback_sql(expression) 408 409 # Ensures we don't enter an infinite loop. This can happen when the original expression 410 # has the same type as the final expression and there's no _sql method available for it, 411 # because then it'd re-enter _to_sql. 412 raise ValueError( 413 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 414 ) 415 416 return transforms_handler(self, expression) 417 418 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 419 420 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.