sqlglot.optimizer.pushdown_projections
1from collections import defaultdict 2 3from sqlglot import alias, exp 4from sqlglot.optimizer.qualify_columns import Resolver 5from sqlglot.optimizer.scope import Scope, traverse_scope 6from sqlglot.schema import ensure_schema 7from sqlglot.errors import OptimizeError 8 9# Sentinel value that means an outer query selecting ALL columns 10SELECT_ALL = object() 11 12 13# Selection to use if selection list is empty 14def default_selection(is_agg: bool) -> exp.Alias: 15 return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") 16 17 18def pushdown_projections(expression, schema=None, remove_unused_selections=True): 19 """ 20 Rewrite sqlglot AST to remove unused columns projections. 21 22 Example: 23 >>> import sqlglot 24 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 25 >>> expression = sqlglot.parse_one(sql) 26 >>> pushdown_projections(expression).sql() 27 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 28 29 Args: 30 expression (sqlglot.Expression): expression to optimize 31 remove_unused_selections (bool): remove selects that are unused 32 Returns: 33 sqlglot.Expression: optimized expression 34 """ 35 # Map of Scope to all columns being selected by outer queries. 36 schema = ensure_schema(schema) 37 source_column_alias_count = {} 38 referenced_columns = defaultdict(set) 39 40 # We build the scope tree (which is traversed in DFS postorder), then iterate 41 # over the result in reverse order. This should ensure that the set of selected 42 # columns for a particular scope are completely build by the time we get to it. 43 for scope in reversed(traverse_scope(expression)): 44 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 45 alias_count = source_column_alias_count.get(scope, 0) 46 47 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 48 if scope.expression.args.get("distinct"): 49 parent_selections = {SELECT_ALL} 50 51 if isinstance(scope.expression, exp.SetOperation): 52 left, right = scope.union_scopes 53 if len(left.expression.selects) != len(right.expression.selects): 54 scope_sql = scope.expression.sql() 55 raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.") 56 57 referenced_columns[left] = parent_selections 58 59 if any(select.is_star for select in right.expression.selects): 60 referenced_columns[right] = parent_selections 61 elif not any(select.is_star for select in left.expression.selects): 62 if scope.expression.args.get("by_name"): 63 referenced_columns[right] = referenced_columns[left] 64 else: 65 referenced_columns[right] = [ 66 right.expression.selects[i].alias_or_name 67 for i, select in enumerate(left.expression.selects) 68 if SELECT_ALL in parent_selections 69 or select.alias_or_name in parent_selections 70 ] 71 72 if isinstance(scope.expression, exp.Select): 73 if remove_unused_selections: 74 _remove_unused_selections(scope, parent_selections, schema, alias_count) 75 76 if scope.expression.is_star: 77 continue 78 79 # Group columns by source name 80 selects = defaultdict(set) 81 for col in scope.columns: 82 table_name = col.table 83 col_name = col.name 84 selects[table_name].add(col_name) 85 86 # Push the selected columns down to the next scope 87 for name, (node, source) in scope.selected_sources.items(): 88 if isinstance(source, Scope): 89 columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set() 90 referenced_columns[source].update(columns) 91 92 column_aliases = node.alias_column_names 93 if column_aliases: 94 source_column_alias_count[source] = len(column_aliases) 95 96 return expression 97 98 99def _remove_unused_selections(scope, parent_selections, schema, alias_count): 100 order = scope.expression.args.get("order") 101 102 if order: 103 # Assume columns without a qualified table are references to output columns 104 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 105 else: 106 order_refs = set() 107 108 new_selections = [] 109 removed = False 110 star = False 111 is_agg = False 112 113 select_all = SELECT_ALL in parent_selections 114 115 for selection in scope.expression.selects: 116 name = selection.alias_or_name 117 118 if select_all or name in parent_selections or name in order_refs or alias_count > 0: 119 new_selections.append(selection) 120 alias_count -= 1 121 else: 122 if selection.is_star: 123 star = True 124 removed = True 125 126 if not is_agg and selection.find(exp.AggFunc): 127 is_agg = True 128 129 if star: 130 resolver = Resolver(scope, schema) 131 names = {s.alias_or_name for s in new_selections} 132 133 for name in sorted(parent_selections): 134 if name not in names: 135 new_selections.append( 136 alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) 137 ) 138 139 # If there are no remaining selections, just select a single constant 140 if not new_selections: 141 new_selections.append(default_selection(is_agg)) 142 143 scope.expression.select(*new_selections, append=False, copy=False) 144 145 if removed: 146 scope.clear_cache()
SELECT_ALL =
<object object>
def
pushdown_projections(expression, schema=None, remove_unused_selections=True):
19def pushdown_projections(expression, schema=None, remove_unused_selections=True): 20 """ 21 Rewrite sqlglot AST to remove unused columns projections. 22 23 Example: 24 >>> import sqlglot 25 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 26 >>> expression = sqlglot.parse_one(sql) 27 >>> pushdown_projections(expression).sql() 28 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 29 30 Args: 31 expression (sqlglot.Expression): expression to optimize 32 remove_unused_selections (bool): remove selects that are unused 33 Returns: 34 sqlglot.Expression: optimized expression 35 """ 36 # Map of Scope to all columns being selected by outer queries. 37 schema = ensure_schema(schema) 38 source_column_alias_count = {} 39 referenced_columns = defaultdict(set) 40 41 # We build the scope tree (which is traversed in DFS postorder), then iterate 42 # over the result in reverse order. This should ensure that the set of selected 43 # columns for a particular scope are completely build by the time we get to it. 44 for scope in reversed(traverse_scope(expression)): 45 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 46 alias_count = source_column_alias_count.get(scope, 0) 47 48 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 49 if scope.expression.args.get("distinct"): 50 parent_selections = {SELECT_ALL} 51 52 if isinstance(scope.expression, exp.SetOperation): 53 left, right = scope.union_scopes 54 if len(left.expression.selects) != len(right.expression.selects): 55 scope_sql = scope.expression.sql() 56 raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.") 57 58 referenced_columns[left] = parent_selections 59 60 if any(select.is_star for select in right.expression.selects): 61 referenced_columns[right] = parent_selections 62 elif not any(select.is_star for select in left.expression.selects): 63 if scope.expression.args.get("by_name"): 64 referenced_columns[right] = referenced_columns[left] 65 else: 66 referenced_columns[right] = [ 67 right.expression.selects[i].alias_or_name 68 for i, select in enumerate(left.expression.selects) 69 if SELECT_ALL in parent_selections 70 or select.alias_or_name in parent_selections 71 ] 72 73 if isinstance(scope.expression, exp.Select): 74 if remove_unused_selections: 75 _remove_unused_selections(scope, parent_selections, schema, alias_count) 76 77 if scope.expression.is_star: 78 continue 79 80 # Group columns by source name 81 selects = defaultdict(set) 82 for col in scope.columns: 83 table_name = col.table 84 col_name = col.name 85 selects[table_name].add(col_name) 86 87 # Push the selected columns down to the next scope 88 for name, (node, source) in scope.selected_sources.items(): 89 if isinstance(source, Scope): 90 columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set() 91 referenced_columns[source].update(columns) 92 93 column_aliases = node.alias_column_names 94 if column_aliases: 95 source_column_alias_count[source] = len(column_aliases) 96 97 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression