anta/anta/device.py
Daniel Baumann a1777afd4b
Adding upstream version 0.15.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-05 11:39:42 +01:00

443 lines
19 KiB
Python

# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""ANTA Device Abstraction Module."""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Literal
import asyncssh
import httpcore
from aiocache import Cache
from aiocache.plugins import HitMissRatioPlugin
from asyncssh import SSHClientConnection, SSHClientConnectionOptions
from httpx import ConnectError, HTTPError, TimeoutException
import asynceapi
from anta import __DEBUG__
from anta.logger import anta_log_exception, exc_to_str
from anta.models import AntaCommand
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
logger = logging.getLogger(__name__)
# Do not load the default keypairs multiple times due to a performance issue introduced in cryptography 37.0
# https://github.com/pyca/cryptography/issues/7236#issuecomment-1131908472
CLIENT_KEYS = asyncssh.public_key.load_default_keypairs()
class AntaDevice(ABC):
"""Abstract class representing a device in ANTA.
An implementation of this class must override the abstract coroutines `_collect()` and
`refresh()`.
Attributes
----------
name: Device name
is_online: True if the device IP is reachable and a port can be open.
established: True if remote command execution succeeds.
hw_model: Hardware model of the device.
tags: Tags for this device.
cache: In-memory cache from aiocache library for this device (None if cache is disabled).
cache_locks: Dictionary mapping keys to asyncio locks to guarantee exclusive access to the cache if not disabled.
"""
def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bool = False) -> None:
"""Initialize an AntaDevice.
Args:
----
name: Device name.
tags: Tags for this device.
disable_cache: Disable caching for all commands for this device.
"""
self.name: str = name
self.hw_model: str | None = None
self.tags: set[str] = tags if tags is not None else set()
# A device always has its own name as tag
self.tags.add(self.name)
self.is_online: bool = False
self.established: bool = False
self.cache: Cache | None = None
self.cache_locks: defaultdict[str, asyncio.Lock] | None = None
# Initialize cache if not disabled
if not disable_cache:
self._init_cache()
@property
@abstractmethod
def _keys(self) -> tuple[Any, ...]:
"""Read-only property to implement hashing and equality for AntaDevice classes."""
def __eq__(self, other: object) -> bool:
"""Implement equality for AntaDevice objects."""
return self._keys == other._keys if isinstance(other, self.__class__) else False
def __hash__(self) -> int:
"""Implement hashing for AntaDevice objects."""
return hash(self._keys)
def _init_cache(self) -> None:
"""Initialize cache for the device, can be overridden by subclasses to manipulate how it works."""
self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()])
self.cache_locks = defaultdict(asyncio.Lock)
@property
def cache_statistics(self) -> dict[str, Any] | None:
"""Returns the device cache statistics for logging purposes."""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
if self.cache is not None:
stats = getattr(self.cache, "hit_miss_ratio", {"total": 0, "hits": 0, "hit_ratio": 0})
return {"total_commands_sent": stats["total"], "cache_hits": stats["hits"], "cache_hit_ratio": f"{stats['hit_ratio'] * 100:.2f}%"}
return None
def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
"""Implement Rich Repr Protocol.
https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol.
"""
yield "name", self.name
yield "tags", self.tags
yield "hw_model", self.hw_model
yield "is_online", self.is_online
yield "established", self.established
yield "disable_cache", self.cache is None
@abstractmethod
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect device command output.
This abstract coroutine can be used to implement any command collection method
for a device in ANTA.
The `_collect()` implementation needs to populate the `output` attribute
of the `AntaCommand` object passed as argument.
If a failure occurs, the `_collect()` implementation is expected to catch the
exception and implement proper logging, the `output` attribute of the
`AntaCommand` object passed as argument would be `None` in this case.
Args:
----
command: The command to collect.
collection_id: An identifier used to build the eAPI request ID.
"""
async def collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect the output for a specified command.
When caching is activated on both the device and the command,
this method prioritizes retrieving the output from the cache. In cases where the output isn't cached yet,
it will be freshly collected and then stored in the cache for future access.
The method employs asynchronous locks based on the command's UID to guarantee exclusive access to the cache.
When caching is NOT enabled, either at the device or command level, the method directly collects the output
via the private `_collect` method without interacting with the cache.
Args:
----
command: The command to collect.
collection_id: An identifier used to build the eAPI request ID.
"""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
if self.cache is not None and self.cache_locks is not None and command.use_cache:
async with self.cache_locks[command.uid]:
cached_output = await self.cache.get(command.uid) # pylint: disable=no-member
if cached_output is not None:
logger.debug("Cache hit for %s on %s", command.command, self.name)
command.output = cached_output
else:
await self._collect(command=command, collection_id=collection_id)
await self.cache.set(command.uid, command.output) # pylint: disable=no-member
else:
await self._collect(command=command, collection_id=collection_id)
async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None:
"""Collect multiple commands.
Args:
----
commands: The commands to collect.
collection_id: An identifier used to build the eAPI request ID.
"""
await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands))
@abstractmethod
async def refresh(self) -> None:
"""Update attributes of an AntaDevice instance.
This coroutine must update the following attributes of AntaDevice:
- `is_online`: When the device IP is reachable and a port can be open
- `established`: When a command execution succeeds
- `hw_model`: The hardware model of the device
"""
async def copy(self, sources: list[Path], destination: Path, direction: Literal["to", "from"] = "from") -> None:
"""Copy files to and from the device, usually through SCP.
It is not mandatory to implement this for a valid AntaDevice subclass.
Args:
----
sources: List of files to copy to or from the device.
destination: Local or remote destination when copying the files. Can be a folder.
direction: Defines if this coroutine copies files to or from the device.
"""
_ = (sources, destination, direction)
msg = f"copy() method has not been implemented in {self.__class__.__name__} definition"
raise NotImplementedError(msg)
class AsyncEOSDevice(AntaDevice):
"""Implementation of AntaDevice for EOS using aio-eapi.
Attributes
----------
name: Device name
is_online: True if the device IP is reachable and a port can be open
established: True if remote command execution succeeds
hw_model: Hardware model of the device
tags: Tags for this device
"""
# pylint: disable=R0913
def __init__(
self,
host: str,
username: str,
password: str,
name: str | None = None,
enable_password: str | None = None,
port: int | None = None,
ssh_port: int | None = 22,
tags: set[str] | None = None,
timeout: float | None = None,
proto: Literal["http", "https"] = "https",
*,
enable: bool = False,
insecure: bool = False,
disable_cache: bool = False,
) -> None:
"""Instantiate an AsyncEOSDevice.
Args:
----
host: Device FQDN or IP.
username: Username to connect to eAPI and SSH.
password: Password to connect to eAPI and SSH.
name: Device name.
enable: Collect commands using privileged mode.
enable_password: Password used to gain privileged access on EOS.
port: eAPI port. Defaults to 80 is proto is 'http' or 443 if proto is 'https'.
ssh_port: SSH port.
tags: Tags for this device.
timeout: Timeout value in seconds for outgoing API calls.
insecure: Disable SSH Host Key validation.
proto: eAPI protocol. Value can be 'http' or 'https'.
disable_cache: Disable caching for all commands for this device.
"""
if host is None:
message = "'host' is required to create an AsyncEOSDevice"
logger.error(message)
raise ValueError(message)
if name is None:
name = f"{host}{f':{port}' if port else ''}"
super().__init__(name, tags, disable_cache=disable_cache)
if username is None:
message = f"'username' is required to instantiate device '{self.name}'"
logger.error(message)
raise ValueError(message)
if password is None:
message = f"'password' is required to instantiate device '{self.name}'"
logger.error(message)
raise ValueError(message)
self.enable = enable
self._enable_password = enable_password
self._session: asynceapi.Device = asynceapi.Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout)
ssh_params: dict[str, Any] = {}
if insecure:
ssh_params["known_hosts"] = None
self._ssh_opts: SSHClientConnectionOptions = SSHClientConnectionOptions(
host=host, port=ssh_port, username=username, password=password, client_keys=CLIENT_KEYS, **ssh_params
)
def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
"""Implement Rich Repr Protocol.
https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol.
"""
yield from super().__rich_repr__()
yield ("host", self._session.host)
yield ("eapi_port", self._session.port)
yield ("username", self._ssh_opts.username)
yield ("enable", self.enable)
yield ("insecure", self._ssh_opts.known_hosts is None)
if __DEBUG__:
_ssh_opts = vars(self._ssh_opts).copy()
removed_pw = "<removed>"
_ssh_opts["password"] = removed_pw
_ssh_opts["kwargs"]["password"] = removed_pw
yield ("_session", vars(self._session))
yield ("_ssh_opts", _ssh_opts)
@property
def _keys(self) -> tuple[Any, ...]:
"""Two AsyncEOSDevice objects are equal if the hostname and the port are the same.
This covers the use case of port forwarding when the host is localhost and the devices have different ports.
"""
return (self._session.host, self._session.port)
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: C901 function is too complex - because of many required except blocks #pylint: disable=line-too-long
"""Collect device command output from EOS using aio-eapi.
Supports outformat `json` and `text` as output structure.
Gain privileged access using the `enable_password` attribute
of the `AntaDevice` instance if populated.
Args:
----
command: The command to collect.
collection_id: An identifier used to build the eAPI request ID.
"""
commands: list[dict[str, str | int]] = []
if self.enable and self._enable_password is not None:
commands.append(
{
"cmd": "enable",
"input": str(self._enable_password),
},
)
elif self.enable:
# No password
commands.append({"cmd": "enable"})
commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
try:
response: list[dict[str, Any] | str] = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}",
) # type: ignore[assignment] # multiple commands returns a list
# Do not keep response of 'enable' command
command.output = response[-1]
except asynceapi.EapiCommandError as e:
# This block catches exceptions related to EOS issuing an error.
command.errors = e.errors
if command.requires_privileges:
logger.error(
"Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", command.command, self.name
)
if command.supported:
logger.error("Command '%s' failed on %s: %s", command.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors)
else:
logger.debug("Command '%s' is not supported on '%s' (%s)", command.command, self.name, self.hw_model)
except TimeoutException as e:
# This block catches Timeout exceptions.
command.errors = [exc_to_str(e)]
timeouts = self._session.timeout.as_dict()
logger.error(
"%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
exc_to_str(e),
self.name,
timeouts["connect"],
timeouts["read"],
timeouts["write"],
timeouts["pool"],
)
except (ConnectError, OSError) as e:
# This block catches OSError and socket issues related exceptions.
command.errors = [exc_to_str(e)]
if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member
if isinstance(os_error.__cause__, OSError):
os_error = os_error.__cause__
logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
else:
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
except HTTPError as e:
# This block catches most of the httpx Exceptions and logs a general message.
command.errors = [exc_to_str(e)]
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
logger.debug("%s: %s", self.name, command)
async def refresh(self) -> None:
"""Update attributes of an AsyncEOSDevice instance.
This coroutine must update the following attributes of AsyncEOSDevice:
- is_online: When a device IP is reachable and a port can be open
- established: When a command execution succeeds
- hw_model: The hardware model of the device
"""
logger.debug("Refreshing device %s", self.name)
self.is_online = await self._session.check_connection()
if self.is_online:
show_version = AntaCommand(command="show version")
await self._collect(show_version)
if not show_version.collected:
logger.warning("Cannot get hardware information from device %s", self.name)
else:
self.hw_model = show_version.json_output.get("modelName", None)
if self.hw_model is None:
logger.critical("Cannot parse 'show version' returned by device %s", self.name)
else:
logger.warning("Could not connect to device %s: cannot open eAPI port", self.name)
self.established = bool(self.is_online and self.hw_model)
async def copy(self, sources: list[Path], destination: Path, direction: Literal["to", "from"] = "from") -> None:
"""Copy files to and from the device using asyncssh.scp().
Args:
----
sources: List of files to copy to or from the device.
destination: Local or remote destination when copying the files. Can be a folder.
direction: Defines if this coroutine copies files to or from the device.
"""
async with asyncssh.connect(
host=self._ssh_opts.host,
port=self._ssh_opts.port,
tunnel=self._ssh_opts.tunnel,
family=self._ssh_opts.family,
local_addr=self._ssh_opts.local_addr,
options=self._ssh_opts,
) as conn:
src: list[tuple[SSHClientConnection, Path]] | list[Path]
dst: tuple[SSHClientConnection, Path] | Path
if direction == "from":
src = [(conn, file) for file in sources]
dst = destination
for file in sources:
message = f"Copying '{file}' from device {self.name} to '{destination}' locally"
logger.info(message)
elif direction == "to":
src = sources
dst = conn, destination
for file in src:
message = f"Copying '{file}' to device {self.name} to '{destination}' remotely"
logger.info(message)
else:
logger.critical("'direction' argument to copy() function is invalid: %s", direction)
return
await asyncssh.scp(src, dst)