sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 72 window = exp.alias_(window, row_number) 73 expression.select(window, copy=False) 74 75 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 76 77 return expression 78 79 80def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 81 """ 82 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 83 84 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 85 https://docs.snowflake.com/en/sql-reference/constructs/qualify 86 87 Some dialects don't support window functions in the WHERE clause, so we need to include them as 88 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 89 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 90 otherwise we won't be able to refer to it in the outer query's WHERE clause. 91 """ 92 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 93 taken = set(expression.named_selects) 94 for select in expression.selects: 95 if not select.alias_or_name: 96 alias = find_new_name(taken, "_c") 97 select.replace(exp.alias_(select, alias)) 98 taken.add(alias) 99 100 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 101 qualify_filters = expression.args["qualify"].pop().this 102 103 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 104 if isinstance(expr, exp.Window): 105 alias = find_new_name(expression.named_selects, "_w") 106 expression.select(exp.alias_(expr, alias), copy=False) 107 column = exp.column(alias) 108 109 if isinstance(expr.parent, exp.Qualify): 110 qualify_filters = column 111 else: 112 expr.replace(column) 113 elif expr.name not in expression.named_selects: 114 expression.select(expr.copy(), copy=False) 115 116 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 117 118 return expression 119 120 121def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 122 """ 123 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 124 other expressions. This transforms removes the precision from parameterized types in expressions. 125 """ 126 for node in expression.find_all(exp.DataType): 127 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 128 129 return expression 130 131 132def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 133 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 134 if isinstance(expression, exp.Select): 135 for join in expression.args.get("joins") or []: 136 unnest = join.this 137 138 if isinstance(unnest, exp.Unnest): 139 alias = unnest.args.get("alias") 140 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 141 142 expression.args["joins"].remove(join) 143 144 for e, column in zip(unnest.expressions, alias.columns if alias else []): 145 expression.append( 146 "laterals", 147 exp.Lateral( 148 this=udtf(this=e), 149 view=True, 150 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 151 ), 152 ) 153 154 return expression 155 156 157def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 158 """Convert explode/posexplode into unnest (used in hive -> presto).""" 159 if isinstance(expression, exp.Select): 160 from sqlglot.optimizer.scope import Scope 161 162 taken_select_names = set(expression.named_selects) 163 taken_source_names = {name for name, _ in Scope(expression).references} 164 165 for select in expression.selects: 166 to_replace = select 167 168 pos_alias = "" 169 explode_alias = "" 170 171 if isinstance(select, exp.Alias): 172 explode_alias = select.alias 173 select = select.this 174 elif isinstance(select, exp.Aliases): 175 pos_alias = select.aliases[0].name 176 explode_alias = select.aliases[1].name 177 select = select.this 178 179 if isinstance(select, (exp.Explode, exp.Posexplode)): 180 is_posexplode = isinstance(select, exp.Posexplode) 181 182 explode_arg = select.this 183 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 184 185 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 186 if isinstance(explode_arg, exp.Column): 187 taken_select_names.add(explode_arg.output_name) 188 189 unnest_source_alias = find_new_name(taken_source_names, "_u") 190 taken_source_names.add(unnest_source_alias) 191 192 if not explode_alias: 193 explode_alias = find_new_name(taken_select_names, "col") 194 taken_select_names.add(explode_alias) 195 196 if is_posexplode: 197 pos_alias = find_new_name(taken_select_names, "pos") 198 taken_select_names.add(pos_alias) 199 200 if is_posexplode: 201 column_names = [explode_alias, pos_alias] 202 to_replace.pop() 203 expression.select(pos_alias, explode_alias, copy=False) 204 else: 205 column_names = [explode_alias] 206 to_replace.replace(exp.column(explode_alias)) 207 208 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 209 210 if not expression.args.get("from"): 211 expression.from_(unnest, copy=False) 212 else: 213 expression.join(unnest, join_type="CROSS", copy=False) 214 215 return expression 216 217 218def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 219 """Remove table refs from columns in when statements.""" 220 if isinstance(expression, exp.Merge): 221 alias = expression.this.args.get("alias") 222 targets = {expression.this.this} 223 if alias: 224 targets.add(alias.this) 225 226 for when in expression.expressions: 227 when.transform( 228 lambda node: exp.column(node.name) 229 if isinstance(node, exp.Column) and node.args.get("table") in targets 230 else node, 231 copy=False, 232 ) 233 234 return expression 235 236 237def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 238 if ( 239 isinstance(expression, exp.WithinGroup) 240 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 241 and isinstance(expression.expression, exp.Order) 242 ): 243 quantile = expression.this.this 244 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 245 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 246 247 return expression 248 249 250def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 251 if isinstance(expression, exp.With) and expression.recursive: 252 next_name = name_sequence("_c_") 253 254 for cte in expression.expressions: 255 if not cte.args["alias"].columns: 256 query = cte.this 257 if isinstance(query, exp.Union): 258 query = query.this 259 260 cte.args["alias"].set( 261 "columns", 262 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 263 ) 264 265 return expression 266 267 268def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 269 if ( 270 isinstance(expression, (exp.Cast, exp.TryCast)) 271 and expression.name.lower() == "epoch" 272 and expression.to.this in exp.DataType.TEMPORAL_TYPES 273 ): 274 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 275 276 return expression 277 278 279def preprocess( 280 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 281) -> t.Callable[[Generator, exp.Expression], str]: 282 """ 283 Creates a new transform by chaining a sequence of transformations and converts the resulting 284 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 285 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 286 287 Args: 288 transforms: sequence of transform functions. These will be called in order. 289 290 Returns: 291 Function that can be used as a generator transform. 292 """ 293 294 def _to_sql(self, expression: exp.Expression) -> str: 295 expression_type = type(expression) 296 297 expression = transforms[0](expression.copy()) 298 for t in transforms[1:]: 299 expression = t(expression) 300 301 _sql_handler = getattr(self, expression.key + "_sql", None) 302 if _sql_handler: 303 return _sql_handler(expression) 304 305 transforms_handler = self.TRANSFORMS.get(type(expression)) 306 if transforms_handler: 307 # Ensures we don't enter an infinite loop. This can happen when the original expression 308 # has the same type as the final expression and there's no _sql method available for it, 309 # because then it'd re-enter _to_sql. 310 if expression_type is type(expression): 311 raise ValueError( 312 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 313 ) 314 315 return transforms_handler(self, expression) 316 317 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 318 319 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 73 window = exp.alias_(window, row_number) 74 expression.select(window, copy=False) 75 76 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 77 78 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
81def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 82 """ 83 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 84 85 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 86 https://docs.snowflake.com/en/sql-reference/constructs/qualify 87 88 Some dialects don't support window functions in the WHERE clause, so we need to include them as 89 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 90 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 91 otherwise we won't be able to refer to it in the outer query's WHERE clause. 92 """ 93 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 94 taken = set(expression.named_selects) 95 for select in expression.selects: 96 if not select.alias_or_name: 97 alias = find_new_name(taken, "_c") 98 select.replace(exp.alias_(select, alias)) 99 taken.add(alias) 100 101 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 102 qualify_filters = expression.args["qualify"].pop().this 103 104 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 105 if isinstance(expr, exp.Window): 106 alias = find_new_name(expression.named_selects, "_w") 107 expression.select(exp.alias_(expr, alias), copy=False) 108 column = exp.column(alias) 109 110 if isinstance(expr.parent, exp.Qualify): 111 qualify_filters = column 112 else: 113 expr.replace(column) 114 elif expr.name not in expression.named_selects: 115 expression.select(expr.copy(), copy=False) 116 117 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 118 119 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 123 """ 124 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 125 other expressions. This transforms removes the precision from parameterized types in expressions. 126 """ 127 for node in expression.find_all(exp.DataType): 128 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 129 130 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
133def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 134 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 135 if isinstance(expression, exp.Select): 136 for join in expression.args.get("joins") or []: 137 unnest = join.this 138 139 if isinstance(unnest, exp.Unnest): 140 alias = unnest.args.get("alias") 141 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 142 143 expression.args["joins"].remove(join) 144 145 for e, column in zip(unnest.expressions, alias.columns if alias else []): 146 expression.append( 147 "laterals", 148 exp.Lateral( 149 this=udtf(this=e), 150 view=True, 151 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 152 ), 153 ) 154 155 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
158def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 159 """Convert explode/posexplode into unnest (used in hive -> presto).""" 160 if isinstance(expression, exp.Select): 161 from sqlglot.optimizer.scope import Scope 162 163 taken_select_names = set(expression.named_selects) 164 taken_source_names = {name for name, _ in Scope(expression).references} 165 166 for select in expression.selects: 167 to_replace = select 168 169 pos_alias = "" 170 explode_alias = "" 171 172 if isinstance(select, exp.Alias): 173 explode_alias = select.alias 174 select = select.this 175 elif isinstance(select, exp.Aliases): 176 pos_alias = select.aliases[0].name 177 explode_alias = select.aliases[1].name 178 select = select.this 179 180 if isinstance(select, (exp.Explode, exp.Posexplode)): 181 is_posexplode = isinstance(select, exp.Posexplode) 182 183 explode_arg = select.this 184 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 185 186 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 187 if isinstance(explode_arg, exp.Column): 188 taken_select_names.add(explode_arg.output_name) 189 190 unnest_source_alias = find_new_name(taken_source_names, "_u") 191 taken_source_names.add(unnest_source_alias) 192 193 if not explode_alias: 194 explode_alias = find_new_name(taken_select_names, "col") 195 taken_select_names.add(explode_alias) 196 197 if is_posexplode: 198 pos_alias = find_new_name(taken_select_names, "pos") 199 taken_select_names.add(pos_alias) 200 201 if is_posexplode: 202 column_names = [explode_alias, pos_alias] 203 to_replace.pop() 204 expression.select(pos_alias, explode_alias, copy=False) 205 else: 206 column_names = [explode_alias] 207 to_replace.replace(exp.column(explode_alias)) 208 209 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 210 211 if not expression.args.get("from"): 212 expression.from_(unnest, copy=False) 213 else: 214 expression.join(unnest, join_type="CROSS", copy=False) 215 216 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
219def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 220 """Remove table refs from columns in when statements.""" 221 if isinstance(expression, exp.Merge): 222 alias = expression.this.args.get("alias") 223 targets = {expression.this.this} 224 if alias: 225 targets.add(alias.this) 226 227 for when in expression.expressions: 228 when.transform( 229 lambda node: exp.column(node.name) 230 if isinstance(node, exp.Column) and node.args.get("table") in targets 231 else node, 232 copy=False, 233 ) 234 235 return expression
Remove table refs from columns in when statements.
238def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 239 if ( 240 isinstance(expression, exp.WithinGroup) 241 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 242 and isinstance(expression.expression, exp.Order) 243 ): 244 quantile = expression.this.this 245 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 246 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 247 248 return expression
251def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 252 if isinstance(expression, exp.With) and expression.recursive: 253 next_name = name_sequence("_c_") 254 255 for cte in expression.expressions: 256 if not cte.args["alias"].columns: 257 query = cte.this 258 if isinstance(query, exp.Union): 259 query = query.this 260 261 cte.args["alias"].set( 262 "columns", 263 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 264 ) 265 266 return expression
269def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 270 if ( 271 isinstance(expression, (exp.Cast, exp.TryCast)) 272 and expression.name.lower() == "epoch" 273 and expression.to.this in exp.DataType.TEMPORAL_TYPES 274 ): 275 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 276 277 return expression
280def preprocess( 281 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 282) -> t.Callable[[Generator, exp.Expression], str]: 283 """ 284 Creates a new transform by chaining a sequence of transformations and converts the resulting 285 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 286 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 287 288 Args: 289 transforms: sequence of transform functions. These will be called in order. 290 291 Returns: 292 Function that can be used as a generator transform. 293 """ 294 295 def _to_sql(self, expression: exp.Expression) -> str: 296 expression_type = type(expression) 297 298 expression = transforms[0](expression.copy()) 299 for t in transforms[1:]: 300 expression = t(expression) 301 302 _sql_handler = getattr(self, expression.key + "_sql", None) 303 if _sql_handler: 304 return _sql_handler(expression) 305 306 transforms_handler = self.TRANSFORMS.get(type(expression)) 307 if transforms_handler: 308 # Ensures we don't enter an infinite loop. This can happen when the original expression 309 # has the same type as the final expression and there's no _sql method available for it, 310 # because then it'd re-enter _to_sql. 311 if expression_type is type(expression): 312 raise ValueError( 313 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 314 ) 315 316 return transforms_handler(self, expression) 317 318 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 319 320 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.