sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name 12 13logger = logging.getLogger("sqlglot") 14 15 16class ScopeType(Enum): 17 ROOT = auto() 18 SUBQUERY = auto() 19 DERIVED_TABLE = auto() 20 CTE = auto() 21 UNION = auto() 22 UDTF = auto() 23 24 25class Scope: 26 """ 27 Selection scope. 28 29 Attributes: 30 expression (exp.Select|exp.Union): Root expression of this scope 31 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 32 a Table expression or another Scope instance. For example: 33 SELECT * FROM x {"x": Table(this="x")} 34 SELECT * FROM x AS y {"y": Table(this="x")} 35 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 36 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 37 For example: 38 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 39 The LATERAL VIEW EXPLODE gets x as a source. 40 cte_sources (dict[str, Scope]): Sources from CTES 41 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 42 defines a column list of it's alias of this scope, this is that list of columns. 43 For example: 44 SELECT * FROM (SELECT ...) AS y(col1, col2) 45 The inner query would have `["col1", "col2"]` for its `outer_column_list` 46 parent (Scope): Parent scope 47 scope_type (ScopeType): Type of this scope, relative to it's parent 48 subquery_scopes (list[Scope]): List of all child scopes for subqueries 49 cte_scopes (list[Scope]): List of all child scopes for CTEs 50 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 51 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 52 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 53 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 54 a list of the left and right child scopes. 55 """ 56 57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_column_list=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 cte_sources=None, 66 ): 67 self.expression = expression 68 self.sources = sources or {} 69 self.lateral_sources = lateral_sources or {} 70 self.cte_sources = cte_sources or {} 71 self.sources.update(self.lateral_sources) 72 self.sources.update(self.cte_sources) 73 self.outer_column_list = outer_column_list or [] 74 self.parent = parent 75 self.scope_type = scope_type 76 self.subquery_scopes = [] 77 self.derived_table_scopes = [] 78 self.table_scopes = [] 79 self.cte_scopes = [] 80 self.union_scopes = [] 81 self.udtf_scopes = [] 82 self.clear_cache() 83 84 def clear_cache(self): 85 self._collected = False 86 self._raw_columns = None 87 self._derived_tables = None 88 self._udtfs = None 89 self._tables = None 90 self._ctes = None 91 self._subqueries = None 92 self._selected_sources = None 93 self._columns = None 94 self._external_columns = None 95 self._join_hints = None 96 self._pivots = None 97 self._references = None 98 99 def branch( 100 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 101 ): 102 """Branch from the current scope to a new, inner scope""" 103 return Scope( 104 expression=expression.unnest(), 105 sources=sources.copy() if sources else None, 106 parent=self, 107 scope_type=scope_type, 108 cte_sources={**self.cte_sources, **(cte_sources or {})}, 109 lateral_sources=lateral_sources.copy() if lateral_sources else None, 110 **kwargs, 111 ) 112 113 def _collect(self): 114 self._tables = [] 115 self._ctes = [] 116 self._subqueries = [] 117 self._derived_tables = [] 118 self._udtfs = [] 119 self._raw_columns = [] 120 self._join_hints = [] 121 122 for node, parent, _ in self.walk(bfs=False): 123 if node is self.expression: 124 continue 125 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 126 self._raw_columns.append(node) 127 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 128 self._tables.append(node) 129 elif isinstance(node, exp.JoinHint): 130 self._join_hints.append(node) 131 elif isinstance(node, exp.UDTF): 132 self._udtfs.append(node) 133 elif isinstance(node, exp.CTE): 134 self._ctes.append(node) 135 elif ( 136 isinstance(node, exp.Subquery) 137 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 138 and _is_derived_table(node) 139 ): 140 self._derived_tables.append(node) 141 elif isinstance(node, exp.Subqueryable): 142 self._subqueries.append(node) 143 144 self._collected = True 145 146 def _ensure_collected(self): 147 if not self._collected: 148 self._collect() 149 150 def walk(self, bfs=True, prune=None): 151 return walk_in_scope(self.expression, bfs=bfs, prune=None) 152 153 def find(self, *expression_types, bfs=True): 154 return find_in_scope(self.expression, expression_types, bfs=bfs) 155 156 def find_all(self, *expression_types, bfs=True): 157 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 158 159 def replace(self, old, new): 160 """ 161 Replace `old` with `new`. 162 163 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 164 165 Args: 166 old (exp.Expression): old node 167 new (exp.Expression): new node 168 """ 169 old.replace(new) 170 self.clear_cache() 171 172 @property 173 def tables(self): 174 """ 175 List of tables in this scope. 176 177 Returns: 178 list[exp.Table]: tables 179 """ 180 self._ensure_collected() 181 return self._tables 182 183 @property 184 def ctes(self): 185 """ 186 List of CTEs in this scope. 187 188 Returns: 189 list[exp.CTE]: ctes 190 """ 191 self._ensure_collected() 192 return self._ctes 193 194 @property 195 def derived_tables(self): 196 """ 197 List of derived tables in this scope. 198 199 For example: 200 SELECT * FROM (SELECT ...) <- that's a derived table 201 202 Returns: 203 list[exp.Subquery]: derived tables 204 """ 205 self._ensure_collected() 206 return self._derived_tables 207 208 @property 209 def udtfs(self): 210 """ 211 List of "User Defined Tabular Functions" in this scope. 212 213 Returns: 214 list[exp.UDTF]: UDTFs 215 """ 216 self._ensure_collected() 217 return self._udtfs 218 219 @property 220 def subqueries(self): 221 """ 222 List of subqueries in this scope. 223 224 For example: 225 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 226 227 Returns: 228 list[exp.Subqueryable]: subqueries 229 """ 230 self._ensure_collected() 231 return self._subqueries 232 233 @property 234 def columns(self): 235 """ 236 List of columns in this scope. 237 238 Returns: 239 list[exp.Column]: Column instances in this scope, plus any 240 Columns that reference this scope from correlated subqueries. 241 """ 242 if self._columns is None: 243 self._ensure_collected() 244 columns = self._raw_columns 245 246 external_columns = [ 247 column 248 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 249 for column in scope.external_columns 250 ] 251 252 named_selects = set(self.expression.named_selects) 253 254 self._columns = [] 255 for column in columns + external_columns: 256 ancestor = column.find_ancestor( 257 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 258 ) 259 if ( 260 not ancestor 261 or column.table 262 or isinstance(ancestor, exp.Select) 263 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 264 or ( 265 isinstance(ancestor, exp.Order) 266 and ( 267 isinstance(ancestor.parent, exp.Window) 268 or column.name not in named_selects 269 ) 270 ) 271 ): 272 self._columns.append(column) 273 274 return self._columns 275 276 @property 277 def selected_sources(self): 278 """ 279 Mapping of nodes and sources that are actually selected from in this scope. 280 281 That is, all tables in a schema are selectable at any point. But a 282 table only becomes a selected source if it's included in a FROM or JOIN clause. 283 284 Returns: 285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 286 """ 287 if self._selected_sources is None: 288 result = {} 289 290 for name, node in self.references: 291 if name in result: 292 raise OptimizeError(f"Alias already used: {name}") 293 if name in self.sources: 294 result[name] = (node, self.sources[name]) 295 296 self._selected_sources = result 297 return self._selected_sources 298 299 @property 300 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 301 if self._references is None: 302 self._references = [] 303 304 for table in self.tables: 305 self._references.append((table.alias_or_name, table)) 306 for expression in itertools.chain(self.derived_tables, self.udtfs): 307 self._references.append( 308 ( 309 expression.alias, 310 expression if expression.args.get("pivots") else expression.unnest(), 311 ) 312 ) 313 314 return self._references 315 316 @property 317 def external_columns(self): 318 """ 319 Columns that appear to reference sources in outer scopes. 320 321 Returns: 322 list[exp.Column]: Column instances that don't reference 323 sources in the current scope. 324 """ 325 if self._external_columns is None: 326 if isinstance(self.expression, exp.Union): 327 left, right = self.union_scopes 328 self._external_columns = left.external_columns + right.external_columns 329 else: 330 self._external_columns = [ 331 c for c in self.columns if c.table not in self.selected_sources 332 ] 333 334 return self._external_columns 335 336 @property 337 def unqualified_columns(self): 338 """ 339 Unqualified columns in the current scope. 340 341 Returns: 342 list[exp.Column]: Unqualified columns 343 """ 344 return [c for c in self.columns if not c.table] 345 346 @property 347 def join_hints(self): 348 """ 349 Hints that exist in the scope that reference tables 350 351 Returns: 352 list[exp.JoinHint]: Join hints that are referenced within the scope 353 """ 354 if self._join_hints is None: 355 return [] 356 return self._join_hints 357 358 @property 359 def pivots(self): 360 if not self._pivots: 361 self._pivots = [ 362 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 363 ] 364 365 return self._pivots 366 367 def source_columns(self, source_name): 368 """ 369 Get all columns in the current scope for a particular source. 370 371 Args: 372 source_name (str): Name of the source 373 Returns: 374 list[exp.Column]: Column instances that reference `source_name` 375 """ 376 return [column for column in self.columns if column.table == source_name] 377 378 @property 379 def is_subquery(self): 380 """Determine if this scope is a subquery""" 381 return self.scope_type == ScopeType.SUBQUERY 382 383 @property 384 def is_derived_table(self): 385 """Determine if this scope is a derived table""" 386 return self.scope_type == ScopeType.DERIVED_TABLE 387 388 @property 389 def is_union(self): 390 """Determine if this scope is a union""" 391 return self.scope_type == ScopeType.UNION 392 393 @property 394 def is_cte(self): 395 """Determine if this scope is a common table expression""" 396 return self.scope_type == ScopeType.CTE 397 398 @property 399 def is_root(self): 400 """Determine if this is the root scope""" 401 return self.scope_type == ScopeType.ROOT 402 403 @property 404 def is_udtf(self): 405 """Determine if this scope is a UDTF (User Defined Table Function)""" 406 return self.scope_type == ScopeType.UDTF 407 408 @property 409 def is_correlated_subquery(self): 410 """Determine if this scope is a correlated subquery""" 411 return bool( 412 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 413 and self.external_columns 414 ) 415 416 def rename_source(self, old_name, new_name): 417 """Rename a source in this scope""" 418 columns = self.sources.pop(old_name or "", []) 419 self.sources[new_name] = columns 420 421 def add_source(self, name, source): 422 """Add a source to this scope""" 423 self.sources[name] = source 424 self.clear_cache() 425 426 def remove_source(self, name): 427 """Remove a source from this scope""" 428 self.sources.pop(name, None) 429 self.clear_cache() 430 431 def __repr__(self): 432 return f"Scope<{self.expression.sql()}>" 433 434 def traverse(self): 435 """ 436 Traverse the scope tree from this node. 437 438 Yields: 439 Scope: scope instances in depth-first-search post-order 440 """ 441 for child_scope in itertools.chain( 442 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 443 ): 444 yield from child_scope.traverse() 445 yield self 446 447 def ref_count(self): 448 """ 449 Count the number of times each scope in this tree is referenced. 450 451 Returns: 452 dict[int, int]: Mapping of Scope instance ID to reference count 453 """ 454 scope_ref_count = defaultdict(lambda: 0) 455 456 for scope in self.traverse(): 457 for _, source in scope.selected_sources.values(): 458 scope_ref_count[id(source)] += 1 459 460 return scope_ref_count 461 462 463def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 464 """ 465 Traverse an expression by its "scopes". 466 467 "Scope" represents the current context of a Select statement. 468 469 This is helpful for optimizing queries, where we need more information than 470 the expression tree itself. For example, we might care about the source 471 names within a subquery. Returns a list because a generator could result in 472 incomplete properties which is confusing. 473 474 Examples: 475 >>> import sqlglot 476 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 477 >>> scopes = traverse_scope(expression) 478 >>> scopes[0].expression.sql(), list(scopes[0].sources) 479 ('SELECT a FROM x', ['x']) 480 >>> scopes[1].expression.sql(), list(scopes[1].sources) 481 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 482 483 Args: 484 expression (exp.Expression): expression to traverse 485 486 Returns: 487 list[Scope]: scope instances 488 """ 489 if isinstance(expression, exp.Unionable) or ( 490 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable) 491 ): 492 return list(_traverse_scope(Scope(expression))) 493 494 return [] 495 496 497def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 498 """ 499 Build a scope tree. 500 501 Args: 502 expression (exp.Expression): expression to build the scope tree for 503 Returns: 504 Scope: root scope 505 """ 506 scopes = traverse_scope(expression) 507 if scopes: 508 return scopes[-1] 509 return None 510 511 512def _traverse_scope(scope): 513 if isinstance(scope.expression, exp.Select): 514 yield from _traverse_select(scope) 515 elif isinstance(scope.expression, exp.Union): 516 yield from _traverse_union(scope) 517 elif isinstance(scope.expression, exp.Subquery): 518 if scope.is_root: 519 yield from _traverse_select(scope) 520 else: 521 yield from _traverse_subqueries(scope) 522 elif isinstance(scope.expression, exp.Table): 523 yield from _traverse_tables(scope) 524 elif isinstance(scope.expression, exp.UDTF): 525 yield from _traverse_udtfs(scope) 526 elif isinstance(scope.expression, exp.DDL): 527 yield from _traverse_ddl(scope) 528 else: 529 logger.warning( 530 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 531 ) 532 return 533 534 yield scope 535 536 537def _traverse_select(scope): 538 yield from _traverse_ctes(scope) 539 yield from _traverse_tables(scope) 540 yield from _traverse_subqueries(scope) 541 542 543def _traverse_union(scope): 544 yield from _traverse_ctes(scope) 545 546 # The last scope to be yield should be the top most scope 547 left = None 548 for left in _traverse_scope( 549 scope.branch( 550 scope.expression.left, 551 outer_column_list=scope.outer_column_list, 552 scope_type=ScopeType.UNION, 553 ) 554 ): 555 yield left 556 557 right = None 558 for right in _traverse_scope( 559 scope.branch( 560 scope.expression.right, 561 outer_column_list=scope.outer_column_list, 562 scope_type=ScopeType.UNION, 563 ) 564 ): 565 yield right 566 567 scope.union_scopes = [left, right] 568 569 570def _traverse_ctes(scope): 571 sources = {} 572 573 for cte in scope.ctes: 574 recursive_scope = None 575 576 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 577 # thus the recursive scope is the first section of the union. 578 with_ = scope.expression.args.get("with") 579 if with_ and with_.recursive: 580 union = cte.this 581 582 if isinstance(union, exp.Union): 583 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 584 585 child_scope = None 586 587 for child_scope in _traverse_scope( 588 scope.branch( 589 cte.this, 590 cte_sources=sources, 591 outer_column_list=cte.alias_column_names, 592 scope_type=ScopeType.CTE, 593 ) 594 ): 595 yield child_scope 596 597 alias = cte.alias 598 sources[alias] = child_scope 599 600 if recursive_scope: 601 child_scope.add_source(alias, recursive_scope) 602 child_scope.cte_sources[alias] = recursive_scope 603 604 # append the final child_scope yielded 605 if child_scope: 606 scope.cte_scopes.append(child_scope) 607 608 scope.sources.update(sources) 609 scope.cte_sources.update(sources) 610 611 612def _is_derived_table(expression: exp.Subquery) -> bool: 613 """ 614 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 615 as it doesn't introduce a new scope. If an alias is present, it shadows all names 616 under the Subquery, so that's one exception to this rule. 617 """ 618 return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) 619 620 621def _traverse_tables(scope): 622 sources = {} 623 624 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 625 expressions = [] 626 from_ = scope.expression.args.get("from") 627 if from_: 628 expressions.append(from_.this) 629 630 for join in scope.expression.args.get("joins") or []: 631 expressions.append(join.this) 632 633 if isinstance(scope.expression, exp.Table): 634 expressions.append(scope.expression) 635 636 expressions.extend(scope.expression.args.get("laterals") or []) 637 638 for expression in expressions: 639 if isinstance(expression, exp.Table): 640 table_name = expression.name 641 source_name = expression.alias_or_name 642 643 if table_name in scope.sources and not expression.db: 644 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 645 # it is pivoted, because then we get back a new table and hence a new source. 646 pivots = expression.args.get("pivots") 647 if pivots: 648 sources[pivots[0].alias] = expression 649 else: 650 sources[source_name] = scope.sources[table_name] 651 elif source_name in sources: 652 sources[find_new_name(sources, table_name)] = expression 653 else: 654 sources[source_name] = expression 655 656 # Make sure to not include the joins twice 657 if expression is not scope.expression: 658 expressions.extend(join.this for join in expression.args.get("joins") or []) 659 660 continue 661 662 if not isinstance(expression, exp.DerivedTable): 663 continue 664 665 if isinstance(expression, exp.UDTF): 666 lateral_sources = sources 667 scope_type = ScopeType.UDTF 668 scopes = scope.udtf_scopes 669 elif _is_derived_table(expression): 670 lateral_sources = None 671 scope_type = ScopeType.DERIVED_TABLE 672 scopes = scope.derived_table_scopes 673 expressions.extend(join.this for join in expression.args.get("joins") or []) 674 else: 675 # Makes sure we check for possible sources in nested table constructs 676 expressions.append(expression.this) 677 expressions.extend(join.this for join in expression.args.get("joins") or []) 678 continue 679 680 for child_scope in _traverse_scope( 681 scope.branch( 682 expression, 683 lateral_sources=lateral_sources, 684 outer_column_list=expression.alias_column_names, 685 scope_type=scope_type, 686 ) 687 ): 688 yield child_scope 689 690 # Tables without aliases will be set as "" 691 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 692 # Until then, this means that only a single, unaliased derived table is allowed (rather, 693 # the latest one wins. 694 sources[expression.alias] = child_scope 695 696 # append the final child_scope yielded 697 scopes.append(child_scope) 698 scope.table_scopes.append(child_scope) 699 700 scope.sources.update(sources) 701 702 703def _traverse_subqueries(scope): 704 for subquery in scope.subqueries: 705 top = None 706 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 707 yield child_scope 708 top = child_scope 709 scope.subquery_scopes.append(top) 710 711 712def _traverse_udtfs(scope): 713 if isinstance(scope.expression, exp.Unnest): 714 expressions = scope.expression.expressions 715 elif isinstance(scope.expression, exp.Lateral): 716 expressions = [scope.expression.this] 717 else: 718 expressions = [] 719 720 sources = {} 721 for expression in expressions: 722 if isinstance(expression, exp.Subquery) and _is_derived_table(expression): 723 top = None 724 for child_scope in _traverse_scope( 725 scope.branch( 726 expression, 727 scope_type=ScopeType.DERIVED_TABLE, 728 outer_column_list=expression.alias_column_names, 729 ) 730 ): 731 yield child_scope 732 top = child_scope 733 sources[expression.alias] = child_scope 734 735 scope.derived_table_scopes.append(top) 736 scope.table_scopes.append(top) 737 738 scope.sources.update(sources) 739 740 741def _traverse_ddl(scope): 742 yield from _traverse_ctes(scope) 743 744 query_scope = scope.branch( 745 scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources 746 ) 747 query_scope._collect() 748 query_scope._ctes = scope.ctes + query_scope._ctes 749 750 yield from _traverse_scope(query_scope) 751 752 753def walk_in_scope(expression, bfs=True, prune=None): 754 """ 755 Returns a generator object which visits all nodes in the syntrax tree, stopping at 756 nodes that start child scopes. 757 758 Args: 759 expression (exp.Expression): 760 bfs (bool): if set to True the BFS traversal order will be applied, 761 otherwise the DFS traversal will be used instead. 762 prune ((node, parent, arg_key) -> bool): callable that returns True if 763 the generator should stop traversing this branch of the tree. 764 765 Yields: 766 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 767 """ 768 # We'll use this variable to pass state into the dfs generator. 769 # Whenever we set it to True, we exclude a subtree from traversal. 770 crossed_scope_boundary = False 771 772 for node, parent, key in expression.walk( 773 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 774 ): 775 crossed_scope_boundary = False 776 777 yield node, parent, key 778 779 if node is expression: 780 continue 781 if ( 782 isinstance(node, exp.CTE) 783 or ( 784 isinstance(node, exp.Subquery) 785 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 786 and _is_derived_table(node) 787 ) 788 or isinstance(node, exp.UDTF) 789 or isinstance(node, exp.Subqueryable) 790 ): 791 crossed_scope_boundary = True 792 793 if isinstance(node, (exp.Subquery, exp.UDTF)): 794 # The following args are not actually in the inner scope, so we should visit them 795 for key in ("joins", "laterals", "pivots"): 796 for arg in node.args.get(key) or []: 797 yield from walk_in_scope(arg, bfs=bfs) 798 799 800def find_all_in_scope(expression, expression_types, bfs=True): 801 """ 802 Returns a generator object which visits all nodes in this scope and only yields those that 803 match at least one of the specified expression types. 804 805 This does NOT traverse into subscopes. 806 807 Args: 808 expression (exp.Expression): 809 expression_types (tuple[type]|type): the expression type(s) to match. 810 bfs (bool): True to use breadth-first search, False to use depth-first. 811 812 Yields: 813 exp.Expression: nodes 814 """ 815 for expression, *_ in walk_in_scope(expression, bfs=bfs): 816 if isinstance(expression, tuple(ensure_collection(expression_types))): 817 yield expression 818 819 820def find_in_scope(expression, expression_types, bfs=True): 821 """ 822 Returns the first node in this scope which matches at least one of the specified types. 823 824 This does NOT traverse into subscopes. 825 826 Args: 827 expression (exp.Expression): 828 expression_types (tuple[type]|type): the expression type(s) to match. 829 bfs (bool): True to use breadth-first search, False to use depth-first. 830 831 Returns: 832 exp.Expression: the node which matches the criteria or None if no node matching 833 the criteria was found. 834 """ 835 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
17class ScopeType(Enum): 18 ROOT = auto() 19 SUBQUERY = auto() 20 DERIVED_TABLE = auto() 21 CTE = auto() 22 UNION = auto() 23 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
26class Scope: 27 """ 28 Selection scope. 29 30 Attributes: 31 expression (exp.Select|exp.Union): Root expression of this scope 32 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 33 a Table expression or another Scope instance. For example: 34 SELECT * FROM x {"x": Table(this="x")} 35 SELECT * FROM x AS y {"y": Table(this="x")} 36 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 37 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 38 For example: 39 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 40 The LATERAL VIEW EXPLODE gets x as a source. 41 cte_sources (dict[str, Scope]): Sources from CTES 42 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 43 defines a column list of it's alias of this scope, this is that list of columns. 44 For example: 45 SELECT * FROM (SELECT ...) AS y(col1, col2) 46 The inner query would have `["col1", "col2"]` for its `outer_column_list` 47 parent (Scope): Parent scope 48 scope_type (ScopeType): Type of this scope, relative to it's parent 49 subquery_scopes (list[Scope]): List of all child scopes for subqueries 50 cte_scopes (list[Scope]): List of all child scopes for CTEs 51 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 52 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 53 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 54 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 55 a list of the left and right child scopes. 56 """ 57 58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache() 84 85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None 99 100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 ) 113 114 def _collect(self): 115 self._tables = [] 116 self._ctes = [] 117 self._subqueries = [] 118 self._derived_tables = [] 119 self._udtfs = [] 120 self._raw_columns = [] 121 self._join_hints = [] 122 123 for node, parent, _ in self.walk(bfs=False): 124 if node is self.expression: 125 continue 126 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 127 self._raw_columns.append(node) 128 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 129 self._tables.append(node) 130 elif isinstance(node, exp.JoinHint): 131 self._join_hints.append(node) 132 elif isinstance(node, exp.UDTF): 133 self._udtfs.append(node) 134 elif isinstance(node, exp.CTE): 135 self._ctes.append(node) 136 elif ( 137 isinstance(node, exp.Subquery) 138 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 139 and _is_derived_table(node) 140 ): 141 self._derived_tables.append(node) 142 elif isinstance(node, exp.Subqueryable): 143 self._subqueries.append(node) 144 145 self._collected = True 146 147 def _ensure_collected(self): 148 if not self._collected: 149 self._collect() 150 151 def walk(self, bfs=True, prune=None): 152 return walk_in_scope(self.expression, bfs=bfs, prune=None) 153 154 def find(self, *expression_types, bfs=True): 155 return find_in_scope(self.expression, expression_types, bfs=bfs) 156 157 def find_all(self, *expression_types, bfs=True): 158 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 159 160 def replace(self, old, new): 161 """ 162 Replace `old` with `new`. 163 164 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 165 166 Args: 167 old (exp.Expression): old node 168 new (exp.Expression): new node 169 """ 170 old.replace(new) 171 self.clear_cache() 172 173 @property 174 def tables(self): 175 """ 176 List of tables in this scope. 177 178 Returns: 179 list[exp.Table]: tables 180 """ 181 self._ensure_collected() 182 return self._tables 183 184 @property 185 def ctes(self): 186 """ 187 List of CTEs in this scope. 188 189 Returns: 190 list[exp.CTE]: ctes 191 """ 192 self._ensure_collected() 193 return self._ctes 194 195 @property 196 def derived_tables(self): 197 """ 198 List of derived tables in this scope. 199 200 For example: 201 SELECT * FROM (SELECT ...) <- that's a derived table 202 203 Returns: 204 list[exp.Subquery]: derived tables 205 """ 206 self._ensure_collected() 207 return self._derived_tables 208 209 @property 210 def udtfs(self): 211 """ 212 List of "User Defined Tabular Functions" in this scope. 213 214 Returns: 215 list[exp.UDTF]: UDTFs 216 """ 217 self._ensure_collected() 218 return self._udtfs 219 220 @property 221 def subqueries(self): 222 """ 223 List of subqueries in this scope. 224 225 For example: 226 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 227 228 Returns: 229 list[exp.Subqueryable]: subqueries 230 """ 231 self._ensure_collected() 232 return self._subqueries 233 234 @property 235 def columns(self): 236 """ 237 List of columns in this scope. 238 239 Returns: 240 list[exp.Column]: Column instances in this scope, plus any 241 Columns that reference this scope from correlated subqueries. 242 """ 243 if self._columns is None: 244 self._ensure_collected() 245 columns = self._raw_columns 246 247 external_columns = [ 248 column 249 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 250 for column in scope.external_columns 251 ] 252 253 named_selects = set(self.expression.named_selects) 254 255 self._columns = [] 256 for column in columns + external_columns: 257 ancestor = column.find_ancestor( 258 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 259 ) 260 if ( 261 not ancestor 262 or column.table 263 or isinstance(ancestor, exp.Select) 264 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 265 or ( 266 isinstance(ancestor, exp.Order) 267 and ( 268 isinstance(ancestor.parent, exp.Window) 269 or column.name not in named_selects 270 ) 271 ) 272 ): 273 self._columns.append(column) 274 275 return self._columns 276 277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 result = {} 290 291 for name, node in self.references: 292 if name in result: 293 raise OptimizeError(f"Alias already used: {name}") 294 if name in self.sources: 295 result[name] = (node, self.sources[name]) 296 297 self._selected_sources = result 298 return self._selected_sources 299 300 @property 301 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 302 if self._references is None: 303 self._references = [] 304 305 for table in self.tables: 306 self._references.append((table.alias_or_name, table)) 307 for expression in itertools.chain(self.derived_tables, self.udtfs): 308 self._references.append( 309 ( 310 expression.alias, 311 expression if expression.args.get("pivots") else expression.unnest(), 312 ) 313 ) 314 315 return self._references 316 317 @property 318 def external_columns(self): 319 """ 320 Columns that appear to reference sources in outer scopes. 321 322 Returns: 323 list[exp.Column]: Column instances that don't reference 324 sources in the current scope. 325 """ 326 if self._external_columns is None: 327 if isinstance(self.expression, exp.Union): 328 left, right = self.union_scopes 329 self._external_columns = left.external_columns + right.external_columns 330 else: 331 self._external_columns = [ 332 c for c in self.columns if c.table not in self.selected_sources 333 ] 334 335 return self._external_columns 336 337 @property 338 def unqualified_columns(self): 339 """ 340 Unqualified columns in the current scope. 341 342 Returns: 343 list[exp.Column]: Unqualified columns 344 """ 345 return [c for c in self.columns if not c.table] 346 347 @property 348 def join_hints(self): 349 """ 350 Hints that exist in the scope that reference tables 351 352 Returns: 353 list[exp.JoinHint]: Join hints that are referenced within the scope 354 """ 355 if self._join_hints is None: 356 return [] 357 return self._join_hints 358 359 @property 360 def pivots(self): 361 if not self._pivots: 362 self._pivots = [ 363 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 364 ] 365 366 return self._pivots 367 368 def source_columns(self, source_name): 369 """ 370 Get all columns in the current scope for a particular source. 371 372 Args: 373 source_name (str): Name of the source 374 Returns: 375 list[exp.Column]: Column instances that reference `source_name` 376 """ 377 return [column for column in self.columns if column.table == source_name] 378 379 @property 380 def is_subquery(self): 381 """Determine if this scope is a subquery""" 382 return self.scope_type == ScopeType.SUBQUERY 383 384 @property 385 def is_derived_table(self): 386 """Determine if this scope is a derived table""" 387 return self.scope_type == ScopeType.DERIVED_TABLE 388 389 @property 390 def is_union(self): 391 """Determine if this scope is a union""" 392 return self.scope_type == ScopeType.UNION 393 394 @property 395 def is_cte(self): 396 """Determine if this scope is a common table expression""" 397 return self.scope_type == ScopeType.CTE 398 399 @property 400 def is_root(self): 401 """Determine if this is the root scope""" 402 return self.scope_type == ScopeType.ROOT 403 404 @property 405 def is_udtf(self): 406 """Determine if this scope is a UDTF (User Defined Table Function)""" 407 return self.scope_type == ScopeType.UDTF 408 409 @property 410 def is_correlated_subquery(self): 411 """Determine if this scope is a correlated subquery""" 412 return bool( 413 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 414 and self.external_columns 415 ) 416 417 def rename_source(self, old_name, new_name): 418 """Rename a source in this scope""" 419 columns = self.sources.pop(old_name or "", []) 420 self.sources[new_name] = columns 421 422 def add_source(self, name, source): 423 """Add a source to this scope""" 424 self.sources[name] = source 425 self.clear_cache() 426 427 def remove_source(self, name): 428 """Remove a source from this scope""" 429 self.sources.pop(name, None) 430 self.clear_cache() 431 432 def __repr__(self): 433 return f"Scope<{self.expression.sql()}>" 434 435 def traverse(self): 436 """ 437 Traverse the scope tree from this node. 438 439 Yields: 440 Scope: scope instances in depth-first-search post-order 441 """ 442 for child_scope in itertools.chain( 443 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 444 ): 445 yield from child_scope.traverse() 446 yield self 447 448 def ref_count(self): 449 """ 450 Count the number of times each scope in this tree is referenced. 451 452 Returns: 453 dict[int, int]: Mapping of Scope instance ID to reference count 454 """ 455 scope_ref_count = defaultdict(lambda: 0) 456 457 for scope in self.traverse(): 458 for _, source in scope.selected_sources.values(): 459 scope_ref_count[id(source)] += 1 460 461 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache()
85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None
100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 )
Branch from the current scope to a new, inner scope
160 def replace(self, old, new): 161 """ 162 Replace `old` with `new`. 163 164 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 165 166 Args: 167 old (exp.Expression): old node 168 new (exp.Expression): new node 169 """ 170 old.replace(new) 171 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
173 @property 174 def tables(self): 175 """ 176 List of tables in this scope. 177 178 Returns: 179 list[exp.Table]: tables 180 """ 181 self._ensure_collected() 182 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
184 @property 185 def ctes(self): 186 """ 187 List of CTEs in this scope. 188 189 Returns: 190 list[exp.CTE]: ctes 191 """ 192 self._ensure_collected() 193 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
195 @property 196 def derived_tables(self): 197 """ 198 List of derived tables in this scope. 199 200 For example: 201 SELECT * FROM (SELECT ...) <- that's a derived table 202 203 Returns: 204 list[exp.Subquery]: derived tables 205 """ 206 self._ensure_collected() 207 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
209 @property 210 def udtfs(self): 211 """ 212 List of "User Defined Tabular Functions" in this scope. 213 214 Returns: 215 list[exp.UDTF]: UDTFs 216 """ 217 self._ensure_collected() 218 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
220 @property 221 def subqueries(self): 222 """ 223 List of subqueries in this scope. 224 225 For example: 226 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 227 228 Returns: 229 list[exp.Subqueryable]: subqueries 230 """ 231 self._ensure_collected() 232 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
234 @property 235 def columns(self): 236 """ 237 List of columns in this scope. 238 239 Returns: 240 list[exp.Column]: Column instances in this scope, plus any 241 Columns that reference this scope from correlated subqueries. 242 """ 243 if self._columns is None: 244 self._ensure_collected() 245 columns = self._raw_columns 246 247 external_columns = [ 248 column 249 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 250 for column in scope.external_columns 251 ] 252 253 named_selects = set(self.expression.named_selects) 254 255 self._columns = [] 256 for column in columns + external_columns: 257 ancestor = column.find_ancestor( 258 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 259 ) 260 if ( 261 not ancestor 262 or column.table 263 or isinstance(ancestor, exp.Select) 264 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 265 or ( 266 isinstance(ancestor, exp.Order) 267 and ( 268 isinstance(ancestor.parent, exp.Window) 269 or column.name not in named_selects 270 ) 271 ) 272 ): 273 self._columns.append(column) 274 275 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 result = {} 290 291 for name, node in self.references: 292 if name in result: 293 raise OptimizeError(f"Alias already used: {name}") 294 if name in self.sources: 295 result[name] = (node, self.sources[name]) 296 297 self._selected_sources = result 298 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
300 @property 301 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 302 if self._references is None: 303 self._references = [] 304 305 for table in self.tables: 306 self._references.append((table.alias_or_name, table)) 307 for expression in itertools.chain(self.derived_tables, self.udtfs): 308 self._references.append( 309 ( 310 expression.alias, 311 expression if expression.args.get("pivots") else expression.unnest(), 312 ) 313 ) 314 315 return self._references
317 @property 318 def external_columns(self): 319 """ 320 Columns that appear to reference sources in outer scopes. 321 322 Returns: 323 list[exp.Column]: Column instances that don't reference 324 sources in the current scope. 325 """ 326 if self._external_columns is None: 327 if isinstance(self.expression, exp.Union): 328 left, right = self.union_scopes 329 self._external_columns = left.external_columns + right.external_columns 330 else: 331 self._external_columns = [ 332 c for c in self.columns if c.table not in self.selected_sources 333 ] 334 335 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
337 @property 338 def unqualified_columns(self): 339 """ 340 Unqualified columns in the current scope. 341 342 Returns: 343 list[exp.Column]: Unqualified columns 344 """ 345 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
347 @property 348 def join_hints(self): 349 """ 350 Hints that exist in the scope that reference tables 351 352 Returns: 353 list[exp.JoinHint]: Join hints that are referenced within the scope 354 """ 355 if self._join_hints is None: 356 return [] 357 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
368 def source_columns(self, source_name): 369 """ 370 Get all columns in the current scope for a particular source. 371 372 Args: 373 source_name (str): Name of the source 374 Returns: 375 list[exp.Column]: Column instances that reference `source_name` 376 """ 377 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
379 @property 380 def is_subquery(self): 381 """Determine if this scope is a subquery""" 382 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
384 @property 385 def is_derived_table(self): 386 """Determine if this scope is a derived table""" 387 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
389 @property 390 def is_union(self): 391 """Determine if this scope is a union""" 392 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
394 @property 395 def is_cte(self): 396 """Determine if this scope is a common table expression""" 397 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
399 @property 400 def is_root(self): 401 """Determine if this is the root scope""" 402 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
404 @property 405 def is_udtf(self): 406 """Determine if this scope is a UDTF (User Defined Table Function)""" 407 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
417 def rename_source(self, old_name, new_name): 418 """Rename a source in this scope""" 419 columns = self.sources.pop(old_name or "", []) 420 self.sources[new_name] = columns
Rename a source in this scope
422 def add_source(self, name, source): 423 """Add a source to this scope""" 424 self.sources[name] = source 425 self.clear_cache()
Add a source to this scope
427 def remove_source(self, name): 428 """Remove a source from this scope""" 429 self.sources.pop(name, None) 430 self.clear_cache()
Remove a source from this scope
435 def traverse(self): 436 """ 437 Traverse the scope tree from this node. 438 439 Yields: 440 Scope: scope instances in depth-first-search post-order 441 """ 442 for child_scope in itertools.chain( 443 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 444 ): 445 yield from child_scope.traverse() 446 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
448 def ref_count(self): 449 """ 450 Count the number of times each scope in this tree is referenced. 451 452 Returns: 453 dict[int, int]: Mapping of Scope instance ID to reference count 454 """ 455 scope_ref_count = defaultdict(lambda: 0) 456 457 for scope in self.traverse(): 458 for _, source in scope.selected_sources.values(): 459 scope_ref_count[id(source)] += 1 460 461 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
464def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 465 """ 466 Traverse an expression by its "scopes". 467 468 "Scope" represents the current context of a Select statement. 469 470 This is helpful for optimizing queries, where we need more information than 471 the expression tree itself. For example, we might care about the source 472 names within a subquery. Returns a list because a generator could result in 473 incomplete properties which is confusing. 474 475 Examples: 476 >>> import sqlglot 477 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 478 >>> scopes = traverse_scope(expression) 479 >>> scopes[0].expression.sql(), list(scopes[0].sources) 480 ('SELECT a FROM x', ['x']) 481 >>> scopes[1].expression.sql(), list(scopes[1].sources) 482 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 483 484 Args: 485 expression (exp.Expression): expression to traverse 486 487 Returns: 488 list[Scope]: scope instances 489 """ 490 if isinstance(expression, exp.Unionable) or ( 491 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable) 492 ): 493 return list(_traverse_scope(Scope(expression))) 494 495 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
498def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 499 """ 500 Build a scope tree. 501 502 Args: 503 expression (exp.Expression): expression to build the scope tree for 504 Returns: 505 Scope: root scope 506 """ 507 scopes = traverse_scope(expression) 508 if scopes: 509 return scopes[-1] 510 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
754def walk_in_scope(expression, bfs=True, prune=None): 755 """ 756 Returns a generator object which visits all nodes in the syntrax tree, stopping at 757 nodes that start child scopes. 758 759 Args: 760 expression (exp.Expression): 761 bfs (bool): if set to True the BFS traversal order will be applied, 762 otherwise the DFS traversal will be used instead. 763 prune ((node, parent, arg_key) -> bool): callable that returns True if 764 the generator should stop traversing this branch of the tree. 765 766 Yields: 767 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 768 """ 769 # We'll use this variable to pass state into the dfs generator. 770 # Whenever we set it to True, we exclude a subtree from traversal. 771 crossed_scope_boundary = False 772 773 for node, parent, key in expression.walk( 774 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 775 ): 776 crossed_scope_boundary = False 777 778 yield node, parent, key 779 780 if node is expression: 781 continue 782 if ( 783 isinstance(node, exp.CTE) 784 or ( 785 isinstance(node, exp.Subquery) 786 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 787 and _is_derived_table(node) 788 ) 789 or isinstance(node, exp.UDTF) 790 or isinstance(node, exp.Subqueryable) 791 ): 792 crossed_scope_boundary = True 793 794 if isinstance(node, (exp.Subquery, exp.UDTF)): 795 # The following args are not actually in the inner scope, so we should visit them 796 for key in ("joins", "laterals", "pivots"): 797 for arg in node.args.get(key) or []: 798 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
801def find_all_in_scope(expression, expression_types, bfs=True): 802 """ 803 Returns a generator object which visits all nodes in this scope and only yields those that 804 match at least one of the specified expression types. 805 806 This does NOT traverse into subscopes. 807 808 Args: 809 expression (exp.Expression): 810 expression_types (tuple[type]|type): the expression type(s) to match. 811 bfs (bool): True to use breadth-first search, False to use depth-first. 812 813 Yields: 814 exp.Expression: nodes 815 """ 816 for expression, *_ in walk_in_scope(expression, bfs=bfs): 817 if isinstance(expression, tuple(ensure_collection(expression_types))): 818 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
821def find_in_scope(expression, expression_types, bfs=True): 822 """ 823 Returns the first node in this scope which matches at least one of the specified types. 824 825 This does NOT traverse into subscopes. 826 827 Args: 828 expression (exp.Expression): 829 expression_types (tuple[type]|type): the expression type(s) to match. 830 bfs (bool): True to use breadth-first search, False to use depth-first. 831 832 Returns: 833 exp.Expression: the node which matches the criteria or None if no node matching 834 the criteria was found. 835 """ 836 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.