1
0
Fork 0
python-aristaproto/tests/grpc/test_grpclib_client.py
Daniel Baumann 8512f66c5a
Adding upstream version 1.2+20240521.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-05 14:24:15 +01:00

298 lines
10 KiB
Python

import asyncio
import sys
import uuid
import grpclib
import grpclib.client
import grpclib.metadata
import grpclib.server
import pytest
from grpclib.testing import ChannelFor
from aristaproto.grpc.util.async_channel import AsyncChannel
from tests.output_aristaproto.service import (
DoThingRequest,
DoThingResponse,
GetThingRequest,
TestStub as ThingServiceClient,
)
from .thing_service import ThingService
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
response = await client.do_thing(DoThingRequest(name=name), **kwargs)
assert response.names == [name]
def _assert_request_meta_received(deadline, metadata):
def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1
), "The provided deadline should be received serverside"
assert (
stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be received serverside"
return server_side_test
@pytest.fixture
def handler_trailer_only_unauthenticated():
async def handler(stream: grpclib.server.Stream):
await stream.recv_message()
await stream.send_initial_metadata()
await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED)
return handler
@pytest.mark.asyncio
async def test_simple_service_call():
async with ChannelFor([ThingService()]) as channel:
await _test_client(ThingServiceClient(channel))
@pytest.mark.asyncio
async def test_trailer_only_error_unary_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_thing",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
async def test_trailer_only_error_stream_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_many_things",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things(
do_thing_request_iterator=[DoThingRequest(name="something")]
)
await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
)
async def test_service_call_mutable_defaults(mocker):
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
spy = mocker.spy(client, "_unary_unary")
await _test_client(client)
comments = spy.call_args_list[-1].args[1].comments
await _test_client(client)
assert spy.call_args_list[-1].args[1].comments is not comments
@pytest.mark.asyncio
async def test_service_call_with_upfront_request_params():
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
await _test_client(
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
)
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
await _test_client(
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
)
@pytest.mark.asyncio
async def test_service_call_lower_level_with_overrides():
THING_TO_DO = "get milk"
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary(
"/service.Test/DoThing",
DoThingRequest(THING_TO_DO),
DoThingResponse,
deadline=kwarg_deadline,
metadata=kwarg_metadata,
)
assert response.names == [THING_TO_DO]
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
kwarg_timeout = 9000
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
kwarg_metadata = {"authorization": "09876"}
async with ChannelFor(
[
ThingService(
test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
)
]
) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary(
"/service.Test/DoThing",
DoThingRequest(THING_TO_DO),
DoThingResponse,
timeout=kwarg_timeout,
metadata=kwarg_metadata,
)
assert response.names == [THING_TO_DO]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("overrides_gen",),
[
(lambda: dict(timeout=10),),
(lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
(lambda: dict(metadata={"authorization": str(uuid.uuid4())}),),
(lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
],
)
async def test_service_call_high_level_with_overrides(mocker, overrides_gen):
overrides = overrides_gen()
request_spy = mocker.spy(grpclib.client.Channel, "request")
name = str(uuid.uuid4())
defaults = dict(
timeout=99,
deadline=grpclib.metadata.Deadline.from_timeout(99),
metadata={"authorization": name},
)
async with ChannelFor(
[
ThingService(
test_hook=_assert_request_meta_received(
deadline=grpclib.metadata.Deadline.from_timeout(
overrides.get("timeout", 99)
),
metadata=overrides.get("metadata", defaults.get("metadata")),
)
)
]
) as channel:
client = ThingServiceClient(channel, **defaults)
await _test_client(client, name=name, **overrides)
assert request_spy.call_count == 1
# for python <3.8 request_spy.call_args.kwargs do not work
_, request_spy_call_kwargs = request_spy.call_args_list[0]
# ensure all overrides were successful
for key, value in overrides.items():
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == value
# ensure default values were retained
for key in set(defaults.keys()) - set(overrides.keys()):
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == defaults[key]
@pytest.mark.asyncio
async def test_async_gen_for_unary_stream_request():
thing_name = "my milkshakes"
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
expected_versions = [5, 4, 3, 2, 1]
async for response in client.get_thing_versions(
GetThingRequest(name=thing_name)
):
assert response.name == thing_name
assert response.version == expected_versions.pop()
@pytest.mark.asyncio
async def test_async_gen_for_stream_stream_request():
some_things = ["cake", "cricket", "coral reef"]
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
expected_things = (*some_things, *more_things)
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
# immediately and we'll use it to send more_things later, after recieving some
# results
request_chan = AsyncChannel()
send_initial_requests = asyncio.ensure_future(
request_chan.send_from(GetThingRequest(name) for name in some_things)
)
response_index = 0
async for response in client.get_different_things(request_chan):
assert response.name == expected_things[response_index]
assert response.version == response_index + 1
response_index += 1
if more_things:
# Send some more requests as we receive responses to be sure coordination of
# send/receive events doesn't matter
await request_chan.send(GetThingRequest(more_things.pop(0)))
elif not send_initial_requests.done():
# Make sure the sending task it completed
await send_initial_requests
else:
# No more things to send make sure channel is closed
request_chan.close()
assert response_index == len(
expected_things
), "Didn't receive all expected responses"
@pytest.mark.asyncio
async def test_stream_unary_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [DoThingRequest(name) for name in things]
response = await client.do_many_things(requests)
assert len(response.names) == 0
@pytest.mark.asyncio
async def test_stream_stream_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [GetThingRequest(name) for name in things]
responses = [
response async for response in client.get_different_things(requests)
]
assert len(responses) == 0