import importlib
import json
import math
import os
import sys
from collections import namedtuple
from types import ModuleType
from typing import (
    Any,
    Dict,
    List,
    Set,
    Tuple,
)

import pytest

import aristaproto
from tests.inputs import config as test_input_config
from tests.mocks import MockChannel
from tests.util import (
    find_module,
    get_directories,
    get_test_case_json_data,
    inputs_path,
)


# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

from google.protobuf.json_format import Parse


class TestCases:
    def __init__(
        self,
        path,
        services: Set[str],
        xfail: Set[str],
    ):
        _all = set(get_directories(path)) - {"__pycache__"}
        _services = services
        _messages = (_all - services) - {"__pycache__"}
        _messages_with_json = {
            test for test in _messages if get_test_case_json_data(test)
        }

        unknown_xfail_tests = xfail - _all
        if unknown_xfail_tests:
            raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}")

        self.all = self.apply_xfail_marks(_all, xfail)
        self.services = self.apply_xfail_marks(_services, xfail)
        self.messages = self.apply_xfail_marks(_messages, xfail)
        self.messages_with_json = self.apply_xfail_marks(_messages_with_json, xfail)

    @staticmethod
    def apply_xfail_marks(test_set: Set[str], xfail: Set[str]):
        return [
            pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test
            for test in test_set
        ]


test_cases = TestCases(
    path=inputs_path,
    services=test_input_config.services,
    xfail=test_input_config.xfail,
)

plugin_output_package = "tests.output_aristaproto"
reference_output_package = "tests.output_reference"

TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])


def module_has_entry_point(module: ModuleType):
    return any(hasattr(module, attr) for attr in ["Test", "TestStub"])


def list_replace_nans(items: List) -> List[Any]:
    """Replace float("nan") in a list with the string "NaN"

    Parameters
    ----------
    items : List
            List to update

    Returns
    -------
    List[Any]
        Updated list
    """
    result = []
    for item in items:
        if isinstance(item, list):
            result.append(list_replace_nans(item))
        elif isinstance(item, dict):
            result.append(dict_replace_nans(item))
        elif isinstance(item, float) and math.isnan(item):
            result.append(aristaproto.NAN)
    return result


def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
    """Replace float("nan") in a dictionary with the string "NaN"

    Parameters
    ----------
    input_dict : Dict[Any, Any]
            Dictionary to update

    Returns
    -------
    Dict[Any, Any]
        Updated dictionary
    """
    result = {}
    for key, value in input_dict.items():
        if isinstance(value, dict):
            value = dict_replace_nans(value)
        elif isinstance(value, list):
            value = list_replace_nans(value)
        elif isinstance(value, float) and math.isnan(value):
            value = aristaproto.NAN
        result[key] = value
    return result


@pytest.fixture
def test_data(request, reset_sys_path):
    test_case_name = request.param

    reference_module_root = os.path.join(
        *reference_output_package.split("."), test_case_name
    )
    sys.path.append(reference_module_root)

    plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")

    plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)

    if not plugin_module_entry_point:
        raise Exception(
            f"Test case {repr(test_case_name)} has no entry point. "
            "Please add a proto message or service called Test and recompile."
        )

    yield (
        TestData(
            plugin_module=plugin_module_entry_point,
            reference_module=lambda: importlib.import_module(
                f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
            ),
            json_data=get_test_case_json_data(test_case_name),
        )
    )


@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
def test_message_can_instantiated(test_data: TestData) -> None:
    plugin_module, *_ = test_data
    plugin_module.Test()


@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
def test_message_equality(test_data: TestData) -> None:
    plugin_module, *_ = test_data
    message1 = plugin_module.Test()
    message2 = plugin_module.Test()
    assert message1 == message2


@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_message_json(repeat, test_data: TestData) -> None:
    plugin_module, _, json_data = test_data

    for _ in range(repeat):
        for sample in json_data:
            if sample.belongs_to(test_input_config.non_symmetrical_json):
                continue

            message: aristaproto.Message = plugin_module.Test()

            message.from_json(sample.json)
            message_json = message.to_json(0)

            assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
                json.loads(sample.json)
            )


@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
def test_service_can_be_instantiated(test_data: TestData) -> None:
    test_data.plugin_module.TestStub(MockChannel())


@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_binary_compatibility(repeat, test_data: TestData) -> None:
    plugin_module, reference_module, json_data = test_data

    for sample in json_data:
        reference_instance = Parse(sample.json, reference_module().Test())
        reference_binary_output = reference_instance.SerializeToString()

        for _ in range(repeat):
            plugin_instance_from_json: aristaproto.Message = (
                plugin_module.Test().from_json(sample.json)
            )
            plugin_instance_from_binary = plugin_module.Test.FromString(
                reference_binary_output
            )

            # Generally this can't be relied on, but here we are aiming to match the
            # existing Python implementation and aren't doing anything tricky.
            # https://developers.google.com/protocol-buffers/docs/encoding#implications
            assert bytes(plugin_instance_from_json) == reference_binary_output
            assert bytes(plugin_instance_from_binary) == reference_binary_output

            assert plugin_instance_from_json == plugin_instance_from_binary
            assert dict_replace_nans(
                plugin_instance_from_json.to_dict()
            ) == dict_replace_nans(plugin_instance_from_binary.to_dict())