225 lines
6.9 KiB
Python
225 lines
6.9 KiB
Python
import importlib
|
|
import json
|
|
import math
|
|
import os
|
|
import sys
|
|
from collections import namedtuple
|
|
from types import ModuleType
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Set,
|
|
Tuple,
|
|
)
|
|
|
|
import pytest
|
|
|
|
import aristaproto
|
|
from tests.inputs import config as test_input_config
|
|
from tests.mocks import MockChannel
|
|
from tests.util import (
|
|
find_module,
|
|
get_directories,
|
|
get_test_case_json_data,
|
|
inputs_path,
|
|
)
|
|
|
|
|
|
# Force pure-python implementation instead of C++, otherwise imports
|
|
# break things because we can't properly reset the symbol database.
|
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
|
|
|
from google.protobuf.json_format import Parse
|
|
|
|
|
|
class TestCases:
|
|
def __init__(
|
|
self,
|
|
path,
|
|
services: Set[str],
|
|
xfail: Set[str],
|
|
):
|
|
_all = set(get_directories(path)) - {"__pycache__"}
|
|
_services = services
|
|
_messages = (_all - services) - {"__pycache__"}
|
|
_messages_with_json = {
|
|
test for test in _messages if get_test_case_json_data(test)
|
|
}
|
|
|
|
unknown_xfail_tests = xfail - _all
|
|
if unknown_xfail_tests:
|
|
raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}")
|
|
|
|
self.all = self.apply_xfail_marks(_all, xfail)
|
|
self.services = self.apply_xfail_marks(_services, xfail)
|
|
self.messages = self.apply_xfail_marks(_messages, xfail)
|
|
self.messages_with_json = self.apply_xfail_marks(_messages_with_json, xfail)
|
|
|
|
@staticmethod
|
|
def apply_xfail_marks(test_set: Set[str], xfail: Set[str]):
|
|
return [
|
|
pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test
|
|
for test in test_set
|
|
]
|
|
|
|
|
|
test_cases = TestCases(
|
|
path=inputs_path,
|
|
services=test_input_config.services,
|
|
xfail=test_input_config.xfail,
|
|
)
|
|
|
|
plugin_output_package = "tests.output_aristaproto"
|
|
reference_output_package = "tests.output_reference"
|
|
|
|
TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
|
|
|
|
|
|
def module_has_entry_point(module: ModuleType):
|
|
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
|
|
|
|
|
|
def list_replace_nans(items: List) -> List[Any]:
|
|
"""Replace float("nan") in a list with the string "NaN"
|
|
|
|
Parameters
|
|
----------
|
|
items : List
|
|
List to update
|
|
|
|
Returns
|
|
-------
|
|
List[Any]
|
|
Updated list
|
|
"""
|
|
result = []
|
|
for item in items:
|
|
if isinstance(item, list):
|
|
result.append(list_replace_nans(item))
|
|
elif isinstance(item, dict):
|
|
result.append(dict_replace_nans(item))
|
|
elif isinstance(item, float) and math.isnan(item):
|
|
result.append(aristaproto.NAN)
|
|
return result
|
|
|
|
|
|
def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
|
|
"""Replace float("nan") in a dictionary with the string "NaN"
|
|
|
|
Parameters
|
|
----------
|
|
input_dict : Dict[Any, Any]
|
|
Dictionary to update
|
|
|
|
Returns
|
|
-------
|
|
Dict[Any, Any]
|
|
Updated dictionary
|
|
"""
|
|
result = {}
|
|
for key, value in input_dict.items():
|
|
if isinstance(value, dict):
|
|
value = dict_replace_nans(value)
|
|
elif isinstance(value, list):
|
|
value = list_replace_nans(value)
|
|
elif isinstance(value, float) and math.isnan(value):
|
|
value = aristaproto.NAN
|
|
result[key] = value
|
|
return result
|
|
|
|
|
|
@pytest.fixture
|
|
def test_data(request, reset_sys_path):
|
|
test_case_name = request.param
|
|
|
|
reference_module_root = os.path.join(
|
|
*reference_output_package.split("."), test_case_name
|
|
)
|
|
sys.path.append(reference_module_root)
|
|
|
|
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
|
|
|
|
plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
|
|
|
|
if not plugin_module_entry_point:
|
|
raise Exception(
|
|
f"Test case {repr(test_case_name)} has no entry point. "
|
|
"Please add a proto message or service called Test and recompile."
|
|
)
|
|
|
|
yield (
|
|
TestData(
|
|
plugin_module=plugin_module_entry_point,
|
|
reference_module=lambda: importlib.import_module(
|
|
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
|
|
),
|
|
json_data=get_test_case_json_data(test_case_name),
|
|
)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
|
|
def test_message_can_instantiated(test_data: TestData) -> None:
|
|
plugin_module, *_ = test_data
|
|
plugin_module.Test()
|
|
|
|
|
|
@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
|
|
def test_message_equality(test_data: TestData) -> None:
|
|
plugin_module, *_ = test_data
|
|
message1 = plugin_module.Test()
|
|
message2 = plugin_module.Test()
|
|
assert message1 == message2
|
|
|
|
|
|
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
|
|
def test_message_json(repeat, test_data: TestData) -> None:
|
|
plugin_module, _, json_data = test_data
|
|
|
|
for _ in range(repeat):
|
|
for sample in json_data:
|
|
if sample.belongs_to(test_input_config.non_symmetrical_json):
|
|
continue
|
|
|
|
message: aristaproto.Message = plugin_module.Test()
|
|
|
|
message.from_json(sample.json)
|
|
message_json = message.to_json(0)
|
|
|
|
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
|
|
json.loads(sample.json)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
|
|
def test_service_can_be_instantiated(test_data: TestData) -> None:
|
|
test_data.plugin_module.TestStub(MockChannel())
|
|
|
|
|
|
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
|
|
def test_binary_compatibility(repeat, test_data: TestData) -> None:
|
|
plugin_module, reference_module, json_data = test_data
|
|
|
|
for sample in json_data:
|
|
reference_instance = Parse(sample.json, reference_module().Test())
|
|
reference_binary_output = reference_instance.SerializeToString()
|
|
|
|
for _ in range(repeat):
|
|
plugin_instance_from_json: aristaproto.Message = (
|
|
plugin_module.Test().from_json(sample.json)
|
|
)
|
|
plugin_instance_from_binary = plugin_module.Test.FromString(
|
|
reference_binary_output
|
|
)
|
|
|
|
# Generally this can't be relied on, but here we are aiming to match the
|
|
# existing Python implementation and aren't doing anything tricky.
|
|
# https://developers.google.com/protocol-buffers/docs/encoding#implications
|
|
assert bytes(plugin_instance_from_json) == reference_binary_output
|
|
assert bytes(plugin_instance_from_binary) == reference_binary_output
|
|
|
|
assert plugin_instance_from_json == plugin_instance_from_binary
|
|
assert dict_replace_nans(
|
|
plugin_instance_from_json.to_dict()
|
|
) == dict_replace_nans(plugin_instance_from_binary.to_dict())
|