sqlglot.optimizer.scope
1import itertools 2import logging 3import typing as t 4from collections import defaultdict 5from enum import Enum, auto 6 7from sqlglot import exp 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import find_new_name 10 11logger = logging.getLogger("sqlglot") 12 13 14class ScopeType(Enum): 15 ROOT = auto() 16 SUBQUERY = auto() 17 DERIVED_TABLE = auto() 18 CTE = auto() 19 UNION = auto() 20 UDTF = auto() 21 22 23class Scope: 24 """ 25 Selection scope. 26 27 Attributes: 28 expression (exp.Select|exp.Union): Root expression of this scope 29 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 30 a Table expression or another Scope instance. For example: 31 SELECT * FROM x {"x": Table(this="x")} 32 SELECT * FROM x AS y {"y": Table(this="x")} 33 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 34 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 35 For example: 36 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 37 The LATERAL VIEW EXPLODE gets x as a source. 38 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 39 defines a column list of it's alias of this scope, this is that list of columns. 40 For example: 41 SELECT * FROM (SELECT ...) AS y(col1, col2) 42 The inner query would have `["col1", "col2"]` for its `outer_column_list` 43 parent (Scope): Parent scope 44 scope_type (ScopeType): Type of this scope, relative to it's parent 45 subquery_scopes (list[Scope]): List of all child scopes for subqueries 46 cte_scopes (list[Scope]): List of all child scopes for CTEs 47 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 48 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 49 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 50 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 51 a list of the left and right child scopes. 52 """ 53 54 def __init__( 55 self, 56 expression, 57 sources=None, 58 outer_column_list=None, 59 parent=None, 60 scope_type=ScopeType.ROOT, 61 lateral_sources=None, 62 ): 63 self.expression = expression 64 self.sources = sources or {} 65 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 66 self.sources.update(self.lateral_sources) 67 self.outer_column_list = outer_column_list or [] 68 self.parent = parent 69 self.scope_type = scope_type 70 self.subquery_scopes = [] 71 self.derived_table_scopes = [] 72 self.table_scopes = [] 73 self.cte_scopes = [] 74 self.union_scopes = [] 75 self.udtf_scopes = [] 76 self.clear_cache() 77 78 def clear_cache(self): 79 self._collected = False 80 self._raw_columns = None 81 self._derived_tables = None 82 self._udtfs = None 83 self._tables = None 84 self._ctes = None 85 self._subqueries = None 86 self._selected_sources = None 87 self._columns = None 88 self._external_columns = None 89 self._join_hints = None 90 self._pivots = None 91 self._references = None 92 93 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 94 """Branch from the current scope to a new, inner scope""" 95 return Scope( 96 expression=expression.unnest(), 97 sources={**self.cte_sources, **(chain_sources or {})}, 98 parent=self, 99 scope_type=scope_type, 100 **kwargs, 101 ) 102 103 def _collect(self): 104 self._tables = [] 105 self._ctes = [] 106 self._subqueries = [] 107 self._derived_tables = [] 108 self._udtfs = [] 109 self._raw_columns = [] 110 self._join_hints = [] 111 112 for node, parent, _ in self.walk(bfs=False): 113 if node is self.expression: 114 continue 115 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 116 self._raw_columns.append(node) 117 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 118 self._tables.append(node) 119 elif isinstance(node, exp.JoinHint): 120 self._join_hints.append(node) 121 elif isinstance(node, exp.UDTF): 122 self._udtfs.append(node) 123 elif isinstance(node, exp.CTE): 124 self._ctes.append(node) 125 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 126 self._derived_tables.append(node) 127 elif isinstance(node, exp.Subqueryable): 128 self._subqueries.append(node) 129 130 self._collected = True 131 132 def _ensure_collected(self): 133 if not self._collected: 134 self._collect() 135 136 def walk(self, bfs=True): 137 return walk_in_scope(self.expression, bfs=bfs) 138 139 def find(self, *expression_types, bfs=True): 140 """ 141 Returns the first node in this scope which matches at least one of the specified types. 142 143 This does NOT traverse into subscopes. 144 145 Args: 146 expression_types (type): the expression type(s) to match. 147 bfs (bool): True to use breadth-first search, False to use depth-first. 148 149 Returns: 150 exp.Expression: the node which matches the criteria or None if no node matching 151 the criteria was found. 152 """ 153 return next(self.find_all(*expression_types, bfs=bfs), None) 154 155 def find_all(self, *expression_types, bfs=True): 156 """ 157 Returns a generator object which visits all nodes in this scope and only yields those that 158 match at least one of the specified expression types. 159 160 This does NOT traverse into subscopes. 161 162 Args: 163 expression_types (type): the expression type(s) to match. 164 bfs (bool): True to use breadth-first search, False to use depth-first. 165 166 Yields: 167 exp.Expression: nodes 168 """ 169 for expression, *_ in self.walk(bfs=bfs): 170 if isinstance(expression, expression_types): 171 yield expression 172 173 def replace(self, old, new): 174 """ 175 Replace `old` with `new`. 176 177 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 178 179 Args: 180 old (exp.Expression): old node 181 new (exp.Expression): new node 182 """ 183 old.replace(new) 184 self.clear_cache() 185 186 @property 187 def tables(self): 188 """ 189 List of tables in this scope. 190 191 Returns: 192 list[exp.Table]: tables 193 """ 194 self._ensure_collected() 195 return self._tables 196 197 @property 198 def ctes(self): 199 """ 200 List of CTEs in this scope. 201 202 Returns: 203 list[exp.CTE]: ctes 204 """ 205 self._ensure_collected() 206 return self._ctes 207 208 @property 209 def derived_tables(self): 210 """ 211 List of derived tables in this scope. 212 213 For example: 214 SELECT * FROM (SELECT ...) <- that's a derived table 215 216 Returns: 217 list[exp.Subquery]: derived tables 218 """ 219 self._ensure_collected() 220 return self._derived_tables 221 222 @property 223 def udtfs(self): 224 """ 225 List of "User Defined Tabular Functions" in this scope. 226 227 Returns: 228 list[exp.UDTF]: UDTFs 229 """ 230 self._ensure_collected() 231 return self._udtfs 232 233 @property 234 def subqueries(self): 235 """ 236 List of subqueries in this scope. 237 238 For example: 239 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 240 241 Returns: 242 list[exp.Subqueryable]: subqueries 243 """ 244 self._ensure_collected() 245 return self._subqueries 246 247 @property 248 def columns(self): 249 """ 250 List of columns in this scope. 251 252 Returns: 253 list[exp.Column]: Column instances in this scope, plus any 254 Columns that reference this scope from correlated subqueries. 255 """ 256 if self._columns is None: 257 self._ensure_collected() 258 columns = self._raw_columns 259 260 external_columns = [ 261 column 262 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 263 for column in scope.external_columns 264 ] 265 266 named_selects = set(self.expression.named_selects) 267 268 self._columns = [] 269 for column in columns + external_columns: 270 ancestor = column.find_ancestor( 271 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint 272 ) 273 if ( 274 not ancestor 275 or column.table 276 or isinstance(ancestor, exp.Select) 277 or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) 278 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 279 ): 280 self._columns.append(column) 281 282 return self._columns 283 284 @property 285 def selected_sources(self): 286 """ 287 Mapping of nodes and sources that are actually selected from in this scope. 288 289 That is, all tables in a schema are selectable at any point. But a 290 table only becomes a selected source if it's included in a FROM or JOIN clause. 291 292 Returns: 293 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 294 """ 295 if self._selected_sources is None: 296 result = {} 297 298 for name, node in self.references: 299 if name in result: 300 raise OptimizeError(f"Alias already used: {name}") 301 if name in self.sources: 302 result[name] = (node, self.sources[name]) 303 304 self._selected_sources = result 305 return self._selected_sources 306 307 @property 308 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 309 if self._references is None: 310 self._references = [] 311 312 for table in self.tables: 313 self._references.append((table.alias_or_name, table)) 314 for expression in itertools.chain(self.derived_tables, self.udtfs): 315 self._references.append( 316 ( 317 expression.alias, 318 expression if expression.args.get("pivots") else expression.unnest(), 319 ) 320 ) 321 322 return self._references 323 324 @property 325 def cte_sources(self): 326 """ 327 Sources that are CTEs. 328 329 Returns: 330 dict[str, Scope]: Mapping of source alias to Scope 331 """ 332 return { 333 alias: scope 334 for alias, scope in self.sources.items() 335 if isinstance(scope, Scope) and scope.is_cte 336 } 337 338 @property 339 def selects(self): 340 """ 341 Select expressions of this scope. 342 343 For example, for the following expression: 344 SELECT 1 as a, 2 as b FROM x 345 346 The outputs are the "1 as a" and "2 as b" expressions. 347 348 Returns: 349 list[exp.Expression]: expressions 350 """ 351 if isinstance(self.expression, exp.Union): 352 return self.expression.unnest().selects 353 return self.expression.selects 354 355 @property 356 def external_columns(self): 357 """ 358 Columns that appear to reference sources in outer scopes. 359 360 Returns: 361 list[exp.Column]: Column instances that don't reference 362 sources in the current scope. 363 """ 364 if self._external_columns is None: 365 self._external_columns = [ 366 c for c in self.columns if c.table not in self.selected_sources 367 ] 368 return self._external_columns 369 370 @property 371 def unqualified_columns(self): 372 """ 373 Unqualified columns in the current scope. 374 375 Returns: 376 list[exp.Column]: Unqualified columns 377 """ 378 return [c for c in self.columns if not c.table] 379 380 @property 381 def join_hints(self): 382 """ 383 Hints that exist in the scope that reference tables 384 385 Returns: 386 list[exp.JoinHint]: Join hints that are referenced within the scope 387 """ 388 if self._join_hints is None: 389 return [] 390 return self._join_hints 391 392 @property 393 def pivots(self): 394 if not self._pivots: 395 self._pivots = [ 396 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 397 ] 398 399 return self._pivots 400 401 def source_columns(self, source_name): 402 """ 403 Get all columns in the current scope for a particular source. 404 405 Args: 406 source_name (str): Name of the source 407 Returns: 408 list[exp.Column]: Column instances that reference `source_name` 409 """ 410 return [column for column in self.columns if column.table == source_name] 411 412 @property 413 def is_subquery(self): 414 """Determine if this scope is a subquery""" 415 return self.scope_type == ScopeType.SUBQUERY 416 417 @property 418 def is_derived_table(self): 419 """Determine if this scope is a derived table""" 420 return self.scope_type == ScopeType.DERIVED_TABLE 421 422 @property 423 def is_union(self): 424 """Determine if this scope is a union""" 425 return self.scope_type == ScopeType.UNION 426 427 @property 428 def is_cte(self): 429 """Determine if this scope is a common table expression""" 430 return self.scope_type == ScopeType.CTE 431 432 @property 433 def is_root(self): 434 """Determine if this is the root scope""" 435 return self.scope_type == ScopeType.ROOT 436 437 @property 438 def is_udtf(self): 439 """Determine if this scope is a UDTF (User Defined Table Function)""" 440 return self.scope_type == ScopeType.UDTF 441 442 @property 443 def is_correlated_subquery(self): 444 """Determine if this scope is a correlated subquery""" 445 return bool(self.is_subquery and self.external_columns) 446 447 def rename_source(self, old_name, new_name): 448 """Rename a source in this scope""" 449 columns = self.sources.pop(old_name or "", []) 450 self.sources[new_name] = columns 451 452 def add_source(self, name, source): 453 """Add a source to this scope""" 454 self.sources[name] = source 455 self.clear_cache() 456 457 def remove_source(self, name): 458 """Remove a source from this scope""" 459 self.sources.pop(name, None) 460 self.clear_cache() 461 462 def __repr__(self): 463 return f"Scope<{self.expression.sql()}>" 464 465 def traverse(self): 466 """ 467 Traverse the scope tree from this node. 468 469 Yields: 470 Scope: scope instances in depth-first-search post-order 471 """ 472 for child_scope in itertools.chain( 473 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 474 ): 475 yield from child_scope.traverse() 476 yield self 477 478 def ref_count(self): 479 """ 480 Count the number of times each scope in this tree is referenced. 481 482 Returns: 483 dict[int, int]: Mapping of Scope instance ID to reference count 484 """ 485 scope_ref_count = defaultdict(lambda: 0) 486 487 for scope in self.traverse(): 488 for _, source in scope.selected_sources.values(): 489 scope_ref_count[id(source)] += 1 490 491 return scope_ref_count 492 493 494def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 495 """ 496 Traverse an expression by it's "scopes". 497 498 "Scope" represents the current context of a Select statement. 499 500 This is helpful for optimizing queries, where we need more information than 501 the expression tree itself. For example, we might care about the source 502 names within a subquery. Returns a list because a generator could result in 503 incomplete properties which is confusing. 504 505 Examples: 506 >>> import sqlglot 507 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 508 >>> scopes = traverse_scope(expression) 509 >>> scopes[0].expression.sql(), list(scopes[0].sources) 510 ('SELECT a FROM x', ['x']) 511 >>> scopes[1].expression.sql(), list(scopes[1].sources) 512 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 513 514 Args: 515 expression (exp.Expression): expression to traverse 516 Returns: 517 list[Scope]: scope instances 518 """ 519 if not isinstance(expression, exp.Unionable): 520 return [] 521 return list(_traverse_scope(Scope(expression))) 522 523 524def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 525 """ 526 Build a scope tree. 527 528 Args: 529 expression (exp.Expression): expression to build the scope tree for 530 Returns: 531 Scope: root scope 532 """ 533 scopes = traverse_scope(expression) 534 if scopes: 535 return scopes[-1] 536 return None 537 538 539def _traverse_scope(scope): 540 if isinstance(scope.expression, exp.Select): 541 yield from _traverse_select(scope) 542 elif isinstance(scope.expression, exp.Union): 543 yield from _traverse_union(scope) 544 elif isinstance(scope.expression, exp.Subquery): 545 yield from _traverse_subqueries(scope) 546 elif isinstance(scope.expression, exp.Table): 547 # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..) 548 yield from _traverse_tables(scope) 549 elif isinstance(scope.expression, exp.UDTF): 550 pass 551 else: 552 logger.warning( 553 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 554 ) 555 return 556 557 yield scope 558 559 560def _traverse_select(scope): 561 yield from _traverse_ctes(scope) 562 yield from _traverse_tables(scope) 563 yield from _traverse_subqueries(scope) 564 565 566def _traverse_union(scope): 567 yield from _traverse_ctes(scope) 568 569 # The last scope to be yield should be the top most scope 570 left = None 571 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 572 yield left 573 574 right = None 575 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 576 yield right 577 578 scope.union_scopes = [left, right] 579 580 581def _traverse_ctes(scope): 582 sources = {} 583 584 for cte in scope.ctes: 585 recursive_scope = None 586 587 # if the scope is a recursive cte, it must be in the form of 588 # base_case UNION recursive. thus the recursive scope is the first 589 # section of the union. 590 if scope.expression.args["with"].recursive: 591 union = cte.this 592 593 if isinstance(union, exp.Union): 594 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 595 596 child_scope = None 597 598 for child_scope in _traverse_scope( 599 scope.branch( 600 cte.this, 601 chain_sources=sources, 602 outer_column_list=cte.alias_column_names, 603 scope_type=ScopeType.CTE, 604 ) 605 ): 606 yield child_scope 607 608 alias = cte.alias 609 sources[alias] = child_scope 610 611 if recursive_scope: 612 child_scope.add_source(alias, recursive_scope) 613 614 # append the final child_scope yielded 615 if child_scope: 616 scope.cte_scopes.append(child_scope) 617 618 scope.sources.update(sources) 619 620 621def _traverse_tables(scope): 622 sources = {} 623 624 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 625 expressions = [] 626 from_ = scope.expression.args.get("from") 627 if from_: 628 expressions.append(from_.this) 629 630 for join in scope.expression.args.get("joins") or []: 631 expressions.append(join.this) 632 633 if isinstance(scope.expression, exp.Table): 634 expressions.append(scope.expression) 635 636 expressions.extend(scope.expression.args.get("laterals") or []) 637 638 for expression in expressions: 639 if isinstance(expression, exp.Table): 640 table_name = expression.name 641 source_name = expression.alias_or_name 642 643 if table_name in scope.sources and not expression.db: 644 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 645 # it is pivoted, because then we get back a new table and hence a new source. 646 pivots = expression.args.get("pivots") 647 if pivots: 648 sources[pivots[0].alias] = expression 649 else: 650 sources[source_name] = scope.sources[table_name] 651 elif source_name in sources: 652 sources[find_new_name(sources, table_name)] = expression 653 else: 654 sources[source_name] = expression 655 continue 656 657 if not isinstance(expression, exp.DerivedTable): 658 continue 659 660 if isinstance(expression, exp.UDTF): 661 lateral_sources = sources 662 scope_type = ScopeType.UDTF 663 scopes = scope.udtf_scopes 664 else: 665 lateral_sources = None 666 scope_type = ScopeType.DERIVED_TABLE 667 scopes = scope.derived_table_scopes 668 669 for child_scope in _traverse_scope( 670 scope.branch( 671 expression, 672 lateral_sources=lateral_sources, 673 outer_column_list=expression.alias_column_names, 674 scope_type=scope_type, 675 ) 676 ): 677 yield child_scope 678 679 # Tables without aliases will be set as "" 680 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 681 # Until then, this means that only a single, unaliased derived table is allowed (rather, 682 # the latest one wins. 683 alias = expression.alias 684 sources[alias] = child_scope 685 686 # append the final child_scope yielded 687 scopes.append(child_scope) 688 scope.table_scopes.append(child_scope) 689 690 scope.sources.update(sources) 691 692 693def _traverse_subqueries(scope): 694 for subquery in scope.subqueries: 695 top = None 696 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 697 yield child_scope 698 top = child_scope 699 scope.subquery_scopes.append(top) 700 701 702def walk_in_scope(expression, bfs=True): 703 """ 704 Returns a generator object which visits all nodes in the syntrax tree, stopping at 705 nodes that start child scopes. 706 707 Args: 708 expression (exp.Expression): 709 bfs (bool): if set to True the BFS traversal order will be applied, 710 otherwise the DFS traversal will be used instead. 711 712 Yields: 713 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 714 """ 715 # We'll use this variable to pass state into the dfs generator. 716 # Whenever we set it to True, we exclude a subtree from traversal. 717 prune = False 718 719 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 720 prune = False 721 722 yield node, parent, key 723 724 if node is expression: 725 continue 726 if ( 727 isinstance(node, exp.CTE) 728 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 729 or isinstance(node, exp.UDTF) 730 or isinstance(node, exp.Subqueryable) 731 ): 732 prune = True
15class ScopeType(Enum): 16 ROOT = auto() 17 SUBQUERY = auto() 18 DERIVED_TABLE = auto() 19 CTE = auto() 20 UNION = auto() 21 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
24class Scope: 25 """ 26 Selection scope. 27 28 Attributes: 29 expression (exp.Select|exp.Union): Root expression of this scope 30 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 31 a Table expression or another Scope instance. For example: 32 SELECT * FROM x {"x": Table(this="x")} 33 SELECT * FROM x AS y {"y": Table(this="x")} 34 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 35 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 36 For example: 37 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 38 The LATERAL VIEW EXPLODE gets x as a source. 39 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 40 defines a column list of it's alias of this scope, this is that list of columns. 41 For example: 42 SELECT * FROM (SELECT ...) AS y(col1, col2) 43 The inner query would have `["col1", "col2"]` for its `outer_column_list` 44 parent (Scope): Parent scope 45 scope_type (ScopeType): Type of this scope, relative to it's parent 46 subquery_scopes (list[Scope]): List of all child scopes for subqueries 47 cte_scopes (list[Scope]): List of all child scopes for CTEs 48 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 49 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 50 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 51 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 52 a list of the left and right child scopes. 53 """ 54 55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache() 78 79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None 93 94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 ) 103 104 def _collect(self): 105 self._tables = [] 106 self._ctes = [] 107 self._subqueries = [] 108 self._derived_tables = [] 109 self._udtfs = [] 110 self._raw_columns = [] 111 self._join_hints = [] 112 113 for node, parent, _ in self.walk(bfs=False): 114 if node is self.expression: 115 continue 116 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 117 self._raw_columns.append(node) 118 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 119 self._tables.append(node) 120 elif isinstance(node, exp.JoinHint): 121 self._join_hints.append(node) 122 elif isinstance(node, exp.UDTF): 123 self._udtfs.append(node) 124 elif isinstance(node, exp.CTE): 125 self._ctes.append(node) 126 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 127 self._derived_tables.append(node) 128 elif isinstance(node, exp.Subqueryable): 129 self._subqueries.append(node) 130 131 self._collected = True 132 133 def _ensure_collected(self): 134 if not self._collected: 135 self._collect() 136 137 def walk(self, bfs=True): 138 return walk_in_scope(self.expression, bfs=bfs) 139 140 def find(self, *expression_types, bfs=True): 141 """ 142 Returns the first node in this scope which matches at least one of the specified types. 143 144 This does NOT traverse into subscopes. 145 146 Args: 147 expression_types (type): the expression type(s) to match. 148 bfs (bool): True to use breadth-first search, False to use depth-first. 149 150 Returns: 151 exp.Expression: the node which matches the criteria or None if no node matching 152 the criteria was found. 153 """ 154 return next(self.find_all(*expression_types, bfs=bfs), None) 155 156 def find_all(self, *expression_types, bfs=True): 157 """ 158 Returns a generator object which visits all nodes in this scope and only yields those that 159 match at least one of the specified expression types. 160 161 This does NOT traverse into subscopes. 162 163 Args: 164 expression_types (type): the expression type(s) to match. 165 bfs (bool): True to use breadth-first search, False to use depth-first. 166 167 Yields: 168 exp.Expression: nodes 169 """ 170 for expression, *_ in self.walk(bfs=bfs): 171 if isinstance(expression, expression_types): 172 yield expression 173 174 def replace(self, old, new): 175 """ 176 Replace `old` with `new`. 177 178 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 179 180 Args: 181 old (exp.Expression): old node 182 new (exp.Expression): new node 183 """ 184 old.replace(new) 185 self.clear_cache() 186 187 @property 188 def tables(self): 189 """ 190 List of tables in this scope. 191 192 Returns: 193 list[exp.Table]: tables 194 """ 195 self._ensure_collected() 196 return self._tables 197 198 @property 199 def ctes(self): 200 """ 201 List of CTEs in this scope. 202 203 Returns: 204 list[exp.CTE]: ctes 205 """ 206 self._ensure_collected() 207 return self._ctes 208 209 @property 210 def derived_tables(self): 211 """ 212 List of derived tables in this scope. 213 214 For example: 215 SELECT * FROM (SELECT ...) <- that's a derived table 216 217 Returns: 218 list[exp.Subquery]: derived tables 219 """ 220 self._ensure_collected() 221 return self._derived_tables 222 223 @property 224 def udtfs(self): 225 """ 226 List of "User Defined Tabular Functions" in this scope. 227 228 Returns: 229 list[exp.UDTF]: UDTFs 230 """ 231 self._ensure_collected() 232 return self._udtfs 233 234 @property 235 def subqueries(self): 236 """ 237 List of subqueries in this scope. 238 239 For example: 240 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 241 242 Returns: 243 list[exp.Subqueryable]: subqueries 244 """ 245 self._ensure_collected() 246 return self._subqueries 247 248 @property 249 def columns(self): 250 """ 251 List of columns in this scope. 252 253 Returns: 254 list[exp.Column]: Column instances in this scope, plus any 255 Columns that reference this scope from correlated subqueries. 256 """ 257 if self._columns is None: 258 self._ensure_collected() 259 columns = self._raw_columns 260 261 external_columns = [ 262 column 263 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 264 for column in scope.external_columns 265 ] 266 267 named_selects = set(self.expression.named_selects) 268 269 self._columns = [] 270 for column in columns + external_columns: 271 ancestor = column.find_ancestor( 272 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint 273 ) 274 if ( 275 not ancestor 276 or column.table 277 or isinstance(ancestor, exp.Select) 278 or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) 279 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 280 ): 281 self._columns.append(column) 282 283 return self._columns 284 285 @property 286 def selected_sources(self): 287 """ 288 Mapping of nodes and sources that are actually selected from in this scope. 289 290 That is, all tables in a schema are selectable at any point. But a 291 table only becomes a selected source if it's included in a FROM or JOIN clause. 292 293 Returns: 294 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 295 """ 296 if self._selected_sources is None: 297 result = {} 298 299 for name, node in self.references: 300 if name in result: 301 raise OptimizeError(f"Alias already used: {name}") 302 if name in self.sources: 303 result[name] = (node, self.sources[name]) 304 305 self._selected_sources = result 306 return self._selected_sources 307 308 @property 309 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 310 if self._references is None: 311 self._references = [] 312 313 for table in self.tables: 314 self._references.append((table.alias_or_name, table)) 315 for expression in itertools.chain(self.derived_tables, self.udtfs): 316 self._references.append( 317 ( 318 expression.alias, 319 expression if expression.args.get("pivots") else expression.unnest(), 320 ) 321 ) 322 323 return self._references 324 325 @property 326 def cte_sources(self): 327 """ 328 Sources that are CTEs. 329 330 Returns: 331 dict[str, Scope]: Mapping of source alias to Scope 332 """ 333 return { 334 alias: scope 335 for alias, scope in self.sources.items() 336 if isinstance(scope, Scope) and scope.is_cte 337 } 338 339 @property 340 def selects(self): 341 """ 342 Select expressions of this scope. 343 344 For example, for the following expression: 345 SELECT 1 as a, 2 as b FROM x 346 347 The outputs are the "1 as a" and "2 as b" expressions. 348 349 Returns: 350 list[exp.Expression]: expressions 351 """ 352 if isinstance(self.expression, exp.Union): 353 return self.expression.unnest().selects 354 return self.expression.selects 355 356 @property 357 def external_columns(self): 358 """ 359 Columns that appear to reference sources in outer scopes. 360 361 Returns: 362 list[exp.Column]: Column instances that don't reference 363 sources in the current scope. 364 """ 365 if self._external_columns is None: 366 self._external_columns = [ 367 c for c in self.columns if c.table not in self.selected_sources 368 ] 369 return self._external_columns 370 371 @property 372 def unqualified_columns(self): 373 """ 374 Unqualified columns in the current scope. 375 376 Returns: 377 list[exp.Column]: Unqualified columns 378 """ 379 return [c for c in self.columns if not c.table] 380 381 @property 382 def join_hints(self): 383 """ 384 Hints that exist in the scope that reference tables 385 386 Returns: 387 list[exp.JoinHint]: Join hints that are referenced within the scope 388 """ 389 if self._join_hints is None: 390 return [] 391 return self._join_hints 392 393 @property 394 def pivots(self): 395 if not self._pivots: 396 self._pivots = [ 397 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 398 ] 399 400 return self._pivots 401 402 def source_columns(self, source_name): 403 """ 404 Get all columns in the current scope for a particular source. 405 406 Args: 407 source_name (str): Name of the source 408 Returns: 409 list[exp.Column]: Column instances that reference `source_name` 410 """ 411 return [column for column in self.columns if column.table == source_name] 412 413 @property 414 def is_subquery(self): 415 """Determine if this scope is a subquery""" 416 return self.scope_type == ScopeType.SUBQUERY 417 418 @property 419 def is_derived_table(self): 420 """Determine if this scope is a derived table""" 421 return self.scope_type == ScopeType.DERIVED_TABLE 422 423 @property 424 def is_union(self): 425 """Determine if this scope is a union""" 426 return self.scope_type == ScopeType.UNION 427 428 @property 429 def is_cte(self): 430 """Determine if this scope is a common table expression""" 431 return self.scope_type == ScopeType.CTE 432 433 @property 434 def is_root(self): 435 """Determine if this is the root scope""" 436 return self.scope_type == ScopeType.ROOT 437 438 @property 439 def is_udtf(self): 440 """Determine if this scope is a UDTF (User Defined Table Function)""" 441 return self.scope_type == ScopeType.UDTF 442 443 @property 444 def is_correlated_subquery(self): 445 """Determine if this scope is a correlated subquery""" 446 return bool(self.is_subquery and self.external_columns) 447 448 def rename_source(self, old_name, new_name): 449 """Rename a source in this scope""" 450 columns = self.sources.pop(old_name or "", []) 451 self.sources[new_name] = columns 452 453 def add_source(self, name, source): 454 """Add a source to this scope""" 455 self.sources[name] = source 456 self.clear_cache() 457 458 def remove_source(self, name): 459 """Remove a source from this scope""" 460 self.sources.pop(name, None) 461 self.clear_cache() 462 463 def __repr__(self): 464 return f"Scope<{self.expression.sql()}>" 465 466 def traverse(self): 467 """ 468 Traverse the scope tree from this node. 469 470 Yields: 471 Scope: scope instances in depth-first-search post-order 472 """ 473 for child_scope in itertools.chain( 474 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 475 ): 476 yield from child_scope.traverse() 477 yield self 478 479 def ref_count(self): 480 """ 481 Count the number of times each scope in this tree is referenced. 482 483 Returns: 484 dict[int, int]: Mapping of Scope instance ID to reference count 485 """ 486 scope_ref_count = defaultdict(lambda: 0) 487 488 for scope in self.traverse(): 489 for _, source in scope.selected_sources.values(): 490 scope_ref_count[id(source)] += 1 491 492 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache()
79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None
94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 )
Branch from the current scope to a new, inner scope
140 def find(self, *expression_types, bfs=True): 141 """ 142 Returns the first node in this scope which matches at least one of the specified types. 143 144 This does NOT traverse into subscopes. 145 146 Args: 147 expression_types (type): the expression type(s) to match. 148 bfs (bool): True to use breadth-first search, False to use depth-first. 149 150 Returns: 151 exp.Expression: the node which matches the criteria or None if no node matching 152 the criteria was found. 153 """ 154 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
156 def find_all(self, *expression_types, bfs=True): 157 """ 158 Returns a generator object which visits all nodes in this scope and only yields those that 159 match at least one of the specified expression types. 160 161 This does NOT traverse into subscopes. 162 163 Args: 164 expression_types (type): the expression type(s) to match. 165 bfs (bool): True to use breadth-first search, False to use depth-first. 166 167 Yields: 168 exp.Expression: nodes 169 """ 170 for expression, *_ in self.walk(bfs=bfs): 171 if isinstance(expression, expression_types): 172 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
174 def replace(self, old, new): 175 """ 176 Replace `old` with `new`. 177 178 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 179 180 Args: 181 old (exp.Expression): old node 182 new (exp.Expression): new node 183 """ 184 old.replace(new) 185 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Select expressions of this scope.
For example, for the following expression: SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
402 def source_columns(self, source_name): 403 """ 404 Get all columns in the current scope for a particular source. 405 406 Args: 407 source_name (str): Name of the source 408 Returns: 409 list[exp.Column]: Column instances that reference `source_name` 410 """ 411 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
448 def rename_source(self, old_name, new_name): 449 """Rename a source in this scope""" 450 columns = self.sources.pop(old_name or "", []) 451 self.sources[new_name] = columns
Rename a source in this scope
453 def add_source(self, name, source): 454 """Add a source to this scope""" 455 self.sources[name] = source 456 self.clear_cache()
Add a source to this scope
458 def remove_source(self, name): 459 """Remove a source from this scope""" 460 self.sources.pop(name, None) 461 self.clear_cache()
Remove a source from this scope
466 def traverse(self): 467 """ 468 Traverse the scope tree from this node. 469 470 Yields: 471 Scope: scope instances in depth-first-search post-order 472 """ 473 for child_scope in itertools.chain( 474 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 475 ): 476 yield from child_scope.traverse() 477 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
479 def ref_count(self): 480 """ 481 Count the number of times each scope in this tree is referenced. 482 483 Returns: 484 dict[int, int]: Mapping of Scope instance ID to reference count 485 """ 486 scope_ref_count = defaultdict(lambda: 0) 487 488 for scope in self.traverse(): 489 for _, source in scope.selected_sources.values(): 490 scope_ref_count[id(source)] += 1 491 492 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
495def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 496 """ 497 Traverse an expression by it's "scopes". 498 499 "Scope" represents the current context of a Select statement. 500 501 This is helpful for optimizing queries, where we need more information than 502 the expression tree itself. For example, we might care about the source 503 names within a subquery. Returns a list because a generator could result in 504 incomplete properties which is confusing. 505 506 Examples: 507 >>> import sqlglot 508 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 509 >>> scopes = traverse_scope(expression) 510 >>> scopes[0].expression.sql(), list(scopes[0].sources) 511 ('SELECT a FROM x', ['x']) 512 >>> scopes[1].expression.sql(), list(scopes[1].sources) 513 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 514 515 Args: 516 expression (exp.Expression): expression to traverse 517 Returns: 518 list[Scope]: scope instances 519 """ 520 if not isinstance(expression, exp.Unionable): 521 return [] 522 return list(_traverse_scope(Scope(expression)))
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
525def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 526 """ 527 Build a scope tree. 528 529 Args: 530 expression (exp.Expression): expression to build the scope tree for 531 Returns: 532 Scope: root scope 533 """ 534 scopes = traverse_scope(expression) 535 if scopes: 536 return scopes[-1] 537 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
703def walk_in_scope(expression, bfs=True): 704 """ 705 Returns a generator object which visits all nodes in the syntrax tree, stopping at 706 nodes that start child scopes. 707 708 Args: 709 expression (exp.Expression): 710 bfs (bool): if set to True the BFS traversal order will be applied, 711 otherwise the DFS traversal will be used instead. 712 713 Yields: 714 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 715 """ 716 # We'll use this variable to pass state into the dfs generator. 717 # Whenever we set it to True, we exclude a subtree from traversal. 718 prune = False 719 720 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 721 prune = False 722 723 yield node, parent, key 724 725 if node is expression: 726 continue 727 if ( 728 isinstance(node, exp.CTE) 729 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 730 or isinstance(node, exp.UDTF) 731 or isinstance(node, exp.Subqueryable) 732 ): 733 prune = True
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key