import pickle
from copy import (
    copy,
    deepcopy,
)
from dataclasses import dataclass
from typing import (
    Dict,
    List,
)
from unittest.mock import ANY

import cachelib

import aristaproto
from aristaproto.lib.google import protobuf as google


def unpickled(message):
    return pickle.loads(pickle.dumps(message))


@dataclass(eq=False, repr=False)
class Fe(aristaproto.Message):
    abc: str = aristaproto.string_field(1)


@dataclass(eq=False, repr=False)
class Fi(aristaproto.Message):
    abc: str = aristaproto.string_field(1)


@dataclass(eq=False, repr=False)
class Fo(aristaproto.Message):
    abc: str = aristaproto.string_field(1)


@dataclass(eq=False, repr=False)
class NestedData(aristaproto.Message):
    struct_foo: Dict[str, "google.Struct"] = aristaproto.map_field(
        1, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
    )
    map_str_any_bar: Dict[str, "google.Any"] = aristaproto.map_field(
        2, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
    )


@dataclass(eq=False, repr=False)
class Complex(aristaproto.Message):
    foo_str: str = aristaproto.string_field(1)
    fe: "Fe" = aristaproto.message_field(3, group="grp")
    fi: "Fi" = aristaproto.message_field(4, group="grp")
    fo: "Fo" = aristaproto.message_field(5, group="grp")
    nested_data: "NestedData" = aristaproto.message_field(6)
    mapping: Dict[str, "google.Any"] = aristaproto.map_field(
        7, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
    )


def complex_msg():
    return Complex(
        foo_str="yep",
        fe=Fe(abc="1"),
        nested_data=NestedData(
            struct_foo={
                "foo": google.Struct(
                    fields={
                        "hello": google.Value(
                            list_value=google.ListValue(
                                values=[google.Value(string_value="world")]
                            )
                        )
                    }
                ),
            },
            map_str_any_bar={
                "key": google.Any(value=b"value"),
            },
        ),
        mapping={
            "message": google.Any(value=bytes(Fi(abc="hi"))),
            "string": google.Any(value=b"howdy"),
        },
    )


def test_pickling_complex_message():
    msg = complex_msg()
    deser = unpickled(msg)
    assert msg == deser
    assert msg.fe.abc == "1"
    assert msg.is_set("fi") is not True
    assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
    assert msg.mapping["string"].value.decode() == "howdy"
    assert (
        msg.nested_data.struct_foo["foo"]
        .fields["hello"]
        .list_value.values[0]
        .string_value
        == "world"
    )


def test_recursive_message():
    from tests.output_aristaproto.recursivemessage import Test as RecursiveMessage

    msg = RecursiveMessage()
    msg = unpickled(msg)

    assert msg.child == RecursiveMessage()

    # Lazily-created zero-value children must not affect equality.
    assert msg == RecursiveMessage()

    # Lazily-created zero-value children must not affect serialization.
    assert bytes(msg) == b""


def test_recursive_message_defaults():
    from tests.output_aristaproto.recursivemessage import (
        Intermediate,
        Test as RecursiveMessage,
    )

    msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
    msg = unpickled(msg)

    # set values are as expected
    assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))

    # lazy initialized works modifies the message
    assert msg != RecursiveMessage(
        name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
    )
    msg.child.child.name = "jude"
    assert msg == RecursiveMessage(
        name="bob",
        intermediate=Intermediate(42),
        child=RecursiveMessage(child=RecursiveMessage(name="jude")),
    )

    # lazily initialization recurses as needed
    assert msg.child.child.child.child.child.child.child == RecursiveMessage()
    assert msg.intermediate.child.intermediate == Intermediate()


@dataclass
class PickledMessage(aristaproto.Message):
    foo: bool = aristaproto.bool_field(1)
    bar: int = aristaproto.int32_field(2)
    baz: List[str] = aristaproto.string_field(3)


def test_copyability():
    msg = PickledMessage(bar=12, baz=["hello"])
    msg = unpickled(msg)

    copied = copy(msg)
    assert msg == copied
    assert msg is not copied
    assert msg.baz is copied.baz

    deepcopied = deepcopy(msg)
    assert msg == deepcopied
    assert msg is not deepcopied
    assert msg.baz is not deepcopied.baz


def test_message_can_be_cached():
    """Cachelib uses pickling to cache values"""

    cache = cachelib.SimpleCache()

    def use_cache():
        calls = getattr(use_cache, "calls", 0)
        result = cache.get("message")
        if result is not None:
            return result
        else:
            setattr(use_cache, "calls", calls + 1)
            result = complex_msg()
            cache.set("message", result)
            return result

    for n in range(10):
        if n == 0:
            assert not cache.has("message")
        else:
            assert cache.has("message")

        msg = use_cache()
        assert use_cache.calls == 1  # The message is only ever built once
        assert msg.fe.abc == "1"
        assert msg.is_set("fi") is not True
        assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
        assert msg.mapping["string"].value.decode() == "howdy"
        assert (
            msg.nested_data.struct_foo["foo"]
            .fields["hello"]
            .list_value.values[0]
            .string_value
            == "world"
        )