sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name 12 13logger = logging.getLogger("sqlglot") 14 15 16class ScopeType(Enum): 17 ROOT = auto() 18 SUBQUERY = auto() 19 DERIVED_TABLE = auto() 20 CTE = auto() 21 UNION = auto() 22 UDTF = auto() 23 24 25class Scope: 26 """ 27 Selection scope. 28 29 Attributes: 30 expression (exp.Select|exp.Union): Root expression of this scope 31 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 32 a Table expression or another Scope instance. For example: 33 SELECT * FROM x {"x": Table(this="x")} 34 SELECT * FROM x AS y {"y": Table(this="x")} 35 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 36 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 37 For example: 38 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 39 The LATERAL VIEW EXPLODE gets x as a source. 40 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 41 defines a column list of it's alias of this scope, this is that list of columns. 42 For example: 43 SELECT * FROM (SELECT ...) AS y(col1, col2) 44 The inner query would have `["col1", "col2"]` for its `outer_column_list` 45 parent (Scope): Parent scope 46 scope_type (ScopeType): Type of this scope, relative to it's parent 47 subquery_scopes (list[Scope]): List of all child scopes for subqueries 48 cte_scopes (list[Scope]): List of all child scopes for CTEs 49 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 50 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 51 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 52 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 53 a list of the left and right child scopes. 54 """ 55 56 def __init__( 57 self, 58 expression, 59 sources=None, 60 outer_column_list=None, 61 parent=None, 62 scope_type=ScopeType.ROOT, 63 lateral_sources=None, 64 ): 65 self.expression = expression 66 self.sources = sources or {} 67 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 68 self.sources.update(self.lateral_sources) 69 self.outer_column_list = outer_column_list or [] 70 self.parent = parent 71 self.scope_type = scope_type 72 self.subquery_scopes = [] 73 self.derived_table_scopes = [] 74 self.table_scopes = [] 75 self.cte_scopes = [] 76 self.union_scopes = [] 77 self.udtf_scopes = [] 78 self.clear_cache() 79 80 def clear_cache(self): 81 self._collected = False 82 self._raw_columns = None 83 self._derived_tables = None 84 self._udtfs = None 85 self._tables = None 86 self._ctes = None 87 self._subqueries = None 88 self._selected_sources = None 89 self._columns = None 90 self._external_columns = None 91 self._join_hints = None 92 self._pivots = None 93 self._references = None 94 95 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 96 """Branch from the current scope to a new, inner scope""" 97 return Scope( 98 expression=expression.unnest(), 99 sources={**self.cte_sources, **(chain_sources or {})}, 100 parent=self, 101 scope_type=scope_type, 102 **kwargs, 103 ) 104 105 def _collect(self): 106 self._tables = [] 107 self._ctes = [] 108 self._subqueries = [] 109 self._derived_tables = [] 110 self._udtfs = [] 111 self._raw_columns = [] 112 self._join_hints = [] 113 114 for node, parent, _ in self.walk(bfs=False): 115 if node is self.expression: 116 continue 117 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 118 self._raw_columns.append(node) 119 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 120 self._tables.append(node) 121 elif isinstance(node, exp.JoinHint): 122 self._join_hints.append(node) 123 elif isinstance(node, exp.UDTF): 124 self._udtfs.append(node) 125 elif isinstance(node, exp.CTE): 126 self._ctes.append(node) 127 elif ( 128 isinstance(node, exp.Subquery) 129 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 130 and _is_derived_table(node) 131 ): 132 self._derived_tables.append(node) 133 elif isinstance(node, exp.Subqueryable): 134 self._subqueries.append(node) 135 136 self._collected = True 137 138 def _ensure_collected(self): 139 if not self._collected: 140 self._collect() 141 142 def walk(self, bfs=True, prune=None): 143 return walk_in_scope(self.expression, bfs=bfs, prune=None) 144 145 def find(self, *expression_types, bfs=True): 146 return find_in_scope(self.expression, expression_types, bfs=bfs) 147 148 def find_all(self, *expression_types, bfs=True): 149 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 150 151 def replace(self, old, new): 152 """ 153 Replace `old` with `new`. 154 155 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 156 157 Args: 158 old (exp.Expression): old node 159 new (exp.Expression): new node 160 """ 161 old.replace(new) 162 self.clear_cache() 163 164 @property 165 def tables(self): 166 """ 167 List of tables in this scope. 168 169 Returns: 170 list[exp.Table]: tables 171 """ 172 self._ensure_collected() 173 return self._tables 174 175 @property 176 def ctes(self): 177 """ 178 List of CTEs in this scope. 179 180 Returns: 181 list[exp.CTE]: ctes 182 """ 183 self._ensure_collected() 184 return self._ctes 185 186 @property 187 def derived_tables(self): 188 """ 189 List of derived tables in this scope. 190 191 For example: 192 SELECT * FROM (SELECT ...) <- that's a derived table 193 194 Returns: 195 list[exp.Subquery]: derived tables 196 """ 197 self._ensure_collected() 198 return self._derived_tables 199 200 @property 201 def udtfs(self): 202 """ 203 List of "User Defined Tabular Functions" in this scope. 204 205 Returns: 206 list[exp.UDTF]: UDTFs 207 """ 208 self._ensure_collected() 209 return self._udtfs 210 211 @property 212 def subqueries(self): 213 """ 214 List of subqueries in this scope. 215 216 For example: 217 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 218 219 Returns: 220 list[exp.Subqueryable]: subqueries 221 """ 222 self._ensure_collected() 223 return self._subqueries 224 225 @property 226 def columns(self): 227 """ 228 List of columns in this scope. 229 230 Returns: 231 list[exp.Column]: Column instances in this scope, plus any 232 Columns that reference this scope from correlated subqueries. 233 """ 234 if self._columns is None: 235 self._ensure_collected() 236 columns = self._raw_columns 237 238 external_columns = [ 239 column 240 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 241 for column in scope.external_columns 242 ] 243 244 named_selects = set(self.expression.named_selects) 245 246 self._columns = [] 247 for column in columns + external_columns: 248 ancestor = column.find_ancestor( 249 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 250 ) 251 if ( 252 not ancestor 253 or column.table 254 or isinstance(ancestor, exp.Select) 255 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 256 or ( 257 isinstance(ancestor, exp.Order) 258 and ( 259 isinstance(ancestor.parent, exp.Window) 260 or column.name not in named_selects 261 ) 262 ) 263 ): 264 self._columns.append(column) 265 266 return self._columns 267 268 @property 269 def selected_sources(self): 270 """ 271 Mapping of nodes and sources that are actually selected from in this scope. 272 273 That is, all tables in a schema are selectable at any point. But a 274 table only becomes a selected source if it's included in a FROM or JOIN clause. 275 276 Returns: 277 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 278 """ 279 if self._selected_sources is None: 280 result = {} 281 282 for name, node in self.references: 283 if name in result: 284 raise OptimizeError(f"Alias already used: {name}") 285 if name in self.sources: 286 result[name] = (node, self.sources[name]) 287 288 self._selected_sources = result 289 return self._selected_sources 290 291 @property 292 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 293 if self._references is None: 294 self._references = [] 295 296 for table in self.tables: 297 self._references.append((table.alias_or_name, table)) 298 for expression in itertools.chain(self.derived_tables, self.udtfs): 299 self._references.append( 300 ( 301 expression.alias, 302 expression if expression.args.get("pivots") else expression.unnest(), 303 ) 304 ) 305 306 return self._references 307 308 @property 309 def cte_sources(self): 310 """ 311 Sources that are CTEs. 312 313 Returns: 314 dict[str, Scope]: Mapping of source alias to Scope 315 """ 316 return { 317 alias: scope 318 for alias, scope in self.sources.items() 319 if isinstance(scope, Scope) and scope.is_cte 320 } 321 322 @property 323 def external_columns(self): 324 """ 325 Columns that appear to reference sources in outer scopes. 326 327 Returns: 328 list[exp.Column]: Column instances that don't reference 329 sources in the current scope. 330 """ 331 if self._external_columns is None: 332 self._external_columns = [ 333 c for c in self.columns if c.table not in self.selected_sources 334 ] 335 return self._external_columns 336 337 @property 338 def unqualified_columns(self): 339 """ 340 Unqualified columns in the current scope. 341 342 Returns: 343 list[exp.Column]: Unqualified columns 344 """ 345 return [c for c in self.columns if not c.table] 346 347 @property 348 def join_hints(self): 349 """ 350 Hints that exist in the scope that reference tables 351 352 Returns: 353 list[exp.JoinHint]: Join hints that are referenced within the scope 354 """ 355 if self._join_hints is None: 356 return [] 357 return self._join_hints 358 359 @property 360 def pivots(self): 361 if not self._pivots: 362 self._pivots = [ 363 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 364 ] 365 366 return self._pivots 367 368 def source_columns(self, source_name): 369 """ 370 Get all columns in the current scope for a particular source. 371 372 Args: 373 source_name (str): Name of the source 374 Returns: 375 list[exp.Column]: Column instances that reference `source_name` 376 """ 377 return [column for column in self.columns if column.table == source_name] 378 379 @property 380 def is_subquery(self): 381 """Determine if this scope is a subquery""" 382 return self.scope_type == ScopeType.SUBQUERY 383 384 @property 385 def is_derived_table(self): 386 """Determine if this scope is a derived table""" 387 return self.scope_type == ScopeType.DERIVED_TABLE 388 389 @property 390 def is_union(self): 391 """Determine if this scope is a union""" 392 return self.scope_type == ScopeType.UNION 393 394 @property 395 def is_cte(self): 396 """Determine if this scope is a common table expression""" 397 return self.scope_type == ScopeType.CTE 398 399 @property 400 def is_root(self): 401 """Determine if this is the root scope""" 402 return self.scope_type == ScopeType.ROOT 403 404 @property 405 def is_udtf(self): 406 """Determine if this scope is a UDTF (User Defined Table Function)""" 407 return self.scope_type == ScopeType.UDTF 408 409 @property 410 def is_correlated_subquery(self): 411 """Determine if this scope is a correlated subquery""" 412 return bool( 413 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 414 and self.external_columns 415 ) 416 417 def rename_source(self, old_name, new_name): 418 """Rename a source in this scope""" 419 columns = self.sources.pop(old_name or "", []) 420 self.sources[new_name] = columns 421 422 def add_source(self, name, source): 423 """Add a source to this scope""" 424 self.sources[name] = source 425 self.clear_cache() 426 427 def remove_source(self, name): 428 """Remove a source from this scope""" 429 self.sources.pop(name, None) 430 self.clear_cache() 431 432 def __repr__(self): 433 return f"Scope<{self.expression.sql()}>" 434 435 def traverse(self): 436 """ 437 Traverse the scope tree from this node. 438 439 Yields: 440 Scope: scope instances in depth-first-search post-order 441 """ 442 for child_scope in itertools.chain( 443 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 444 ): 445 yield from child_scope.traverse() 446 yield self 447 448 def ref_count(self): 449 """ 450 Count the number of times each scope in this tree is referenced. 451 452 Returns: 453 dict[int, int]: Mapping of Scope instance ID to reference count 454 """ 455 scope_ref_count = defaultdict(lambda: 0) 456 457 for scope in self.traverse(): 458 for _, source in scope.selected_sources.values(): 459 scope_ref_count[id(source)] += 1 460 461 return scope_ref_count 462 463 464def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 465 """ 466 Traverse an expression by its "scopes". 467 468 "Scope" represents the current context of a Select statement. 469 470 This is helpful for optimizing queries, where we need more information than 471 the expression tree itself. For example, we might care about the source 472 names within a subquery. Returns a list because a generator could result in 473 incomplete properties which is confusing. 474 475 Examples: 476 >>> import sqlglot 477 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 478 >>> scopes = traverse_scope(expression) 479 >>> scopes[0].expression.sql(), list(scopes[0].sources) 480 ('SELECT a FROM x', ['x']) 481 >>> scopes[1].expression.sql(), list(scopes[1].sources) 482 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 483 484 Args: 485 expression (exp.Expression): expression to traverse 486 Returns: 487 list[Scope]: scope instances 488 """ 489 if isinstance(expression, exp.Unionable) or ( 490 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) 491 ): 492 return list(_traverse_scope(Scope(expression))) 493 494 return [] 495 496 497def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 498 """ 499 Build a scope tree. 500 501 Args: 502 expression (exp.Expression): expression to build the scope tree for 503 Returns: 504 Scope: root scope 505 """ 506 scopes = traverse_scope(expression) 507 if scopes: 508 return scopes[-1] 509 return None 510 511 512def _traverse_scope(scope): 513 if isinstance(scope.expression, exp.Select): 514 yield from _traverse_select(scope) 515 elif isinstance(scope.expression, exp.Union): 516 yield from _traverse_union(scope) 517 elif isinstance(scope.expression, exp.Subquery): 518 if scope.is_root: 519 yield from _traverse_select(scope) 520 else: 521 yield from _traverse_subqueries(scope) 522 elif isinstance(scope.expression, exp.Table): 523 yield from _traverse_tables(scope) 524 elif isinstance(scope.expression, exp.UDTF): 525 yield from _traverse_udtfs(scope) 526 elif isinstance(scope.expression, exp.DDL): 527 yield from _traverse_ddl(scope) 528 else: 529 logger.warning( 530 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 531 ) 532 return 533 534 yield scope 535 536 537def _traverse_select(scope): 538 yield from _traverse_ctes(scope) 539 yield from _traverse_tables(scope) 540 yield from _traverse_subqueries(scope) 541 542 543def _traverse_union(scope): 544 yield from _traverse_ctes(scope) 545 546 # The last scope to be yield should be the top most scope 547 left = None 548 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 549 yield left 550 551 right = None 552 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 553 yield right 554 555 scope.union_scopes = [left, right] 556 557 558def _traverse_ctes(scope): 559 sources = {} 560 561 for cte in scope.ctes: 562 recursive_scope = None 563 564 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 565 # thus the recursive scope is the first section of the union. 566 with_ = scope.expression.args.get("with") 567 if with_ and with_.recursive: 568 union = cte.this 569 570 if isinstance(union, exp.Union): 571 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 572 573 child_scope = None 574 575 for child_scope in _traverse_scope( 576 scope.branch( 577 cte.this, 578 chain_sources=sources, 579 outer_column_list=cte.alias_column_names, 580 scope_type=ScopeType.CTE, 581 ) 582 ): 583 yield child_scope 584 585 alias = cte.alias 586 sources[alias] = child_scope 587 588 if recursive_scope: 589 child_scope.add_source(alias, recursive_scope) 590 591 # append the final child_scope yielded 592 if child_scope: 593 scope.cte_scopes.append(child_scope) 594 595 scope.sources.update(sources) 596 597 598def _is_derived_table(expression: exp.Subquery) -> bool: 599 """ 600 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 601 as it doesn't introduce a new scope. If an alias is present, it shadows all names 602 under the Subquery, so that's one exception to this rule. 603 """ 604 return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) 605 606 607def _traverse_tables(scope): 608 sources = {} 609 610 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 611 expressions = [] 612 from_ = scope.expression.args.get("from") 613 if from_: 614 expressions.append(from_.this) 615 616 for join in scope.expression.args.get("joins") or []: 617 expressions.append(join.this) 618 619 if isinstance(scope.expression, exp.Table): 620 expressions.append(scope.expression) 621 622 expressions.extend(scope.expression.args.get("laterals") or []) 623 624 for expression in expressions: 625 if isinstance(expression, exp.Table): 626 table_name = expression.name 627 source_name = expression.alias_or_name 628 629 if table_name in scope.sources and not expression.db: 630 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 631 # it is pivoted, because then we get back a new table and hence a new source. 632 pivots = expression.args.get("pivots") 633 if pivots: 634 sources[pivots[0].alias] = expression 635 else: 636 sources[source_name] = scope.sources[table_name] 637 elif source_name in sources: 638 sources[find_new_name(sources, table_name)] = expression 639 else: 640 sources[source_name] = expression 641 642 # Make sure to not include the joins twice 643 if expression is not scope.expression: 644 expressions.extend(join.this for join in expression.args.get("joins") or []) 645 646 continue 647 648 if not isinstance(expression, exp.DerivedTable): 649 continue 650 651 if isinstance(expression, exp.UDTF): 652 lateral_sources = sources 653 scope_type = ScopeType.UDTF 654 scopes = scope.udtf_scopes 655 elif _is_derived_table(expression): 656 lateral_sources = None 657 scope_type = ScopeType.DERIVED_TABLE 658 scopes = scope.derived_table_scopes 659 expressions.extend(join.this for join in expression.args.get("joins") or []) 660 else: 661 # Makes sure we check for possible sources in nested table constructs 662 expressions.append(expression.this) 663 expressions.extend(join.this for join in expression.args.get("joins") or []) 664 continue 665 666 for child_scope in _traverse_scope( 667 scope.branch( 668 expression, 669 lateral_sources=lateral_sources, 670 outer_column_list=expression.alias_column_names, 671 scope_type=scope_type, 672 ) 673 ): 674 yield child_scope 675 676 # Tables without aliases will be set as "" 677 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 678 # Until then, this means that only a single, unaliased derived table is allowed (rather, 679 # the latest one wins. 680 sources[expression.alias] = child_scope 681 682 # append the final child_scope yielded 683 scopes.append(child_scope) 684 scope.table_scopes.append(child_scope) 685 686 scope.sources.update(sources) 687 688 689def _traverse_subqueries(scope): 690 for subquery in scope.subqueries: 691 top = None 692 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 693 yield child_scope 694 top = child_scope 695 scope.subquery_scopes.append(top) 696 697 698def _traverse_udtfs(scope): 699 if isinstance(scope.expression, exp.Unnest): 700 expressions = scope.expression.expressions 701 elif isinstance(scope.expression, exp.Lateral): 702 expressions = [scope.expression.this] 703 else: 704 expressions = [] 705 706 sources = {} 707 for expression in expressions: 708 if isinstance(expression, exp.Subquery) and _is_derived_table(expression): 709 top = None 710 for child_scope in _traverse_scope( 711 scope.branch( 712 expression, 713 scope_type=ScopeType.DERIVED_TABLE, 714 outer_column_list=expression.alias_column_names, 715 ) 716 ): 717 yield child_scope 718 top = child_scope 719 sources[expression.alias] = child_scope 720 721 scope.derived_table_scopes.append(top) 722 scope.table_scopes.append(top) 723 724 scope.sources.update(sources) 725 726 727def _traverse_ddl(scope): 728 yield from _traverse_ctes(scope) 729 730 query_scope = scope.branch( 731 scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources 732 ) 733 query_scope._collect() 734 query_scope._ctes = scope.ctes + query_scope._ctes 735 736 yield from _traverse_scope(query_scope) 737 738 739def walk_in_scope(expression, bfs=True, prune=None): 740 """ 741 Returns a generator object which visits all nodes in the syntrax tree, stopping at 742 nodes that start child scopes. 743 744 Args: 745 expression (exp.Expression): 746 bfs (bool): if set to True the BFS traversal order will be applied, 747 otherwise the DFS traversal will be used instead. 748 prune ((node, parent, arg_key) -> bool): callable that returns True if 749 the generator should stop traversing this branch of the tree. 750 751 Yields: 752 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 753 """ 754 # We'll use this variable to pass state into the dfs generator. 755 # Whenever we set it to True, we exclude a subtree from traversal. 756 crossed_scope_boundary = False 757 758 for node, parent, key in expression.walk( 759 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 760 ): 761 crossed_scope_boundary = False 762 763 yield node, parent, key 764 765 if node is expression: 766 continue 767 if ( 768 isinstance(node, exp.CTE) 769 or ( 770 isinstance(node, exp.Subquery) 771 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 772 and _is_derived_table(node) 773 ) 774 or isinstance(node, exp.UDTF) 775 or isinstance(node, exp.Subqueryable) 776 ): 777 crossed_scope_boundary = True 778 779 if isinstance(node, (exp.Subquery, exp.UDTF)): 780 # The following args are not actually in the inner scope, so we should visit them 781 for key in ("joins", "laterals", "pivots"): 782 for arg in node.args.get(key) or []: 783 yield from walk_in_scope(arg, bfs=bfs) 784 785 786def find_all_in_scope(expression, expression_types, bfs=True): 787 """ 788 Returns a generator object which visits all nodes in this scope and only yields those that 789 match at least one of the specified expression types. 790 791 This does NOT traverse into subscopes. 792 793 Args: 794 expression (exp.Expression): 795 expression_types (tuple[type]|type): the expression type(s) to match. 796 bfs (bool): True to use breadth-first search, False to use depth-first. 797 798 Yields: 799 exp.Expression: nodes 800 """ 801 for expression, *_ in walk_in_scope(expression, bfs=bfs): 802 if isinstance(expression, tuple(ensure_collection(expression_types))): 803 yield expression 804 805 806def find_in_scope(expression, expression_types, bfs=True): 807 """ 808 Returns the first node in this scope which matches at least one of the specified types. 809 810 This does NOT traverse into subscopes. 811 812 Args: 813 expression (exp.Expression): 814 expression_types (tuple[type]|type): the expression type(s) to match. 815 bfs (bool): True to use breadth-first search, False to use depth-first. 816 817 Returns: 818 exp.Expression: the node which matches the criteria or None if no node matching 819 the criteria was found. 820 """ 821 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
17class ScopeType(Enum): 18 ROOT = auto() 19 SUBQUERY = auto() 20 DERIVED_TABLE = auto() 21 CTE = auto() 22 UNION = auto() 23 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
26class Scope: 27 """ 28 Selection scope. 29 30 Attributes: 31 expression (exp.Select|exp.Union): Root expression of this scope 32 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 33 a Table expression or another Scope instance. For example: 34 SELECT * FROM x {"x": Table(this="x")} 35 SELECT * FROM x AS y {"y": Table(this="x")} 36 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 37 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 38 For example: 39 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 40 The LATERAL VIEW EXPLODE gets x as a source. 41 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 42 defines a column list of it's alias of this scope, this is that list of columns. 43 For example: 44 SELECT * FROM (SELECT ...) AS y(col1, col2) 45 The inner query would have `["col1", "col2"]` for its `outer_column_list` 46 parent (Scope): Parent scope 47 scope_type (ScopeType): Type of this scope, relative to it's parent 48 subquery_scopes (list[Scope]): List of all child scopes for subqueries 49 cte_scopes (list[Scope]): List of all child scopes for CTEs 50 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 51 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 52 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 53 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 54 a list of the left and right child scopes. 55 """ 56 57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_column_list=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 ): 66 self.expression = expression 67 self.sources = sources or {} 68 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 69 self.sources.update(self.lateral_sources) 70 self.outer_column_list = outer_column_list or [] 71 self.parent = parent 72 self.scope_type = scope_type 73 self.subquery_scopes = [] 74 self.derived_table_scopes = [] 75 self.table_scopes = [] 76 self.cte_scopes = [] 77 self.union_scopes = [] 78 self.udtf_scopes = [] 79 self.clear_cache() 80 81 def clear_cache(self): 82 self._collected = False 83 self._raw_columns = None 84 self._derived_tables = None 85 self._udtfs = None 86 self._tables = None 87 self._ctes = None 88 self._subqueries = None 89 self._selected_sources = None 90 self._columns = None 91 self._external_columns = None 92 self._join_hints = None 93 self._pivots = None 94 self._references = None 95 96 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 97 """Branch from the current scope to a new, inner scope""" 98 return Scope( 99 expression=expression.unnest(), 100 sources={**self.cte_sources, **(chain_sources or {})}, 101 parent=self, 102 scope_type=scope_type, 103 **kwargs, 104 ) 105 106 def _collect(self): 107 self._tables = [] 108 self._ctes = [] 109 self._subqueries = [] 110 self._derived_tables = [] 111 self._udtfs = [] 112 self._raw_columns = [] 113 self._join_hints = [] 114 115 for node, parent, _ in self.walk(bfs=False): 116 if node is self.expression: 117 continue 118 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 119 self._raw_columns.append(node) 120 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 121 self._tables.append(node) 122 elif isinstance(node, exp.JoinHint): 123 self._join_hints.append(node) 124 elif isinstance(node, exp.UDTF): 125 self._udtfs.append(node) 126 elif isinstance(node, exp.CTE): 127 self._ctes.append(node) 128 elif ( 129 isinstance(node, exp.Subquery) 130 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 131 and _is_derived_table(node) 132 ): 133 self._derived_tables.append(node) 134 elif isinstance(node, exp.Subqueryable): 135 self._subqueries.append(node) 136 137 self._collected = True 138 139 def _ensure_collected(self): 140 if not self._collected: 141 self._collect() 142 143 def walk(self, bfs=True, prune=None): 144 return walk_in_scope(self.expression, bfs=bfs, prune=None) 145 146 def find(self, *expression_types, bfs=True): 147 return find_in_scope(self.expression, expression_types, bfs=bfs) 148 149 def find_all(self, *expression_types, bfs=True): 150 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 151 152 def replace(self, old, new): 153 """ 154 Replace `old` with `new`. 155 156 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 157 158 Args: 159 old (exp.Expression): old node 160 new (exp.Expression): new node 161 """ 162 old.replace(new) 163 self.clear_cache() 164 165 @property 166 def tables(self): 167 """ 168 List of tables in this scope. 169 170 Returns: 171 list[exp.Table]: tables 172 """ 173 self._ensure_collected() 174 return self._tables 175 176 @property 177 def ctes(self): 178 """ 179 List of CTEs in this scope. 180 181 Returns: 182 list[exp.CTE]: ctes 183 """ 184 self._ensure_collected() 185 return self._ctes 186 187 @property 188 def derived_tables(self): 189 """ 190 List of derived tables in this scope. 191 192 For example: 193 SELECT * FROM (SELECT ...) <- that's a derived table 194 195 Returns: 196 list[exp.Subquery]: derived tables 197 """ 198 self._ensure_collected() 199 return self._derived_tables 200 201 @property 202 def udtfs(self): 203 """ 204 List of "User Defined Tabular Functions" in this scope. 205 206 Returns: 207 list[exp.UDTF]: UDTFs 208 """ 209 self._ensure_collected() 210 return self._udtfs 211 212 @property 213 def subqueries(self): 214 """ 215 List of subqueries in this scope. 216 217 For example: 218 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 219 220 Returns: 221 list[exp.Subqueryable]: subqueries 222 """ 223 self._ensure_collected() 224 return self._subqueries 225 226 @property 227 def columns(self): 228 """ 229 List of columns in this scope. 230 231 Returns: 232 list[exp.Column]: Column instances in this scope, plus any 233 Columns that reference this scope from correlated subqueries. 234 """ 235 if self._columns is None: 236 self._ensure_collected() 237 columns = self._raw_columns 238 239 external_columns = [ 240 column 241 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 242 for column in scope.external_columns 243 ] 244 245 named_selects = set(self.expression.named_selects) 246 247 self._columns = [] 248 for column in columns + external_columns: 249 ancestor = column.find_ancestor( 250 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 251 ) 252 if ( 253 not ancestor 254 or column.table 255 or isinstance(ancestor, exp.Select) 256 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 257 or ( 258 isinstance(ancestor, exp.Order) 259 and ( 260 isinstance(ancestor.parent, exp.Window) 261 or column.name not in named_selects 262 ) 263 ) 264 ): 265 self._columns.append(column) 266 267 return self._columns 268 269 @property 270 def selected_sources(self): 271 """ 272 Mapping of nodes and sources that are actually selected from in this scope. 273 274 That is, all tables in a schema are selectable at any point. But a 275 table only becomes a selected source if it's included in a FROM or JOIN clause. 276 277 Returns: 278 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 279 """ 280 if self._selected_sources is None: 281 result = {} 282 283 for name, node in self.references: 284 if name in result: 285 raise OptimizeError(f"Alias already used: {name}") 286 if name in self.sources: 287 result[name] = (node, self.sources[name]) 288 289 self._selected_sources = result 290 return self._selected_sources 291 292 @property 293 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 294 if self._references is None: 295 self._references = [] 296 297 for table in self.tables: 298 self._references.append((table.alias_or_name, table)) 299 for expression in itertools.chain(self.derived_tables, self.udtfs): 300 self._references.append( 301 ( 302 expression.alias, 303 expression if expression.args.get("pivots") else expression.unnest(), 304 ) 305 ) 306 307 return self._references 308 309 @property 310 def cte_sources(self): 311 """ 312 Sources that are CTEs. 313 314 Returns: 315 dict[str, Scope]: Mapping of source alias to Scope 316 """ 317 return { 318 alias: scope 319 for alias, scope in self.sources.items() 320 if isinstance(scope, Scope) and scope.is_cte 321 } 322 323 @property 324 def external_columns(self): 325 """ 326 Columns that appear to reference sources in outer scopes. 327 328 Returns: 329 list[exp.Column]: Column instances that don't reference 330 sources in the current scope. 331 """ 332 if self._external_columns is None: 333 self._external_columns = [ 334 c for c in self.columns if c.table not in self.selected_sources 335 ] 336 return self._external_columns 337 338 @property 339 def unqualified_columns(self): 340 """ 341 Unqualified columns in the current scope. 342 343 Returns: 344 list[exp.Column]: Unqualified columns 345 """ 346 return [c for c in self.columns if not c.table] 347 348 @property 349 def join_hints(self): 350 """ 351 Hints that exist in the scope that reference tables 352 353 Returns: 354 list[exp.JoinHint]: Join hints that are referenced within the scope 355 """ 356 if self._join_hints is None: 357 return [] 358 return self._join_hints 359 360 @property 361 def pivots(self): 362 if not self._pivots: 363 self._pivots = [ 364 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 365 ] 366 367 return self._pivots 368 369 def source_columns(self, source_name): 370 """ 371 Get all columns in the current scope for a particular source. 372 373 Args: 374 source_name (str): Name of the source 375 Returns: 376 list[exp.Column]: Column instances that reference `source_name` 377 """ 378 return [column for column in self.columns if column.table == source_name] 379 380 @property 381 def is_subquery(self): 382 """Determine if this scope is a subquery""" 383 return self.scope_type == ScopeType.SUBQUERY 384 385 @property 386 def is_derived_table(self): 387 """Determine if this scope is a derived table""" 388 return self.scope_type == ScopeType.DERIVED_TABLE 389 390 @property 391 def is_union(self): 392 """Determine if this scope is a union""" 393 return self.scope_type == ScopeType.UNION 394 395 @property 396 def is_cte(self): 397 """Determine if this scope is a common table expression""" 398 return self.scope_type == ScopeType.CTE 399 400 @property 401 def is_root(self): 402 """Determine if this is the root scope""" 403 return self.scope_type == ScopeType.ROOT 404 405 @property 406 def is_udtf(self): 407 """Determine if this scope is a UDTF (User Defined Table Function)""" 408 return self.scope_type == ScopeType.UDTF 409 410 @property 411 def is_correlated_subquery(self): 412 """Determine if this scope is a correlated subquery""" 413 return bool( 414 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 415 and self.external_columns 416 ) 417 418 def rename_source(self, old_name, new_name): 419 """Rename a source in this scope""" 420 columns = self.sources.pop(old_name or "", []) 421 self.sources[new_name] = columns 422 423 def add_source(self, name, source): 424 """Add a source to this scope""" 425 self.sources[name] = source 426 self.clear_cache() 427 428 def remove_source(self, name): 429 """Remove a source from this scope""" 430 self.sources.pop(name, None) 431 self.clear_cache() 432 433 def __repr__(self): 434 return f"Scope<{self.expression.sql()}>" 435 436 def traverse(self): 437 """ 438 Traverse the scope tree from this node. 439 440 Yields: 441 Scope: scope instances in depth-first-search post-order 442 """ 443 for child_scope in itertools.chain( 444 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 445 ): 446 yield from child_scope.traverse() 447 yield self 448 449 def ref_count(self): 450 """ 451 Count the number of times each scope in this tree is referenced. 452 453 Returns: 454 dict[int, int]: Mapping of Scope instance ID to reference count 455 """ 456 scope_ref_count = defaultdict(lambda: 0) 457 458 for scope in self.traverse(): 459 for _, source in scope.selected_sources.values(): 460 scope_ref_count[id(source)] += 1 461 462 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_column_list=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 ): 66 self.expression = expression 67 self.sources = sources or {} 68 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 69 self.sources.update(self.lateral_sources) 70 self.outer_column_list = outer_column_list or [] 71 self.parent = parent 72 self.scope_type = scope_type 73 self.subquery_scopes = [] 74 self.derived_table_scopes = [] 75 self.table_scopes = [] 76 self.cte_scopes = [] 77 self.union_scopes = [] 78 self.udtf_scopes = [] 79 self.clear_cache()
81 def clear_cache(self): 82 self._collected = False 83 self._raw_columns = None 84 self._derived_tables = None 85 self._udtfs = None 86 self._tables = None 87 self._ctes = None 88 self._subqueries = None 89 self._selected_sources = None 90 self._columns = None 91 self._external_columns = None 92 self._join_hints = None 93 self._pivots = None 94 self._references = None
96 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 97 """Branch from the current scope to a new, inner scope""" 98 return Scope( 99 expression=expression.unnest(), 100 sources={**self.cte_sources, **(chain_sources or {})}, 101 parent=self, 102 scope_type=scope_type, 103 **kwargs, 104 )
Branch from the current scope to a new, inner scope
152 def replace(self, old, new): 153 """ 154 Replace `old` with `new`. 155 156 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 157 158 Args: 159 old (exp.Expression): old node 160 new (exp.Expression): new node 161 """ 162 old.replace(new) 163 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
369 def source_columns(self, source_name): 370 """ 371 Get all columns in the current scope for a particular source. 372 373 Args: 374 source_name (str): Name of the source 375 Returns: 376 list[exp.Column]: Column instances that reference `source_name` 377 """ 378 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
418 def rename_source(self, old_name, new_name): 419 """Rename a source in this scope""" 420 columns = self.sources.pop(old_name or "", []) 421 self.sources[new_name] = columns
Rename a source in this scope
423 def add_source(self, name, source): 424 """Add a source to this scope""" 425 self.sources[name] = source 426 self.clear_cache()
Add a source to this scope
428 def remove_source(self, name): 429 """Remove a source from this scope""" 430 self.sources.pop(name, None) 431 self.clear_cache()
Remove a source from this scope
436 def traverse(self): 437 """ 438 Traverse the scope tree from this node. 439 440 Yields: 441 Scope: scope instances in depth-first-search post-order 442 """ 443 for child_scope in itertools.chain( 444 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 445 ): 446 yield from child_scope.traverse() 447 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
449 def ref_count(self): 450 """ 451 Count the number of times each scope in this tree is referenced. 452 453 Returns: 454 dict[int, int]: Mapping of Scope instance ID to reference count 455 """ 456 scope_ref_count = defaultdict(lambda: 0) 457 458 for scope in self.traverse(): 459 for _, source in scope.selected_sources.values(): 460 scope_ref_count[id(source)] += 1 461 462 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
465def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 466 """ 467 Traverse an expression by its "scopes". 468 469 "Scope" represents the current context of a Select statement. 470 471 This is helpful for optimizing queries, where we need more information than 472 the expression tree itself. For example, we might care about the source 473 names within a subquery. Returns a list because a generator could result in 474 incomplete properties which is confusing. 475 476 Examples: 477 >>> import sqlglot 478 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 479 >>> scopes = traverse_scope(expression) 480 >>> scopes[0].expression.sql(), list(scopes[0].sources) 481 ('SELECT a FROM x', ['x']) 482 >>> scopes[1].expression.sql(), list(scopes[1].sources) 483 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 484 485 Args: 486 expression (exp.Expression): expression to traverse 487 Returns: 488 list[Scope]: scope instances 489 """ 490 if isinstance(expression, exp.Unionable) or ( 491 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) 492 ): 493 return list(_traverse_scope(Scope(expression))) 494 495 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
498def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 499 """ 500 Build a scope tree. 501 502 Args: 503 expression (exp.Expression): expression to build the scope tree for 504 Returns: 505 Scope: root scope 506 """ 507 scopes = traverse_scope(expression) 508 if scopes: 509 return scopes[-1] 510 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
740def walk_in_scope(expression, bfs=True, prune=None): 741 """ 742 Returns a generator object which visits all nodes in the syntrax tree, stopping at 743 nodes that start child scopes. 744 745 Args: 746 expression (exp.Expression): 747 bfs (bool): if set to True the BFS traversal order will be applied, 748 otherwise the DFS traversal will be used instead. 749 prune ((node, parent, arg_key) -> bool): callable that returns True if 750 the generator should stop traversing this branch of the tree. 751 752 Yields: 753 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 754 """ 755 # We'll use this variable to pass state into the dfs generator. 756 # Whenever we set it to True, we exclude a subtree from traversal. 757 crossed_scope_boundary = False 758 759 for node, parent, key in expression.walk( 760 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 761 ): 762 crossed_scope_boundary = False 763 764 yield node, parent, key 765 766 if node is expression: 767 continue 768 if ( 769 isinstance(node, exp.CTE) 770 or ( 771 isinstance(node, exp.Subquery) 772 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 773 and _is_derived_table(node) 774 ) 775 or isinstance(node, exp.UDTF) 776 or isinstance(node, exp.Subqueryable) 777 ): 778 crossed_scope_boundary = True 779 780 if isinstance(node, (exp.Subquery, exp.UDTF)): 781 # The following args are not actually in the inner scope, so we should visit them 782 for key in ("joins", "laterals", "pivots"): 783 for arg in node.args.get(key) or []: 784 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
787def find_all_in_scope(expression, expression_types, bfs=True): 788 """ 789 Returns a generator object which visits all nodes in this scope and only yields those that 790 match at least one of the specified expression types. 791 792 This does NOT traverse into subscopes. 793 794 Args: 795 expression (exp.Expression): 796 expression_types (tuple[type]|type): the expression type(s) to match. 797 bfs (bool): True to use breadth-first search, False to use depth-first. 798 799 Yields: 800 exp.Expression: nodes 801 """ 802 for expression, *_ in walk_in_scope(expression, bfs=bfs): 803 if isinstance(expression, tuple(ensure_collection(expression_types))): 804 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
807def find_in_scope(expression, expression_types, bfs=True): 808 """ 809 Returns the first node in this scope which matches at least one of the specified types. 810 811 This does NOT traverse into subscopes. 812 813 Args: 814 expression (exp.Expression): 815 expression_types (tuple[type]|type): the expression type(s) to match. 816 bfs (bool): True to use breadth-first search, False to use depth-first. 817 818 Returns: 819 exp.Expression: the node which matches the criteria or None if no node matching 820 the criteria was found. 821 """ 822 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.