anta/tests/units/anta_tests/__init__.py

81 lines
2.7 KiB
Python
Raw Normal View History

# Copyright (c) 2023-2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Tests for anta.tests module."""
from __future__ import annotations
import asyncio
import sys
from typing import TYPE_CHECKING, Any, Literal, TypedDict
from anta.models import AntaTest
if TYPE_CHECKING:
from anta.device import AntaDevice
from anta.result_manager.models import AntaTestStatus
if sys.version_info >= (3, 11):
from typing import NotRequired
else:
from typing_extensions import NotRequired
class UnitTestResult(TypedDict):
"""Expected result of a unit test of an AntaTest subclass.
For our AntaTest unit tests we expect only success, failure or skipped.
Never unset nor error.
"""
result: Literal[AntaTestStatus.SUCCESS, AntaTestStatus.FAILURE, AntaTestStatus.SKIPPED]
messages: NotRequired[list[str]]
class AntaUnitTest(TypedDict):
"""The parameters required for a unit test of an AntaTest subclass."""
inputs: NotRequired[dict[str, Any]]
eos_data: list[dict[str, Any] | str]
expected: UnitTestResult
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
TypeAlias = type
AntaUnitTestDataDict: TypeAlias = dict[tuple[type[AntaTest], str], AntaUnitTest]
def test(device: AntaDevice, data: tuple[tuple[type[AntaTest], str], AntaUnitTest]) -> None:
"""Generic test function for AntaTest subclass.
Generate unit tests for each AntaTest subclass.
See `tests/units/anta_tests/README.md` for more information on how to use it.
"""
# Extract the test class, name and test data from a nested tuple structure:
# `val: Tuple[Tuple[Type[AntaTest], str], AntaUnitTest]`
(anta_test, name), test_data = data
# Instantiate the AntaTest subclass
test_instance = anta_test(device, inputs=test_data.get("inputs"), eos_data=test_data["eos_data"])
# Run the test() method
asyncio.run(test_instance.test())
# Assert expected result
assert test_instance.result.result == test_data["expected"]["result"], (
f"Expected '{test_data['expected']['result']}' result, got '{test_instance.result.result}'"
)
if "messages" in test_data["expected"]:
# We expect messages in test result
assert len(test_instance.result.messages) == len(test_data["expected"]["messages"])
# Test will pass if the expected message is included in the test result message
for message, expected in zip(test_instance.result.messages, test_data["expected"]["messages"]): # NOTE: zip(strict=True) has been added in Python 3.10
assert expected in message
else:
# Test result should not have messages
assert test_instance.result.messages == []