197 lines
5.9 KiB
Python
197 lines
5.9 KiB
Python
|
#!/usr/bin/env python
|
||
|
import asyncio
|
||
|
import os
|
||
|
import platform
|
||
|
import shutil
|
||
|
import sys
|
||
|
from pathlib import Path
|
||
|
from typing import Set
|
||
|
|
||
|
from tests.util import (
|
||
|
get_directories,
|
||
|
inputs_path,
|
||
|
output_path_aristaproto,
|
||
|
output_path_aristaproto_pydantic,
|
||
|
output_path_reference,
|
||
|
protoc,
|
||
|
)
|
||
|
|
||
|
|
||
|
# Force pure-python implementation instead of C++, otherwise imports
|
||
|
# break things because we can't properly reset the symbol database.
|
||
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||
|
|
||
|
|
||
|
def clear_directory(dir_path: Path):
|
||
|
for file_or_directory in dir_path.glob("*"):
|
||
|
if file_or_directory.is_dir():
|
||
|
shutil.rmtree(file_or_directory)
|
||
|
else:
|
||
|
file_or_directory.unlink()
|
||
|
|
||
|
|
||
|
async def generate(whitelist: Set[str], verbose: bool):
|
||
|
test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
|
||
|
|
||
|
path_whitelist = set()
|
||
|
name_whitelist = set()
|
||
|
for item in whitelist:
|
||
|
if item in test_case_names:
|
||
|
name_whitelist.add(item)
|
||
|
continue
|
||
|
path_whitelist.add(item)
|
||
|
|
||
|
generation_tasks = []
|
||
|
for test_case_name in sorted(test_case_names):
|
||
|
test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
|
||
|
if (
|
||
|
whitelist
|
||
|
and str(test_case_input_path) not in path_whitelist
|
||
|
and test_case_name not in name_whitelist
|
||
|
):
|
||
|
continue
|
||
|
generation_tasks.append(
|
||
|
generate_test_case_output(test_case_input_path, test_case_name, verbose)
|
||
|
)
|
||
|
|
||
|
failed_test_cases = []
|
||
|
# Wait for all subprocs and match any failures to names to report
|
||
|
for test_case_name, result in zip(
|
||
|
sorted(test_case_names), await asyncio.gather(*generation_tasks)
|
||
|
):
|
||
|
if result != 0:
|
||
|
failed_test_cases.append(test_case_name)
|
||
|
|
||
|
if len(failed_test_cases) > 0:
|
||
|
sys.stderr.write(
|
||
|
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
|
||
|
)
|
||
|
for failed_test_case in failed_test_cases:
|
||
|
sys.stderr.write(f"- {failed_test_case}\n")
|
||
|
|
||
|
sys.exit(1)
|
||
|
|
||
|
|
||
|
async def generate_test_case_output(
|
||
|
test_case_input_path: Path, test_case_name: str, verbose: bool
|
||
|
) -> int:
|
||
|
"""
|
||
|
Returns the max of the subprocess return values
|
||
|
"""
|
||
|
|
||
|
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
||
|
test_case_output_path_aristaproto = output_path_aristaproto
|
||
|
test_case_output_path_aristaproto_pyd = output_path_aristaproto_pydantic
|
||
|
|
||
|
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
||
|
os.makedirs(test_case_output_path_aristaproto, exist_ok=True)
|
||
|
os.makedirs(test_case_output_path_aristaproto_pyd, exist_ok=True)
|
||
|
|
||
|
clear_directory(test_case_output_path_reference)
|
||
|
clear_directory(test_case_output_path_aristaproto)
|
||
|
|
||
|
(
|
||
|
(ref_out, ref_err, ref_code),
|
||
|
(plg_out, plg_err, plg_code),
|
||
|
(plg_out_pyd, plg_err_pyd, plg_code_pyd),
|
||
|
) = await asyncio.gather(
|
||
|
protoc(test_case_input_path, test_case_output_path_reference, True),
|
||
|
protoc(test_case_input_path, test_case_output_path_aristaproto, False),
|
||
|
protoc(
|
||
|
test_case_input_path, test_case_output_path_aristaproto_pyd, False, True
|
||
|
),
|
||
|
)
|
||
|
|
||
|
if ref_code == 0:
|
||
|
print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m")
|
||
|
else:
|
||
|
print(
|
||
|
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
|
||
|
)
|
||
|
|
||
|
if verbose:
|
||
|
if ref_out:
|
||
|
print("Reference stdout:")
|
||
|
sys.stdout.buffer.write(ref_out)
|
||
|
sys.stdout.buffer.flush()
|
||
|
|
||
|
if ref_err:
|
||
|
print("Reference stderr:")
|
||
|
sys.stderr.buffer.write(ref_err)
|
||
|
sys.stderr.buffer.flush()
|
||
|
|
||
|
if plg_code == 0:
|
||
|
print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m")
|
||
|
else:
|
||
|
print(
|
||
|
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
|
||
|
)
|
||
|
|
||
|
if verbose:
|
||
|
if plg_out:
|
||
|
print("Plugin stdout:")
|
||
|
sys.stdout.buffer.write(plg_out)
|
||
|
sys.stdout.buffer.flush()
|
||
|
|
||
|
if plg_err:
|
||
|
print("Plugin stderr:")
|
||
|
sys.stderr.buffer.write(plg_err)
|
||
|
sys.stderr.buffer.flush()
|
||
|
|
||
|
if plg_code_pyd == 0:
|
||
|
print(
|
||
|
f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||
|
)
|
||
|
else:
|
||
|
print(
|
||
|
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||
|
)
|
||
|
|
||
|
if verbose:
|
||
|
if plg_out_pyd:
|
||
|
print("Plugin stdout:")
|
||
|
sys.stdout.buffer.write(plg_out_pyd)
|
||
|
sys.stdout.buffer.flush()
|
||
|
|
||
|
if plg_err_pyd:
|
||
|
print("Plugin stderr:")
|
||
|
sys.stderr.buffer.write(plg_err_pyd)
|
||
|
sys.stderr.buffer.flush()
|
||
|
|
||
|
return max(ref_code, plg_code, plg_code_pyd)
|
||
|
|
||
|
|
||
|
HELP = "\n".join(
|
||
|
(
|
||
|
"Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
|
||
|
"Generate python classes for standard tests.",
|
||
|
"",
|
||
|
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
|
||
|
" python generate.py inputs/bool inputs/double inputs/enum",
|
||
|
"",
|
||
|
"NAMES One or more test-case names to generate classes for.",
|
||
|
" python generate.py bool double enums",
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
if set(sys.argv).intersection({"-h", "--help"}):
|
||
|
print(HELP)
|
||
|
return
|
||
|
if sys.argv[1:2] == ["-v"]:
|
||
|
verbose = True
|
||
|
whitelist = set(sys.argv[2:])
|
||
|
else:
|
||
|
verbose = False
|
||
|
whitelist = set(sys.argv[1:])
|
||
|
|
||
|
if platform.system() == "Windows":
|
||
|
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||
|
|
||
|
asyncio.run(generate(whitelist, verbose))
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|