100 lines
2.7 KiB
Python
100 lines
2.7 KiB
Python
|
import asyncio
|
||
|
from dataclasses import dataclass
|
||
|
from typing import AsyncIterator
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
import aristaproto
|
||
|
from aristaproto.grpc.util.async_channel import AsyncChannel
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Message(aristaproto.Message):
|
||
|
body: str = aristaproto.string_field(1)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def expected_responses():
|
||
|
return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
|
||
|
|
||
|
|
||
|
class ClientStub:
|
||
|
async def connect(self, requests: AsyncIterator):
|
||
|
await asyncio.sleep(0.1)
|
||
|
async for request in requests:
|
||
|
await asyncio.sleep(0.1)
|
||
|
yield request
|
||
|
await asyncio.sleep(0.1)
|
||
|
yield Message("Done")
|
||
|
|
||
|
|
||
|
async def to_list(generator: AsyncIterator):
|
||
|
return [value async for value in generator]
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def client():
|
||
|
# channel = Channel(host='127.0.0.1', port=50051)
|
||
|
# return ClientStub(channel)
|
||
|
return ClientStub()
|
||
|
|
||
|
|
||
|
@pytest.mark.asyncio
|
||
|
async def test_send_from_before_connect_and_close_automatically(
|
||
|
client, expected_responses
|
||
|
):
|
||
|
requests = AsyncChannel()
|
||
|
await requests.send_from(
|
||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||
|
)
|
||
|
responses = client.connect(requests)
|
||
|
|
||
|
assert await to_list(responses) == expected_responses
|
||
|
|
||
|
|
||
|
@pytest.mark.asyncio
|
||
|
async def test_send_from_after_connect_and_close_automatically(
|
||
|
client, expected_responses
|
||
|
):
|
||
|
requests = AsyncChannel()
|
||
|
responses = client.connect(requests)
|
||
|
await requests.send_from(
|
||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||
|
)
|
||
|
|
||
|
assert await to_list(responses) == expected_responses
|
||
|
|
||
|
|
||
|
@pytest.mark.asyncio
|
||
|
async def test_send_from_close_manually_immediately(client, expected_responses):
|
||
|
requests = AsyncChannel()
|
||
|
responses = client.connect(requests)
|
||
|
await requests.send_from(
|
||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
|
||
|
)
|
||
|
requests.close()
|
||
|
|
||
|
assert await to_list(responses) == expected_responses
|
||
|
|
||
|
|
||
|
@pytest.mark.asyncio
|
||
|
async def test_send_individually_and_close_before_connect(client, expected_responses):
|
||
|
requests = AsyncChannel()
|
||
|
await requests.send(Message(body="Hello world 1"))
|
||
|
await requests.send(Message(body="Hello world 2"))
|
||
|
requests.close()
|
||
|
responses = client.connect(requests)
|
||
|
|
||
|
assert await to_list(responses) == expected_responses
|
||
|
|
||
|
|
||
|
@pytest.mark.asyncio
|
||
|
async def test_send_individually_and_close_after_connect(client, expected_responses):
|
||
|
requests = AsyncChannel()
|
||
|
await requests.send(Message(body="Hello world 1"))
|
||
|
await requests.send(Message(body="Hello world 2"))
|
||
|
responses = client.connect(requests)
|
||
|
requests.close()
|
||
|
|
||
|
assert await to_list(responses) == expected_responses
|