sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15 16class ScopeType(Enum): 17 ROOT = auto() 18 SUBQUERY = auto() 19 DERIVED_TABLE = auto() 20 CTE = auto() 21 UNION = auto() 22 UDTF = auto() 23 24 25class Scope: 26 """ 27 Selection scope. 28 29 Attributes: 30 expression (exp.Select|exp.Union): Root expression of this scope 31 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 32 a Table expression or another Scope instance. For example: 33 SELECT * FROM x {"x": Table(this="x")} 34 SELECT * FROM x AS y {"y": Table(this="x")} 35 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 36 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 37 For example: 38 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 39 The LATERAL VIEW EXPLODE gets x as a source. 40 cte_sources (dict[str, Scope]): Sources from CTES 41 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 42 defines a column list for the alias of this scope, this is that list of columns. 43 For example: 44 SELECT * FROM (SELECT ...) AS y(col1, col2) 45 The inner query would have `["col1", "col2"]` for its `outer_columns` 46 parent (Scope): Parent scope 47 scope_type (ScopeType): Type of this scope, relative to it's parent 48 subquery_scopes (list[Scope]): List of all child scopes for subqueries 49 cte_scopes (list[Scope]): List of all child scopes for CTEs 50 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 51 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 52 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 53 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 54 a list of the left and right child scopes. 55 """ 56 57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_columns=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 cte_sources=None, 66 ): 67 self.expression = expression 68 self.sources = sources or {} 69 self.lateral_sources = lateral_sources or {} 70 self.cte_sources = cte_sources or {} 71 self.sources.update(self.lateral_sources) 72 self.sources.update(self.cte_sources) 73 self.outer_columns = outer_columns or [] 74 self.parent = parent 75 self.scope_type = scope_type 76 self.subquery_scopes = [] 77 self.derived_table_scopes = [] 78 self.table_scopes = [] 79 self.cte_scopes = [] 80 self.union_scopes = [] 81 self.udtf_scopes = [] 82 self.clear_cache() 83 84 def clear_cache(self): 85 self._collected = False 86 self._raw_columns = None 87 self._derived_tables = None 88 self._udtfs = None 89 self._tables = None 90 self._ctes = None 91 self._subqueries = None 92 self._selected_sources = None 93 self._columns = None 94 self._external_columns = None 95 self._join_hints = None 96 self._pivots = None 97 self._references = None 98 99 def branch( 100 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 101 ): 102 """Branch from the current scope to a new, inner scope""" 103 return Scope( 104 expression=expression.unnest(), 105 sources=sources.copy() if sources else None, 106 parent=self, 107 scope_type=scope_type, 108 cte_sources={**self.cte_sources, **(cte_sources or {})}, 109 lateral_sources=lateral_sources.copy() if lateral_sources else None, 110 **kwargs, 111 ) 112 113 def _collect(self): 114 self._tables = [] 115 self._ctes = [] 116 self._subqueries = [] 117 self._derived_tables = [] 118 self._udtfs = [] 119 self._raw_columns = [] 120 self._join_hints = [] 121 122 for node in self.walk(bfs=False): 123 if node is self.expression: 124 continue 125 126 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 127 self._raw_columns.append(node) 128 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 129 self._tables.append(node) 130 elif isinstance(node, exp.JoinHint): 131 self._join_hints.append(node) 132 elif isinstance(node, exp.UDTF): 133 self._udtfs.append(node) 134 elif isinstance(node, exp.CTE): 135 self._ctes.append(node) 136 elif _is_derived_table(node) and isinstance( 137 node.parent, (exp.From, exp.Join, exp.Subquery) 138 ): 139 self._derived_tables.append(node) 140 elif isinstance(node, exp.UNWRAPPED_QUERIES): 141 self._subqueries.append(node) 142 143 self._collected = True 144 145 def _ensure_collected(self): 146 if not self._collected: 147 self._collect() 148 149 def walk(self, bfs=True, prune=None): 150 return walk_in_scope(self.expression, bfs=bfs, prune=None) 151 152 def find(self, *expression_types, bfs=True): 153 return find_in_scope(self.expression, expression_types, bfs=bfs) 154 155 def find_all(self, *expression_types, bfs=True): 156 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 157 158 def replace(self, old, new): 159 """ 160 Replace `old` with `new`. 161 162 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 163 164 Args: 165 old (exp.Expression): old node 166 new (exp.Expression): new node 167 """ 168 old.replace(new) 169 self.clear_cache() 170 171 @property 172 def tables(self): 173 """ 174 List of tables in this scope. 175 176 Returns: 177 list[exp.Table]: tables 178 """ 179 self._ensure_collected() 180 return self._tables 181 182 @property 183 def ctes(self): 184 """ 185 List of CTEs in this scope. 186 187 Returns: 188 list[exp.CTE]: ctes 189 """ 190 self._ensure_collected() 191 return self._ctes 192 193 @property 194 def derived_tables(self): 195 """ 196 List of derived tables in this scope. 197 198 For example: 199 SELECT * FROM (SELECT ...) <- that's a derived table 200 201 Returns: 202 list[exp.Subquery]: derived tables 203 """ 204 self._ensure_collected() 205 return self._derived_tables 206 207 @property 208 def udtfs(self): 209 """ 210 List of "User Defined Tabular Functions" in this scope. 211 212 Returns: 213 list[exp.UDTF]: UDTFs 214 """ 215 self._ensure_collected() 216 return self._udtfs 217 218 @property 219 def subqueries(self): 220 """ 221 List of subqueries in this scope. 222 223 For example: 224 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 225 226 Returns: 227 list[exp.Select | exp.Union]: subqueries 228 """ 229 self._ensure_collected() 230 return self._subqueries 231 232 @property 233 def columns(self): 234 """ 235 List of columns in this scope. 236 237 Returns: 238 list[exp.Column]: Column instances in this scope, plus any 239 Columns that reference this scope from correlated subqueries. 240 """ 241 if self._columns is None: 242 self._ensure_collected() 243 columns = self._raw_columns 244 245 external_columns = [ 246 column 247 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 248 for column in scope.external_columns 249 ] 250 251 named_selects = set(self.expression.named_selects) 252 253 self._columns = [] 254 for column in columns + external_columns: 255 ancestor = column.find_ancestor( 256 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 257 ) 258 if ( 259 not ancestor 260 or column.table 261 or isinstance(ancestor, exp.Select) 262 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 263 or ( 264 isinstance(ancestor, exp.Order) 265 and ( 266 isinstance(ancestor.parent, exp.Window) 267 or column.name not in named_selects 268 ) 269 ) 270 ): 271 self._columns.append(column) 272 273 return self._columns 274 275 @property 276 def selected_sources(self): 277 """ 278 Mapping of nodes and sources that are actually selected from in this scope. 279 280 That is, all tables in a schema are selectable at any point. But a 281 table only becomes a selected source if it's included in a FROM or JOIN clause. 282 283 Returns: 284 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 285 """ 286 if self._selected_sources is None: 287 result = {} 288 289 for name, node in self.references: 290 if name in result: 291 raise OptimizeError(f"Alias already used: {name}") 292 if name in self.sources: 293 result[name] = (node, self.sources[name]) 294 295 self._selected_sources = result 296 return self._selected_sources 297 298 @property 299 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 300 if self._references is None: 301 self._references = [] 302 303 for table in self.tables: 304 self._references.append((table.alias_or_name, table)) 305 for expression in itertools.chain(self.derived_tables, self.udtfs): 306 self._references.append( 307 ( 308 expression.alias, 309 expression if expression.args.get("pivots") else expression.unnest(), 310 ) 311 ) 312 313 return self._references 314 315 @property 316 def external_columns(self): 317 """ 318 Columns that appear to reference sources in outer scopes. 319 320 Returns: 321 list[exp.Column]: Column instances that don't reference 322 sources in the current scope. 323 """ 324 if self._external_columns is None: 325 if isinstance(self.expression, exp.Union): 326 left, right = self.union_scopes 327 self._external_columns = left.external_columns + right.external_columns 328 else: 329 self._external_columns = [ 330 c for c in self.columns if c.table not in self.selected_sources 331 ] 332 333 return self._external_columns 334 335 @property 336 def unqualified_columns(self): 337 """ 338 Unqualified columns in the current scope. 339 340 Returns: 341 list[exp.Column]: Unqualified columns 342 """ 343 return [c for c in self.columns if not c.table] 344 345 @property 346 def join_hints(self): 347 """ 348 Hints that exist in the scope that reference tables 349 350 Returns: 351 list[exp.JoinHint]: Join hints that are referenced within the scope 352 """ 353 if self._join_hints is None: 354 return [] 355 return self._join_hints 356 357 @property 358 def pivots(self): 359 if not self._pivots: 360 self._pivots = [ 361 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 362 ] 363 364 return self._pivots 365 366 def source_columns(self, source_name): 367 """ 368 Get all columns in the current scope for a particular source. 369 370 Args: 371 source_name (str): Name of the source 372 Returns: 373 list[exp.Column]: Column instances that reference `source_name` 374 """ 375 return [column for column in self.columns if column.table == source_name] 376 377 @property 378 def is_subquery(self): 379 """Determine if this scope is a subquery""" 380 return self.scope_type == ScopeType.SUBQUERY 381 382 @property 383 def is_derived_table(self): 384 """Determine if this scope is a derived table""" 385 return self.scope_type == ScopeType.DERIVED_TABLE 386 387 @property 388 def is_union(self): 389 """Determine if this scope is a union""" 390 return self.scope_type == ScopeType.UNION 391 392 @property 393 def is_cte(self): 394 """Determine if this scope is a common table expression""" 395 return self.scope_type == ScopeType.CTE 396 397 @property 398 def is_root(self): 399 """Determine if this is the root scope""" 400 return self.scope_type == ScopeType.ROOT 401 402 @property 403 def is_udtf(self): 404 """Determine if this scope is a UDTF (User Defined Table Function)""" 405 return self.scope_type == ScopeType.UDTF 406 407 @property 408 def is_correlated_subquery(self): 409 """Determine if this scope is a correlated subquery""" 410 return bool( 411 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 412 and self.external_columns 413 ) 414 415 def rename_source(self, old_name, new_name): 416 """Rename a source in this scope""" 417 columns = self.sources.pop(old_name or "", []) 418 self.sources[new_name] = columns 419 420 def add_source(self, name, source): 421 """Add a source to this scope""" 422 self.sources[name] = source 423 self.clear_cache() 424 425 def remove_source(self, name): 426 """Remove a source from this scope""" 427 self.sources.pop(name, None) 428 self.clear_cache() 429 430 def __repr__(self): 431 return f"Scope<{self.expression.sql()}>" 432 433 def traverse(self): 434 """ 435 Traverse the scope tree from this node. 436 437 Yields: 438 Scope: scope instances in depth-first-search post-order 439 """ 440 stack = [self] 441 result = [] 442 while stack: 443 scope = stack.pop() 444 result.append(scope) 445 stack.extend( 446 itertools.chain( 447 scope.cte_scopes, 448 scope.union_scopes, 449 scope.table_scopes, 450 scope.subquery_scopes, 451 ) 452 ) 453 454 yield from reversed(result) 455 456 def ref_count(self): 457 """ 458 Count the number of times each scope in this tree is referenced. 459 460 Returns: 461 dict[int, int]: Mapping of Scope instance ID to reference count 462 """ 463 scope_ref_count = defaultdict(lambda: 0) 464 465 for scope in self.traverse(): 466 for _, source in scope.selected_sources.values(): 467 scope_ref_count[id(source)] += 1 468 469 return scope_ref_count 470 471 472def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 473 """ 474 Traverse an expression by its "scopes". 475 476 "Scope" represents the current context of a Select statement. 477 478 This is helpful for optimizing queries, where we need more information than 479 the expression tree itself. For example, we might care about the source 480 names within a subquery. Returns a list because a generator could result in 481 incomplete properties which is confusing. 482 483 Examples: 484 >>> import sqlglot 485 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 486 >>> scopes = traverse_scope(expression) 487 >>> scopes[0].expression.sql(), list(scopes[0].sources) 488 ('SELECT a FROM x', ['x']) 489 >>> scopes[1].expression.sql(), list(scopes[1].sources) 490 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 491 492 Args: 493 expression: Expression to traverse 494 495 Returns: 496 A list of the created scope instances 497 """ 498 if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): 499 # We ignore the DDL expression and build a scope for its query instead 500 ddl_with = expression.args.get("with") 501 expression = expression.expression 502 503 # If the DDL has CTEs attached, we need to add them to the query, or 504 # prepend them if the query itself already has CTEs attached to it 505 if ddl_with: 506 ddl_with.pop() 507 query_ctes = expression.ctes 508 if not query_ctes: 509 expression.set("with", ddl_with) 510 else: 511 expression.args["with"].set("recursive", ddl_with.recursive) 512 expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) 513 514 if isinstance(expression, exp.Query): 515 return list(_traverse_scope(Scope(expression))) 516 517 return [] 518 519 520def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 521 """ 522 Build a scope tree. 523 524 Args: 525 expression: Expression to build the scope tree for. 526 527 Returns: 528 The root scope 529 """ 530 return seq_get(traverse_scope(expression), -1) 531 532 533def _traverse_scope(scope): 534 if isinstance(scope.expression, exp.Select): 535 yield from _traverse_select(scope) 536 elif isinstance(scope.expression, exp.Union): 537 yield from _traverse_ctes(scope) 538 yield from _traverse_union(scope) 539 return 540 elif isinstance(scope.expression, exp.Subquery): 541 if scope.is_root: 542 yield from _traverse_select(scope) 543 else: 544 yield from _traverse_subqueries(scope) 545 elif isinstance(scope.expression, exp.Table): 546 yield from _traverse_tables(scope) 547 elif isinstance(scope.expression, exp.UDTF): 548 yield from _traverse_udtfs(scope) 549 else: 550 logger.warning( 551 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 552 ) 553 return 554 555 yield scope 556 557 558def _traverse_select(scope): 559 yield from _traverse_ctes(scope) 560 yield from _traverse_tables(scope) 561 yield from _traverse_subqueries(scope) 562 563 564def _traverse_union(scope): 565 prev_scope = None 566 union_scope_stack = [scope] 567 expression_stack = [scope.expression.right, scope.expression.left] 568 569 while expression_stack: 570 expression = expression_stack.pop() 571 union_scope = union_scope_stack[-1] 572 573 new_scope = union_scope.branch( 574 expression, 575 outer_columns=union_scope.outer_columns, 576 scope_type=ScopeType.UNION, 577 ) 578 579 if isinstance(expression, exp.Union): 580 yield from _traverse_ctes(new_scope) 581 582 union_scope_stack.append(new_scope) 583 expression_stack.extend([expression.right, expression.left]) 584 continue 585 586 for scope in _traverse_scope(new_scope): 587 yield scope 588 589 if prev_scope: 590 union_scope_stack.pop() 591 union_scope.union_scopes = [prev_scope, scope] 592 prev_scope = union_scope 593 594 yield union_scope 595 else: 596 prev_scope = scope 597 598 599def _traverse_ctes(scope): 600 sources = {} 601 602 for cte in scope.ctes: 603 cte_name = cte.alias 604 605 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 606 # thus the recursive scope is the first section of the union. 607 with_ = scope.expression.args.get("with") 608 if with_ and with_.recursive: 609 union = cte.this 610 611 if isinstance(union, exp.Union): 612 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 613 614 child_scope = None 615 616 for child_scope in _traverse_scope( 617 scope.branch( 618 cte.this, 619 cte_sources=sources, 620 outer_columns=cte.alias_column_names, 621 scope_type=ScopeType.CTE, 622 ) 623 ): 624 yield child_scope 625 626 # append the final child_scope yielded 627 if child_scope: 628 sources[cte_name] = child_scope 629 scope.cte_scopes.append(child_scope) 630 631 scope.sources.update(sources) 632 scope.cte_sources.update(sources) 633 634 635def _is_derived_table(expression: exp.Subquery) -> bool: 636 """ 637 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 638 as it doesn't introduce a new scope. If an alias is present, it shadows all names 639 under the Subquery, so that's one exception to this rule. 640 """ 641 return isinstance(expression, exp.Subquery) and bool( 642 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 643 ) 644 645 646def _traverse_tables(scope): 647 sources = {} 648 649 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 650 expressions = [] 651 from_ = scope.expression.args.get("from") 652 if from_: 653 expressions.append(from_.this) 654 655 for join in scope.expression.args.get("joins") or []: 656 expressions.append(join.this) 657 658 if isinstance(scope.expression, exp.Table): 659 expressions.append(scope.expression) 660 661 expressions.extend(scope.expression.args.get("laterals") or []) 662 663 for expression in expressions: 664 if isinstance(expression, exp.Table): 665 table_name = expression.name 666 source_name = expression.alias_or_name 667 668 if table_name in scope.sources and not expression.db: 669 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 670 # it is pivoted, because then we get back a new table and hence a new source. 671 pivots = expression.args.get("pivots") 672 if pivots: 673 sources[pivots[0].alias] = expression 674 else: 675 sources[source_name] = scope.sources[table_name] 676 elif source_name in sources: 677 sources[find_new_name(sources, table_name)] = expression 678 else: 679 sources[source_name] = expression 680 681 # Make sure to not include the joins twice 682 if expression is not scope.expression: 683 expressions.extend(join.this for join in expression.args.get("joins") or []) 684 685 continue 686 687 if not isinstance(expression, exp.DerivedTable): 688 continue 689 690 if isinstance(expression, exp.UDTF): 691 lateral_sources = sources 692 scope_type = ScopeType.UDTF 693 scopes = scope.udtf_scopes 694 elif _is_derived_table(expression): 695 lateral_sources = None 696 scope_type = ScopeType.DERIVED_TABLE 697 scopes = scope.derived_table_scopes 698 expressions.extend(join.this for join in expression.args.get("joins") or []) 699 else: 700 # Makes sure we check for possible sources in nested table constructs 701 expressions.append(expression.this) 702 expressions.extend(join.this for join in expression.args.get("joins") or []) 703 continue 704 705 for child_scope in _traverse_scope( 706 scope.branch( 707 expression, 708 lateral_sources=lateral_sources, 709 outer_columns=expression.alias_column_names, 710 scope_type=scope_type, 711 ) 712 ): 713 yield child_scope 714 715 # Tables without aliases will be set as "" 716 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 717 # Until then, this means that only a single, unaliased derived table is allowed (rather, 718 # the latest one wins. 719 sources[expression.alias] = child_scope 720 721 # append the final child_scope yielded 722 scopes.append(child_scope) 723 scope.table_scopes.append(child_scope) 724 725 scope.sources.update(sources) 726 727 728def _traverse_subqueries(scope): 729 for subquery in scope.subqueries: 730 top = None 731 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 732 yield child_scope 733 top = child_scope 734 scope.subquery_scopes.append(top) 735 736 737def _traverse_udtfs(scope): 738 if isinstance(scope.expression, exp.Unnest): 739 expressions = scope.expression.expressions 740 elif isinstance(scope.expression, exp.Lateral): 741 expressions = [scope.expression.this] 742 else: 743 expressions = [] 744 745 sources = {} 746 for expression in expressions: 747 if _is_derived_table(expression): 748 top = None 749 for child_scope in _traverse_scope( 750 scope.branch( 751 expression, 752 scope_type=ScopeType.DERIVED_TABLE, 753 outer_columns=expression.alias_column_names, 754 ) 755 ): 756 yield child_scope 757 top = child_scope 758 sources[expression.alias] = child_scope 759 760 scope.derived_table_scopes.append(top) 761 scope.table_scopes.append(top) 762 763 scope.sources.update(sources) 764 765 766def walk_in_scope(expression, bfs=True, prune=None): 767 """ 768 Returns a generator object which visits all nodes in the syntrax tree, stopping at 769 nodes that start child scopes. 770 771 Args: 772 expression (exp.Expression): 773 bfs (bool): if set to True the BFS traversal order will be applied, 774 otherwise the DFS traversal will be used instead. 775 prune ((node, parent, arg_key) -> bool): callable that returns True if 776 the generator should stop traversing this branch of the tree. 777 778 Yields: 779 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 780 """ 781 # We'll use this variable to pass state into the dfs generator. 782 # Whenever we set it to True, we exclude a subtree from traversal. 783 crossed_scope_boundary = False 784 785 for node in expression.walk( 786 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 787 ): 788 crossed_scope_boundary = False 789 790 yield node 791 792 if node is expression: 793 continue 794 if ( 795 isinstance(node, exp.CTE) 796 or ( 797 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 798 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 799 ) 800 or isinstance(node, exp.UNWRAPPED_QUERIES) 801 ): 802 crossed_scope_boundary = True 803 804 if isinstance(node, (exp.Subquery, exp.UDTF)): 805 # The following args are not actually in the inner scope, so we should visit them 806 for key in ("joins", "laterals", "pivots"): 807 for arg in node.args.get(key) or []: 808 yield from walk_in_scope(arg, bfs=bfs) 809 810 811def find_all_in_scope(expression, expression_types, bfs=True): 812 """ 813 Returns a generator object which visits all nodes in this scope and only yields those that 814 match at least one of the specified expression types. 815 816 This does NOT traverse into subscopes. 817 818 Args: 819 expression (exp.Expression): 820 expression_types (tuple[type]|type): the expression type(s) to match. 821 bfs (bool): True to use breadth-first search, False to use depth-first. 822 823 Yields: 824 exp.Expression: nodes 825 """ 826 for expression in walk_in_scope(expression, bfs=bfs): 827 if isinstance(expression, tuple(ensure_collection(expression_types))): 828 yield expression 829 830 831def find_in_scope(expression, expression_types, bfs=True): 832 """ 833 Returns the first node in this scope which matches at least one of the specified types. 834 835 This does NOT traverse into subscopes. 836 837 Args: 838 expression (exp.Expression): 839 expression_types (tuple[type]|type): the expression type(s) to match. 840 bfs (bool): True to use breadth-first search, False to use depth-first. 841 842 Returns: 843 exp.Expression: the node which matches the criteria or None if no node matching 844 the criteria was found. 845 """ 846 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
17class ScopeType(Enum): 18 ROOT = auto() 19 SUBQUERY = auto() 20 DERIVED_TABLE = auto() 21 CTE = auto() 22 UNION = auto() 23 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
26class Scope: 27 """ 28 Selection scope. 29 30 Attributes: 31 expression (exp.Select|exp.Union): Root expression of this scope 32 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 33 a Table expression or another Scope instance. For example: 34 SELECT * FROM x {"x": Table(this="x")} 35 SELECT * FROM x AS y {"y": Table(this="x")} 36 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 37 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 38 For example: 39 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 40 The LATERAL VIEW EXPLODE gets x as a source. 41 cte_sources (dict[str, Scope]): Sources from CTES 42 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 43 defines a column list for the alias of this scope, this is that list of columns. 44 For example: 45 SELECT * FROM (SELECT ...) AS y(col1, col2) 46 The inner query would have `["col1", "col2"]` for its `outer_columns` 47 parent (Scope): Parent scope 48 scope_type (ScopeType): Type of this scope, relative to it's parent 49 subquery_scopes (list[Scope]): List of all child scopes for subqueries 50 cte_scopes (list[Scope]): List of all child scopes for CTEs 51 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 52 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 53 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 54 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 55 a list of the left and right child scopes. 56 """ 57 58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_columns=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_columns = outer_columns or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache() 84 85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None 99 100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 ) 113 114 def _collect(self): 115 self._tables = [] 116 self._ctes = [] 117 self._subqueries = [] 118 self._derived_tables = [] 119 self._udtfs = [] 120 self._raw_columns = [] 121 self._join_hints = [] 122 123 for node in self.walk(bfs=False): 124 if node is self.expression: 125 continue 126 127 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 128 self._raw_columns.append(node) 129 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 130 self._tables.append(node) 131 elif isinstance(node, exp.JoinHint): 132 self._join_hints.append(node) 133 elif isinstance(node, exp.UDTF): 134 self._udtfs.append(node) 135 elif isinstance(node, exp.CTE): 136 self._ctes.append(node) 137 elif _is_derived_table(node) and isinstance( 138 node.parent, (exp.From, exp.Join, exp.Subquery) 139 ): 140 self._derived_tables.append(node) 141 elif isinstance(node, exp.UNWRAPPED_QUERIES): 142 self._subqueries.append(node) 143 144 self._collected = True 145 146 def _ensure_collected(self): 147 if not self._collected: 148 self._collect() 149 150 def walk(self, bfs=True, prune=None): 151 return walk_in_scope(self.expression, bfs=bfs, prune=None) 152 153 def find(self, *expression_types, bfs=True): 154 return find_in_scope(self.expression, expression_types, bfs=bfs) 155 156 def find_all(self, *expression_types, bfs=True): 157 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 158 159 def replace(self, old, new): 160 """ 161 Replace `old` with `new`. 162 163 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 164 165 Args: 166 old (exp.Expression): old node 167 new (exp.Expression): new node 168 """ 169 old.replace(new) 170 self.clear_cache() 171 172 @property 173 def tables(self): 174 """ 175 List of tables in this scope. 176 177 Returns: 178 list[exp.Table]: tables 179 """ 180 self._ensure_collected() 181 return self._tables 182 183 @property 184 def ctes(self): 185 """ 186 List of CTEs in this scope. 187 188 Returns: 189 list[exp.CTE]: ctes 190 """ 191 self._ensure_collected() 192 return self._ctes 193 194 @property 195 def derived_tables(self): 196 """ 197 List of derived tables in this scope. 198 199 For example: 200 SELECT * FROM (SELECT ...) <- that's a derived table 201 202 Returns: 203 list[exp.Subquery]: derived tables 204 """ 205 self._ensure_collected() 206 return self._derived_tables 207 208 @property 209 def udtfs(self): 210 """ 211 List of "User Defined Tabular Functions" in this scope. 212 213 Returns: 214 list[exp.UDTF]: UDTFs 215 """ 216 self._ensure_collected() 217 return self._udtfs 218 219 @property 220 def subqueries(self): 221 """ 222 List of subqueries in this scope. 223 224 For example: 225 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 226 227 Returns: 228 list[exp.Select | exp.Union]: subqueries 229 """ 230 self._ensure_collected() 231 return self._subqueries 232 233 @property 234 def columns(self): 235 """ 236 List of columns in this scope. 237 238 Returns: 239 list[exp.Column]: Column instances in this scope, plus any 240 Columns that reference this scope from correlated subqueries. 241 """ 242 if self._columns is None: 243 self._ensure_collected() 244 columns = self._raw_columns 245 246 external_columns = [ 247 column 248 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 249 for column in scope.external_columns 250 ] 251 252 named_selects = set(self.expression.named_selects) 253 254 self._columns = [] 255 for column in columns + external_columns: 256 ancestor = column.find_ancestor( 257 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 258 ) 259 if ( 260 not ancestor 261 or column.table 262 or isinstance(ancestor, exp.Select) 263 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 264 or ( 265 isinstance(ancestor, exp.Order) 266 and ( 267 isinstance(ancestor.parent, exp.Window) 268 or column.name not in named_selects 269 ) 270 ) 271 ): 272 self._columns.append(column) 273 274 return self._columns 275 276 @property 277 def selected_sources(self): 278 """ 279 Mapping of nodes and sources that are actually selected from in this scope. 280 281 That is, all tables in a schema are selectable at any point. But a 282 table only becomes a selected source if it's included in a FROM or JOIN clause. 283 284 Returns: 285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 286 """ 287 if self._selected_sources is None: 288 result = {} 289 290 for name, node in self.references: 291 if name in result: 292 raise OptimizeError(f"Alias already used: {name}") 293 if name in self.sources: 294 result[name] = (node, self.sources[name]) 295 296 self._selected_sources = result 297 return self._selected_sources 298 299 @property 300 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 301 if self._references is None: 302 self._references = [] 303 304 for table in self.tables: 305 self._references.append((table.alias_or_name, table)) 306 for expression in itertools.chain(self.derived_tables, self.udtfs): 307 self._references.append( 308 ( 309 expression.alias, 310 expression if expression.args.get("pivots") else expression.unnest(), 311 ) 312 ) 313 314 return self._references 315 316 @property 317 def external_columns(self): 318 """ 319 Columns that appear to reference sources in outer scopes. 320 321 Returns: 322 list[exp.Column]: Column instances that don't reference 323 sources in the current scope. 324 """ 325 if self._external_columns is None: 326 if isinstance(self.expression, exp.Union): 327 left, right = self.union_scopes 328 self._external_columns = left.external_columns + right.external_columns 329 else: 330 self._external_columns = [ 331 c for c in self.columns if c.table not in self.selected_sources 332 ] 333 334 return self._external_columns 335 336 @property 337 def unqualified_columns(self): 338 """ 339 Unqualified columns in the current scope. 340 341 Returns: 342 list[exp.Column]: Unqualified columns 343 """ 344 return [c for c in self.columns if not c.table] 345 346 @property 347 def join_hints(self): 348 """ 349 Hints that exist in the scope that reference tables 350 351 Returns: 352 list[exp.JoinHint]: Join hints that are referenced within the scope 353 """ 354 if self._join_hints is None: 355 return [] 356 return self._join_hints 357 358 @property 359 def pivots(self): 360 if not self._pivots: 361 self._pivots = [ 362 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 363 ] 364 365 return self._pivots 366 367 def source_columns(self, source_name): 368 """ 369 Get all columns in the current scope for a particular source. 370 371 Args: 372 source_name (str): Name of the source 373 Returns: 374 list[exp.Column]: Column instances that reference `source_name` 375 """ 376 return [column for column in self.columns if column.table == source_name] 377 378 @property 379 def is_subquery(self): 380 """Determine if this scope is a subquery""" 381 return self.scope_type == ScopeType.SUBQUERY 382 383 @property 384 def is_derived_table(self): 385 """Determine if this scope is a derived table""" 386 return self.scope_type == ScopeType.DERIVED_TABLE 387 388 @property 389 def is_union(self): 390 """Determine if this scope is a union""" 391 return self.scope_type == ScopeType.UNION 392 393 @property 394 def is_cte(self): 395 """Determine if this scope is a common table expression""" 396 return self.scope_type == ScopeType.CTE 397 398 @property 399 def is_root(self): 400 """Determine if this is the root scope""" 401 return self.scope_type == ScopeType.ROOT 402 403 @property 404 def is_udtf(self): 405 """Determine if this scope is a UDTF (User Defined Table Function)""" 406 return self.scope_type == ScopeType.UDTF 407 408 @property 409 def is_correlated_subquery(self): 410 """Determine if this scope is a correlated subquery""" 411 return bool( 412 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 413 and self.external_columns 414 ) 415 416 def rename_source(self, old_name, new_name): 417 """Rename a source in this scope""" 418 columns = self.sources.pop(old_name or "", []) 419 self.sources[new_name] = columns 420 421 def add_source(self, name, source): 422 """Add a source to this scope""" 423 self.sources[name] = source 424 self.clear_cache() 425 426 def remove_source(self, name): 427 """Remove a source from this scope""" 428 self.sources.pop(name, None) 429 self.clear_cache() 430 431 def __repr__(self): 432 return f"Scope<{self.expression.sql()}>" 433 434 def traverse(self): 435 """ 436 Traverse the scope tree from this node. 437 438 Yields: 439 Scope: scope instances in depth-first-search post-order 440 """ 441 stack = [self] 442 result = [] 443 while stack: 444 scope = stack.pop() 445 result.append(scope) 446 stack.extend( 447 itertools.chain( 448 scope.cte_scopes, 449 scope.union_scopes, 450 scope.table_scopes, 451 scope.subquery_scopes, 452 ) 453 ) 454 455 yield from reversed(result) 456 457 def ref_count(self): 458 """ 459 Count the number of times each scope in this tree is referenced. 460 461 Returns: 462 dict[int, int]: Mapping of Scope instance ID to reference count 463 """ 464 scope_ref_count = defaultdict(lambda: 0) 465 466 for scope in self.traverse(): 467 for _, source in scope.selected_sources.values(): 468 scope_ref_count[id(source)] += 1 469 470 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_columns=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_columns = outer_columns or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache()
85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None
100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 )
Branch from the current scope to a new, inner scope
159 def replace(self, old, new): 160 """ 161 Replace `old` with `new`. 162 163 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 164 165 Args: 166 old (exp.Expression): old node 167 new (exp.Expression): new node 168 """ 169 old.replace(new) 170 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
172 @property 173 def tables(self): 174 """ 175 List of tables in this scope. 176 177 Returns: 178 list[exp.Table]: tables 179 """ 180 self._ensure_collected() 181 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
183 @property 184 def ctes(self): 185 """ 186 List of CTEs in this scope. 187 188 Returns: 189 list[exp.CTE]: ctes 190 """ 191 self._ensure_collected() 192 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
194 @property 195 def derived_tables(self): 196 """ 197 List of derived tables in this scope. 198 199 For example: 200 SELECT * FROM (SELECT ...) <- that's a derived table 201 202 Returns: 203 list[exp.Subquery]: derived tables 204 """ 205 self._ensure_collected() 206 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
208 @property 209 def udtfs(self): 210 """ 211 List of "User Defined Tabular Functions" in this scope. 212 213 Returns: 214 list[exp.UDTF]: UDTFs 215 """ 216 self._ensure_collected() 217 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
219 @property 220 def subqueries(self): 221 """ 222 List of subqueries in this scope. 223 224 For example: 225 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 226 227 Returns: 228 list[exp.Select | exp.Union]: subqueries 229 """ 230 self._ensure_collected() 231 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
233 @property 234 def columns(self): 235 """ 236 List of columns in this scope. 237 238 Returns: 239 list[exp.Column]: Column instances in this scope, plus any 240 Columns that reference this scope from correlated subqueries. 241 """ 242 if self._columns is None: 243 self._ensure_collected() 244 columns = self._raw_columns 245 246 external_columns = [ 247 column 248 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 249 for column in scope.external_columns 250 ] 251 252 named_selects = set(self.expression.named_selects) 253 254 self._columns = [] 255 for column in columns + external_columns: 256 ancestor = column.find_ancestor( 257 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 258 ) 259 if ( 260 not ancestor 261 or column.table 262 or isinstance(ancestor, exp.Select) 263 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 264 or ( 265 isinstance(ancestor, exp.Order) 266 and ( 267 isinstance(ancestor.parent, exp.Window) 268 or column.name not in named_selects 269 ) 270 ) 271 ): 272 self._columns.append(column) 273 274 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
276 @property 277 def selected_sources(self): 278 """ 279 Mapping of nodes and sources that are actually selected from in this scope. 280 281 That is, all tables in a schema are selectable at any point. But a 282 table only becomes a selected source if it's included in a FROM or JOIN clause. 283 284 Returns: 285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 286 """ 287 if self._selected_sources is None: 288 result = {} 289 290 for name, node in self.references: 291 if name in result: 292 raise OptimizeError(f"Alias already used: {name}") 293 if name in self.sources: 294 result[name] = (node, self.sources[name]) 295 296 self._selected_sources = result 297 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
299 @property 300 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 301 if self._references is None: 302 self._references = [] 303 304 for table in self.tables: 305 self._references.append((table.alias_or_name, table)) 306 for expression in itertools.chain(self.derived_tables, self.udtfs): 307 self._references.append( 308 ( 309 expression.alias, 310 expression if expression.args.get("pivots") else expression.unnest(), 311 ) 312 ) 313 314 return self._references
316 @property 317 def external_columns(self): 318 """ 319 Columns that appear to reference sources in outer scopes. 320 321 Returns: 322 list[exp.Column]: Column instances that don't reference 323 sources in the current scope. 324 """ 325 if self._external_columns is None: 326 if isinstance(self.expression, exp.Union): 327 left, right = self.union_scopes 328 self._external_columns = left.external_columns + right.external_columns 329 else: 330 self._external_columns = [ 331 c for c in self.columns if c.table not in self.selected_sources 332 ] 333 334 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
336 @property 337 def unqualified_columns(self): 338 """ 339 Unqualified columns in the current scope. 340 341 Returns: 342 list[exp.Column]: Unqualified columns 343 """ 344 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
346 @property 347 def join_hints(self): 348 """ 349 Hints that exist in the scope that reference tables 350 351 Returns: 352 list[exp.JoinHint]: Join hints that are referenced within the scope 353 """ 354 if self._join_hints is None: 355 return [] 356 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
367 def source_columns(self, source_name): 368 """ 369 Get all columns in the current scope for a particular source. 370 371 Args: 372 source_name (str): Name of the source 373 Returns: 374 list[exp.Column]: Column instances that reference `source_name` 375 """ 376 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
378 @property 379 def is_subquery(self): 380 """Determine if this scope is a subquery""" 381 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
383 @property 384 def is_derived_table(self): 385 """Determine if this scope is a derived table""" 386 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
388 @property 389 def is_union(self): 390 """Determine if this scope is a union""" 391 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
393 @property 394 def is_cte(self): 395 """Determine if this scope is a common table expression""" 396 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
398 @property 399 def is_root(self): 400 """Determine if this is the root scope""" 401 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
403 @property 404 def is_udtf(self): 405 """Determine if this scope is a UDTF (User Defined Table Function)""" 406 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
416 def rename_source(self, old_name, new_name): 417 """Rename a source in this scope""" 418 columns = self.sources.pop(old_name or "", []) 419 self.sources[new_name] = columns
Rename a source in this scope
421 def add_source(self, name, source): 422 """Add a source to this scope""" 423 self.sources[name] = source 424 self.clear_cache()
Add a source to this scope
426 def remove_source(self, name): 427 """Remove a source from this scope""" 428 self.sources.pop(name, None) 429 self.clear_cache()
Remove a source from this scope
434 def traverse(self): 435 """ 436 Traverse the scope tree from this node. 437 438 Yields: 439 Scope: scope instances in depth-first-search post-order 440 """ 441 stack = [self] 442 result = [] 443 while stack: 444 scope = stack.pop() 445 result.append(scope) 446 stack.extend( 447 itertools.chain( 448 scope.cte_scopes, 449 scope.union_scopes, 450 scope.table_scopes, 451 scope.subquery_scopes, 452 ) 453 ) 454 455 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
457 def ref_count(self): 458 """ 459 Count the number of times each scope in this tree is referenced. 460 461 Returns: 462 dict[int, int]: Mapping of Scope instance ID to reference count 463 """ 464 scope_ref_count = defaultdict(lambda: 0) 465 466 for scope in self.traverse(): 467 for _, source in scope.selected_sources.values(): 468 scope_ref_count[id(source)] += 1 469 470 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
473def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 474 """ 475 Traverse an expression by its "scopes". 476 477 "Scope" represents the current context of a Select statement. 478 479 This is helpful for optimizing queries, where we need more information than 480 the expression tree itself. For example, we might care about the source 481 names within a subquery. Returns a list because a generator could result in 482 incomplete properties which is confusing. 483 484 Examples: 485 >>> import sqlglot 486 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 487 >>> scopes = traverse_scope(expression) 488 >>> scopes[0].expression.sql(), list(scopes[0].sources) 489 ('SELECT a FROM x', ['x']) 490 >>> scopes[1].expression.sql(), list(scopes[1].sources) 491 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 492 493 Args: 494 expression: Expression to traverse 495 496 Returns: 497 A list of the created scope instances 498 """ 499 if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): 500 # We ignore the DDL expression and build a scope for its query instead 501 ddl_with = expression.args.get("with") 502 expression = expression.expression 503 504 # If the DDL has CTEs attached, we need to add them to the query, or 505 # prepend them if the query itself already has CTEs attached to it 506 if ddl_with: 507 ddl_with.pop() 508 query_ctes = expression.ctes 509 if not query_ctes: 510 expression.set("with", ddl_with) 511 else: 512 expression.args["with"].set("recursive", ddl_with.recursive) 513 expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) 514 515 if isinstance(expression, exp.Query): 516 return list(_traverse_scope(Scope(expression))) 517 518 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
521def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 522 """ 523 Build a scope tree. 524 525 Args: 526 expression: Expression to build the scope tree for. 527 528 Returns: 529 The root scope 530 """ 531 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
767def walk_in_scope(expression, bfs=True, prune=None): 768 """ 769 Returns a generator object which visits all nodes in the syntrax tree, stopping at 770 nodes that start child scopes. 771 772 Args: 773 expression (exp.Expression): 774 bfs (bool): if set to True the BFS traversal order will be applied, 775 otherwise the DFS traversal will be used instead. 776 prune ((node, parent, arg_key) -> bool): callable that returns True if 777 the generator should stop traversing this branch of the tree. 778 779 Yields: 780 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 781 """ 782 # We'll use this variable to pass state into the dfs generator. 783 # Whenever we set it to True, we exclude a subtree from traversal. 784 crossed_scope_boundary = False 785 786 for node in expression.walk( 787 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 788 ): 789 crossed_scope_boundary = False 790 791 yield node 792 793 if node is expression: 794 continue 795 if ( 796 isinstance(node, exp.CTE) 797 or ( 798 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 799 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 800 ) 801 or isinstance(node, exp.UNWRAPPED_QUERIES) 802 ): 803 crossed_scope_boundary = True 804 805 if isinstance(node, (exp.Subquery, exp.UDTF)): 806 # The following args are not actually in the inner scope, so we should visit them 807 for key in ("joins", "laterals", "pivots"): 808 for arg in node.args.get(key) or []: 809 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
812def find_all_in_scope(expression, expression_types, bfs=True): 813 """ 814 Returns a generator object which visits all nodes in this scope and only yields those that 815 match at least one of the specified expression types. 816 817 This does NOT traverse into subscopes. 818 819 Args: 820 expression (exp.Expression): 821 expression_types (tuple[type]|type): the expression type(s) to match. 822 bfs (bool): True to use breadth-first search, False to use depth-first. 823 824 Yields: 825 exp.Expression: nodes 826 """ 827 for expression in walk_in_scope(expression, bfs=bfs): 828 if isinstance(expression, tuple(ensure_collection(expression_types))): 829 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
832def find_in_scope(expression, expression_types, bfs=True): 833 """ 834 Returns the first node in this scope which matches at least one of the specified types. 835 836 This does NOT traverse into subscopes. 837 838 Args: 839 expression (exp.Expression): 840 expression_types (tuple[type]|type): the expression type(s) to match. 841 bfs (bool): True to use breadth-first search, False to use depth-first. 842 843 Returns: 844 exp.Expression: the node which matches the criteria or None if no node matching 845 the criteria was found. 846 """ 847 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.