from __future__ import annotations

import itertools
import logging
import typing as t
from collections import defaultdict
from enum import Enum, auto

from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import ensure_collection, find_new_name, seq_get

logger = logging.getLogger("sqlglot")


class ScopeType(Enum):
    ROOT = auto()
    SUBQUERY = auto()
    DERIVED_TABLE = auto()
    CTE = auto()
    UNION = auto()
    UDTF = auto()


class Scope:
    """
    Selection scope.

    Attributes:
        expression (exp.Select|exp.Union): Root expression of this scope
        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
            a Table expression or another Scope instance. For example:
                SELECT * FROM x                     {"x": Table(this="x")}
                SELECT * FROM x AS y                {"y": Table(this="x")}
                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
            For example:
                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
            The LATERAL VIEW EXPLODE gets x as a source.
        cte_sources (dict[str, Scope]): Sources from CTES
        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
            defines a column list for the alias of this scope, this is that list of columns.
            For example:
                SELECT * FROM (SELECT ...) AS y(col1, col2)
            The inner query would have `["col1", "col2"]` for its `outer_columns`
        parent (Scope): Parent scope
        scope_type (ScopeType): Type of this scope, relative to it's parent
        subquery_scopes (list[Scope]): List of all child scopes for subqueries
        cte_scopes (list[Scope]): List of all child scopes for CTEs
        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
            a list of the left and right child scopes.
    """

    def __init__(
        self,
        expression,
        sources=None,
        outer_columns=None,
        parent=None,
        scope_type=ScopeType.ROOT,
        lateral_sources=None,
        cte_sources=None,
    ):
        self.expression = expression
        self.sources = sources or {}
        self.lateral_sources = lateral_sources or {}
        self.cte_sources = cte_sources or {}
        self.sources.update(self.lateral_sources)
        self.sources.update(self.cte_sources)
        self.outer_columns = outer_columns or []
        self.parent = parent
        self.scope_type = scope_type
        self.subquery_scopes = []
        self.derived_table_scopes = []
        self.table_scopes = []
        self.cte_scopes = []
        self.union_scopes = []
        self.udtf_scopes = []
        self.clear_cache()

    def clear_cache(self):
        self._collected = False
        self._raw_columns = None
        self._derived_tables = None
        self._udtfs = None
        self._tables = None
        self._ctes = None
        self._subqueries = None
        self._selected_sources = None
        self._columns = None
        self._external_columns = None
        self._join_hints = None
        self._pivots = None
        self._references = None

    def branch(
        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
    ):
        """Branch from the current scope to a new, inner scope"""
        return Scope(
            expression=expression.unnest(),
            sources=sources.copy() if sources else None,
            parent=self,
            scope_type=scope_type,
            cte_sources={**self.cte_sources, **(cte_sources or {})},
            lateral_sources=lateral_sources.copy() if lateral_sources else None,
            **kwargs,
        )

    def _collect(self):
        self._tables = []
        self._ctes = []
        self._subqueries = []
        self._derived_tables = []
        self._udtfs = []
        self._raw_columns = []
        self._join_hints = []

        for node in self.walk(bfs=False):
            if node is self.expression:
                continue

            if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
                self._raw_columns.append(node)
            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
                self._tables.append(node)
            elif isinstance(node, exp.JoinHint):
                self._join_hints.append(node)
            elif isinstance(node, exp.UDTF):
                self._udtfs.append(node)
            elif isinstance(node, exp.CTE):
                self._ctes.append(node)
            elif _is_derived_table(node) and isinstance(
                node.parent, (exp.From, exp.Join, exp.Subquery)
            ):
                self._derived_tables.append(node)
            elif isinstance(node, exp.UNWRAPPED_QUERIES):
                self._subqueries.append(node)

        self._collected = True

    def _ensure_collected(self):
        if not self._collected:
            self._collect()

    def walk(self, bfs=True, prune=None):
        return walk_in_scope(self.expression, bfs=bfs, prune=None)

    def find(self, *expression_types, bfs=True):
        return find_in_scope(self.expression, expression_types, bfs=bfs)

    def find_all(self, *expression_types, bfs=True):
        return find_all_in_scope(self.expression, expression_types, bfs=bfs)

    def replace(self, old, new):
        """
        Replace `old` with `new`.

        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.

        Args:
            old (exp.Expression): old node
            new (exp.Expression): new node
        """
        old.replace(new)
        self.clear_cache()

    @property
    def tables(self):
        """
        List of tables in this scope.

        Returns:
            list[exp.Table]: tables
        """
        self._ensure_collected()
        return self._tables

    @property
    def ctes(self):
        """
        List of CTEs in this scope.

        Returns:
            list[exp.CTE]: ctes
        """
        self._ensure_collected()
        return self._ctes

    @property
    def derived_tables(self):
        """
        List of derived tables in this scope.

        For example:
            SELECT * FROM (SELECT ...) <- that's a derived table

        Returns:
            list[exp.Subquery]: derived tables
        """
        self._ensure_collected()
        return self._derived_tables

    @property
    def udtfs(self):
        """
        List of "User Defined Tabular Functions" in this scope.

        Returns:
            list[exp.UDTF]: UDTFs
        """
        self._ensure_collected()
        return self._udtfs

    @property
    def subqueries(self):
        """
        List of subqueries in this scope.

        For example:
            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

        Returns:
            list[exp.Select | exp.Union]: subqueries
        """
        self._ensure_collected()
        return self._subqueries

    @property
    def columns(self):
        """
        List of columns in this scope.

        Returns:
            list[exp.Column]: Column instances in this scope, plus any
                Columns that reference this scope from correlated subqueries.
        """
        if self._columns is None:
            self._ensure_collected()
            columns = self._raw_columns

            external_columns = [
                column
                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
                for column in scope.external_columns
            ]

            named_selects = set(self.expression.named_selects)

            self._columns = []
            for column in columns + external_columns:
                ancestor = column.find_ancestor(
                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star
                )
                if (
                    not ancestor
                    or column.table
                    or isinstance(ancestor, exp.Select)
                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
                    or (
                        isinstance(ancestor, exp.Order)
                        and (
                            isinstance(ancestor.parent, exp.Window)
                            or column.name not in named_selects
                        )
                    )
                ):
                    self._columns.append(column)

        return self._columns

    @property
    def selected_sources(self):
        """
        Mapping of nodes and sources that are actually selected from in this scope.

        That is, all tables in a schema are selectable at any point. But a
        table only becomes a selected source if it's included in a FROM or JOIN clause.

        Returns:
            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
        """
        if self._selected_sources is None:
            result = {}

            for name, node in self.references:
                if name in result:
                    raise OptimizeError(f"Alias already used: {name}")
                if name in self.sources:
                    result[name] = (node, self.sources[name])

            self._selected_sources = result
        return self._selected_sources

    @property
    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
        if self._references is None:
            self._references = []

            for table in self.tables:
                self._references.append((table.alias_or_name, table))
            for expression in itertools.chain(self.derived_tables, self.udtfs):
                self._references.append(
                    (
                        expression.alias,
                        expression if expression.args.get("pivots") else expression.unnest(),
                    )
                )

        return self._references

    @property
    def external_columns(self):
        """
        Columns that appear to reference sources in outer scopes.

        Returns:
            list[exp.Column]: Column instances that don't reference
                sources in the current scope.
        """
        if self._external_columns is None:
            if isinstance(self.expression, exp.Union):
                left, right = self.union_scopes
                self._external_columns = left.external_columns + right.external_columns
            else:
                self._external_columns = [
                    c for c in self.columns if c.table not in self.selected_sources
                ]

        return self._external_columns

    @property
    def unqualified_columns(self):
        """
        Unqualified columns in the current scope.

        Returns:
             list[exp.Column]: Unqualified columns
        """
        return [c for c in self.columns if not c.table]

    @property
    def join_hints(self):
        """
        Hints that exist in the scope that reference tables

        Returns:
            list[exp.JoinHint]: Join hints that are referenced within the scope
        """
        if self._join_hints is None:
            return []
        return self._join_hints

    @property
    def pivots(self):
        if not self._pivots:
            self._pivots = [
                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
            ]

        return self._pivots

    def source_columns(self, source_name):
        """
        Get all columns in the current scope for a particular source.

        Args:
            source_name (str): Name of the source
        Returns:
            list[exp.Column]: Column instances that reference `source_name`
        """
        return [column for column in self.columns if column.table == source_name]

    @property
    def is_subquery(self):
        """Determine if this scope is a subquery"""
        return self.scope_type == ScopeType.SUBQUERY

    @property
    def is_derived_table(self):
        """Determine if this scope is a derived table"""
        return self.scope_type == ScopeType.DERIVED_TABLE

    @property
    def is_union(self):
        """Determine if this scope is a union"""
        return self.scope_type == ScopeType.UNION

    @property
    def is_cte(self):
        """Determine if this scope is a common table expression"""
        return self.scope_type == ScopeType.CTE

    @property
    def is_root(self):
        """Determine if this is the root scope"""
        return self.scope_type == ScopeType.ROOT

    @property
    def is_udtf(self):
        """Determine if this scope is a UDTF (User Defined Table Function)"""
        return self.scope_type == ScopeType.UDTF

    @property
    def is_correlated_subquery(self):
        """Determine if this scope is a correlated subquery"""
        return bool(
            (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
            and self.external_columns
        )

    def rename_source(self, old_name, new_name):
        """Rename a source in this scope"""
        columns = self.sources.pop(old_name or "", [])
        self.sources[new_name] = columns

    def add_source(self, name, source):
        """Add a source to this scope"""
        self.sources[name] = source
        self.clear_cache()

    def remove_source(self, name):
        """Remove a source from this scope"""
        self.sources.pop(name, None)
        self.clear_cache()

    def __repr__(self):
        return f"Scope<{self.expression.sql()}>"

    def traverse(self):
        """
        Traverse the scope tree from this node.

        Yields:
            Scope: scope instances in depth-first-search post-order
        """
        stack = [self]
        result = []
        while stack:
            scope = stack.pop()
            result.append(scope)
            stack.extend(
                itertools.chain(
                    scope.cte_scopes,
                    scope.union_scopes,
                    scope.table_scopes,
                    scope.subquery_scopes,
                )
            )

        yield from reversed(result)

    def ref_count(self):
        """
        Count the number of times each scope in this tree is referenced.

        Returns:
            dict[int, int]: Mapping of Scope instance ID to reference count
        """
        scope_ref_count = defaultdict(lambda: 0)

        for scope in self.traverse():
            for _, source in scope.selected_sources.values():
                scope_ref_count[id(source)] += 1

        return scope_ref_count


def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
    """
    Traverse an expression by its "scopes".

    "Scope" represents the current context of a Select statement.

    This is helpful for optimizing queries, where we need more information than
    the expression tree itself. For example, we might care about the source
    names within a subquery. Returns a list because a generator could result in
    incomplete properties which is confusing.

    Examples:
        >>> import sqlglot
        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
        >>> scopes = traverse_scope(expression)
        >>> scopes[0].expression.sql(), list(scopes[0].sources)
        ('SELECT a FROM x', ['x'])
        >>> scopes[1].expression.sql(), list(scopes[1].sources)
        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])

    Args:
        expression: Expression to traverse

    Returns:
        A list of the created scope instances
    """
    if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
        # We ignore the DDL expression and build a scope for its query instead
        ddl_with = expression.args.get("with")
        expression = expression.expression

        # If the DDL has CTEs attached, we need to add them to the query, or
        # prepend them if the query itself already has CTEs attached to it
        if ddl_with:
            ddl_with.pop()
            query_ctes = expression.ctes
            if not query_ctes:
                expression.set("with", ddl_with)
            else:
                expression.args["with"].set("recursive", ddl_with.recursive)
                expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])

    if isinstance(expression, exp.Query):
        return list(_traverse_scope(Scope(expression)))

    return []


def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
    """
    Build a scope tree.

    Args:
        expression: Expression to build the scope tree for.

    Returns:
        The root scope
    """
    return seq_get(traverse_scope(expression), -1)


def _traverse_scope(scope):
    if isinstance(scope.expression, exp.Select):
        yield from _traverse_select(scope)
    elif isinstance(scope.expression, exp.Union):
        yield from _traverse_ctes(scope)
        yield from _traverse_union(scope)
        return
    elif isinstance(scope.expression, exp.Subquery):
        if scope.is_root:
            yield from _traverse_select(scope)
        else:
            yield from _traverse_subqueries(scope)
    elif isinstance(scope.expression, exp.Table):
        yield from _traverse_tables(scope)
    elif isinstance(scope.expression, exp.UDTF):
        yield from _traverse_udtfs(scope)
    else:
        logger.warning(
            "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
        )
        return

    yield scope


def _traverse_select(scope):
    yield from _traverse_ctes(scope)
    yield from _traverse_tables(scope)
    yield from _traverse_subqueries(scope)


def _traverse_union(scope):
    prev_scope = None
    union_scope_stack = [scope]
    expression_stack = [scope.expression.right, scope.expression.left]

    while expression_stack:
        expression = expression_stack.pop()
        union_scope = union_scope_stack[-1]

        new_scope = union_scope.branch(
            expression,
            outer_columns=union_scope.outer_columns,
            scope_type=ScopeType.UNION,
        )

        if isinstance(expression, exp.Union):
            yield from _traverse_ctes(new_scope)

            union_scope_stack.append(new_scope)
            expression_stack.extend([expression.right, expression.left])
            continue

        for scope in _traverse_scope(new_scope):
            yield scope

        if prev_scope:
            union_scope_stack.pop()
            union_scope.union_scopes = [prev_scope, scope]
            prev_scope = union_scope

            yield union_scope
        else:
            prev_scope = scope


def _traverse_ctes(scope):
    sources = {}

    for cte in scope.ctes:
        cte_name = cte.alias

        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
        # thus the recursive scope is the first section of the union.
        with_ = scope.expression.args.get("with")
        if with_ and with_.recursive:
            union = cte.this

            if isinstance(union, exp.Union):
                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)

        child_scope = None

        for child_scope in _traverse_scope(
            scope.branch(
                cte.this,
                cte_sources=sources,
                outer_columns=cte.alias_column_names,
                scope_type=ScopeType.CTE,
            )
        ):
            yield child_scope

        # append the final child_scope yielded
        if child_scope:
            sources[cte_name] = child_scope
            scope.cte_scopes.append(child_scope)

    scope.sources.update(sources)
    scope.cte_sources.update(sources)


def _is_derived_table(expression: exp.Subquery) -> bool:
    """
    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
    as it doesn't introduce a new scope. If an alias is present, it shadows all names
    under the Subquery, so that's one exception to this rule.
    """
    return isinstance(expression, exp.Subquery) and bool(
        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
    )


def _traverse_tables(scope):
    sources = {}

    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
    expressions = []
    from_ = scope.expression.args.get("from")
    if from_:
        expressions.append(from_.this)

    for join in scope.expression.args.get("joins") or []:
        expressions.append(join.this)

    if isinstance(scope.expression, exp.Table):
        expressions.append(scope.expression)

    expressions.extend(scope.expression.args.get("laterals") or [])

    for expression in expressions:
        if isinstance(expression, exp.Table):
            table_name = expression.name
            source_name = expression.alias_or_name

            if table_name in scope.sources and not expression.db:
                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
                # it is pivoted, because then we get back a new table and hence a new source.
                pivots = expression.args.get("pivots")
                if pivots:
                    sources[pivots[0].alias] = expression
                else:
                    sources[source_name] = scope.sources[table_name]
            elif source_name in sources:
                sources[find_new_name(sources, table_name)] = expression
            else:
                sources[source_name] = expression

            # Make sure to not include the joins twice
            if expression is not scope.expression:
                expressions.extend(join.this for join in expression.args.get("joins") or [])

            continue

        if not isinstance(expression, exp.DerivedTable):
            continue

        if isinstance(expression, exp.UDTF):
            lateral_sources = sources
            scope_type = ScopeType.UDTF
            scopes = scope.udtf_scopes
        elif _is_derived_table(expression):
            lateral_sources = None
            scope_type = ScopeType.DERIVED_TABLE
            scopes = scope.derived_table_scopes
            expressions.extend(join.this for join in expression.args.get("joins") or [])
        else:
            # Makes sure we check for possible sources in nested table constructs
            expressions.append(expression.this)
            expressions.extend(join.this for join in expression.args.get("joins") or [])
            continue

        for child_scope in _traverse_scope(
            scope.branch(
                expression,
                lateral_sources=lateral_sources,
                outer_columns=expression.alias_column_names,
                scope_type=scope_type,
            )
        ):
            yield child_scope

            # Tables without aliases will be set as ""
            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
            # Until then, this means that only a single, unaliased derived table is allowed (rather,
            # the latest one wins.
            sources[expression.alias] = child_scope

        # append the final child_scope yielded
        scopes.append(child_scope)
        scope.table_scopes.append(child_scope)

    scope.sources.update(sources)


def _traverse_subqueries(scope):
    for subquery in scope.subqueries:
        top = None
        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
            yield child_scope
            top = child_scope
        scope.subquery_scopes.append(top)


def _traverse_udtfs(scope):
    if isinstance(scope.expression, exp.Unnest):
        expressions = scope.expression.expressions
    elif isinstance(scope.expression, exp.Lateral):
        expressions = [scope.expression.this]
    else:
        expressions = []

    sources = {}
    for expression in expressions:
        if _is_derived_table(expression):
            top = None
            for child_scope in _traverse_scope(
                scope.branch(
                    expression,
                    scope_type=ScopeType.DERIVED_TABLE,
                    outer_columns=expression.alias_column_names,
                )
            ):
                yield child_scope
                top = child_scope
                sources[expression.alias] = child_scope

            scope.derived_table_scopes.append(top)
            scope.table_scopes.append(top)

    scope.sources.update(sources)


def walk_in_scope(expression, bfs=True, prune=None):
    """
    Returns a generator object which visits all nodes in the syntrax tree, stopping at
    nodes that start child scopes.

    Args:
        expression (exp.Expression):
        bfs (bool): if set to True the BFS traversal order will be applied,
            otherwise the DFS traversal will be used instead.
        prune ((node, parent, arg_key) -> bool): callable that returns True if
            the generator should stop traversing this branch of the tree.

    Yields:
        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
    """
    # We'll use this variable to pass state into the dfs generator.
    # Whenever we set it to True, we exclude a subtree from traversal.
    crossed_scope_boundary = False

    for node in expression.walk(
        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
    ):
        crossed_scope_boundary = False

        yield node

        if node is expression:
            continue
        if (
            isinstance(node, exp.CTE)
            or (
                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
                and (_is_derived_table(node) or isinstance(node, exp.UDTF))
            )
            or isinstance(node, exp.UNWRAPPED_QUERIES)
        ):
            crossed_scope_boundary = True

            if isinstance(node, (exp.Subquery, exp.UDTF)):
                # The following args are not actually in the inner scope, so we should visit them
                for key in ("joins", "laterals", "pivots"):
                    for arg in node.args.get(key) or []:
                        yield from walk_in_scope(arg, bfs=bfs)


def find_all_in_scope(expression, expression_types, bfs=True):
    """
    Returns a generator object which visits all nodes in this scope and only yields those that
    match at least one of the specified expression types.

    This does NOT traverse into subscopes.

    Args:
        expression (exp.Expression):
        expression_types (tuple[type]|type): the expression type(s) to match.
        bfs (bool): True to use breadth-first search, False to use depth-first.

    Yields:
        exp.Expression: nodes
    """
    for expression in walk_in_scope(expression, bfs=bfs):
        if isinstance(expression, tuple(ensure_collection(expression_types))):
            yield expression


def find_in_scope(expression, expression_types, bfs=True):
    """
    Returns the first node in this scope which matches at least one of the specified types.

    This does NOT traverse into subscopes.

    Args:
        expression (exp.Expression):
        expression_types (tuple[type]|type): the expression type(s) to match.
        bfs (bool): True to use breadth-first search, False to use depth-first.

    Returns:
        exp.Expression: the node which matches the criteria or None if no node matching
        the criteria was found.
    """
    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)