1
0
Fork 0
python-aristaproto/tests/grpc/test_grpclib_client.py

299 lines
10 KiB
Python
Raw Normal View History

import asyncio
import sys
import uuid
import grpclib
import grpclib.client
import grpclib.metadata
import grpclib.server
import pytest
from grpclib.testing import ChannelFor
from aristaproto.grpc.util.async_channel import AsyncChannel
from tests.output_aristaproto.service import (
DoThingRequest,
DoThingResponse,
GetThingRequest,
TestStub as ThingServiceClient,
)
from .thing_service import ThingService
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
response = await client.do_thing(DoThingRequest(name=name), **kwargs)
assert response.names == [name]
def _assert_request_meta_received(deadline, metadata):
def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1
), "The provided deadline should be received serverside"
assert (
stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be received serverside"
return server_side_test
@pytest.fixture
def handler_trailer_only_unauthenticated():
async def handler(stream: grpclib.server.Stream):
await stream.recv_message()
await stream.send_initial_metadata()
await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED)
return handler
@pytest.mark.asyncio
async def test_simple_service_call():
async with ChannelFor([ThingService()]) as channel:
await _test_client(ThingServiceClient(channel))
@pytest.mark.asyncio
async def test_trailer_only_error_unary_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_thing",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
async def test_trailer_only_error_stream_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_many_things",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things(
do_thing_request_iterator=[DoThingRequest(name="something")]
)
await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
)
async def test_service_call_mutable_defaults(mocker):
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
spy = mocker.spy(client, "_unary_unary")
await _test_client(client)
comments = spy.call_args_list[-1].args[1].comments
await _test_client(client)
assert spy.call_args_list[-1].args[1].comments is not comments
@pytest.mark.asyncio
async def test_service_call_with_upfront_request_params():
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
await _test_client(
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
)
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
await _test_client(
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
)
@pytest.mark.asyncio
async def test_service_call_lower_level_with_overrides():
THING_TO_DO = "get milk"
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary(
"/service.Test/DoThing",
DoThingRequest(THING_TO_DO),
DoThingResponse,
deadline=kwarg_deadline,
metadata=kwarg_metadata,
)
assert response.names == [THING_TO_DO]
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
kwarg_timeout = 9000
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
kwarg_metadata = {"authorization": "09876"}
async with ChannelFor(
[
ThingService(
test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
)
]
) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary(
"/service.Test/DoThing",
DoThingRequest(THING_TO_DO),
DoThingResponse,
timeout=kwarg_timeout,
metadata=kwarg_metadata,
)
assert response.names == [THING_TO_DO]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("overrides_gen",),
[
(lambda: dict(timeout=10),),
(lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
(lambda: dict(metadata={"authorization": str(uuid.uuid4())}),),
(lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
],
)
async def test_service_call_high_level_with_overrides(mocker, overrides_gen):
overrides = overrides_gen()
request_spy = mocker.spy(grpclib.client.Channel, "request")
name = str(uuid.uuid4())
defaults = dict(
timeout=99,
deadline=grpclib.metadata.Deadline.from_timeout(99),
metadata={"authorization": name},
)
async with ChannelFor(
[
ThingService(
test_hook=_assert_request_meta_received(
deadline=grpclib.metadata.Deadline.from_timeout(
overrides.get("timeout", 99)
),
metadata=overrides.get("metadata", defaults.get("metadata")),
)
)
]
) as channel:
client = ThingServiceClient(channel, **defaults)
await _test_client(client, name=name, **overrides)
assert request_spy.call_count == 1
# for python <3.8 request_spy.call_args.kwargs do not work
_, request_spy_call_kwargs = request_spy.call_args_list[0]
# ensure all overrides were successful
for key, value in overrides.items():
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == value
# ensure default values were retained
for key in set(defaults.keys()) - set(overrides.keys()):
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == defaults[key]
@pytest.mark.asyncio
async def test_async_gen_for_unary_stream_request():
thing_name = "my milkshakes"
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
expected_versions = [5, 4, 3, 2, 1]
async for response in client.get_thing_versions(
GetThingRequest(name=thing_name)
):
assert response.name == thing_name
assert response.version == expected_versions.pop()
@pytest.mark.asyncio
async def test_async_gen_for_stream_stream_request():
some_things = ["cake", "cricket", "coral reef"]
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
expected_things = (*some_things, *more_things)
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
# immediately and we'll use it to send more_things later, after recieving some
# results
request_chan = AsyncChannel()
send_initial_requests = asyncio.ensure_future(
request_chan.send_from(GetThingRequest(name) for name in some_things)
)
response_index = 0
async for response in client.get_different_things(request_chan):
assert response.name == expected_things[response_index]
assert response.version == response_index + 1
response_index += 1
if more_things:
# Send some more requests as we receive responses to be sure coordination of
# send/receive events doesn't matter
await request_chan.send(GetThingRequest(more_things.pop(0)))
elif not send_initial_requests.done():
# Make sure the sending task it completed
await send_initial_requests
else:
# No more things to send make sure channel is closed
request_chan.close()
assert response_index == len(
expected_things
), "Didn't receive all expected responses"
@pytest.mark.asyncio
async def test_stream_unary_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [DoThingRequest(name) for name in things]
response = await client.do_many_things(requests)
assert len(response.names) == 0
@pytest.mark.asyncio
async def test_stream_stream_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [GetThingRequest(name) for name in things]
responses = [
response async for response in client.get_different_things(requests)
]
assert len(responses) == 0