# Copyright (c) 2023-2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Tests for anta.cli.exec.utils."""

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from unittest.mock import call, patch

import pytest
import respx

from anta.cli.exec.utils import clear_counters, collect_commands
from anta.models import AntaCommand
from anta.tools import safe_command

# collect_scheduled_show_tech

if TYPE_CHECKING:
    from anta.device import AntaDevice
    from anta.inventory import AntaInventory


# TODO: complete test cases
@pytest.mark.parametrize(
    ("inventory", "inventory_state", "per_device_command_output", "tags"),
    [
        pytest.param(
            {"count": 3},
            {
                "device-0": {"is_online": False},
                "device-1": {"is_online": False},
                "device-2": {"is_online": False},
            },
            {},
            None,
            id="no_connected_device",
        ),
        pytest.param(
            {"count": 3},
            {
                "device-0": {"is_online": True, "hw_model": "cEOSLab"},
                "device-1": {"is_online": True, "hw_model": "vEOS-lab"},
                "device-2": {"is_online": False},
            },
            {},
            None,
            id="cEOSLab and vEOS-lab devices",
        ),
        pytest.param(
            {"count": 3},
            {
                "device-0": {"is_online": True},
                "device-1": {"is_online": True},
                "device-2": {"is_online": False},
            },
            {"device-0": None},  # None means the command failed to collect
            None,
            id="device with error",
        ),
        pytest.param(
            {"count": 3},
            {
                "device-0": {"is_online": True},
                "device-1": {"is_online": True},
                "device-2": {"is_online": True},
            },
            {},
            ["spine"],
            id="tags",
        ),
    ],
    indirect=["inventory"],
)
async def test_clear_counters(
    caplog: pytest.LogCaptureFixture,
    inventory: AntaInventory,
    inventory_state: dict[str, Any],
    per_device_command_output: dict[str, Any],
    tags: set[str] | None,
) -> None:
    """Test anta.cli.exec.utils.clear_counters."""

    async def mock_connect_inventory() -> None:
        """Mock connect_inventory coroutine."""
        for name, device in inventory.items():
            device.is_online = inventory_state[name].get("is_online", True)
            device.established = inventory_state[name].get("established", device.is_online)
            device.hw_model = inventory_state[name].get("hw_model", "dummy")

    async def collect(self: AntaDevice, command: AntaCommand, *args: Any, **kwargs: Any) -> None:  # noqa: ARG001, ANN401
        """Mock collect coroutine."""
        command.output = per_device_command_output.get(self.name, "")

    # Need to patch the child device class
    with (
        patch("anta.device.AsyncEOSDevice.collect", side_effect=collect, autospec=True) as mocked_collect,
        patch(
            "anta.inventory.AntaInventory.connect_inventory",
            side_effect=mock_connect_inventory,
        ) as mocked_connect_inventory,
    ):
        await clear_counters(inventory, tags=tags)

    mocked_connect_inventory.assert_awaited_once()
    devices_established = inventory.get_inventory(established_only=True, tags=tags).devices
    if devices_established:
        # Building the list of calls
        calls = []
        for device in devices_established:
            calls.append(
                call(
                    device,
                    command=AntaCommand(
                        command="clear counters",
                        version="latest",
                        revision=None,
                        ofmt="json",
                        output=per_device_command_output.get(device.name, ""),
                        errors=[],
                    ),
                    collection_id=None,
                ),
            )
            if device.hw_model not in ["cEOSLab", "vEOS-lab"]:
                calls.append(
                    call(
                        device,
                        command=AntaCommand(
                            command="clear hardware counter drop",
                            version="latest",
                            revision=None,
                            ofmt="json",
                            output=per_device_command_output.get(device.name, ""),
                        ),
                        collection_id=None,
                    ),
                )
        mocked_collect.assert_has_awaits(calls)
        # Check error
        for key, value in per_device_command_output.items():
            if value is None:
                # means some command failed to collect
                assert "ERROR" in caplog.text
                assert f"Could not clear counters on device {key}: []" in caplog.text
    else:
        mocked_collect.assert_not_awaited()


# TODO: test with changing root_dir, test with failing to write (OSError)
@pytest.mark.parametrize(
    ("inventory", "inventory_state", "commands", "tags"),
    [
        pytest.param(
            {"count": 1},
            {
                "device-0": {"is_online": False},
            },
            {"json_format": ["show version"]},
            None,
            id="no_connected_device",
        ),
        pytest.param(
            {"count": 3},
            {
                "device-0": {"is_online": True},
                "device-1": {"is_online": True},
                "device-2": {"is_online": False},
            },
            {"json_format": ["show version", "show ip interface brief"]},
            None,
            id="JSON commands",
        ),
        pytest.param(
            {"count": 3},
            {
                "device-0": {"is_online": True},
                "device-1": {"is_online": True},
                "device-2": {"is_online": False},
            },
            {"json_format": ["show version"], "text_format": ["show running-config", "show ip interface"]},
            None,
            id="Text commands",
        ),
        pytest.param(
            {"count": 2},
            {
                "device-0": {"is_online": True, "tags": {"spine"}},
                "device-1": {"is_online": True},
            },
            {"json_format": ["show version"]},
            {"spine"},
            id="tags",
        ),
        pytest.param(  # TODO: This test should not be there we should catch the wrong user input with pydantic.
            {"count": 1},
            {
                "device-0": {"is_online": True},
            },
            {"blah_format": ["42"]},
            None,
            id="bad-input",
        ),
        pytest.param(
            {"count": 1},
            {
                "device-0": {"is_online": True},
            },
            {"json_format": ["undefined command", "show version"]},
            None,
            id="command-failed-to-be-collected",
        ),
        pytest.param(
            {"count": 1},
            {
                "device-0": {"is_online": True},
            },
            {"json_format": ["uncaught exception"]},
            None,
            id="uncaught-exception",
        ),
    ],
    indirect=["inventory"],
)
async def test_collect_commands(
    caplog: pytest.LogCaptureFixture,
    tmp_path: Path,
    inventory: AntaInventory,
    inventory_state: dict[str, Any],
    commands: dict[str, list[str]],
    tags: set[str] | None,
) -> None:
    """Test anta.cli.exec.utils.collect_commands."""
    caplog.set_level(logging.INFO)
    root_dir = tmp_path

    async def mock_connect_inventory() -> None:
        """Mock connect_inventory coroutine."""
        for name, device in inventory.items():
            device.is_online = inventory_state[name].get("is_online", True)
            device.established = inventory_state[name].get("established", device.is_online)
            device.hw_model = inventory_state[name].get("hw_model", "dummy")
            device.tags = inventory_state[name].get("tags", set())

    # Need to patch the child device class
    # ruff: noqa: C901
    with (
        respx.mock,
        patch(
            "anta.inventory.AntaInventory.connect_inventory",
            side_effect=mock_connect_inventory,
        ) as mocked_connect_inventory,
    ):
        # Mocking responses from devices
        respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}, json__params__cmds__0__cmd="show version").respond(
            json={"result": [{"toto": 42}]}
        )
        respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}, json__params__cmds__0__cmd="show ip interface brief").respond(
            json={"result": [{"toto": 42}]}
        )
        respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}, json__params__cmds__0__cmd="show running-config").respond(
            json={"result": [{"output": "blah"}]}
        )
        respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}, json__params__cmds__0__cmd="show ip interface").respond(
            json={"result": [{"output": "blah"}]}
        )
        respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}, json__params__cmds__0__cmd="undefined command").respond(
            json={
                "error": {
                    "code": 1002,
                    "message": "CLI command 1 of 1 'undefined command' failed: invalid command",
                    "data": [{"errors": ["Invalid input (at token 0: 'undefined')"]}],
                }
            }
        )
        await collect_commands(inventory, commands, root_dir, tags=tags)

    mocked_connect_inventory.assert_awaited_once()
    devices_established = inventory.get_inventory(established_only=True, tags=tags or None).devices
    if not devices_established:
        assert "INFO" in caplog.text
        assert "No online device found. Exiting" in caplog.text
        return

    for device in devices_established:
        # Verify tags selection
        assert device.tags.intersection(tags) != {} if tags else True
        json_path = root_dir / device.name / "json"
        text_path = root_dir / device.name / "text"
        if "json_format" in commands:
            # Handle undefined command
            if "undefined command" in commands["json_format"]:
                assert "ERROR" in caplog.text
                assert "Command 'undefined command' failed on device-0: Invalid input (at token 0: 'undefined')" in caplog.text
                # Verify we don't claim it was collected
                assert f"Collected command 'undefined command' from device {device.name}" not in caplog.text
                commands["json_format"].remove("undefined command")
            # Handle uncaught exception
            elif "uncaught exception" in commands["json_format"]:
                assert "ERROR" in caplog.text
                assert "Error when collecting commands: " in caplog.text
                # Verify we don't claim it was collected
                assert f"Collected command 'uncaught exception' from device {device.name}" not in caplog.text
                commands["json_format"].remove("uncaught exception")

            assert json_path.is_dir()
            assert len(list(Path.iterdir(json_path))) == len(commands["json_format"])
            for command in commands["json_format"]:
                assert Path.is_file(json_path / f"{safe_command(command)}.json")
                assert f"Collected command '{command}' from device {device.name}" in caplog.text
        if "text_format" in commands:
            assert text_path.is_dir()
            assert len(list(text_path.iterdir())) == len(commands["text_format"])
            for command in commands["text_format"]:
                assert Path.is_file(text_path / f"{safe_command(command)}.log")
                assert f"Collected command '{command}' from device {device.name}" in caplog.text