anta/anta/tools.py
Daniel Baumann 3254dea030
Merging upstream version 1.4.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-05-15 09:34:30 +02:00

417 lines
12 KiB
Python

# Copyright (c) 2023-2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Common functions used in ANTA tests."""
from __future__ import annotations
import cProfile
import os
import pstats
import re
from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from anta.constants import ACRONYM_CATEGORIES
from anta.custom_types import REGEXP_PATH_MARKERS
from anta.logger import format_td
if TYPE_CHECKING:
import sys
from logging import Logger
from types import TracebackType
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
F = TypeVar("F", bound=Callable[..., Any])
def get_failed_logs(expected_output: dict[Any, Any], actual_output: dict[Any, Any]) -> str:
"""Get the failed log for a test.
Returns the failed log or an empty string if there is no difference between the expected and actual output.
Parameters
----------
expected_output
Expected output of a test.
actual_output
Actual output of a test
Returns
-------
str
Failed log of a test.
"""
failed_logs = []
for element, expected_data in expected_output.items():
actual_data = actual_output.get(element)
if actual_data == expected_data:
continue
if actual_data is None:
failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but it was not found in the actual output.")
continue
# actual_data != expected_data: and actual_data is not None
failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but found `{actual_data}` instead.")
return "".join(failed_logs)
def custom_division(numerator: float, denominator: float) -> int | float:
"""Get the custom division of numbers.
Custom division that returns an integer if the result is an integer, otherwise a float.
Parameters
----------
numerator
The numerator.
denominator
The denominator.
Returns
-------
Union[int, float]
The result of the division.
"""
result = numerator / denominator
return int(result) if result.is_integer() else result
def get_dict_superset(
list_of_dicts: list[dict[Any, Any]],
input_dict: dict[Any, Any],
default: Any | None = None,
var_name: str | None = None,
custom_error_msg: str | None = None,
*,
required: bool = False,
) -> Any:
"""Get the first dictionary from a list of dictionaries that is a superset of the input dict.
Returns the supplied default value or None if there is no match and "required" is False.
Will return the first matching item if there are multiple matching items.
Parameters
----------
list_of_dicts: list(dict)
List of Dictionaries to get list items from
input_dict : dict
Dictionary to check subset with a list of dict
default: any
Default value returned if the key and value are not found
required: bool
Fail if there is no match
var_name : str
String used for raising an exception with the full variable name
custom_error_msg : str
Custom error message to raise when required is True and the value is not found
Returns
-------
any
Dict or default value
Raises
------
ValueError
If the keys and values are not found and "required" == True
"""
if not isinstance(list_of_dicts, list) or not list_of_dicts or not isinstance(input_dict, dict) or not input_dict:
if required:
error_msg = custom_error_msg or f"{var_name} not found in the provided list."
raise ValueError(error_msg)
return default
for list_item in list_of_dicts:
if isinstance(list_item, dict) and input_dict.items() <= list_item.items():
return list_item
if required:
error_msg = custom_error_msg or f"{var_name} not found in the provided list."
raise ValueError(error_msg)
return default
def get_value(
dictionary: dict[Any, Any],
key: str,
default: Any | None = None,
org_key: str | None = None,
separator: str = ".",
*,
required: bool = False,
) -> Any:
"""Get a value from a dictionary or nested dictionaries.
Key supports dot-notation like "foo.bar" to do deeper lookups.
Returns the supplied default value or None if the key is not found and required is False.
Parameters
----------
dictionary : dict
Dictionary to get key from
key : str
Dictionary Key - supporting dot-notation for nested dictionaries
default : any
Default value returned if the key is not found
required : bool
Fail if the key is not found
org_key : str
Internal variable used for raising exception with the full key name even when called recursively
separator: str
String to use as the separator parameter in the split function. Useful in cases when the key
can contain variables with "." inside (e.g. hostnames)
Returns
-------
any
Value or default value
Raises
------
ValueError
If the key is not found and required == True.
"""
if org_key is None:
org_key = key
keys = key.split(separator)
value = dictionary.get(keys[0])
if value is None:
if required:
raise ValueError(org_key)
return default
if len(keys) > 1:
return get_value(value, separator.join(keys[1:]), default=default, required=required, org_key=org_key, separator=separator)
return value
def get_item(
list_of_dicts: list[dict[Any, Any]],
key: Any,
value: Any,
default: Any | None = None,
var_name: str | None = None,
custom_error_msg: str | None = None,
*,
required: bool = False,
case_sensitive: bool = False,
) -> Any:
"""Get one dictionary from a list of dictionaries by matching the given key and value.
Returns the supplied default value or None if there is no match and "required" is False.
Will return the first matching item if there are multiple matching items.
Parameters
----------
list_of_dicts : list(dict)
List of Dictionaries to get list item from
key : any
Dictionary Key to match on
value : any
Value that must match
default : any
Default value returned if the key and value is not found
required : bool
Fail if there is no match
case_sensitive : bool
If the search value is a string, the comparison will ignore case by default
var_name : str
String used for raising exception with the full variable name
custom_error_msg : str
Custom error message to raise when required is True and the value is not found
Returns
-------
any
Dict or default value
Raises
------
ValueError
If the key and value is not found and "required" == True
"""
if var_name is None:
var_name = key
if (not isinstance(list_of_dicts, list)) or list_of_dicts == [] or value is None or key is None:
if required is True:
raise ValueError(custom_error_msg or var_name)
return default
for list_item in list_of_dicts:
if not isinstance(list_item, dict):
# List item is not a dict as required. Skip this item
continue
item_value = list_item.get(key)
# Perform case-insensitive comparison if value and item_value are strings and case_sensitive is False
if not case_sensitive and isinstance(value, str) and isinstance(item_value, str):
if item_value.casefold() == value.casefold():
return list_item
elif item_value == value:
# Match. Return this item
return list_item
# No Match
if required is True:
raise ValueError(custom_error_msg or var_name)
return default
class Catchtime:
"""A class working as a context to capture time differences."""
start: float
raw_time: float
time: str
def __init__(self, logger: Logger | None = None, message: str | None = None) -> None:
self.logger = logger
self.message = message
def __enter__(self) -> Self:
"""__enter__ method."""
self.start = perf_counter()
if self.logger and self.message:
self.logger.debug("%s ...", self.message)
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None:
"""__exit__ method."""
self.raw_time = perf_counter() - self.start
self.time = format_td(self.raw_time, 3)
if self.logger and self.message:
self.logger.debug("%s completed in: %s.", self.message, self.time)
def cprofile(sort_by: str = "cumtime") -> Callable[[F], F]:
"""Profile a function with cProfile.
profile is conditionally enabled based on the presence of ANTA_CPROFILE environment variable.
Expect to decorate an async function.
Parameters
----------
sort_by
The criterion to sort the profiling results. Default is 'cumtime'.
Returns
-------
Callable
The decorated function with conditional profiling.
"""
def decorator(func: F) -> F:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Enable cProfile or not.
If `ANTA_CPROFILE` is set, cProfile is enabled and dumps the stats to the file.
Parameters
----------
*args
Arbitrary positional arguments.
**kwargs
Arbitrary keyword arguments.
Returns
-------
Any
The result of the function call.
"""
cprofile_file = os.environ.get("ANTA_CPROFILE")
if cprofile_file is not None:
profiler = cProfile.Profile()
profiler.enable()
try:
result = await func(*args, **kwargs)
finally:
if cprofile_file is not None:
profiler.disable()
stats = pstats.Stats(profiler).sort_stats(sort_by)
stats.dump_stats(cprofile_file)
return result
return cast("F", wrapper)
return decorator
def safe_command(command: str) -> str:
"""Return a sanitized command.
Parameters
----------
command
The command to sanitize.
Returns
-------
str
The sanitized command.
"""
return re.sub(rf"{REGEXP_PATH_MARKERS}", "_", command)
def convert_categories(categories: list[str]) -> list[str]:
"""Convert categories for reports.
If the category is part of the defined acronym, transform it to upper case
otherwise capitalize the first letter.
Parameters
----------
categories
A list of categories
Returns
-------
list[str]
The list of converted categories
"""
if isinstance(categories, list):
return [" ".join(word.upper() if word.lower() in ACRONYM_CATEGORIES else word.title() for word in category.split()) for category in categories]
msg = f"Wrong input type '{type(categories)}' for convert_categories."
raise TypeError(msg)
def format_data(data: dict[str, bool]) -> str:
"""Format a data dictionary for logging purposes.
Parameters
----------
data
A dictionary containing the data to format.
Returns
-------
str
The formatted data.
Example
-------
>>> format_data({"advertised": True, "received": True, "enabled": True})
"Advertised: True, Received: True, Enabled: True"
"""
return ", ".join(f"{k.capitalize()}: {v}" for k, v in data.items())