Skip to content

Commit 9532844

Browse files
Merge pull request #83 from nat-n/client-streaming
Client streaming
2 parents 5fb4b4b + 0c5d1ff commit 9532844

19 files changed

+888
-318
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ dist
1212
**/*.egg-info
1313
output
1414
.idea
15+
.DS_Store
1516
.tox

betterproto/__init__.py

Lines changed: 9 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,30 @@
55
import struct
66
import sys
77
from abc import ABC
8-
from base64 import b64encode, b64decode
8+
from base64 import b64decode, b64encode
99
from datetime import datetime, timedelta, timezone
10+
import stringcase
1011
from typing import (
1112
Any,
1213
AsyncGenerator,
1314
Callable,
1415
Collection,
1516
Dict,
1617
Generator,
18+
Iterator,
1719
List,
1820
Mapping,
1921
Optional,
2022
Set,
23+
SupportsBytes,
2124
Tuple,
2225
Type,
23-
TypeVar,
2426
Union,
2527
get_type_hints,
26-
TYPE_CHECKING,
2728
)
28-
29-
30-
import grpclib.const
31-
import stringcase
32-
29+
from ._types import ST, T
3330
from .casing import safe_snake_case
34-
35-
if TYPE_CHECKING:
36-
from grpclib._protocols import IProtoMessage
37-
from grpclib.client import Channel
38-
from grpclib.metadata import Deadline
31+
from .grpc.grpclib_client import ServiceStub
3932

4033
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
4134
# Apply backport of datetime.fromisoformat from 3.7
@@ -429,10 +422,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
429422
)
430423

431424

432-
# Bound type variable to allow methods to return `self` of subclasses
433-
T = TypeVar("T", bound="Message")
434-
435-
436425
class ProtoClassMetadata:
437426
oneof_group_by_field: Dict[str, str]
438427
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
@@ -451,7 +440,7 @@ class ProtoClassMetadata:
451440

452441
def __init__(self, cls: Type["Message"]):
453442
by_field = {}
454-
by_group = {}
443+
by_group: Dict[str, Set] = {}
455444
by_field_name = {}
456445
by_field_number = {}
457446

@@ -604,7 +593,7 @@ def __bytes__(self) -> bytes:
604593
serialize_empty = False
605594
if isinstance(value, Message) and value._serialized_on_wire:
606595
# Empty messages can still be sent on the wire if they were
607-
# set (or received empty).
596+
# set (or recieved empty).
608597
serialize_empty = True
609598

610599
if value == self._get_field_default(field_name) and not (
@@ -791,7 +780,7 @@ def FromString(cls: Type[T], data: bytes) -> T:
791780

792781
def to_dict(
793782
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
794-
) -> dict:
783+
) -> Dict[str, Any]:
795784
"""
796785
Returns a dict representation of this message instance which can be
797786
used to serialize to e.g. JSON. Defaults to camel casing for
@@ -1024,83 +1013,3 @@ def _get_wrapper(proto_type: str) -> Type:
10241013
TYPE_STRING: StringValue,
10251014
TYPE_BYTES: BytesValue,
10261015
}[proto_type]
1027-
1028-
1029-
_Value = Union[str, bytes]
1030-
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
1031-
1032-
1033-
class ServiceStub(ABC):
1034-
"""
1035-
Base class for async gRPC service stubs.
1036-
"""
1037-
1038-
def __init__(
1039-
self,
1040-
channel: "Channel",
1041-
*,
1042-
timeout: Optional[float] = None,
1043-
deadline: Optional["Deadline"] = None,
1044-
metadata: Optional[_MetadataLike] = None,
1045-
) -> None:
1046-
self.channel = channel
1047-
self.timeout = timeout
1048-
self.deadline = deadline
1049-
self.metadata = metadata
1050-
1051-
def __resolve_request_kwargs(
1052-
self,
1053-
timeout: Optional[float],
1054-
deadline: Optional["Deadline"],
1055-
metadata: Optional[_MetadataLike],
1056-
):
1057-
return {
1058-
"timeout": self.timeout if timeout is None else timeout,
1059-
"deadline": self.deadline if deadline is None else deadline,
1060-
"metadata": self.metadata if metadata is None else metadata,
1061-
}
1062-
1063-
async def _unary_unary(
1064-
self,
1065-
route: str,
1066-
request: "IProtoMessage",
1067-
response_type: Type[T],
1068-
*,
1069-
timeout: Optional[float] = None,
1070-
deadline: Optional["Deadline"] = None,
1071-
metadata: Optional[_MetadataLike] = None,
1072-
) -> T:
1073-
"""Make a unary request and return the response."""
1074-
async with self.channel.request(
1075-
route,
1076-
grpclib.const.Cardinality.UNARY_UNARY,
1077-
type(request),
1078-
response_type,
1079-
**self.__resolve_request_kwargs(timeout, deadline, metadata),
1080-
) as stream:
1081-
await stream.send_message(request, end=True)
1082-
response = await stream.recv_message()
1083-
assert response is not None
1084-
return response
1085-
1086-
async def _unary_stream(
1087-
self,
1088-
route: str,
1089-
request: "IProtoMessage",
1090-
response_type: Type[T],
1091-
*,
1092-
timeout: Optional[float] = None,
1093-
deadline: Optional["Deadline"] = None,
1094-
metadata: Optional[_MetadataLike] = None,
1095-
) -> AsyncGenerator[T, None]:
1096-
"""Make a unary request and return the stream response iterator."""
1097-
async with self.channel.request(
1098-
route,
1099-
grpclib.const.Cardinality.UNARY_STREAM,
1100-
type(request),
1101-
response_type,
1102-
**self.__resolve_request_kwargs(timeout, deadline, metadata),
1103-
) as stream:
1104-
await stream.send_message(request, end=True)
1105-
async for message in stream:
1106-
yield message

betterproto/_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import TYPE_CHECKING, TypeVar
2+
3+
if TYPE_CHECKING:
4+
from . import Message
5+
from grpclib._protocols import IProtoMessage
6+
7+
# Bound type variable to allow methods to return `self` of subclasses
8+
T = TypeVar("T", bound="Message")
9+
ST = TypeVar("ST", bound="IProtoMessage")

betterproto/grpc/__init__.py

Whitespace-only changes.

betterproto/grpc/grpclib_client.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from abc import ABC
2+
import asyncio
3+
import grpclib.const
4+
from typing import (
5+
Any,
6+
AsyncIterable,
7+
AsyncIterator,
8+
Collection,
9+
Iterable,
10+
Mapping,
11+
Optional,
12+
Tuple,
13+
TYPE_CHECKING,
14+
Type,
15+
Union,
16+
)
17+
from .._types import ST, T
18+
19+
if TYPE_CHECKING:
20+
from grpclib._protocols import IProtoMessage
21+
from grpclib.client import Channel, Stream
22+
from grpclib.metadata import Deadline
23+
24+
25+
_Value = Union[str, bytes]
26+
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
27+
_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
28+
29+
30+
class ServiceStub(ABC):
31+
"""
32+
Base class for async gRPC clients.
33+
"""
34+
35+
def __init__(
36+
self,
37+
channel: "Channel",
38+
*,
39+
timeout: Optional[float] = None,
40+
deadline: Optional["Deadline"] = None,
41+
metadata: Optional[_MetadataLike] = None,
42+
) -> None:
43+
self.channel = channel
44+
self.timeout = timeout
45+
self.deadline = deadline
46+
self.metadata = metadata
47+
48+
def __resolve_request_kwargs(
49+
self,
50+
timeout: Optional[float],
51+
deadline: Optional["Deadline"],
52+
metadata: Optional[_MetadataLike],
53+
):
54+
return {
55+
"timeout": self.timeout if timeout is None else timeout,
56+
"deadline": self.deadline if deadline is None else deadline,
57+
"metadata": self.metadata if metadata is None else metadata,
58+
}
59+
60+
async def _unary_unary(
61+
self,
62+
route: str,
63+
request: "IProtoMessage",
64+
response_type: Type[T],
65+
*,
66+
timeout: Optional[float] = None,
67+
deadline: Optional["Deadline"] = None,
68+
metadata: Optional[_MetadataLike] = None,
69+
) -> T:
70+
"""Make a unary request and return the response."""
71+
async with self.channel.request(
72+
route,
73+
grpclib.const.Cardinality.UNARY_UNARY,
74+
type(request),
75+
response_type,
76+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
77+
) as stream:
78+
await stream.send_message(request, end=True)
79+
response = await stream.recv_message()
80+
assert response is not None
81+
return response
82+
83+
async def _unary_stream(
84+
self,
85+
route: str,
86+
request: "IProtoMessage",
87+
response_type: Type[T],
88+
*,
89+
timeout: Optional[float] = None,
90+
deadline: Optional["Deadline"] = None,
91+
metadata: Optional[_MetadataLike] = None,
92+
) -> AsyncIterator[T]:
93+
"""Make a unary request and return the stream response iterator."""
94+
async with self.channel.request(
95+
route,
96+
grpclib.const.Cardinality.UNARY_STREAM,
97+
type(request),
98+
response_type,
99+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
100+
) as stream:
101+
await stream.send_message(request, end=True)
102+
async for message in stream:
103+
yield message
104+
105+
async def _stream_unary(
106+
self,
107+
route: str,
108+
request_iterator: _MessageSource,
109+
request_type: Type[ST],
110+
response_type: Type[T],
111+
*,
112+
timeout: Optional[float] = None,
113+
deadline: Optional["Deadline"] = None,
114+
metadata: Optional[_MetadataLike] = None,
115+
) -> T:
116+
"""Make a stream request and return the response."""
117+
async with self.channel.request(
118+
route,
119+
grpclib.const.Cardinality.STREAM_UNARY,
120+
request_type,
121+
response_type,
122+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
123+
) as stream:
124+
await self._send_messages(stream, request_iterator)
125+
response = await stream.recv_message()
126+
assert response is not None
127+
return response
128+
129+
async def _stream_stream(
130+
self,
131+
route: str,
132+
request_iterator: _MessageSource,
133+
request_type: Type[ST],
134+
response_type: Type[T],
135+
*,
136+
timeout: Optional[float] = None,
137+
deadline: Optional["Deadline"] = None,
138+
metadata: Optional[_MetadataLike] = None,
139+
) -> AsyncIterator[T]:
140+
"""
141+
Make a stream request and return an AsyncIterator to iterate over response
142+
messages.
143+
"""
144+
async with self.channel.request(
145+
route,
146+
grpclib.const.Cardinality.STREAM_STREAM,
147+
request_type,
148+
response_type,
149+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
150+
) as stream:
151+
await stream.send_request()
152+
sending_task = asyncio.ensure_future(
153+
self._send_messages(stream, request_iterator)
154+
)
155+
try:
156+
async for response in stream:
157+
yield response
158+
except:
159+
sending_task.cancel()
160+
raise
161+
162+
@staticmethod
163+
async def _send_messages(stream, messages: _MessageSource):
164+
if isinstance(messages, AsyncIterable):
165+
async for message in messages:
166+
await stream.send_message(message)
167+
else:
168+
for message in messages:
169+
await stream.send_message(message)
170+
await stream.end()

betterproto/grpc/util/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)