Adding upstream version 1.2+20240521.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
6b2864e4b9
commit
8512f66c5a
229 changed files with 19561 additions and 0 deletions
0
tests/grpc/__init__.py
Normal file
0
tests/grpc/__init__.py
Normal file
298
tests/grpc/test_grpclib_client.py
Normal file
298
tests/grpc/test_grpclib_client.py
Normal file
|
@ -0,0 +1,298 @@
|
|||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import grpclib
|
||||
import grpclib.client
|
||||
import grpclib.metadata
|
||||
import grpclib.server
|
||||
import pytest
|
||||
from grpclib.testing import ChannelFor
|
||||
|
||||
from aristaproto.grpc.util.async_channel import AsyncChannel
|
||||
from tests.output_aristaproto.service import (
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
GetThingRequest,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
|
||||
from .thing_service import ThingService
|
||||
|
||||
|
||||
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
|
||||
response = await client.do_thing(DoThingRequest(name=name), **kwargs)
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
def _assert_request_meta_received(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be received serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be received serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler_trailer_only_unauthenticated():
|
||||
async def handler(stream: grpclib.server.Stream):
|
||||
await stream.recv_message()
|
||||
await stream.send_initial_metadata()
|
||||
await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_service_call():
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
await _test_client(ThingServiceClient(channel))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailer_only_error_unary_unary(
|
||||
mocker, handler_trailer_only_unauthenticated
|
||||
):
|
||||
service = ThingService()
|
||||
mocker.patch.object(
|
||||
service,
|
||||
"do_thing",
|
||||
side_effect=handler_trailer_only_unauthenticated,
|
||||
autospec=True,
|
||||
)
|
||||
async with ChannelFor([service]) as channel:
|
||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
|
||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailer_only_error_stream_unary(
|
||||
mocker, handler_trailer_only_unauthenticated
|
||||
):
|
||||
service = ThingService()
|
||||
mocker.patch.object(
|
||||
service,
|
||||
"do_many_things",
|
||||
side_effect=handler_trailer_only_unauthenticated,
|
||||
autospec=True,
|
||||
)
|
||||
async with ChannelFor([service]) as channel:
|
||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||
await ThingServiceClient(channel).do_many_things(
|
||||
do_thing_request_iterator=[DoThingRequest(name="something")]
|
||||
)
|
||||
await _test_client(ThingServiceClient(channel))
|
||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
|
||||
)
|
||||
async def test_service_call_mutable_defaults(mocker):
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
spy = mocker.spy(client, "_unary_unary")
|
||||
await _test_client(client)
|
||||
comments = spy.call_args_list[-1].args[1].comments
|
||||
await _test_client(client)
|
||||
assert spy.call_args_list[-1].args[1].comments is not comments
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_with_upfront_request_params():
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
)
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_lower_level_with_overrides():
|
||||
THING_TO_DO = "get milk"
|
||||
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
deadline=kwarg_deadline,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_timeout = 9000
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||
kwarg_metadata = {"authorization": "09876"}
|
||||
async with ChannelFor(
|
||||
[
|
||||
ThingService(
|
||||
test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
timeout=kwarg_timeout,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("overrides_gen",),
|
||||
[
|
||||
(lambda: dict(timeout=10),),
|
||||
(lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
|
||||
(lambda: dict(metadata={"authorization": str(uuid.uuid4())}),),
|
||||
(lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
|
||||
],
|
||||
)
|
||||
async def test_service_call_high_level_with_overrides(mocker, overrides_gen):
|
||||
overrides = overrides_gen()
|
||||
request_spy = mocker.spy(grpclib.client.Channel, "request")
|
||||
name = str(uuid.uuid4())
|
||||
defaults = dict(
|
||||
timeout=99,
|
||||
deadline=grpclib.metadata.Deadline.from_timeout(99),
|
||||
metadata={"authorization": name},
|
||||
)
|
||||
|
||||
async with ChannelFor(
|
||||
[
|
||||
ThingService(
|
||||
test_hook=_assert_request_meta_received(
|
||||
deadline=grpclib.metadata.Deadline.from_timeout(
|
||||
overrides.get("timeout", 99)
|
||||
),
|
||||
metadata=overrides.get("metadata", defaults.get("metadata")),
|
||||
)
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, **defaults)
|
||||
await _test_client(client, name=name, **overrides)
|
||||
assert request_spy.call_count == 1
|
||||
|
||||
# for python <3.8 request_spy.call_args.kwargs do not work
|
||||
_, request_spy_call_kwargs = request_spy.call_args_list[0]
|
||||
|
||||
# ensure all overrides were successful
|
||||
for key, value in overrides.items():
|
||||
assert key in request_spy_call_kwargs
|
||||
assert request_spy_call_kwargs[key] == value
|
||||
|
||||
# ensure default values were retained
|
||||
for key in set(defaults.keys()) - set(overrides.keys()):
|
||||
assert key in request_spy_call_kwargs
|
||||
assert request_spy_call_kwargs[key] == defaults[key]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_unary_stream_request():
|
||||
thing_name = "my milkshakes"
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
expected_versions = [5, 4, 3, 2, 1]
|
||||
async for response in client.get_thing_versions(
|
||||
GetThingRequest(name=thing_name)
|
||||
):
|
||||
assert response.name == thing_name
|
||||
assert response.version == expected_versions.pop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_stream_stream_request():
|
||||
some_things = ["cake", "cricket", "coral reef"]
|
||||
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
|
||||
expected_things = (*some_things, *more_things)
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
||||
# immediately and we'll use it to send more_things later, after recieving some
|
||||
# results
|
||||
request_chan = AsyncChannel()
|
||||
send_initial_requests = asyncio.ensure_future(
|
||||
request_chan.send_from(GetThingRequest(name) for name in some_things)
|
||||
)
|
||||
response_index = 0
|
||||
async for response in client.get_different_things(request_chan):
|
||||
assert response.name == expected_things[response_index]
|
||||
assert response.version == response_index + 1
|
||||
response_index += 1
|
||||
if more_things:
|
||||
# Send some more requests as we receive responses to be sure coordination of
|
||||
# send/receive events doesn't matter
|
||||
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||
elif not send_initial_requests.done():
|
||||
# Make sure the sending task it completed
|
||||
await send_initial_requests
|
||||
else:
|
||||
# No more things to send make sure channel is closed
|
||||
request_chan.close()
|
||||
assert response_index == len(
|
||||
expected_things
|
||||
), "Didn't receive all expected responses"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_unary_with_empty_iterable():
|
||||
things = [] # empty
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
requests = [DoThingRequest(name) for name in things]
|
||||
response = await client.do_many_things(requests)
|
||||
assert len(response.names) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_stream_with_empty_iterable():
|
||||
things = [] # empty
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
requests = [GetThingRequest(name) for name in things]
|
||||
responses = [
|
||||
response async for response in client.get_different_things(requests)
|
||||
]
|
||||
assert len(responses) == 0
|
99
tests/grpc/test_stream_stream.py
Normal file
99
tests/grpc/test_stream_stream.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
import aristaproto
|
||||
from aristaproto.grpc.util.async_channel import AsyncChannel
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message(aristaproto.Message):
|
||||
body: str = aristaproto.string_field(1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_responses():
|
||||
return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
|
||||
|
||||
|
||||
class ClientStub:
|
||||
async def connect(self, requests: AsyncIterator):
|
||||
await asyncio.sleep(0.1)
|
||||
async for request in requests:
|
||||
await asyncio.sleep(0.1)
|
||||
yield request
|
||||
await asyncio.sleep(0.1)
|
||||
yield Message("Done")
|
||||
|
||||
|
||||
async def to_list(generator: AsyncIterator):
|
||||
return [value async for value in generator]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# channel = Channel(host='127.0.0.1', port=50051)
|
||||
# return ClientStub(channel)
|
||||
return ClientStub()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_before_connect_and_close_automatically(
|
||||
client, expected_responses
|
||||
):
|
||||
requests = AsyncChannel()
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||
)
|
||||
responses = client.connect(requests)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_after_connect_and_close_automatically(
|
||||
client, expected_responses
|
||||
):
|
||||
requests = AsyncChannel()
|
||||
responses = client.connect(requests)
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||
)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_close_manually_immediately(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
responses = client.connect(requests)
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
|
||||
)
|
||||
requests.close()
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_individually_and_close_before_connect(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
await requests.send(Message(body="Hello world 1"))
|
||||
await requests.send(Message(body="Hello world 2"))
|
||||
requests.close()
|
||||
responses = client.connect(requests)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_individually_and_close_after_connect(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
await requests.send(Message(body="Hello world 1"))
|
||||
await requests.send(Message(body="Hello world 2"))
|
||||
responses = client.connect(requests)
|
||||
requests.close()
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
85
tests/grpc/thing_service.py
Normal file
85
tests/grpc/thing_service.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
from typing import Dict
|
||||
|
||||
import grpclib
|
||||
import grpclib.server
|
||||
|
||||
from tests.output_aristaproto.service import (
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
)
|
||||
|
||||
|
||||
class ThingService:
|
||||
def __init__(self, test_hook=None):
|
||||
# This lets us pass assertions to the servicer ;)
|
||||
self.test_hook = test_hook
|
||||
|
||||
async def do_thing(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse([request.name]))
|
||||
|
||||
async def do_many_things(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
thing_names = [request.name async for request in stream]
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse(thing_names))
|
||||
|
||||
async def get_thing_versions(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
for version_num in range(1, 6):
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=version_num)
|
||||
)
|
||||
|
||||
async def get_different_things(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
# Respond to each input item immediately
|
||||
response_num = 0
|
||||
async for request in stream:
|
||||
response_num += 1
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=response_num)
|
||||
)
|
||||
|
||||
def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
|
||||
return {
|
||||
"/service.Test/DoThing": grpclib.const.Handler(
|
||||
self.do_thing,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/DoManyThings": grpclib.const.Handler(
|
||||
self.do_many_things,
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/GetThingVersions": grpclib.const.Handler(
|
||||
self.get_thing_versions,
|
||||
grpclib.const.Cardinality.UNARY_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
"/service.Test/GetDifferentThings": grpclib.const.Handler(
|
||||
self.get_different_things,
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue