Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client streaming #83

Merged
merged 13 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Finish implementation and testing of client
Including stream_unary and stream_stream call methods.

Also
- improve organisation of relevant tests
- fix some generated type annotations
- Add AsyncChannel utility cos it's useful
  • Loading branch information
nat-n committed Jun 14, 2020
commit 4b6f55dce58d82f8db14ada7e08063c14eab9b94
2 changes: 1 addition & 1 deletion betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def __bytes__(self) -> bytes:
serialize_empty = False
if isinstance(value, Message) and value._serialized_on_wire:
# Empty messages can still be sent on the wire if they were
# set (or received empty).
# set (or recieved empty).
serialize_empty = True

if value == self._get_field_default(field_name) and not (
Expand Down
68 changes: 51 additions & 17 deletions betterproto/grpc/grpclib_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC
import asyncio
import grpclib.const
from typing import (
AsyncGenerator,
Any,
AsyncIterator,
Collection,
Iterator,
Expand All @@ -16,17 +17,18 @@

if TYPE_CHECKING:
from grpclib._protocols import IProtoMessage
from grpclib.client import Channel
from grpclib.client import Channel, Stream
from grpclib.metadata import Deadline


_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
_MessageSource = Union[Iterator["IProtoMessage"], AsyncIterator["IProtoMessage"]]


class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.
Base class for async gRPC clients.
boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand Down Expand Up @@ -86,7 +88,7 @@ async def _unary_stream(
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> AsyncGenerator[T, None]:
) -> AsyncIterator[T]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route,
Expand All @@ -102,34 +104,66 @@ async def _unary_stream(
async def _stream_unary(
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_iterator: _MessageSource,
request_type: Type[ST],
response_type: Type[T],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> T:
"""Make a stream request and return the response."""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
route,
grpclib.const.Cardinality.STREAM_UNARY,
request_type,
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
for message in request_iterator:
await stream.send_message(message)
await stream.send_request(end=True)
await self._send_messages(stream, request_iterator)
response = await stream.recv_message()
assert response is not None
return response

async def _stream_stream(
boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_iterator: _MessageSource,
request_type: Type[ST],
response_type: Type[T],
) -> AsyncGenerator[T, None]:
"""Make a stream request and return the stream response iterator."""
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> AsyncIterator[T]:
"""
Make a stream request and return an AsyncIterator to iterate over response
messages.
"""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
route,
grpclib.const.Cardinality.STREAM_STREAM,
request_type,
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
for message in request_iterator:
await stream.send_request()
sending_task = asyncio.ensure_future(
self._send_messages(stream, request_iterator)
)
try:
async for response in stream:
yield response
except:
sending_task.cancel()
raise

@staticmethod
async def _send_messages(stream, messages: _MessageSource):
if hasattr(messages, "__aiter__"):
async for message in messages:
await stream.send_message(message)
await stream.send_request(end=True)
async for message in stream:
yield message
else:
for message in messages:
await stream.send_message(message)
await stream.end()
Empty file.
204 changes: 204 additions & 0 deletions betterproto/grpc/util/async_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import asyncio
from typing import (
AsyncIterable,
AsyncIterator,
Iterable,
Optional,
TypeVar,
Union,
)

T = TypeVar("T")


class ChannelClosed(Exception):
"""
An exception raised on an attempt to send through a closed channel
"""

pass


class ChannelDone(Exception):
"""
An exception raised on an attempt to send recieve from a channel that is both closed
and empty.
"""

pass


class AsyncChannel(AsyncIterable[T]):
"""
A buffered async channel for sending items between coroutines with FIFO semantics.

This makes decoupled bidirection steaming gRPC requests easy if used like:

.. code-block:: python
client = GeneratedStub(grpclib_chan)
# The channel can be initialised with items to send immediately
request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)])
async for response in client.rpc_call(request_chan):
# The response iterator will remain active until the connection is closed
...
# More items can be sent at any time
await request_chan.send(ReqestObject(...))
...
# The channel must be closed to complete the gRPC connection
request_chan.close()

Items can be sent through the channel by either:
- providing an iterable to the constructor
- providing an iterable to the send_from method
- passing them to the send method one at a time

Items can be recieved from the channel by either:
- iterating over the channel with a for loop to get all items
- calling the recieve method to get one item at a time

If the channel is empty then recievers will wait until either an item appears or the
channel is closed.

Once the channel is closed then subsequent attempt to send through the channel will
fail with a ChannelClosed exception.

When th channel is closed and empty then it is done, and further attempts to recieve
from it will fail with a ChannelDone exception

If multiple coroutines recieve from the channel concurrently, each item sent will be
recieved by only one of the recievers.

:param source:
An optional iterable will items that should be sent through the channel
immediately.
:param buffer_limit:
Limit the number of items that can be buffered in the channel, A value less than
1 implies no limit. If the channel is full then attempts to send more items will
result in the sender waiting until an item is recieved from the channel.
:param close:
If set to True then the channel will automatically close after exhausting source
or immediately if no source is provided.
"""

def __init__(
self,
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
*,
buffer_limit: int = 0,
close: bool = False,
):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._closed = False
self._sending_task = (
asyncio.ensure_future(self.send_from(source, close)) if source else None
)
self._waiting_recievers: int = 0
# Track whether flush has been invoked so it can only happen once
self._flushed = False

def __aiter__(self) -> AsyncIterator[T]:
return self

async def __anext__(self) -> T:
if self.done:
raise StopAsyncIteration
nat-n marked this conversation as resolved.
Show resolved Hide resolved
self._waiting_recievers += 1
try:
result = await self._queue.get()
if result is self.__flush:
raise StopAsyncIteration
finally:
self._waiting_recievers -= 1
self._queue.task_done()
boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved

def closed(self) -> bool:
"""
Returns True if this channel is closed and no-longer accepting new items
"""
return self._closed

def done(self) -> bool:
"""
Check if this channel is done.

:return: True if this channel is closed and and has been drained of items in
which case any further attempts to recieve an item from this channel will raise
a ChannelDone exception.
"""
# After close the channel is not yet done until there is at least one waiting
# reciever per enqueued item.
return self._closed and self._queue.qsize() <= self._waiting_recievers

async def send_from(
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
):
"""
Iterates the given [Async]Iterable and sends all the resulting items.
If close is set to True then subsequent send calls will be rejected with a
ChannelClosed exception.
:param source: an iterable of items to send
:param close:
if True then the channel will be closed after the source has been exhausted

"""
if self._closed:
raise ChannelClosed("Cannot send through a closed channel")
if isinstance(source, AsyncIterable):
async for item in source:
await self._queue.put(item)
else:
for item in source:
await self._queue.put(item)
if close:
# Complete the closing process
await self.close()
boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved

async def send(self, item: T):
"""
Send a single item over this channel.
:param item: The item to send
"""
if self._closed:
raise ChannelClosed("Cannot send through a closed channel")
await self._queue.put(item)

async def recieve(self) -> Optional[T]:
nat-n marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the next item from this channel when it becomes available,
or None if the channel is closed before another item is sent.
:return: An item from the channel
"""
if self.done:
boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved
raise ChannelDone("Cannot recieve from a closed channel")
self._waiting_recievers += 1
try:
result = await self._queue.get()
if result is self.__flush:
return None
return result
finally:
self._waiting_recievers -= 1
self._queue.task_done()

boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved
def close(self):
"""
Close this channel to new items
"""
if self._sending_task is not None:
self._sending_task.cancel()
self._closed = True
asyncio.ensure_future(self._flush_queue())
boukeversteegh marked this conversation as resolved.
Show resolved Hide resolved

async def _flush_queue(self):
"""
To be called after the channel is closed. Pushes a number of self.__flush
objects to the queue to ensure no waiting consumers get deadlocked.
"""
if not self._flushed:
self._flushed = True
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
for _ in range(deadlocked_recievers):
await self._queue.put(self.__flush)

# A special signal object for flushing the queue when the channel is closed
__flush = object()
9 changes: 5 additions & 4 deletions betterproto/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,12 @@ def generate_code(request, response):
}
)

if method.server_streaming:
output["typing_imports"].add("AsyncGenerator")

if method.client_streaming:
output["typing_imports"].add("Iterator")
output["typing_imports"].add("AsyncIterable")
output["typing_imports"].add("Iterable")
output["typing_imports"].add("Union")
if method.server_streaming:
output["typing_imports"].add("AsyncIterator")

output["services"].append(data)

Expand Down
Loading