Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client streaming #83

Merged
merged 13 commits into from
Jun 24, 2020
Merged
Prev Previous commit
Next Next commit
Fix bugs and remove footgun feature in AsyncChannel
  • Loading branch information
nat-n committed Jun 15, 2020
commit 50bb67bf5dca04ded331adbcdcedab3aed7d7de1
24 changes: 9 additions & 15 deletions betterproto/grpc/util/async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]):
"""

def __init__(
self,
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
*,
buffer_limit: int = 0,
close: bool = False,
self, *, buffer_limit: int = 0, close: bool = False,
):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._closed = False
self._sending_task = (
asyncio.ensure_future(self.send_from(source, close)) if source else None
)
self._waiting_recievers: int = 0
# Track whether flush has been invoked so it can only happen once
self._flushed = False
Expand All @@ -100,13 +93,14 @@ def __aiter__(self) -> AsyncIterator[T]:
return self

async def __anext__(self) -> T:
if self.done:
if self.done():
raise StopAsyncIteration
self._waiting_recievers += 1
try:
result = await self._queue.get()
if result is self.__flush:
raise StopAsyncIteration
return result
finally:
self._waiting_recievers -= 1
self._queue.task_done()
Expand All @@ -131,7 +125,7 @@ def done(self) -> bool:

async def send_from(
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
):
) -> "AsyncChannel[T]":
"""
Iterates the given [Async]Iterable and sends all the resulting items.
If close is set to True then subsequent send calls will be rejected with a
Expand All @@ -151,24 +145,26 @@ async def send_from(
await self._queue.put(item)
if close:
# Complete the closing process
await self.close()
self.close()
return self

async def send(self, item: T):
async def send(self, item: T) -> "AsyncChannel[T]":
"""
Send a single item over this channel.
:param item: The item to send
"""
if self._closed:
raise ChannelClosed("Cannot send through a closed channel")
await self._queue.put(item)
return self

async def recieve(self) -> Optional[T]:
nat-n marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the next item from this channel when it becomes available,
or None if the channel is closed before another item is sent.
:return: An item from the channel
"""
if self.done:
if self.done():
raise ChannelDone("Cannot recieve from a closed channel")
self._waiting_recievers += 1
try:
Expand All @@ -184,8 +180,6 @@ def close(self):
"""
Close this channel to new items
"""
if self._sending_task is not None:
self._sending_task.cancel()
self._closed = True
asyncio.ensure_future(self._flush_queue())

Expand Down
22 changes: 13 additions & 9 deletions betterproto/tests/grpc/test_grpclib_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from betterproto.tests.output_betterproto.service.service import (
DoThingResponse,
DoThingRequest,
Expand Down Expand Up @@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request():
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
# immediately and we'll use it to send more_things later, after recieving some
# results
request_chan = AsyncChannel(GetThingRequest(name) for name in some_things)
request_chan = AsyncChannel()
send_initial_requests = asyncio.ensure_future(
request_chan.send_from(GetThingRequest(name) for name in some_things)
)
response_index = 0
async for response in client.get_different_things(request_chan):
boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved
assert response.name == expected_things[response_index]
Expand All @@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request():
if more_things:
# Send some more requests as we recieve reponses to be sure coordination of
# send/recieve events doesn't matter
nat-n marked this conversation as resolved.
Show resolved Hide resolved
another_response = await request_chan.send(
GetThingRequest(more_things.pop(0))
)
if another_response is not None:
assert another_response.name == expected_things[response_index]
assert another_response.version == response_index
response_index += 1
await request_chan.send(GetThingRequest(more_things.pop(0)))
elif not send_initial_requests.done():
# Make sure the sending task it completed
await send_initial_requests
else:
# No more things to send make sure channel is closed
await request_chan.close()
request_chan.close()
assert response_index == len(
expected_things
), "Didn't recieve all exptected responses"