from __future__ import annotations

import abc
import typing as t

import sqlglot
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie

if t.TYPE_CHECKING:
    from sqlglot.dataframe.sql.types import StructType

    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]

TABLE_ARGS = ("this", "db", "catalog")

T = t.TypeVar("T")


class Schema(abc.ABC):
    """Abstract base class for database schemas"""

    @abc.abstractmethod
    def add_table(
        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
    ) -> None:
        """
        Register or update a table. Some implementing classes may require column information to also be provided.

        Args:
            table: table expression instance or string representing the table.
            column_mapping: a column mapping that describes the structure of the table.
        """

    @abc.abstractmethod
    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
        """
        Get the column names for a table.

        Args:
            table: the `Table` expression instance.
            only_visible: whether to include invisible columns.

        Returns:
            The list of column names.
        """

    @abc.abstractmethod
    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
        """
        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.

        Args:
            table: the source table.
            column: the target column.

        Returns:
            The resulting column type.
        """

    @property
    def supported_table_args(self) -> t.Tuple[str, ...]:
        """
        Table arguments this schema support, e.g. `("this", "db", "catalog")`
        """
        raise NotImplementedError


class AbstractMappingSchema(t.Generic[T]):
    def __init__(
        self,
        mapping: dict | None = None,
    ) -> None:
        self.mapping = mapping or {}
        self.mapping_trie = self._build_trie(self.mapping)
        self._supported_table_args: t.Tuple[str, ...] = tuple()

    def _build_trie(self, schema: t.Dict) -> t.Dict:
        return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))

    def _depth(self) -> int:
        return dict_depth(self.mapping)

    @property
    def supported_table_args(self) -> t.Tuple[str, ...]:
        if not self._supported_table_args and self.mapping:
            depth = self._depth()

            if not depth:  # None
                self._supported_table_args = tuple()
            elif 1 <= depth <= 3:
                self._supported_table_args = TABLE_ARGS[:depth]
            else:
                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")

        return self._supported_table_args

    def table_parts(self, table: exp.Table) -> t.List[str]:
        if isinstance(table.this, exp.ReadCSV):
            return [table.this.name]
        return [table.text(part) for part in TABLE_ARGS if table.text(part)]

    def find(
        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
    ) -> t.Optional[T]:
        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)

        if value == 0:
            return None
        elif value == 1:
            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
            if len(possibilities) == 1:
                parts.extend(possibilities[0])
            else:
                message = ", ".join(".".join(parts) for parts in possibilities)
                if raise_on_missing:
                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
                return None
        return self._nested_get(parts, raise_on_missing=raise_on_missing)

    def _nested_get(
        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
    ) -> t.Optional[t.Any]:
        return _nested_get(
            d or self.mapping,
            *zip(self.supported_table_args, reversed(parts)),
            raise_on_missing=raise_on_missing,
        )


class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
    """
    Schema based on a nested mapping.

    Args:
        schema (dict): Mapping in one of the following forms:
            1. {table: {col: type}}
            2. {db: {table: {col: type}}}
            3. {catalog: {db: {table: {col: type}}}}
            4. None - Tables will be added later
        visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
            are assumed to be visible. The nesting should mirror that of the schema:
            1. {table: set(*cols)}}
            2. {db: {table: set(*cols)}}}
            3. {catalog: {db: {table: set(*cols)}}}}
        dialect (str): The dialect to be used for custom type mappings.
    """

    def __init__(
        self,
        schema: t.Optional[t.Dict] = None,
        visible: t.Optional[t.Dict] = None,
        dialect: t.Optional[str] = None,
    ) -> None:
        self.dialect = dialect
        self.visible = visible or {}
        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
        super().__init__(self._normalize(schema or {}))

    @classmethod
    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
        return MappingSchema(
            schema=mapping_schema.mapping,
            visible=mapping_schema.visible,
            dialect=mapping_schema.dialect,
        )

    def copy(self, **kwargs) -> MappingSchema:
        return MappingSchema(
            **{  # type: ignore
                "schema": self.mapping.copy(),
                "visible": self.visible.copy(),
                "dialect": self.dialect,
                **kwargs,
            }
        )

    def _normalize(self, schema: t.Dict) -> t.Dict:
        """
        Converts all identifiers in the schema into lowercase, unless they're quoted.

        Args:
            schema: the schema to normalize.

        Returns:
            The normalized schema mapping.
        """
        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)

        normalized_mapping: t.Dict = {}
        for keys in flattened_schema:
            columns = _nested_get(schema, *zip(keys, keys))
            assert columns is not None

            normalized_keys = [self._normalize_name(key) for key in keys]
            for column_name, column_type in columns.items():
                _nested_set(
                    normalized_mapping,
                    normalized_keys + [self._normalize_name(column_name)],
                    column_type,
                )

        return normalized_mapping

    def add_table(
        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
    ) -> None:
        """
        Register or update a table. Updates are only performed if a new column mapping is provided.

        Args:
            table: the `Table` expression instance or string representing the table.
            column_mapping: a column mapping that describes the structure of the table.
        """
        table_ = self._ensure_table(table)
        column_mapping = ensure_column_mapping(column_mapping)
        schema = self.find(table_, raise_on_missing=False)

        if schema and not column_mapping:
            return

        _nested_set(
            self.mapping,
            list(reversed(self.table_parts(table_))),
            column_mapping,
        )
        self.mapping_trie = self._build_trie(self.mapping)

    def _normalize_name(self, name: str) -> str:
        try:
            identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
                name, read=self.dialect, into=exp.Identifier
            )
        except:
            identifier = exp.to_identifier(name)
        assert isinstance(identifier, exp.Identifier)

        if identifier.quoted:
            return identifier.name
        return identifier.name.lower()

    def _depth(self) -> int:
        # The columns themselves are a mapping, but we don't want to include those
        return super()._depth() - 1

    def _ensure_table(self, table: exp.Table | str) -> exp.Table:
        table_ = exp.to_table(table)

        if not table_:
            raise SchemaError(f"Not a valid table '{table}'")

        return table_

    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
        table_ = self._ensure_table(table)
        schema = self.find(table_)

        if schema is None:
            return []

        if not only_visible or not self.visible:
            return list(schema)

        visible = self._nested_get(self.table_parts(table_), self.visible)
        return [col for col in schema if col in visible]  # type: ignore

    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
        column_name = column if isinstance(column, str) else column.name
        table_ = exp.to_table(table)
        if table_:
            table_schema = self.find(table_, raise_on_missing=False)
            if table_schema:
                column_type = table_schema.get(column_name)

                if isinstance(column_type, exp.DataType):
                    return column_type
                elif isinstance(column_type, str):
                    return self._to_data_type(column_type.upper())
                raise SchemaError(f"Unknown column type '{column_type}'")
            return exp.DataType(this=exp.DataType.Type.UNKNOWN)
        raise SchemaError(f"Could not convert table '{table}'")

    def _to_data_type(self, schema_type: str) -> exp.DataType:
        """
        Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.

        Args:
            schema_type: the type we want to convert.

        Returns:
            The resulting expression type.
        """
        if schema_type not in self._type_mapping_cache:
            try:
                expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
                if expression is None:
                    raise ValueError(f"Could not parse {schema_type}")
                self._type_mapping_cache[schema_type] = expression  # type: ignore
            except AttributeError:
                raise SchemaError(f"Failed to convert type {schema_type}")

        return self._type_mapping_cache[schema_type]


def ensure_schema(schema: t.Any) -> Schema:
    if isinstance(schema, Schema):
        return schema

    return MappingSchema(schema)


def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
    if isinstance(mapping, dict):
        return mapping
    elif isinstance(mapping, str):
        col_name_type_strs = [x.strip() for x in mapping.split(",")]
        return {
            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
            for name_type_str in col_name_type_strs
        }
    # Check if mapping looks like a DataFrame StructType
    elif hasattr(mapping, "simpleString"):
        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}  # type: ignore
    elif isinstance(mapping, list):
        return {x.strip(): None for x in mapping}
    elif mapping is None:
        return {}
    raise ValueError(f"Invalid mapping provided: {type(mapping)}")


def flatten_schema(
    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
) -> t.List[t.List[str]]:
    tables = []
    keys = keys or []

    for k, v in schema.items():
        if depth >= 2:
            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
        elif depth == 1:
            tables.append(keys + [k])
    return tables


def _nested_get(
    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
    """
    Get a value for a nested dictionary.

    Args:
        d: the dictionary to search.
        *path: tuples of (name, key), where:
            `key` is the key in the dictionary to get.
            `name` is a string to use in the error if `key` isn't found.

    Returns:
        The value or None if it doesn't exist.
    """
    for name, key in path:
        d = d.get(key)  # type: ignore
        if d is None:
            if raise_on_missing:
                name = "table" if name == "this" else name
                raise ValueError(f"Unknown {name}: {key}")
            return None
    return d


def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
    """
    In-place set a value for a nested dictionary

    Example:
        >>> _nested_set({}, ["top_key", "second_key"], "value")
        {'top_key': {'second_key': 'value'}}

        >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}

    Args:
        d: dictionary to update.
            keys: the keys that makeup the path to `value`.
            value: the value to set in the dictionary for the given key path.

        Returns:
                The (possibly) updated dictionary.
    """
    if not keys:
        return d

    if len(keys) == 1:
        d[keys[0]] = value
        return d

    subd = d
    for key in keys[:-1]:
        if key not in subd:
            subd = subd.setdefault(key, {})
        else:
            subd = subd[key]

    subd[keys[-1]] = value
    return d