sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name 12 13logger = logging.getLogger("sqlglot") 14 15 16class ScopeType(Enum): 17 ROOT = auto() 18 SUBQUERY = auto() 19 DERIVED_TABLE = auto() 20 CTE = auto() 21 UNION = auto() 22 UDTF = auto() 23 24 25class Scope: 26 """ 27 Selection scope. 28 29 Attributes: 30 expression (exp.Select|exp.Union): Root expression of this scope 31 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 32 a Table expression or another Scope instance. For example: 33 SELECT * FROM x {"x": Table(this="x")} 34 SELECT * FROM x AS y {"y": Table(this="x")} 35 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 36 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 37 For example: 38 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 39 The LATERAL VIEW EXPLODE gets x as a source. 40 cte_sources (dict[str, Scope]): Sources from CTES 41 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 42 defines a column list of it's alias of this scope, this is that list of columns. 43 For example: 44 SELECT * FROM (SELECT ...) AS y(col1, col2) 45 The inner query would have `["col1", "col2"]` for its `outer_column_list` 46 parent (Scope): Parent scope 47 scope_type (ScopeType): Type of this scope, relative to it's parent 48 subquery_scopes (list[Scope]): List of all child scopes for subqueries 49 cte_scopes (list[Scope]): List of all child scopes for CTEs 50 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 51 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 52 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 53 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 54 a list of the left and right child scopes. 55 """ 56 57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_column_list=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 cte_sources=None, 66 ): 67 self.expression = expression 68 self.sources = sources or {} 69 self.lateral_sources = lateral_sources or {} 70 self.cte_sources = cte_sources or {} 71 self.sources.update(self.lateral_sources) 72 self.sources.update(self.cte_sources) 73 self.outer_column_list = outer_column_list or [] 74 self.parent = parent 75 self.scope_type = scope_type 76 self.subquery_scopes = [] 77 self.derived_table_scopes = [] 78 self.table_scopes = [] 79 self.cte_scopes = [] 80 self.union_scopes = [] 81 self.udtf_scopes = [] 82 self.clear_cache() 83 84 def clear_cache(self): 85 self._collected = False 86 self._raw_columns = None 87 self._derived_tables = None 88 self._udtfs = None 89 self._tables = None 90 self._ctes = None 91 self._subqueries = None 92 self._selected_sources = None 93 self._columns = None 94 self._external_columns = None 95 self._join_hints = None 96 self._pivots = None 97 self._references = None 98 99 def branch( 100 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 101 ): 102 """Branch from the current scope to a new, inner scope""" 103 return Scope( 104 expression=expression.unnest(), 105 sources=sources.copy() if sources else None, 106 parent=self, 107 scope_type=scope_type, 108 cte_sources={**self.cte_sources, **(cte_sources or {})}, 109 lateral_sources=lateral_sources.copy() if lateral_sources else None, 110 **kwargs, 111 ) 112 113 def _collect(self): 114 self._tables = [] 115 self._ctes = [] 116 self._subqueries = [] 117 self._derived_tables = [] 118 self._udtfs = [] 119 self._raw_columns = [] 120 self._join_hints = [] 121 122 for node, parent, _ in self.walk(bfs=False): 123 if node is self.expression: 124 continue 125 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 126 self._raw_columns.append(node) 127 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 128 self._tables.append(node) 129 elif isinstance(node, exp.JoinHint): 130 self._join_hints.append(node) 131 elif isinstance(node, exp.UDTF): 132 self._udtfs.append(node) 133 elif isinstance(node, exp.CTE): 134 self._ctes.append(node) 135 elif ( 136 isinstance(node, exp.Subquery) 137 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 138 and _is_derived_table(node) 139 ): 140 self._derived_tables.append(node) 141 elif isinstance(node, exp.Subqueryable): 142 self._subqueries.append(node) 143 144 self._collected = True 145 146 def _ensure_collected(self): 147 if not self._collected: 148 self._collect() 149 150 def walk(self, bfs=True, prune=None): 151 return walk_in_scope(self.expression, bfs=bfs, prune=None) 152 153 def find(self, *expression_types, bfs=True): 154 return find_in_scope(self.expression, expression_types, bfs=bfs) 155 156 def find_all(self, *expression_types, bfs=True): 157 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 158 159 def replace(self, old, new): 160 """ 161 Replace `old` with `new`. 162 163 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 164 165 Args: 166 old (exp.Expression): old node 167 new (exp.Expression): new node 168 """ 169 old.replace(new) 170 self.clear_cache() 171 172 @property 173 def tables(self): 174 """ 175 List of tables in this scope. 176 177 Returns: 178 list[exp.Table]: tables 179 """ 180 self._ensure_collected() 181 return self._tables 182 183 @property 184 def ctes(self): 185 """ 186 List of CTEs in this scope. 187 188 Returns: 189 list[exp.CTE]: ctes 190 """ 191 self._ensure_collected() 192 return self._ctes 193 194 @property 195 def derived_tables(self): 196 """ 197 List of derived tables in this scope. 198 199 For example: 200 SELECT * FROM (SELECT ...) <- that's a derived table 201 202 Returns: 203 list[exp.Subquery]: derived tables 204 """ 205 self._ensure_collected() 206 return self._derived_tables 207 208 @property 209 def udtfs(self): 210 """ 211 List of "User Defined Tabular Functions" in this scope. 212 213 Returns: 214 list[exp.UDTF]: UDTFs 215 """ 216 self._ensure_collected() 217 return self._udtfs 218 219 @property 220 def subqueries(self): 221 """ 222 List of subqueries in this scope. 223 224 For example: 225 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 226 227 Returns: 228 list[exp.Subqueryable]: subqueries 229 """ 230 self._ensure_collected() 231 return self._subqueries 232 233 @property 234 def columns(self): 235 """ 236 List of columns in this scope. 237 238 Returns: 239 list[exp.Column]: Column instances in this scope, plus any 240 Columns that reference this scope from correlated subqueries. 241 """ 242 if self._columns is None: 243 self._ensure_collected() 244 columns = self._raw_columns 245 246 external_columns = [ 247 column 248 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 249 for column in scope.external_columns 250 ] 251 252 named_selects = set(self.expression.named_selects) 253 254 self._columns = [] 255 for column in columns + external_columns: 256 ancestor = column.find_ancestor( 257 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 258 ) 259 if ( 260 not ancestor 261 or column.table 262 or isinstance(ancestor, exp.Select) 263 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 264 or ( 265 isinstance(ancestor, exp.Order) 266 and ( 267 isinstance(ancestor.parent, exp.Window) 268 or column.name not in named_selects 269 ) 270 ) 271 ): 272 self._columns.append(column) 273 274 return self._columns 275 276 @property 277 def selected_sources(self): 278 """ 279 Mapping of nodes and sources that are actually selected from in this scope. 280 281 That is, all tables in a schema are selectable at any point. But a 282 table only becomes a selected source if it's included in a FROM or JOIN clause. 283 284 Returns: 285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 286 """ 287 if self._selected_sources is None: 288 result = {} 289 290 for name, node in self.references: 291 if name in result: 292 raise OptimizeError(f"Alias already used: {name}") 293 if name in self.sources: 294 result[name] = (node, self.sources[name]) 295 296 self._selected_sources = result 297 return self._selected_sources 298 299 @property 300 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 301 if self._references is None: 302 self._references = [] 303 304 for table in self.tables: 305 self._references.append((table.alias_or_name, table)) 306 for expression in itertools.chain(self.derived_tables, self.udtfs): 307 self._references.append( 308 ( 309 expression.alias, 310 expression if expression.args.get("pivots") else expression.unnest(), 311 ) 312 ) 313 314 return self._references 315 316 @property 317 def external_columns(self): 318 """ 319 Columns that appear to reference sources in outer scopes. 320 321 Returns: 322 list[exp.Column]: Column instances that don't reference 323 sources in the current scope. 324 """ 325 if self._external_columns is None: 326 self._external_columns = [ 327 c for c in self.columns if c.table not in self.selected_sources 328 ] 329 return self._external_columns 330 331 @property 332 def unqualified_columns(self): 333 """ 334 Unqualified columns in the current scope. 335 336 Returns: 337 list[exp.Column]: Unqualified columns 338 """ 339 return [c for c in self.columns if not c.table] 340 341 @property 342 def join_hints(self): 343 """ 344 Hints that exist in the scope that reference tables 345 346 Returns: 347 list[exp.JoinHint]: Join hints that are referenced within the scope 348 """ 349 if self._join_hints is None: 350 return [] 351 return self._join_hints 352 353 @property 354 def pivots(self): 355 if not self._pivots: 356 self._pivots = [ 357 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 358 ] 359 360 return self._pivots 361 362 def source_columns(self, source_name): 363 """ 364 Get all columns in the current scope for a particular source. 365 366 Args: 367 source_name (str): Name of the source 368 Returns: 369 list[exp.Column]: Column instances that reference `source_name` 370 """ 371 return [column for column in self.columns if column.table == source_name] 372 373 @property 374 def is_subquery(self): 375 """Determine if this scope is a subquery""" 376 return self.scope_type == ScopeType.SUBQUERY 377 378 @property 379 def is_derived_table(self): 380 """Determine if this scope is a derived table""" 381 return self.scope_type == ScopeType.DERIVED_TABLE 382 383 @property 384 def is_union(self): 385 """Determine if this scope is a union""" 386 return self.scope_type == ScopeType.UNION 387 388 @property 389 def is_cte(self): 390 """Determine if this scope is a common table expression""" 391 return self.scope_type == ScopeType.CTE 392 393 @property 394 def is_root(self): 395 """Determine if this is the root scope""" 396 return self.scope_type == ScopeType.ROOT 397 398 @property 399 def is_udtf(self): 400 """Determine if this scope is a UDTF (User Defined Table Function)""" 401 return self.scope_type == ScopeType.UDTF 402 403 @property 404 def is_correlated_subquery(self): 405 """Determine if this scope is a correlated subquery""" 406 return bool( 407 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 408 and self.external_columns 409 ) 410 411 def rename_source(self, old_name, new_name): 412 """Rename a source in this scope""" 413 columns = self.sources.pop(old_name or "", []) 414 self.sources[new_name] = columns 415 416 def add_source(self, name, source): 417 """Add a source to this scope""" 418 self.sources[name] = source 419 self.clear_cache() 420 421 def remove_source(self, name): 422 """Remove a source from this scope""" 423 self.sources.pop(name, None) 424 self.clear_cache() 425 426 def __repr__(self): 427 return f"Scope<{self.expression.sql()}>" 428 429 def traverse(self): 430 """ 431 Traverse the scope tree from this node. 432 433 Yields: 434 Scope: scope instances in depth-first-search post-order 435 """ 436 for child_scope in itertools.chain( 437 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 438 ): 439 yield from child_scope.traverse() 440 yield self 441 442 def ref_count(self): 443 """ 444 Count the number of times each scope in this tree is referenced. 445 446 Returns: 447 dict[int, int]: Mapping of Scope instance ID to reference count 448 """ 449 scope_ref_count = defaultdict(lambda: 0) 450 451 for scope in self.traverse(): 452 for _, source in scope.selected_sources.values(): 453 scope_ref_count[id(source)] += 1 454 455 return scope_ref_count 456 457 458def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 459 """ 460 Traverse an expression by its "scopes". 461 462 "Scope" represents the current context of a Select statement. 463 464 This is helpful for optimizing queries, where we need more information than 465 the expression tree itself. For example, we might care about the source 466 names within a subquery. Returns a list because a generator could result in 467 incomplete properties which is confusing. 468 469 Examples: 470 >>> import sqlglot 471 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 472 >>> scopes = traverse_scope(expression) 473 >>> scopes[0].expression.sql(), list(scopes[0].sources) 474 ('SELECT a FROM x', ['x']) 475 >>> scopes[1].expression.sql(), list(scopes[1].sources) 476 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 477 478 Args: 479 expression (exp.Expression): expression to traverse 480 Returns: 481 list[Scope]: scope instances 482 """ 483 if isinstance(expression, exp.Unionable) or ( 484 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) 485 ): 486 return list(_traverse_scope(Scope(expression))) 487 488 return [] 489 490 491def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 492 """ 493 Build a scope tree. 494 495 Args: 496 expression (exp.Expression): expression to build the scope tree for 497 Returns: 498 Scope: root scope 499 """ 500 scopes = traverse_scope(expression) 501 if scopes: 502 return scopes[-1] 503 return None 504 505 506def _traverse_scope(scope): 507 if isinstance(scope.expression, exp.Select): 508 yield from _traverse_select(scope) 509 elif isinstance(scope.expression, exp.Union): 510 yield from _traverse_union(scope) 511 elif isinstance(scope.expression, exp.Subquery): 512 if scope.is_root: 513 yield from _traverse_select(scope) 514 else: 515 yield from _traverse_subqueries(scope) 516 elif isinstance(scope.expression, exp.Table): 517 yield from _traverse_tables(scope) 518 elif isinstance(scope.expression, exp.UDTF): 519 yield from _traverse_udtfs(scope) 520 elif isinstance(scope.expression, exp.DDL): 521 yield from _traverse_ddl(scope) 522 else: 523 logger.warning( 524 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 525 ) 526 return 527 528 yield scope 529 530 531def _traverse_select(scope): 532 yield from _traverse_ctes(scope) 533 yield from _traverse_tables(scope) 534 yield from _traverse_subqueries(scope) 535 536 537def _traverse_union(scope): 538 yield from _traverse_ctes(scope) 539 540 # The last scope to be yield should be the top most scope 541 left = None 542 for left in _traverse_scope( 543 scope.branch( 544 scope.expression.left, 545 outer_column_list=scope.outer_column_list, 546 scope_type=ScopeType.UNION, 547 ) 548 ): 549 yield left 550 551 right = None 552 for right in _traverse_scope( 553 scope.branch( 554 scope.expression.right, 555 outer_column_list=scope.outer_column_list, 556 scope_type=ScopeType.UNION, 557 ) 558 ): 559 yield right 560 561 scope.union_scopes = [left, right] 562 563 564def _traverse_ctes(scope): 565 sources = {} 566 567 for cte in scope.ctes: 568 recursive_scope = None 569 570 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 571 # thus the recursive scope is the first section of the union. 572 with_ = scope.expression.args.get("with") 573 if with_ and with_.recursive: 574 union = cte.this 575 576 if isinstance(union, exp.Union): 577 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 578 579 child_scope = None 580 581 for child_scope in _traverse_scope( 582 scope.branch( 583 cte.this, 584 cte_sources=sources, 585 outer_column_list=cte.alias_column_names, 586 scope_type=ScopeType.CTE, 587 ) 588 ): 589 yield child_scope 590 591 alias = cte.alias 592 sources[alias] = child_scope 593 594 if recursive_scope: 595 child_scope.add_source(alias, recursive_scope) 596 child_scope.cte_sources[alias] = recursive_scope 597 598 # append the final child_scope yielded 599 if child_scope: 600 scope.cte_scopes.append(child_scope) 601 602 scope.sources.update(sources) 603 scope.cte_sources.update(sources) 604 605 606def _is_derived_table(expression: exp.Subquery) -> bool: 607 """ 608 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 609 as it doesn't introduce a new scope. If an alias is present, it shadows all names 610 under the Subquery, so that's one exception to this rule. 611 """ 612 return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) 613 614 615def _traverse_tables(scope): 616 sources = {} 617 618 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 619 expressions = [] 620 from_ = scope.expression.args.get("from") 621 if from_: 622 expressions.append(from_.this) 623 624 for join in scope.expression.args.get("joins") or []: 625 expressions.append(join.this) 626 627 if isinstance(scope.expression, exp.Table): 628 expressions.append(scope.expression) 629 630 expressions.extend(scope.expression.args.get("laterals") or []) 631 632 for expression in expressions: 633 if isinstance(expression, exp.Table): 634 table_name = expression.name 635 source_name = expression.alias_or_name 636 637 if table_name in scope.sources and not expression.db: 638 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 639 # it is pivoted, because then we get back a new table and hence a new source. 640 pivots = expression.args.get("pivots") 641 if pivots: 642 sources[pivots[0].alias] = expression 643 else: 644 sources[source_name] = scope.sources[table_name] 645 elif source_name in sources: 646 sources[find_new_name(sources, table_name)] = expression 647 else: 648 sources[source_name] = expression 649 650 # Make sure to not include the joins twice 651 if expression is not scope.expression: 652 expressions.extend(join.this for join in expression.args.get("joins") or []) 653 654 continue 655 656 if not isinstance(expression, exp.DerivedTable): 657 continue 658 659 if isinstance(expression, exp.UDTF): 660 lateral_sources = sources 661 scope_type = ScopeType.UDTF 662 scopes = scope.udtf_scopes 663 elif _is_derived_table(expression): 664 lateral_sources = None 665 scope_type = ScopeType.DERIVED_TABLE 666 scopes = scope.derived_table_scopes 667 expressions.extend(join.this for join in expression.args.get("joins") or []) 668 else: 669 # Makes sure we check for possible sources in nested table constructs 670 expressions.append(expression.this) 671 expressions.extend(join.this for join in expression.args.get("joins") or []) 672 continue 673 674 for child_scope in _traverse_scope( 675 scope.branch( 676 expression, 677 lateral_sources=lateral_sources, 678 outer_column_list=expression.alias_column_names, 679 scope_type=scope_type, 680 ) 681 ): 682 yield child_scope 683 684 # Tables without aliases will be set as "" 685 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 686 # Until then, this means that only a single, unaliased derived table is allowed (rather, 687 # the latest one wins. 688 sources[expression.alias] = child_scope 689 690 # append the final child_scope yielded 691 scopes.append(child_scope) 692 scope.table_scopes.append(child_scope) 693 694 scope.sources.update(sources) 695 696 697def _traverse_subqueries(scope): 698 for subquery in scope.subqueries: 699 top = None 700 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 701 yield child_scope 702 top = child_scope 703 scope.subquery_scopes.append(top) 704 705 706def _traverse_udtfs(scope): 707 if isinstance(scope.expression, exp.Unnest): 708 expressions = scope.expression.expressions 709 elif isinstance(scope.expression, exp.Lateral): 710 expressions = [scope.expression.this] 711 else: 712 expressions = [] 713 714 sources = {} 715 for expression in expressions: 716 if isinstance(expression, exp.Subquery) and _is_derived_table(expression): 717 top = None 718 for child_scope in _traverse_scope( 719 scope.branch( 720 expression, 721 scope_type=ScopeType.DERIVED_TABLE, 722 outer_column_list=expression.alias_column_names, 723 ) 724 ): 725 yield child_scope 726 top = child_scope 727 sources[expression.alias] = child_scope 728 729 scope.derived_table_scopes.append(top) 730 scope.table_scopes.append(top) 731 732 scope.sources.update(sources) 733 734 735def _traverse_ddl(scope): 736 yield from _traverse_ctes(scope) 737 738 query_scope = scope.branch( 739 scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources 740 ) 741 query_scope._collect() 742 query_scope._ctes = scope.ctes + query_scope._ctes 743 744 yield from _traverse_scope(query_scope) 745 746 747def walk_in_scope(expression, bfs=True, prune=None): 748 """ 749 Returns a generator object which visits all nodes in the syntrax tree, stopping at 750 nodes that start child scopes. 751 752 Args: 753 expression (exp.Expression): 754 bfs (bool): if set to True the BFS traversal order will be applied, 755 otherwise the DFS traversal will be used instead. 756 prune ((node, parent, arg_key) -> bool): callable that returns True if 757 the generator should stop traversing this branch of the tree. 758 759 Yields: 760 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 761 """ 762 # We'll use this variable to pass state into the dfs generator. 763 # Whenever we set it to True, we exclude a subtree from traversal. 764 crossed_scope_boundary = False 765 766 for node, parent, key in expression.walk( 767 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 768 ): 769 crossed_scope_boundary = False 770 771 yield node, parent, key 772 773 if node is expression: 774 continue 775 if ( 776 isinstance(node, exp.CTE) 777 or ( 778 isinstance(node, exp.Subquery) 779 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 780 and _is_derived_table(node) 781 ) 782 or isinstance(node, exp.UDTF) 783 or isinstance(node, exp.Subqueryable) 784 ): 785 crossed_scope_boundary = True 786 787 if isinstance(node, (exp.Subquery, exp.UDTF)): 788 # The following args are not actually in the inner scope, so we should visit them 789 for key in ("joins", "laterals", "pivots"): 790 for arg in node.args.get(key) or []: 791 yield from walk_in_scope(arg, bfs=bfs) 792 793 794def find_all_in_scope(expression, expression_types, bfs=True): 795 """ 796 Returns a generator object which visits all nodes in this scope and only yields those that 797 match at least one of the specified expression types. 798 799 This does NOT traverse into subscopes. 800 801 Args: 802 expression (exp.Expression): 803 expression_types (tuple[type]|type): the expression type(s) to match. 804 bfs (bool): True to use breadth-first search, False to use depth-first. 805 806 Yields: 807 exp.Expression: nodes 808 """ 809 for expression, *_ in walk_in_scope(expression, bfs=bfs): 810 if isinstance(expression, tuple(ensure_collection(expression_types))): 811 yield expression 812 813 814def find_in_scope(expression, expression_types, bfs=True): 815 """ 816 Returns the first node in this scope which matches at least one of the specified types. 817 818 This does NOT traverse into subscopes. 819 820 Args: 821 expression (exp.Expression): 822 expression_types (tuple[type]|type): the expression type(s) to match. 823 bfs (bool): True to use breadth-first search, False to use depth-first. 824 825 Returns: 826 exp.Expression: the node which matches the criteria or None if no node matching 827 the criteria was found. 828 """ 829 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
17class ScopeType(Enum): 18 ROOT = auto() 19 SUBQUERY = auto() 20 DERIVED_TABLE = auto() 21 CTE = auto() 22 UNION = auto() 23 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
26class Scope: 27 """ 28 Selection scope. 29 30 Attributes: 31 expression (exp.Select|exp.Union): Root expression of this scope 32 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 33 a Table expression or another Scope instance. For example: 34 SELECT * FROM x {"x": Table(this="x")} 35 SELECT * FROM x AS y {"y": Table(this="x")} 36 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 37 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 38 For example: 39 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 40 The LATERAL VIEW EXPLODE gets x as a source. 41 cte_sources (dict[str, Scope]): Sources from CTES 42 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 43 defines a column list of it's alias of this scope, this is that list of columns. 44 For example: 45 SELECT * FROM (SELECT ...) AS y(col1, col2) 46 The inner query would have `["col1", "col2"]` for its `outer_column_list` 47 parent (Scope): Parent scope 48 scope_type (ScopeType): Type of this scope, relative to it's parent 49 subquery_scopes (list[Scope]): List of all child scopes for subqueries 50 cte_scopes (list[Scope]): List of all child scopes for CTEs 51 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 52 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 53 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 54 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 55 a list of the left and right child scopes. 56 """ 57 58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache() 84 85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None 99 100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 ) 113 114 def _collect(self): 115 self._tables = [] 116 self._ctes = [] 117 self._subqueries = [] 118 self._derived_tables = [] 119 self._udtfs = [] 120 self._raw_columns = [] 121 self._join_hints = [] 122 123 for node, parent, _ in self.walk(bfs=False): 124 if node is self.expression: 125 continue 126 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 127 self._raw_columns.append(node) 128 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 129 self._tables.append(node) 130 elif isinstance(node, exp.JoinHint): 131 self._join_hints.append(node) 132 elif isinstance(node, exp.UDTF): 133 self._udtfs.append(node) 134 elif isinstance(node, exp.CTE): 135 self._ctes.append(node) 136 elif ( 137 isinstance(node, exp.Subquery) 138 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 139 and _is_derived_table(node) 140 ): 141 self._derived_tables.append(node) 142 elif isinstance(node, exp.Subqueryable): 143 self._subqueries.append(node) 144 145 self._collected = True 146 147 def _ensure_collected(self): 148 if not self._collected: 149 self._collect() 150 151 def walk(self, bfs=True, prune=None): 152 return walk_in_scope(self.expression, bfs=bfs, prune=None) 153 154 def find(self, *expression_types, bfs=True): 155 return find_in_scope(self.expression, expression_types, bfs=bfs) 156 157 def find_all(self, *expression_types, bfs=True): 158 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 159 160 def replace(self, old, new): 161 """ 162 Replace `old` with `new`. 163 164 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 165 166 Args: 167 old (exp.Expression): old node 168 new (exp.Expression): new node 169 """ 170 old.replace(new) 171 self.clear_cache() 172 173 @property 174 def tables(self): 175 """ 176 List of tables in this scope. 177 178 Returns: 179 list[exp.Table]: tables 180 """ 181 self._ensure_collected() 182 return self._tables 183 184 @property 185 def ctes(self): 186 """ 187 List of CTEs in this scope. 188 189 Returns: 190 list[exp.CTE]: ctes 191 """ 192 self._ensure_collected() 193 return self._ctes 194 195 @property 196 def derived_tables(self): 197 """ 198 List of derived tables in this scope. 199 200 For example: 201 SELECT * FROM (SELECT ...) <- that's a derived table 202 203 Returns: 204 list[exp.Subquery]: derived tables 205 """ 206 self._ensure_collected() 207 return self._derived_tables 208 209 @property 210 def udtfs(self): 211 """ 212 List of "User Defined Tabular Functions" in this scope. 213 214 Returns: 215 list[exp.UDTF]: UDTFs 216 """ 217 self._ensure_collected() 218 return self._udtfs 219 220 @property 221 def subqueries(self): 222 """ 223 List of subqueries in this scope. 224 225 For example: 226 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 227 228 Returns: 229 list[exp.Subqueryable]: subqueries 230 """ 231 self._ensure_collected() 232 return self._subqueries 233 234 @property 235 def columns(self): 236 """ 237 List of columns in this scope. 238 239 Returns: 240 list[exp.Column]: Column instances in this scope, plus any 241 Columns that reference this scope from correlated subqueries. 242 """ 243 if self._columns is None: 244 self._ensure_collected() 245 columns = self._raw_columns 246 247 external_columns = [ 248 column 249 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 250 for column in scope.external_columns 251 ] 252 253 named_selects = set(self.expression.named_selects) 254 255 self._columns = [] 256 for column in columns + external_columns: 257 ancestor = column.find_ancestor( 258 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 259 ) 260 if ( 261 not ancestor 262 or column.table 263 or isinstance(ancestor, exp.Select) 264 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 265 or ( 266 isinstance(ancestor, exp.Order) 267 and ( 268 isinstance(ancestor.parent, exp.Window) 269 or column.name not in named_selects 270 ) 271 ) 272 ): 273 self._columns.append(column) 274 275 return self._columns 276 277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 result = {} 290 291 for name, node in self.references: 292 if name in result: 293 raise OptimizeError(f"Alias already used: {name}") 294 if name in self.sources: 295 result[name] = (node, self.sources[name]) 296 297 self._selected_sources = result 298 return self._selected_sources 299 300 @property 301 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 302 if self._references is None: 303 self._references = [] 304 305 for table in self.tables: 306 self._references.append((table.alias_or_name, table)) 307 for expression in itertools.chain(self.derived_tables, self.udtfs): 308 self._references.append( 309 ( 310 expression.alias, 311 expression if expression.args.get("pivots") else expression.unnest(), 312 ) 313 ) 314 315 return self._references 316 317 @property 318 def external_columns(self): 319 """ 320 Columns that appear to reference sources in outer scopes. 321 322 Returns: 323 list[exp.Column]: Column instances that don't reference 324 sources in the current scope. 325 """ 326 if self._external_columns is None: 327 self._external_columns = [ 328 c for c in self.columns if c.table not in self.selected_sources 329 ] 330 return self._external_columns 331 332 @property 333 def unqualified_columns(self): 334 """ 335 Unqualified columns in the current scope. 336 337 Returns: 338 list[exp.Column]: Unqualified columns 339 """ 340 return [c for c in self.columns if not c.table] 341 342 @property 343 def join_hints(self): 344 """ 345 Hints that exist in the scope that reference tables 346 347 Returns: 348 list[exp.JoinHint]: Join hints that are referenced within the scope 349 """ 350 if self._join_hints is None: 351 return [] 352 return self._join_hints 353 354 @property 355 def pivots(self): 356 if not self._pivots: 357 self._pivots = [ 358 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 359 ] 360 361 return self._pivots 362 363 def source_columns(self, source_name): 364 """ 365 Get all columns in the current scope for a particular source. 366 367 Args: 368 source_name (str): Name of the source 369 Returns: 370 list[exp.Column]: Column instances that reference `source_name` 371 """ 372 return [column for column in self.columns if column.table == source_name] 373 374 @property 375 def is_subquery(self): 376 """Determine if this scope is a subquery""" 377 return self.scope_type == ScopeType.SUBQUERY 378 379 @property 380 def is_derived_table(self): 381 """Determine if this scope is a derived table""" 382 return self.scope_type == ScopeType.DERIVED_TABLE 383 384 @property 385 def is_union(self): 386 """Determine if this scope is a union""" 387 return self.scope_type == ScopeType.UNION 388 389 @property 390 def is_cte(self): 391 """Determine if this scope is a common table expression""" 392 return self.scope_type == ScopeType.CTE 393 394 @property 395 def is_root(self): 396 """Determine if this is the root scope""" 397 return self.scope_type == ScopeType.ROOT 398 399 @property 400 def is_udtf(self): 401 """Determine if this scope is a UDTF (User Defined Table Function)""" 402 return self.scope_type == ScopeType.UDTF 403 404 @property 405 def is_correlated_subquery(self): 406 """Determine if this scope is a correlated subquery""" 407 return bool( 408 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 409 and self.external_columns 410 ) 411 412 def rename_source(self, old_name, new_name): 413 """Rename a source in this scope""" 414 columns = self.sources.pop(old_name or "", []) 415 self.sources[new_name] = columns 416 417 def add_source(self, name, source): 418 """Add a source to this scope""" 419 self.sources[name] = source 420 self.clear_cache() 421 422 def remove_source(self, name): 423 """Remove a source from this scope""" 424 self.sources.pop(name, None) 425 self.clear_cache() 426 427 def __repr__(self): 428 return f"Scope<{self.expression.sql()}>" 429 430 def traverse(self): 431 """ 432 Traverse the scope tree from this node. 433 434 Yields: 435 Scope: scope instances in depth-first-search post-order 436 """ 437 for child_scope in itertools.chain( 438 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 439 ): 440 yield from child_scope.traverse() 441 yield self 442 443 def ref_count(self): 444 """ 445 Count the number of times each scope in this tree is referenced. 446 447 Returns: 448 dict[int, int]: Mapping of Scope instance ID to reference count 449 """ 450 scope_ref_count = defaultdict(lambda: 0) 451 452 for scope in self.traverse(): 453 for _, source in scope.selected_sources.values(): 454 scope_ref_count[id(source)] += 1 455 456 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache()
85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None
100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 )
Branch from the current scope to a new, inner scope
160 def replace(self, old, new): 161 """ 162 Replace `old` with `new`. 163 164 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 165 166 Args: 167 old (exp.Expression): old node 168 new (exp.Expression): new node 169 """ 170 old.replace(new) 171 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
173 @property 174 def tables(self): 175 """ 176 List of tables in this scope. 177 178 Returns: 179 list[exp.Table]: tables 180 """ 181 self._ensure_collected() 182 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
184 @property 185 def ctes(self): 186 """ 187 List of CTEs in this scope. 188 189 Returns: 190 list[exp.CTE]: ctes 191 """ 192 self._ensure_collected() 193 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
195 @property 196 def derived_tables(self): 197 """ 198 List of derived tables in this scope. 199 200 For example: 201 SELECT * FROM (SELECT ...) <- that's a derived table 202 203 Returns: 204 list[exp.Subquery]: derived tables 205 """ 206 self._ensure_collected() 207 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
209 @property 210 def udtfs(self): 211 """ 212 List of "User Defined Tabular Functions" in this scope. 213 214 Returns: 215 list[exp.UDTF]: UDTFs 216 """ 217 self._ensure_collected() 218 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
220 @property 221 def subqueries(self): 222 """ 223 List of subqueries in this scope. 224 225 For example: 226 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 227 228 Returns: 229 list[exp.Subqueryable]: subqueries 230 """ 231 self._ensure_collected() 232 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
234 @property 235 def columns(self): 236 """ 237 List of columns in this scope. 238 239 Returns: 240 list[exp.Column]: Column instances in this scope, plus any 241 Columns that reference this scope from correlated subqueries. 242 """ 243 if self._columns is None: 244 self._ensure_collected() 245 columns = self._raw_columns 246 247 external_columns = [ 248 column 249 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 250 for column in scope.external_columns 251 ] 252 253 named_selects = set(self.expression.named_selects) 254 255 self._columns = [] 256 for column in columns + external_columns: 257 ancestor = column.find_ancestor( 258 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 259 ) 260 if ( 261 not ancestor 262 or column.table 263 or isinstance(ancestor, exp.Select) 264 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 265 or ( 266 isinstance(ancestor, exp.Order) 267 and ( 268 isinstance(ancestor.parent, exp.Window) 269 or column.name not in named_selects 270 ) 271 ) 272 ): 273 self._columns.append(column) 274 275 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 result = {} 290 291 for name, node in self.references: 292 if name in result: 293 raise OptimizeError(f"Alias already used: {name}") 294 if name in self.sources: 295 result[name] = (node, self.sources[name]) 296 297 self._selected_sources = result 298 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
300 @property 301 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 302 if self._references is None: 303 self._references = [] 304 305 for table in self.tables: 306 self._references.append((table.alias_or_name, table)) 307 for expression in itertools.chain(self.derived_tables, self.udtfs): 308 self._references.append( 309 ( 310 expression.alias, 311 expression if expression.args.get("pivots") else expression.unnest(), 312 ) 313 ) 314 315 return self._references
317 @property 318 def external_columns(self): 319 """ 320 Columns that appear to reference sources in outer scopes. 321 322 Returns: 323 list[exp.Column]: Column instances that don't reference 324 sources in the current scope. 325 """ 326 if self._external_columns is None: 327 self._external_columns = [ 328 c for c in self.columns if c.table not in self.selected_sources 329 ] 330 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
332 @property 333 def unqualified_columns(self): 334 """ 335 Unqualified columns in the current scope. 336 337 Returns: 338 list[exp.Column]: Unqualified columns 339 """ 340 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
342 @property 343 def join_hints(self): 344 """ 345 Hints that exist in the scope that reference tables 346 347 Returns: 348 list[exp.JoinHint]: Join hints that are referenced within the scope 349 """ 350 if self._join_hints is None: 351 return [] 352 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
363 def source_columns(self, source_name): 364 """ 365 Get all columns in the current scope for a particular source. 366 367 Args: 368 source_name (str): Name of the source 369 Returns: 370 list[exp.Column]: Column instances that reference `source_name` 371 """ 372 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
374 @property 375 def is_subquery(self): 376 """Determine if this scope is a subquery""" 377 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
379 @property 380 def is_derived_table(self): 381 """Determine if this scope is a derived table""" 382 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
384 @property 385 def is_union(self): 386 """Determine if this scope is a union""" 387 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
389 @property 390 def is_cte(self): 391 """Determine if this scope is a common table expression""" 392 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
394 @property 395 def is_root(self): 396 """Determine if this is the root scope""" 397 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
399 @property 400 def is_udtf(self): 401 """Determine if this scope is a UDTF (User Defined Table Function)""" 402 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
412 def rename_source(self, old_name, new_name): 413 """Rename a source in this scope""" 414 columns = self.sources.pop(old_name or "", []) 415 self.sources[new_name] = columns
Rename a source in this scope
417 def add_source(self, name, source): 418 """Add a source to this scope""" 419 self.sources[name] = source 420 self.clear_cache()
Add a source to this scope
422 def remove_source(self, name): 423 """Remove a source from this scope""" 424 self.sources.pop(name, None) 425 self.clear_cache()
Remove a source from this scope
430 def traverse(self): 431 """ 432 Traverse the scope tree from this node. 433 434 Yields: 435 Scope: scope instances in depth-first-search post-order 436 """ 437 for child_scope in itertools.chain( 438 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 439 ): 440 yield from child_scope.traverse() 441 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
443 def ref_count(self): 444 """ 445 Count the number of times each scope in this tree is referenced. 446 447 Returns: 448 dict[int, int]: Mapping of Scope instance ID to reference count 449 """ 450 scope_ref_count = defaultdict(lambda: 0) 451 452 for scope in self.traverse(): 453 for _, source in scope.selected_sources.values(): 454 scope_ref_count[id(source)] += 1 455 456 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
459def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 460 """ 461 Traverse an expression by its "scopes". 462 463 "Scope" represents the current context of a Select statement. 464 465 This is helpful for optimizing queries, where we need more information than 466 the expression tree itself. For example, we might care about the source 467 names within a subquery. Returns a list because a generator could result in 468 incomplete properties which is confusing. 469 470 Examples: 471 >>> import sqlglot 472 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 473 >>> scopes = traverse_scope(expression) 474 >>> scopes[0].expression.sql(), list(scopes[0].sources) 475 ('SELECT a FROM x', ['x']) 476 >>> scopes[1].expression.sql(), list(scopes[1].sources) 477 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 478 479 Args: 480 expression (exp.Expression): expression to traverse 481 Returns: 482 list[Scope]: scope instances 483 """ 484 if isinstance(expression, exp.Unionable) or ( 485 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) 486 ): 487 return list(_traverse_scope(Scope(expression))) 488 489 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
492def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 493 """ 494 Build a scope tree. 495 496 Args: 497 expression (exp.Expression): expression to build the scope tree for 498 Returns: 499 Scope: root scope 500 """ 501 scopes = traverse_scope(expression) 502 if scopes: 503 return scopes[-1] 504 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
748def walk_in_scope(expression, bfs=True, prune=None): 749 """ 750 Returns a generator object which visits all nodes in the syntrax tree, stopping at 751 nodes that start child scopes. 752 753 Args: 754 expression (exp.Expression): 755 bfs (bool): if set to True the BFS traversal order will be applied, 756 otherwise the DFS traversal will be used instead. 757 prune ((node, parent, arg_key) -> bool): callable that returns True if 758 the generator should stop traversing this branch of the tree. 759 760 Yields: 761 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 762 """ 763 # We'll use this variable to pass state into the dfs generator. 764 # Whenever we set it to True, we exclude a subtree from traversal. 765 crossed_scope_boundary = False 766 767 for node, parent, key in expression.walk( 768 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 769 ): 770 crossed_scope_boundary = False 771 772 yield node, parent, key 773 774 if node is expression: 775 continue 776 if ( 777 isinstance(node, exp.CTE) 778 or ( 779 isinstance(node, exp.Subquery) 780 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 781 and _is_derived_table(node) 782 ) 783 or isinstance(node, exp.UDTF) 784 or isinstance(node, exp.Subqueryable) 785 ): 786 crossed_scope_boundary = True 787 788 if isinstance(node, (exp.Subquery, exp.UDTF)): 789 # The following args are not actually in the inner scope, so we should visit them 790 for key in ("joins", "laterals", "pivots"): 791 for arg in node.args.get(key) or []: 792 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
795def find_all_in_scope(expression, expression_types, bfs=True): 796 """ 797 Returns a generator object which visits all nodes in this scope and only yields those that 798 match at least one of the specified expression types. 799 800 This does NOT traverse into subscopes. 801 802 Args: 803 expression (exp.Expression): 804 expression_types (tuple[type]|type): the expression type(s) to match. 805 bfs (bool): True to use breadth-first search, False to use depth-first. 806 807 Yields: 808 exp.Expression: nodes 809 """ 810 for expression, *_ in walk_in_scope(expression, bfs=bfs): 811 if isinstance(expression, tuple(ensure_collection(expression_types))): 812 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
815def find_in_scope(expression, expression_types, bfs=True): 816 """ 817 Returns the first node in this scope which matches at least one of the specified types. 818 819 This does NOT traverse into subscopes. 820 821 Args: 822 expression (exp.Expression): 823 expression_types (tuple[type]|type): the expression type(s) to match. 824 bfs (bool): True to use breadth-first search, False to use depth-first. 825 826 Returns: 827 exp.Expression: the node which matches the criteria or None if no node matching 828 the criteria was found. 829 """ 830 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.