#!/usr/bin/env python
import asyncio
import os
import platform
import shutil
import sys
from pathlib import Path
from typing import Set

from tests.util import (
    get_directories,
    inputs_path,
    output_path_aristaproto,
    output_path_aristaproto_pydantic,
    output_path_reference,
    protoc,
)


# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"


def clear_directory(dir_path: Path):
    for file_or_directory in dir_path.glob("*"):
        if file_or_directory.is_dir():
            shutil.rmtree(file_or_directory)
        else:
            file_or_directory.unlink()


async def generate(whitelist: Set[str], verbose: bool):
    test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}

    path_whitelist = set()
    name_whitelist = set()
    for item in whitelist:
        if item in test_case_names:
            name_whitelist.add(item)
            continue
        path_whitelist.add(item)

    generation_tasks = []
    for test_case_name in sorted(test_case_names):
        test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
        if (
            whitelist
            and str(test_case_input_path) not in path_whitelist
            and test_case_name not in name_whitelist
        ):
            continue
        generation_tasks.append(
            generate_test_case_output(test_case_input_path, test_case_name, verbose)
        )

    failed_test_cases = []
    # Wait for all subprocs and match any failures to names to report
    for test_case_name, result in zip(
        sorted(test_case_names), await asyncio.gather(*generation_tasks)
    ):
        if result != 0:
            failed_test_cases.append(test_case_name)

    if len(failed_test_cases) > 0:
        sys.stderr.write(
            "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
        )
        for failed_test_case in failed_test_cases:
            sys.stderr.write(f"- {failed_test_case}\n")

        sys.exit(1)


async def generate_test_case_output(
    test_case_input_path: Path, test_case_name: str, verbose: bool
) -> int:
    """
    Returns the max of the subprocess return values
    """

    test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
    test_case_output_path_aristaproto = output_path_aristaproto
    test_case_output_path_aristaproto_pyd = output_path_aristaproto_pydantic

    os.makedirs(test_case_output_path_reference, exist_ok=True)
    os.makedirs(test_case_output_path_aristaproto, exist_ok=True)
    os.makedirs(test_case_output_path_aristaproto_pyd, exist_ok=True)

    clear_directory(test_case_output_path_reference)
    clear_directory(test_case_output_path_aristaproto)

    (
        (ref_out, ref_err, ref_code),
        (plg_out, plg_err, plg_code),
        (plg_out_pyd, plg_err_pyd, plg_code_pyd),
    ) = await asyncio.gather(
        protoc(test_case_input_path, test_case_output_path_reference, True),
        protoc(test_case_input_path, test_case_output_path_aristaproto, False),
        protoc(
            test_case_input_path, test_case_output_path_aristaproto_pyd, False, True
        ),
    )

    if ref_code == 0:
        print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m")
    else:
        print(
            f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
        )

    if verbose:
        if ref_out:
            print("Reference stdout:")
            sys.stdout.buffer.write(ref_out)
            sys.stdout.buffer.flush()

        if ref_err:
            print("Reference stderr:")
            sys.stderr.buffer.write(ref_err)
            sys.stderr.buffer.flush()

    if plg_code == 0:
        print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m")
    else:
        print(
            f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
        )

    if verbose:
        if plg_out:
            print("Plugin stdout:")
            sys.stdout.buffer.write(plg_out)
            sys.stdout.buffer.flush()

        if plg_err:
            print("Plugin stderr:")
            sys.stderr.buffer.write(plg_err)
            sys.stderr.buffer.flush()

    if plg_code_pyd == 0:
        print(
            f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
        )
    else:
        print(
            f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
        )

    if verbose:
        if plg_out_pyd:
            print("Plugin stdout:")
            sys.stdout.buffer.write(plg_out_pyd)
            sys.stdout.buffer.flush()

        if plg_err_pyd:
            print("Plugin stderr:")
            sys.stderr.buffer.write(plg_err_pyd)
            sys.stderr.buffer.flush()

    return max(ref_code, plg_code, plg_code_pyd)


HELP = "\n".join(
    (
        "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
        "Generate python classes for standard tests.",
        "",
        "DIRECTORIES    One or more relative or absolute directories of test-cases to generate classes for.",
        "               python generate.py inputs/bool inputs/double inputs/enum",
        "",
        "NAMES          One or more test-case names to generate classes for.",
        "               python generate.py bool double enums",
    )
)


def main():
    if set(sys.argv).intersection({"-h", "--help"}):
        print(HELP)
        return
    if sys.argv[1:2] == ["-v"]:
        verbose = True
        whitelist = set(sys.argv[2:])
    else:
        verbose = False
        whitelist = set(sys.argv[1:])

    if platform.system() == "Windows":
        asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())

    asyncio.run(generate(whitelist, verbose))


if __name__ == "__main__":
    main()