Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name
 12
 13logger = logging.getLogger("sqlglot")
 14
 15
 16class ScopeType(Enum):
 17    ROOT = auto()
 18    SUBQUERY = auto()
 19    DERIVED_TABLE = auto()
 20    CTE = auto()
 21    UNION = auto()
 22    UDTF = auto()
 23
 24
 25class Scope:
 26    """
 27    Selection scope.
 28
 29    Attributes:
 30        expression (exp.Select|exp.Union): Root expression of this scope
 31        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 32            a Table expression or another Scope instance. For example:
 33                SELECT * FROM x                     {"x": Table(this="x")}
 34                SELECT * FROM x AS y                {"y": Table(this="x")}
 35                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 36        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 37            For example:
 38                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 39            The LATERAL VIEW EXPLODE gets x as a source.
 40        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 41            defines a column list of it's alias of this scope, this is that list of columns.
 42            For example:
 43                SELECT * FROM (SELECT ...) AS y(col1, col2)
 44            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 45        parent (Scope): Parent scope
 46        scope_type (ScopeType): Type of this scope, relative to it's parent
 47        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 48        cte_scopes (list[Scope]): List of all child scopes for CTEs
 49        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 50        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 51        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 52        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 53            a list of the left and right child scopes.
 54    """
 55
 56    def __init__(
 57        self,
 58        expression,
 59        sources=None,
 60        outer_column_list=None,
 61        parent=None,
 62        scope_type=ScopeType.ROOT,
 63        lateral_sources=None,
 64    ):
 65        self.expression = expression
 66        self.sources = sources or {}
 67        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 68        self.sources.update(self.lateral_sources)
 69        self.outer_column_list = outer_column_list or []
 70        self.parent = parent
 71        self.scope_type = scope_type
 72        self.subquery_scopes = []
 73        self.derived_table_scopes = []
 74        self.table_scopes = []
 75        self.cte_scopes = []
 76        self.union_scopes = []
 77        self.udtf_scopes = []
 78        self.clear_cache()
 79
 80    def clear_cache(self):
 81        self._collected = False
 82        self._raw_columns = None
 83        self._derived_tables = None
 84        self._udtfs = None
 85        self._tables = None
 86        self._ctes = None
 87        self._subqueries = None
 88        self._selected_sources = None
 89        self._columns = None
 90        self._external_columns = None
 91        self._join_hints = None
 92        self._pivots = None
 93        self._references = None
 94
 95    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 96        """Branch from the current scope to a new, inner scope"""
 97        return Scope(
 98            expression=expression.unnest(),
 99            sources={**self.cte_sources, **(chain_sources or {})},
100            parent=self,
101            scope_type=scope_type,
102            **kwargs,
103        )
104
105    def _collect(self):
106        self._tables = []
107        self._ctes = []
108        self._subqueries = []
109        self._derived_tables = []
110        self._udtfs = []
111        self._raw_columns = []
112        self._join_hints = []
113
114        for node, parent, _ in self.walk(bfs=False):
115            if node is self.expression:
116                continue
117            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
118                self._raw_columns.append(node)
119            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
120                self._tables.append(node)
121            elif isinstance(node, exp.JoinHint):
122                self._join_hints.append(node)
123            elif isinstance(node, exp.UDTF):
124                self._udtfs.append(node)
125            elif isinstance(node, exp.CTE):
126                self._ctes.append(node)
127            elif (
128                isinstance(node, exp.Subquery)
129                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
130                and _is_derived_table(node)
131            ):
132                self._derived_tables.append(node)
133            elif isinstance(node, exp.Subqueryable):
134                self._subqueries.append(node)
135
136        self._collected = True
137
138    def _ensure_collected(self):
139        if not self._collected:
140            self._collect()
141
142    def walk(self, bfs=True, prune=None):
143        return walk_in_scope(self.expression, bfs=bfs, prune=None)
144
145    def find(self, *expression_types, bfs=True):
146        return find_in_scope(self.expression, expression_types, bfs=bfs)
147
148    def find_all(self, *expression_types, bfs=True):
149        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
150
151    def replace(self, old, new):
152        """
153        Replace `old` with `new`.
154
155        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
156
157        Args:
158            old (exp.Expression): old node
159            new (exp.Expression): new node
160        """
161        old.replace(new)
162        self.clear_cache()
163
164    @property
165    def tables(self):
166        """
167        List of tables in this scope.
168
169        Returns:
170            list[exp.Table]: tables
171        """
172        self._ensure_collected()
173        return self._tables
174
175    @property
176    def ctes(self):
177        """
178        List of CTEs in this scope.
179
180        Returns:
181            list[exp.CTE]: ctes
182        """
183        self._ensure_collected()
184        return self._ctes
185
186    @property
187    def derived_tables(self):
188        """
189        List of derived tables in this scope.
190
191        For example:
192            SELECT * FROM (SELECT ...) <- that's a derived table
193
194        Returns:
195            list[exp.Subquery]: derived tables
196        """
197        self._ensure_collected()
198        return self._derived_tables
199
200    @property
201    def udtfs(self):
202        """
203        List of "User Defined Tabular Functions" in this scope.
204
205        Returns:
206            list[exp.UDTF]: UDTFs
207        """
208        self._ensure_collected()
209        return self._udtfs
210
211    @property
212    def subqueries(self):
213        """
214        List of subqueries in this scope.
215
216        For example:
217            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
218
219        Returns:
220            list[exp.Subqueryable]: subqueries
221        """
222        self._ensure_collected()
223        return self._subqueries
224
225    @property
226    def columns(self):
227        """
228        List of columns in this scope.
229
230        Returns:
231            list[exp.Column]: Column instances in this scope, plus any
232                Columns that reference this scope from correlated subqueries.
233        """
234        if self._columns is None:
235            self._ensure_collected()
236            columns = self._raw_columns
237
238            external_columns = [
239                column
240                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
241                for column in scope.external_columns
242            ]
243
244            named_selects = set(self.expression.named_selects)
245
246            self._columns = []
247            for column in columns + external_columns:
248                ancestor = column.find_ancestor(
249                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
250                )
251                if (
252                    not ancestor
253                    or column.table
254                    or isinstance(ancestor, exp.Select)
255                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
256                    or (
257                        isinstance(ancestor, exp.Order)
258                        and (
259                            isinstance(ancestor.parent, exp.Window)
260                            or column.name not in named_selects
261                        )
262                    )
263                ):
264                    self._columns.append(column)
265
266        return self._columns
267
268    @property
269    def selected_sources(self):
270        """
271        Mapping of nodes and sources that are actually selected from in this scope.
272
273        That is, all tables in a schema are selectable at any point. But a
274        table only becomes a selected source if it's included in a FROM or JOIN clause.
275
276        Returns:
277            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
278        """
279        if self._selected_sources is None:
280            result = {}
281
282            for name, node in self.references:
283                if name in result:
284                    raise OptimizeError(f"Alias already used: {name}")
285                if name in self.sources:
286                    result[name] = (node, self.sources[name])
287
288            self._selected_sources = result
289        return self._selected_sources
290
291    @property
292    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
293        if self._references is None:
294            self._references = []
295
296            for table in self.tables:
297                self._references.append((table.alias_or_name, table))
298            for expression in itertools.chain(self.derived_tables, self.udtfs):
299                self._references.append(
300                    (
301                        expression.alias,
302                        expression if expression.args.get("pivots") else expression.unnest(),
303                    )
304                )
305
306        return self._references
307
308    @property
309    def cte_sources(self):
310        """
311        Sources that are CTEs.
312
313        Returns:
314            dict[str, Scope]: Mapping of source alias to Scope
315        """
316        return {
317            alias: scope
318            for alias, scope in self.sources.items()
319            if isinstance(scope, Scope) and scope.is_cte
320        }
321
322    @property
323    def external_columns(self):
324        """
325        Columns that appear to reference sources in outer scopes.
326
327        Returns:
328            list[exp.Column]: Column instances that don't reference
329                sources in the current scope.
330        """
331        if self._external_columns is None:
332            self._external_columns = [
333                c for c in self.columns if c.table not in self.selected_sources
334            ]
335        return self._external_columns
336
337    @property
338    def unqualified_columns(self):
339        """
340        Unqualified columns in the current scope.
341
342        Returns:
343             list[exp.Column]: Unqualified columns
344        """
345        return [c for c in self.columns if not c.table]
346
347    @property
348    def join_hints(self):
349        """
350        Hints that exist in the scope that reference tables
351
352        Returns:
353            list[exp.JoinHint]: Join hints that are referenced within the scope
354        """
355        if self._join_hints is None:
356            return []
357        return self._join_hints
358
359    @property
360    def pivots(self):
361        if not self._pivots:
362            self._pivots = [
363                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
364            ]
365
366        return self._pivots
367
368    def source_columns(self, source_name):
369        """
370        Get all columns in the current scope for a particular source.
371
372        Args:
373            source_name (str): Name of the source
374        Returns:
375            list[exp.Column]: Column instances that reference `source_name`
376        """
377        return [column for column in self.columns if column.table == source_name]
378
379    @property
380    def is_subquery(self):
381        """Determine if this scope is a subquery"""
382        return self.scope_type == ScopeType.SUBQUERY
383
384    @property
385    def is_derived_table(self):
386        """Determine if this scope is a derived table"""
387        return self.scope_type == ScopeType.DERIVED_TABLE
388
389    @property
390    def is_union(self):
391        """Determine if this scope is a union"""
392        return self.scope_type == ScopeType.UNION
393
394    @property
395    def is_cte(self):
396        """Determine if this scope is a common table expression"""
397        return self.scope_type == ScopeType.CTE
398
399    @property
400    def is_root(self):
401        """Determine if this is the root scope"""
402        return self.scope_type == ScopeType.ROOT
403
404    @property
405    def is_udtf(self):
406        """Determine if this scope is a UDTF (User Defined Table Function)"""
407        return self.scope_type == ScopeType.UDTF
408
409    @property
410    def is_correlated_subquery(self):
411        """Determine if this scope is a correlated subquery"""
412        return bool(
413            (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
414            and self.external_columns
415        )
416
417    def rename_source(self, old_name, new_name):
418        """Rename a source in this scope"""
419        columns = self.sources.pop(old_name or "", [])
420        self.sources[new_name] = columns
421
422    def add_source(self, name, source):
423        """Add a source to this scope"""
424        self.sources[name] = source
425        self.clear_cache()
426
427    def remove_source(self, name):
428        """Remove a source from this scope"""
429        self.sources.pop(name, None)
430        self.clear_cache()
431
432    def __repr__(self):
433        return f"Scope<{self.expression.sql()}>"
434
435    def traverse(self):
436        """
437        Traverse the scope tree from this node.
438
439        Yields:
440            Scope: scope instances in depth-first-search post-order
441        """
442        for child_scope in itertools.chain(
443            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
444        ):
445            yield from child_scope.traverse()
446        yield self
447
448    def ref_count(self):
449        """
450        Count the number of times each scope in this tree is referenced.
451
452        Returns:
453            dict[int, int]: Mapping of Scope instance ID to reference count
454        """
455        scope_ref_count = defaultdict(lambda: 0)
456
457        for scope in self.traverse():
458            for _, source in scope.selected_sources.values():
459                scope_ref_count[id(source)] += 1
460
461        return scope_ref_count
462
463
464def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
465    """
466    Traverse an expression by its "scopes".
467
468    "Scope" represents the current context of a Select statement.
469
470    This is helpful for optimizing queries, where we need more information than
471    the expression tree itself. For example, we might care about the source
472    names within a subquery. Returns a list because a generator could result in
473    incomplete properties which is confusing.
474
475    Examples:
476        >>> import sqlglot
477        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
478        >>> scopes = traverse_scope(expression)
479        >>> scopes[0].expression.sql(), list(scopes[0].sources)
480        ('SELECT a FROM x', ['x'])
481        >>> scopes[1].expression.sql(), list(scopes[1].sources)
482        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
483
484    Args:
485        expression (exp.Expression): expression to traverse
486    Returns:
487        list[Scope]: scope instances
488    """
489    if isinstance(expression, exp.Unionable) or (
490        isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
491    ):
492        return list(_traverse_scope(Scope(expression)))
493
494    return []
495
496
497def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
498    """
499    Build a scope tree.
500
501    Args:
502        expression (exp.Expression): expression to build the scope tree for
503    Returns:
504        Scope: root scope
505    """
506    scopes = traverse_scope(expression)
507    if scopes:
508        return scopes[-1]
509    return None
510
511
512def _traverse_scope(scope):
513    if isinstance(scope.expression, exp.Select):
514        yield from _traverse_select(scope)
515    elif isinstance(scope.expression, exp.Union):
516        yield from _traverse_union(scope)
517    elif isinstance(scope.expression, exp.Subquery):
518        if scope.is_root:
519            yield from _traverse_select(scope)
520        else:
521            yield from _traverse_subqueries(scope)
522    elif isinstance(scope.expression, exp.Table):
523        yield from _traverse_tables(scope)
524    elif isinstance(scope.expression, exp.UDTF):
525        yield from _traverse_udtfs(scope)
526    elif isinstance(scope.expression, exp.DDL):
527        yield from _traverse_ddl(scope)
528    else:
529        logger.warning(
530            "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
531        )
532        return
533
534    yield scope
535
536
537def _traverse_select(scope):
538    yield from _traverse_ctes(scope)
539    yield from _traverse_tables(scope)
540    yield from _traverse_subqueries(scope)
541
542
543def _traverse_union(scope):
544    yield from _traverse_ctes(scope)
545
546    # The last scope to be yield should be the top most scope
547    left = None
548    for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
549        yield left
550
551    right = None
552    for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
553        yield right
554
555    scope.union_scopes = [left, right]
556
557
558def _traverse_ctes(scope):
559    sources = {}
560
561    for cte in scope.ctes:
562        recursive_scope = None
563
564        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
565        # thus the recursive scope is the first section of the union.
566        with_ = scope.expression.args.get("with")
567        if with_ and with_.recursive:
568            union = cte.this
569
570            if isinstance(union, exp.Union):
571                recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
572
573        child_scope = None
574
575        for child_scope in _traverse_scope(
576            scope.branch(
577                cte.this,
578                chain_sources=sources,
579                outer_column_list=cte.alias_column_names,
580                scope_type=ScopeType.CTE,
581            )
582        ):
583            yield child_scope
584
585            alias = cte.alias
586            sources[alias] = child_scope
587
588            if recursive_scope:
589                child_scope.add_source(alias, recursive_scope)
590
591        # append the final child_scope yielded
592        if child_scope:
593            scope.cte_scopes.append(child_scope)
594
595    scope.sources.update(sources)
596
597
598def _is_derived_table(expression: exp.Subquery) -> bool:
599    """
600    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
601    as it doesn't introduce a new scope. If an alias is present, it shadows all names
602    under the Subquery, so that's one exception to this rule.
603    """
604    return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
605
606
607def _traverse_tables(scope):
608    sources = {}
609
610    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
611    expressions = []
612    from_ = scope.expression.args.get("from")
613    if from_:
614        expressions.append(from_.this)
615
616    for join in scope.expression.args.get("joins") or []:
617        expressions.append(join.this)
618
619    if isinstance(scope.expression, exp.Table):
620        expressions.append(scope.expression)
621
622    expressions.extend(scope.expression.args.get("laterals") or [])
623
624    for expression in expressions:
625        if isinstance(expression, exp.Table):
626            table_name = expression.name
627            source_name = expression.alias_or_name
628
629            if table_name in scope.sources and not expression.db:
630                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
631                # it is pivoted, because then we get back a new table and hence a new source.
632                pivots = expression.args.get("pivots")
633                if pivots:
634                    sources[pivots[0].alias] = expression
635                else:
636                    sources[source_name] = scope.sources[table_name]
637            elif source_name in sources:
638                sources[find_new_name(sources, table_name)] = expression
639            else:
640                sources[source_name] = expression
641
642            # Make sure to not include the joins twice
643            if expression is not scope.expression:
644                expressions.extend(join.this for join in expression.args.get("joins") or [])
645
646            continue
647
648        if not isinstance(expression, exp.DerivedTable):
649            continue
650
651        if isinstance(expression, exp.UDTF):
652            lateral_sources = sources
653            scope_type = ScopeType.UDTF
654            scopes = scope.udtf_scopes
655        elif _is_derived_table(expression):
656            lateral_sources = None
657            scope_type = ScopeType.DERIVED_TABLE
658            scopes = scope.derived_table_scopes
659            expressions.extend(join.this for join in expression.args.get("joins") or [])
660        else:
661            # Makes sure we check for possible sources in nested table constructs
662            expressions.append(expression.this)
663            expressions.extend(join.this for join in expression.args.get("joins") or [])
664            continue
665
666        for child_scope in _traverse_scope(
667            scope.branch(
668                expression,
669                lateral_sources=lateral_sources,
670                outer_column_list=expression.alias_column_names,
671                scope_type=scope_type,
672            )
673        ):
674            yield child_scope
675
676            # Tables without aliases will be set as ""
677            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
678            # Until then, this means that only a single, unaliased derived table is allowed (rather,
679            # the latest one wins.
680            sources[expression.alias] = child_scope
681
682        # append the final child_scope yielded
683        scopes.append(child_scope)
684        scope.table_scopes.append(child_scope)
685
686    scope.sources.update(sources)
687
688
689def _traverse_subqueries(scope):
690    for subquery in scope.subqueries:
691        top = None
692        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
693            yield child_scope
694            top = child_scope
695        scope.subquery_scopes.append(top)
696
697
698def _traverse_udtfs(scope):
699    if isinstance(scope.expression, exp.Unnest):
700        expressions = scope.expression.expressions
701    elif isinstance(scope.expression, exp.Lateral):
702        expressions = [scope.expression.this]
703    else:
704        expressions = []
705
706    sources = {}
707    for expression in expressions:
708        if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
709            top = None
710            for child_scope in _traverse_scope(
711                scope.branch(
712                    expression,
713                    scope_type=ScopeType.DERIVED_TABLE,
714                    outer_column_list=expression.alias_column_names,
715                )
716            ):
717                yield child_scope
718                top = child_scope
719                sources[expression.alias] = child_scope
720
721            scope.derived_table_scopes.append(top)
722            scope.table_scopes.append(top)
723
724    scope.sources.update(sources)
725
726
727def _traverse_ddl(scope):
728    yield from _traverse_ctes(scope)
729
730    query_scope = scope.branch(
731        scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
732    )
733    query_scope._collect()
734    query_scope._ctes = scope.ctes + query_scope._ctes
735
736    yield from _traverse_scope(query_scope)
737
738
739def walk_in_scope(expression, bfs=True, prune=None):
740    """
741    Returns a generator object which visits all nodes in the syntrax tree, stopping at
742    nodes that start child scopes.
743
744    Args:
745        expression (exp.Expression):
746        bfs (bool): if set to True the BFS traversal order will be applied,
747            otherwise the DFS traversal will be used instead.
748        prune ((node, parent, arg_key) -> bool): callable that returns True if
749            the generator should stop traversing this branch of the tree.
750
751    Yields:
752        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
753    """
754    # We'll use this variable to pass state into the dfs generator.
755    # Whenever we set it to True, we exclude a subtree from traversal.
756    crossed_scope_boundary = False
757
758    for node, parent, key in expression.walk(
759        bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
760    ):
761        crossed_scope_boundary = False
762
763        yield node, parent, key
764
765        if node is expression:
766            continue
767        if (
768            isinstance(node, exp.CTE)
769            or (
770                isinstance(node, exp.Subquery)
771                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
772                and _is_derived_table(node)
773            )
774            or isinstance(node, exp.UDTF)
775            or isinstance(node, exp.Subqueryable)
776        ):
777            crossed_scope_boundary = True
778
779            if isinstance(node, (exp.Subquery, exp.UDTF)):
780                # The following args are not actually in the inner scope, so we should visit them
781                for key in ("joins", "laterals", "pivots"):
782                    for arg in node.args.get(key) or []:
783                        yield from walk_in_scope(arg, bfs=bfs)
784
785
786def find_all_in_scope(expression, expression_types, bfs=True):
787    """
788    Returns a generator object which visits all nodes in this scope and only yields those that
789    match at least one of the specified expression types.
790
791    This does NOT traverse into subscopes.
792
793    Args:
794        expression (exp.Expression):
795        expression_types (tuple[type]|type): the expression type(s) to match.
796        bfs (bool): True to use breadth-first search, False to use depth-first.
797
798    Yields:
799        exp.Expression: nodes
800    """
801    for expression, *_ in walk_in_scope(expression, bfs=bfs):
802        if isinstance(expression, tuple(ensure_collection(expression_types))):
803            yield expression
804
805
806def find_in_scope(expression, expression_types, bfs=True):
807    """
808    Returns the first node in this scope which matches at least one of the specified types.
809
810    This does NOT traverse into subscopes.
811
812    Args:
813        expression (exp.Expression):
814        expression_types (tuple[type]|type): the expression type(s) to match.
815        bfs (bool): True to use breadth-first search, False to use depth-first.
816
817    Returns:
818        exp.Expression: the node which matches the criteria or None if no node matching
819        the criteria was found.
820    """
821    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
logger = <Logger sqlglot (WARNING)>
class ScopeType(enum.Enum):
17class ScopeType(Enum):
18    ROOT = auto()
19    SUBQUERY = auto()
20    DERIVED_TABLE = auto()
21    CTE = auto()
22    UNION = auto()
23    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
Inherited Members
enum.Enum
name
value
class Scope:
 26class Scope:
 27    """
 28    Selection scope.
 29
 30    Attributes:
 31        expression (exp.Select|exp.Union): Root expression of this scope
 32        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 33            a Table expression or another Scope instance. For example:
 34                SELECT * FROM x                     {"x": Table(this="x")}
 35                SELECT * FROM x AS y                {"y": Table(this="x")}
 36                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 37        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 38            For example:
 39                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 40            The LATERAL VIEW EXPLODE gets x as a source.
 41        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 42            defines a column list of it's alias of this scope, this is that list of columns.
 43            For example:
 44                SELECT * FROM (SELECT ...) AS y(col1, col2)
 45            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 46        parent (Scope): Parent scope
 47        scope_type (ScopeType): Type of this scope, relative to it's parent
 48        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 49        cte_scopes (list[Scope]): List of all child scopes for CTEs
 50        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 51        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 52        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 53        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 54            a list of the left and right child scopes.
 55    """
 56
 57    def __init__(
 58        self,
 59        expression,
 60        sources=None,
 61        outer_column_list=None,
 62        parent=None,
 63        scope_type=ScopeType.ROOT,
 64        lateral_sources=None,
 65    ):
 66        self.expression = expression
 67        self.sources = sources or {}
 68        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 69        self.sources.update(self.lateral_sources)
 70        self.outer_column_list = outer_column_list or []
 71        self.parent = parent
 72        self.scope_type = scope_type
 73        self.subquery_scopes = []
 74        self.derived_table_scopes = []
 75        self.table_scopes = []
 76        self.cte_scopes = []
 77        self.union_scopes = []
 78        self.udtf_scopes = []
 79        self.clear_cache()
 80
 81    def clear_cache(self):
 82        self._collected = False
 83        self._raw_columns = None
 84        self._derived_tables = None
 85        self._udtfs = None
 86        self._tables = None
 87        self._ctes = None
 88        self._subqueries = None
 89        self._selected_sources = None
 90        self._columns = None
 91        self._external_columns = None
 92        self._join_hints = None
 93        self._pivots = None
 94        self._references = None
 95
 96    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 97        """Branch from the current scope to a new, inner scope"""
 98        return Scope(
 99            expression=expression.unnest(),
100            sources={**self.cte_sources, **(chain_sources or {})},
101            parent=self,
102            scope_type=scope_type,
103            **kwargs,
104        )
105
106    def _collect(self):
107        self._tables = []
108        self._ctes = []
109        self._subqueries = []
110        self._derived_tables = []
111        self._udtfs = []
112        self._raw_columns = []
113        self._join_hints = []
114
115        for node, parent, _ in self.walk(bfs=False):
116            if node is self.expression:
117                continue
118            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
119                self._raw_columns.append(node)
120            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
121                self._tables.append(node)
122            elif isinstance(node, exp.JoinHint):
123                self._join_hints.append(node)
124            elif isinstance(node, exp.UDTF):
125                self._udtfs.append(node)
126            elif isinstance(node, exp.CTE):
127                self._ctes.append(node)
128            elif (
129                isinstance(node, exp.Subquery)
130                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
131                and _is_derived_table(node)
132            ):
133                self._derived_tables.append(node)
134            elif isinstance(node, exp.Subqueryable):
135                self._subqueries.append(node)
136
137        self._collected = True
138
139    def _ensure_collected(self):
140        if not self._collected:
141            self._collect()
142
143    def walk(self, bfs=True, prune=None):
144        return walk_in_scope(self.expression, bfs=bfs, prune=None)
145
146    def find(self, *expression_types, bfs=True):
147        return find_in_scope(self.expression, expression_types, bfs=bfs)
148
149    def find_all(self, *expression_types, bfs=True):
150        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
151
152    def replace(self, old, new):
153        """
154        Replace `old` with `new`.
155
156        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
157
158        Args:
159            old (exp.Expression): old node
160            new (exp.Expression): new node
161        """
162        old.replace(new)
163        self.clear_cache()
164
165    @property
166    def tables(self):
167        """
168        List of tables in this scope.
169
170        Returns:
171            list[exp.Table]: tables
172        """
173        self._ensure_collected()
174        return self._tables
175
176    @property
177    def ctes(self):
178        """
179        List of CTEs in this scope.
180
181        Returns:
182            list[exp.CTE]: ctes
183        """
184        self._ensure_collected()
185        return self._ctes
186
187    @property
188    def derived_tables(self):
189        """
190        List of derived tables in this scope.
191
192        For example:
193            SELECT * FROM (SELECT ...) <- that's a derived table
194
195        Returns:
196            list[exp.Subquery]: derived tables
197        """
198        self._ensure_collected()
199        return self._derived_tables
200
201    @property
202    def udtfs(self):
203        """
204        List of "User Defined Tabular Functions" in this scope.
205
206        Returns:
207            list[exp.UDTF]: UDTFs
208        """
209        self._ensure_collected()
210        return self._udtfs
211
212    @property
213    def subqueries(self):
214        """
215        List of subqueries in this scope.
216
217        For example:
218            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
219
220        Returns:
221            list[exp.Subqueryable]: subqueries
222        """
223        self._ensure_collected()
224        return self._subqueries
225
226    @property
227    def columns(self):
228        """
229        List of columns in this scope.
230
231        Returns:
232            list[exp.Column]: Column instances in this scope, plus any
233                Columns that reference this scope from correlated subqueries.
234        """
235        if self._columns is None:
236            self._ensure_collected()
237            columns = self._raw_columns
238
239            external_columns = [
240                column
241                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
242                for column in scope.external_columns
243            ]
244
245            named_selects = set(self.expression.named_selects)
246
247            self._columns = []
248            for column in columns + external_columns:
249                ancestor = column.find_ancestor(
250                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
251                )
252                if (
253                    not ancestor
254                    or column.table
255                    or isinstance(ancestor, exp.Select)
256                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
257                    or (
258                        isinstance(ancestor, exp.Order)
259                        and (
260                            isinstance(ancestor.parent, exp.Window)
261                            or column.name not in named_selects
262                        )
263                    )
264                ):
265                    self._columns.append(column)
266
267        return self._columns
268
269    @property
270    def selected_sources(self):
271        """
272        Mapping of nodes and sources that are actually selected from in this scope.
273
274        That is, all tables in a schema are selectable at any point. But a
275        table only becomes a selected source if it's included in a FROM or JOIN clause.
276
277        Returns:
278            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
279        """
280        if self._selected_sources is None:
281            result = {}
282
283            for name, node in self.references:
284                if name in result:
285                    raise OptimizeError(f"Alias already used: {name}")
286                if name in self.sources:
287                    result[name] = (node, self.sources[name])
288
289            self._selected_sources = result
290        return self._selected_sources
291
292    @property
293    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
294        if self._references is None:
295            self._references = []
296
297            for table in self.tables:
298                self._references.append((table.alias_or_name, table))
299            for expression in itertools.chain(self.derived_tables, self.udtfs):
300                self._references.append(
301                    (
302                        expression.alias,
303                        expression if expression.args.get("pivots") else expression.unnest(),
304                    )
305                )
306
307        return self._references
308
309    @property
310    def cte_sources(self):
311        """
312        Sources that are CTEs.
313
314        Returns:
315            dict[str, Scope]: Mapping of source alias to Scope
316        """
317        return {
318            alias: scope
319            for alias, scope in self.sources.items()
320            if isinstance(scope, Scope) and scope.is_cte
321        }
322
323    @property
324    def external_columns(self):
325        """
326        Columns that appear to reference sources in outer scopes.
327
328        Returns:
329            list[exp.Column]: Column instances that don't reference
330                sources in the current scope.
331        """
332        if self._external_columns is None:
333            self._external_columns = [
334                c for c in self.columns if c.table not in self.selected_sources
335            ]
336        return self._external_columns
337
338    @property
339    def unqualified_columns(self):
340        """
341        Unqualified columns in the current scope.
342
343        Returns:
344             list[exp.Column]: Unqualified columns
345        """
346        return [c for c in self.columns if not c.table]
347
348    @property
349    def join_hints(self):
350        """
351        Hints that exist in the scope that reference tables
352
353        Returns:
354            list[exp.JoinHint]: Join hints that are referenced within the scope
355        """
356        if self._join_hints is None:
357            return []
358        return self._join_hints
359
360    @property
361    def pivots(self):
362        if not self._pivots:
363            self._pivots = [
364                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
365            ]
366
367        return self._pivots
368
369    def source_columns(self, source_name):
370        """
371        Get all columns in the current scope for a particular source.
372
373        Args:
374            source_name (str): Name of the source
375        Returns:
376            list[exp.Column]: Column instances that reference `source_name`
377        """
378        return [column for column in self.columns if column.table == source_name]
379
380    @property
381    def is_subquery(self):
382        """Determine if this scope is a subquery"""
383        return self.scope_type == ScopeType.SUBQUERY
384
385    @property
386    def is_derived_table(self):
387        """Determine if this scope is a derived table"""
388        return self.scope_type == ScopeType.DERIVED_TABLE
389
390    @property
391    def is_union(self):
392        """Determine if this scope is a union"""
393        return self.scope_type == ScopeType.UNION
394
395    @property
396    def is_cte(self):
397        """Determine if this scope is a common table expression"""
398        return self.scope_type == ScopeType.CTE
399
400    @property
401    def is_root(self):
402        """Determine if this is the root scope"""
403        return self.scope_type == ScopeType.ROOT
404
405    @property
406    def is_udtf(self):
407        """Determine if this scope is a UDTF (User Defined Table Function)"""
408        return self.scope_type == ScopeType.UDTF
409
410    @property
411    def is_correlated_subquery(self):
412        """Determine if this scope is a correlated subquery"""
413        return bool(
414            (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
415            and self.external_columns
416        )
417
418    def rename_source(self, old_name, new_name):
419        """Rename a source in this scope"""
420        columns = self.sources.pop(old_name or "", [])
421        self.sources[new_name] = columns
422
423    def add_source(self, name, source):
424        """Add a source to this scope"""
425        self.sources[name] = source
426        self.clear_cache()
427
428    def remove_source(self, name):
429        """Remove a source from this scope"""
430        self.sources.pop(name, None)
431        self.clear_cache()
432
433    def __repr__(self):
434        return f"Scope<{self.expression.sql()}>"
435
436    def traverse(self):
437        """
438        Traverse the scope tree from this node.
439
440        Yields:
441            Scope: scope instances in depth-first-search post-order
442        """
443        for child_scope in itertools.chain(
444            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
445        ):
446            yield from child_scope.traverse()
447        yield self
448
449    def ref_count(self):
450        """
451        Count the number of times each scope in this tree is referenced.
452
453        Returns:
454            dict[int, int]: Mapping of Scope instance ID to reference count
455        """
456        scope_ref_count = defaultdict(lambda: 0)
457
458        for scope in self.traverse():
459            for _, source in scope.selected_sources.values():
460                scope_ref_count[id(source)] += 1
461
462        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.Union): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_column_list
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None)
57    def __init__(
58        self,
59        expression,
60        sources=None,
61        outer_column_list=None,
62        parent=None,
63        scope_type=ScopeType.ROOT,
64        lateral_sources=None,
65    ):
66        self.expression = expression
67        self.sources = sources or {}
68        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
69        self.sources.update(self.lateral_sources)
70        self.outer_column_list = outer_column_list or []
71        self.parent = parent
72        self.scope_type = scope_type
73        self.subquery_scopes = []
74        self.derived_table_scopes = []
75        self.table_scopes = []
76        self.cte_scopes = []
77        self.union_scopes = []
78        self.udtf_scopes = []
79        self.clear_cache()
expression
sources
lateral_sources
outer_column_list
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
def clear_cache(self):
81    def clear_cache(self):
82        self._collected = False
83        self._raw_columns = None
84        self._derived_tables = None
85        self._udtfs = None
86        self._tables = None
87        self._ctes = None
88        self._subqueries = None
89        self._selected_sources = None
90        self._columns = None
91        self._external_columns = None
92        self._join_hints = None
93        self._pivots = None
94        self._references = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 96    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 97        """Branch from the current scope to a new, inner scope"""
 98        return Scope(
 99            expression=expression.unnest(),
100            sources={**self.cte_sources, **(chain_sources or {})},
101            parent=self,
102            scope_type=scope_type,
103            **kwargs,
104        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
143    def walk(self, bfs=True, prune=None):
144        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
146    def find(self, *expression_types, bfs=True):
147        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
149    def find_all(self, *expression_types, bfs=True):
150        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
152    def replace(self, old, new):
153        """
154        Replace `old` with `new`.
155
156        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
157
158        Args:
159            old (exp.Expression): old node
160            new (exp.Expression): new node
161        """
162        old.replace(new)
163        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Subqueryable]: subqueries

columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
cte_sources

Sources that are CTEs.

Returns:

dict[str, Scope]: Mapping of source alias to Scope

external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
def source_columns(self, source_name):
369    def source_columns(self, source_name):
370        """
371        Get all columns in the current scope for a particular source.
372
373        Args:
374            source_name (str): Name of the source
375        Returns:
376            list[exp.Column]: Column instances that reference `source_name`
377        """
378        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery

Determine if this scope is a subquery

is_derived_table

Determine if this scope is a derived table

is_union

Determine if this scope is a union

is_cte

Determine if this scope is a common table expression

is_root

Determine if this is the root scope

is_udtf

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
418    def rename_source(self, old_name, new_name):
419        """Rename a source in this scope"""
420        columns = self.sources.pop(old_name or "", [])
421        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
423    def add_source(self, name, source):
424        """Add a source to this scope"""
425        self.sources[name] = source
426        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
428    def remove_source(self, name):
429        """Remove a source from this scope"""
430        self.sources.pop(name, None)
431        self.clear_cache()

Remove a source from this scope

def traverse(self):
436    def traverse(self):
437        """
438        Traverse the scope tree from this node.
439
440        Yields:
441            Scope: scope instances in depth-first-search post-order
442        """
443        for child_scope in itertools.chain(
444            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
445        ):
446            yield from child_scope.traverse()
447        yield self

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
449    def ref_count(self):
450        """
451        Count the number of times each scope in this tree is referenced.
452
453        Returns:
454            dict[int, int]: Mapping of Scope instance ID to reference count
455        """
456        scope_ref_count = defaultdict(lambda: 0)
457
458        for scope in self.traverse():
459            for _, source in scope.selected_sources.values():
460                scope_ref_count[id(source)] += 1
461
462        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
465def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
466    """
467    Traverse an expression by its "scopes".
468
469    "Scope" represents the current context of a Select statement.
470
471    This is helpful for optimizing queries, where we need more information than
472    the expression tree itself. For example, we might care about the source
473    names within a subquery. Returns a list because a generator could result in
474    incomplete properties which is confusing.
475
476    Examples:
477        >>> import sqlglot
478        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
479        >>> scopes = traverse_scope(expression)
480        >>> scopes[0].expression.sql(), list(scopes[0].sources)
481        ('SELECT a FROM x', ['x'])
482        >>> scopes[1].expression.sql(), list(scopes[1].sources)
483        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
484
485    Args:
486        expression (exp.Expression): expression to traverse
487    Returns:
488        list[Scope]: scope instances
489    """
490    if isinstance(expression, exp.Unionable) or (
491        isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
492    ):
493        return list(_traverse_scope(Scope(expression)))
494
495    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression (exp.Expression): expression to traverse
Returns:

list[Scope]: scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
498def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
499    """
500    Build a scope tree.
501
502    Args:
503        expression (exp.Expression): expression to build the scope tree for
504    Returns:
505        Scope: root scope
506    """
507    scopes = traverse_scope(expression)
508    if scopes:
509        return scopes[-1]
510    return None

Build a scope tree.

Arguments:
  • expression (exp.Expression): expression to build the scope tree for
Returns:

Scope: root scope

def walk_in_scope(expression, bfs=True, prune=None):
740def walk_in_scope(expression, bfs=True, prune=None):
741    """
742    Returns a generator object which visits all nodes in the syntrax tree, stopping at
743    nodes that start child scopes.
744
745    Args:
746        expression (exp.Expression):
747        bfs (bool): if set to True the BFS traversal order will be applied,
748            otherwise the DFS traversal will be used instead.
749        prune ((node, parent, arg_key) -> bool): callable that returns True if
750            the generator should stop traversing this branch of the tree.
751
752    Yields:
753        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
754    """
755    # We'll use this variable to pass state into the dfs generator.
756    # Whenever we set it to True, we exclude a subtree from traversal.
757    crossed_scope_boundary = False
758
759    for node, parent, key in expression.walk(
760        bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
761    ):
762        crossed_scope_boundary = False
763
764        yield node, parent, key
765
766        if node is expression:
767            continue
768        if (
769            isinstance(node, exp.CTE)
770            or (
771                isinstance(node, exp.Subquery)
772                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
773                and _is_derived_table(node)
774            )
775            or isinstance(node, exp.UDTF)
776            or isinstance(node, exp.Subqueryable)
777        ):
778            crossed_scope_boundary = True
779
780            if isinstance(node, (exp.Subquery, exp.UDTF)):
781                # The following args are not actually in the inner scope, so we should visit them
782                for key in ("joins", "laterals", "pivots"):
783                    for arg in node.args.get(key) or []:
784                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
787def find_all_in_scope(expression, expression_types, bfs=True):
788    """
789    Returns a generator object which visits all nodes in this scope and only yields those that
790    match at least one of the specified expression types.
791
792    This does NOT traverse into subscopes.
793
794    Args:
795        expression (exp.Expression):
796        expression_types (tuple[type]|type): the expression type(s) to match.
797        bfs (bool): True to use breadth-first search, False to use depth-first.
798
799    Yields:
800        exp.Expression: nodes
801    """
802    for expression, *_ in walk_in_scope(expression, bfs=bfs):
803        if isinstance(expression, tuple(ensure_collection(expression_types))):
804            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
807def find_in_scope(expression, expression_types, bfs=True):
808    """
809    Returns the first node in this scope which matches at least one of the specified types.
810
811    This does NOT traverse into subscopes.
812
813    Args:
814        expression (exp.Expression):
815        expression_types (tuple[type]|type): the expression type(s) to match.
816        bfs (bool): True to use breadth-first search, False to use depth-first.
817
818    Returns:
819        exp.Expression: the node which matches the criteria or None if no node matching
820        the criteria was found.
821    """
822    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.