sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope, find_in_scope 4from sqlglot.optimizer.simplify import simplify 5 6 7def pushdown_predicates(expression, dialect=None): 8 """ 9 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 10 11 Example: 12 >>> import sqlglot 13 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 14 >>> expression = sqlglot.parse_one(sql) 15 >>> pushdown_predicates(expression).sql() 16 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 17 18 Args: 19 expression (sqlglot.Expression): expression to optimize 20 Returns: 21 sqlglot.Expression: optimized expression 22 """ 23 root = build_scope(expression) 24 25 if root: 26 scope_ref_count = root.ref_count() 27 28 for scope in reversed(list(root.traverse())): 29 select = scope.expression 30 where = select.args.get("where") 31 if where: 32 selected_sources = scope.selected_sources 33 # a right join can only push down to itself and not the source FROM table 34 for k, (node, source) in selected_sources.items(): 35 parent = node.find_ancestor(exp.Join, exp.From) 36 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 37 selected_sources = {k: (node, source)} 38 break 39 pushdown(where.this, selected_sources, scope_ref_count, dialect) 40 41 # joins should only pushdown into itself, not to other joins 42 # so we limit the selected sources to only itself 43 for join in select.args.get("joins") or []: 44 name = join.alias_or_name 45 if name in scope.selected_sources: 46 pushdown( 47 join.args.get("on"), 48 {name: scope.selected_sources[name]}, 49 scope_ref_count, 50 dialect, 51 ) 52 53 return expression 54 55 56def pushdown(condition, sources, scope_ref_count, dialect): 57 if not condition: 58 return 59 60 condition = condition.replace(simplify(condition, dialect=dialect)) 61 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 62 63 predicates = list( 64 condition.flatten() 65 if isinstance(condition, exp.And if cnf_like else exp.Or) 66 else [condition] 67 ) 68 69 if cnf_like: 70 pushdown_cnf(predicates, sources, scope_ref_count) 71 else: 72 pushdown_dnf(predicates, sources, scope_ref_count) 73 74 75def pushdown_cnf(predicates, scope, scope_ref_count): 76 """ 77 If the predicates are in CNF like form, we can simply replace each block in the parent. 78 """ 79 for predicate in predicates: 80 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 81 if isinstance(node, exp.Join): 82 predicate.replace(exp.true()) 83 node.on(predicate, copy=False) 84 break 85 if isinstance(node, exp.Select): 86 predicate.replace(exp.true()) 87 inner_predicate = replace_aliases(node, predicate) 88 if find_in_scope(inner_predicate, exp.AggFunc): 89 node.having(inner_predicate, copy=False) 90 else: 91 node.where(inner_predicate, copy=False) 92 93 94def pushdown_dnf(predicates, scope, scope_ref_count): 95 """ 96 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 97 Additionally, we can't remove predicates from their original form. 98 """ 99 # find all the tables that can be pushdown too 100 # these are tables that are referenced in all blocks of a DNF 101 # (a.x AND b.x) OR (a.y AND c.y) 102 # only table a can be push down 103 pushdown_tables = set() 104 105 for a in predicates: 106 a_tables = exp.column_table_names(a) 107 108 for b in predicates: 109 a_tables &= exp.column_table_names(b) 110 111 pushdown_tables.update(a_tables) 112 113 conditions = {} 114 115 # pushdown all predicates to their respective nodes 116 for table in sorted(pushdown_tables): 117 for predicate in predicates: 118 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 119 120 if table not in nodes: 121 continue 122 123 conditions[table] = ( 124 exp.or_(conditions[table], predicate) if table in conditions else predicate 125 ) 126 127 for name, node in nodes.items(): 128 if name not in conditions: 129 continue 130 131 predicate = conditions[name] 132 133 if isinstance(node, exp.Join): 134 node.on(predicate, copy=False) 135 elif isinstance(node, exp.Select): 136 inner_predicate = replace_aliases(node, predicate) 137 if find_in_scope(inner_predicate, exp.AggFunc): 138 node.having(inner_predicate, copy=False) 139 else: 140 node.where(inner_predicate, copy=False) 141 142 143def nodes_for_predicate(predicate, sources, scope_ref_count): 144 nodes = {} 145 tables = exp.column_table_names(predicate) 146 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 147 148 for table in sorted(tables): 149 node, source = sources.get(table) or (None, None) 150 151 # if the predicate is in a where statement we can try to push it down 152 # we want to find the root join or from statement 153 if node and where_condition: 154 node = node.find_ancestor(exp.Join, exp.From) 155 156 # a node can reference a CTE which should be pushed down 157 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 158 with_ = source.parent.expression.args.get("with") 159 if with_ and with_.recursive: 160 return {} 161 node = source.expression 162 163 if isinstance(node, exp.Join): 164 if node.side and node.side != "RIGHT": 165 return {} 166 nodes[table] = node 167 elif isinstance(node, exp.Select) and len(tables) == 1: 168 # We can't push down window expressions 169 has_window_expression = any( 170 select for select in node.selects if select.find(exp.Window) 171 ) 172 # we can't push down predicates to select statements if they are referenced in 173 # multiple places. 174 if ( 175 not node.args.get("group") 176 and scope_ref_count[id(source)] < 2 177 and not has_window_expression 178 ): 179 nodes[table] = node 180 return nodes 181 182 183def replace_aliases(source, predicate): 184 aliases = {} 185 186 for select in source.selects: 187 if isinstance(select, exp.Alias): 188 aliases[select.alias] = select.this 189 else: 190 aliases[select.name] = select 191 192 def _replace_alias(column): 193 if isinstance(column, exp.Column) and column.name in aliases: 194 return aliases[column.name].copy() 195 return column 196 197 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression, dialect=None):
8def pushdown_predicates(expression, dialect=None): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 root = build_scope(expression) 25 26 if root: 27 scope_ref_count = root.ref_count() 28 29 for scope in reversed(list(root.traverse())): 30 select = scope.expression 31 where = select.args.get("where") 32 if where: 33 selected_sources = scope.selected_sources 34 # a right join can only push down to itself and not the source FROM table 35 for k, (node, source) in selected_sources.items(): 36 parent = node.find_ancestor(exp.Join, exp.From) 37 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 38 selected_sources = {k: (node, source)} 39 break 40 pushdown(where.this, selected_sources, scope_ref_count, dialect) 41 42 # joins should only pushdown into itself, not to other joins 43 # so we limit the selected sources to only itself 44 for join in select.args.get("joins") or []: 45 name = join.alias_or_name 46 if name in scope.selected_sources: 47 pushdown( 48 join.args.get("on"), 49 {name: scope.selected_sources[name]}, 50 scope_ref_count, 51 dialect, 52 ) 53 54 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count, dialect):
57def pushdown(condition, sources, scope_ref_count, dialect): 58 if not condition: 59 return 60 61 condition = condition.replace(simplify(condition, dialect=dialect)) 62 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 63 64 predicates = list( 65 condition.flatten() 66 if isinstance(condition, exp.And if cnf_like else exp.Or) 67 else [condition] 68 ) 69 70 if cnf_like: 71 pushdown_cnf(predicates, sources, scope_ref_count) 72 else: 73 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, scope, scope_ref_count):
76def pushdown_cnf(predicates, scope, scope_ref_count): 77 """ 78 If the predicates are in CNF like form, we can simply replace each block in the parent. 79 """ 80 for predicate in predicates: 81 for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): 82 if isinstance(node, exp.Join): 83 predicate.replace(exp.true()) 84 node.on(predicate, copy=False) 85 break 86 if isinstance(node, exp.Select): 87 predicate.replace(exp.true()) 88 inner_predicate = replace_aliases(node, predicate) 89 if find_in_scope(inner_predicate, exp.AggFunc): 90 node.having(inner_predicate, copy=False) 91 else: 92 node.where(inner_predicate, copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, scope, scope_ref_count):
95def pushdown_dnf(predicates, scope, scope_ref_count): 96 """ 97 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 98 Additionally, we can't remove predicates from their original form. 99 """ 100 # find all the tables that can be pushdown too 101 # these are tables that are referenced in all blocks of a DNF 102 # (a.x AND b.x) OR (a.y AND c.y) 103 # only table a can be push down 104 pushdown_tables = set() 105 106 for a in predicates: 107 a_tables = exp.column_table_names(a) 108 109 for b in predicates: 110 a_tables &= exp.column_table_names(b) 111 112 pushdown_tables.update(a_tables) 113 114 conditions = {} 115 116 # pushdown all predicates to their respective nodes 117 for table in sorted(pushdown_tables): 118 for predicate in predicates: 119 nodes = nodes_for_predicate(predicate, scope, scope_ref_count) 120 121 if table not in nodes: 122 continue 123 124 conditions[table] = ( 125 exp.or_(conditions[table], predicate) if table in conditions else predicate 126 ) 127 128 for name, node in nodes.items(): 129 if name not in conditions: 130 continue 131 132 predicate = conditions[name] 133 134 if isinstance(node, exp.Join): 135 node.on(predicate, copy=False) 136 elif isinstance(node, exp.Select): 137 inner_predicate = replace_aliases(node, predicate) 138 if find_in_scope(inner_predicate, exp.AggFunc): 139 node.having(inner_predicate, copy=False) 140 else: 141 node.where(inner_predicate, copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
144def nodes_for_predicate(predicate, sources, scope_ref_count): 145 nodes = {} 146 tables = exp.column_table_names(predicate) 147 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 148 149 for table in sorted(tables): 150 node, source = sources.get(table) or (None, None) 151 152 # if the predicate is in a where statement we can try to push it down 153 # we want to find the root join or from statement 154 if node and where_condition: 155 node = node.find_ancestor(exp.Join, exp.From) 156 157 # a node can reference a CTE which should be pushed down 158 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 159 with_ = source.parent.expression.args.get("with") 160 if with_ and with_.recursive: 161 return {} 162 node = source.expression 163 164 if isinstance(node, exp.Join): 165 if node.side and node.side != "RIGHT": 166 return {} 167 nodes[table] = node 168 elif isinstance(node, exp.Select) and len(tables) == 1: 169 # We can't push down window expressions 170 has_window_expression = any( 171 select for select in node.selects if select.find(exp.Window) 172 ) 173 # we can't push down predicates to select statements if they are referenced in 174 # multiple places. 175 if ( 176 not node.args.get("group") 177 and scope_ref_count[id(source)] < 2 178 and not has_window_expression 179 ): 180 nodes[table] = node 181 return nodes
def
replace_aliases(source, predicate):
184def replace_aliases(source, predicate): 185 aliases = {} 186 187 for select in source.selects: 188 if isinstance(select, exp.Alias): 189 aliases[select.alias] = select.this 190 else: 191 aliases[select.name] = select 192 193 def _replace_alias(column): 194 if isinstance(column, exp.Column) and column.name in aliases: 195 return aliases[column.name].copy() 196 return column 197 198 return predicate.transform(_replace_alias)