1
0
Fork 0
python-aristaproto/tests/test_pickling.py
Daniel Baumann 8512f66c5a
Adding upstream version 1.2+20240521.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-05 14:24:15 +01:00

203 lines
5.6 KiB
Python

import pickle
from copy import (
copy,
deepcopy,
)
from dataclasses import dataclass
from typing import (
Dict,
List,
)
from unittest.mock import ANY
import cachelib
import aristaproto
from aristaproto.lib.google import protobuf as google
def unpickled(message):
return pickle.loads(pickle.dumps(message))
@dataclass(eq=False, repr=False)
class Fe(aristaproto.Message):
abc: str = aristaproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fi(aristaproto.Message):
abc: str = aristaproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fo(aristaproto.Message):
abc: str = aristaproto.string_field(1)
@dataclass(eq=False, repr=False)
class NestedData(aristaproto.Message):
struct_foo: Dict[str, "google.Struct"] = aristaproto.map_field(
1, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
)
map_str_any_bar: Dict[str, "google.Any"] = aristaproto.map_field(
2, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
)
@dataclass(eq=False, repr=False)
class Complex(aristaproto.Message):
foo_str: str = aristaproto.string_field(1)
fe: "Fe" = aristaproto.message_field(3, group="grp")
fi: "Fi" = aristaproto.message_field(4, group="grp")
fo: "Fo" = aristaproto.message_field(5, group="grp")
nested_data: "NestedData" = aristaproto.message_field(6)
mapping: Dict[str, "google.Any"] = aristaproto.map_field(
7, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
)
def complex_msg():
return Complex(
foo_str="yep",
fe=Fe(abc="1"),
nested_data=NestedData(
struct_foo={
"foo": google.Struct(
fields={
"hello": google.Value(
list_value=google.ListValue(
values=[google.Value(string_value="world")]
)
)
}
),
},
map_str_any_bar={
"key": google.Any(value=b"value"),
},
),
mapping={
"message": google.Any(value=bytes(Fi(abc="hi"))),
"string": google.Any(value=b"howdy"),
},
)
def test_pickling_complex_message():
msg = complex_msg()
deser = unpickled(msg)
assert msg == deser
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)
def test_recursive_message():
from tests.output_aristaproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage()
msg = unpickled(msg)
assert msg.child == RecursiveMessage()
# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()
# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""
def test_recursive_message_defaults():
from tests.output_aristaproto.recursivemessage import (
Intermediate,
Test as RecursiveMessage,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
msg = unpickled(msg)
# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)
# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()
@dataclass
class PickledMessage(aristaproto.Message):
foo: bool = aristaproto.bool_field(1)
bar: int = aristaproto.int32_field(2)
baz: List[str] = aristaproto.string_field(3)
def test_copyability():
msg = PickledMessage(bar=12, baz=["hello"])
msg = unpickled(msg)
copied = copy(msg)
assert msg == copied
assert msg is not copied
assert msg.baz is copied.baz
deepcopied = deepcopy(msg)
assert msg == deepcopied
assert msg is not deepcopied
assert msg.baz is not deepcopied.baz
def test_message_can_be_cached():
"""Cachelib uses pickling to cache values"""
cache = cachelib.SimpleCache()
def use_cache():
calls = getattr(use_cache, "calls", 0)
result = cache.get("message")
if result is not None:
return result
else:
setattr(use_cache, "calls", calls + 1)
result = complex_msg()
cache.set("message", result)
return result
for n in range(10):
if n == 0:
assert not cache.has("message")
else:
assert cache.has("message")
msg = use_cache()
assert use_cache.calls == 1 # The message is only ever built once
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)