Skip to content

Commit d5df42c

Browse files
fix function signature for streaming rpc (#279)
* fix function signature for streaming rpc generated code should take request iterator for client streaming, and return response iterator for server streaming
1 parent c104e76 commit d5df42c

File tree

4 files changed

+92
-3
lines changed

4 files changed

+92
-3
lines changed

packages/gapic-generator/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,6 @@ pylintrc.test
5959

6060
# Mypy
6161
.mypy_cache
62+
63+
# pyenv
64+
.python-version

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
{% block content %}
44
from collections import OrderedDict
5-
from typing import Dict, Sequence, Tuple, Type, Union
5+
from typing import Dict, Iterable, Iterator, Sequence, Tuple, Type, Union
66
import pkg_resources
77

88
import google.api_core.client_options as ClientOptions # type: ignore
@@ -11,7 +11,7 @@ from google.api_core import gapic_v1 # type: ignore
1111
from google.api_core import retry as retries # type: ignore
1212
from google.auth import credentials # type: ignore
1313
from google.oauth2 import service_account # type: ignore
14-
14+
1515
{% filter sort_lines -%}
1616
{% for method in service.methods.values() -%}
1717
{% for ref_type in method.ref_types_legacy -%}
@@ -126,18 +126,28 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
126126

127127
{% for method in service.methods.values() -%}
128128
def {{ method.name|snake_case }}(self,
129+
{%- if not method.client_streaming %}
129130
request: {{ method.input.ident }} = None,
130131
*,
131132
{% for field in method.flattened_fields.values() -%}
132133
{{ field.name }}: {{ field.ident }} = None,
133134
{% endfor -%}
135+
{%- else %}
136+
requests: Iterator[{{ method.input.ident }}] = None,
137+
*,
138+
{% endif -%}
134139
retry: retries.Retry = gapic_v1.method.DEFAULT,
135140
timeout: float = None,
136141
metadata: Sequence[Tuple[str, str]] = (),
142+
{%- if not method.server_streaming %}
137143
) -> {{ method.client_output.ident }}:
144+
{%- else %}
145+
) -> Iterable[{{ method.client_output.ident }}]:
146+
{%- endif %}
138147
r"""{{ method.meta.doc|rst(width=72, indent=8) }}
139148

140149
Args:
150+
{%- if not method.client_streaming %}
141151
request (:class:`{{ method.input.ident.sphinx }}`):
142152
The request object.{{ ' ' -}}
143153
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
@@ -148,6 +158,11 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
148158
on the ``request`` instance; if ``request`` is provided, this
149159
should not be set.
150160
{% endfor -%}
161+
{%- else %}
162+
requests (Iterator[`{{ method.input.ident.sphinx }}`]):
163+
The request object iterator.{{ ' ' -}}
164+
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
165+
{%- endif %}
151166
retry (google.api_core.retry.Retry): Designation of what errors, if any,
152167
should be retried.
153168
timeout (float): The timeout for this request.
@@ -156,10 +171,15 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
156171
{%- if not method.void %}
157172

158173
Returns:
174+
{%- if not method.server_streaming %}
159175
{{ method.client_output.ident.sphinx }}:
176+
{%- else %}
177+
Iterable[{{ method.client_output.ident.sphinx }}]:
178+
{%- endif %}
160179
{{ method.client_output.meta.doc|rst(width=72, indent=16) }}
161180
{%- endif %}
162181
"""
182+
{%- if not method.client_streaming %}
163183
# Create or coerce a protobuf request object.
164184
{% if method.flattened_fields -%}
165185
# Sanity check: If we got a request object, we should *not* have
@@ -176,6 +196,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
176196
if {{ field.name }} is not None:
177197
request.{{ key }} = {{ field.name }}
178198
{%- endfor %}
199+
{%- endif %}
179200

180201
# Wrap the RPC method; this adds retry and timeout information,
181202
# and friendly error handling.
@@ -213,7 +234,11 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
213234

214235
# Send the request.
215236
{% if not method.void %}response = {% endif %}rpc(
237+
{%- if not method.client_streaming %}
216238
request,
239+
{%- else %}
240+
requests,
241+
{%- endif %}
217242
retry=retry,
218243
timeout=timeout,
219244
metadata=metadata,

packages/gapic-generator/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
7979
# Everything is optional in proto3 as far as the runtime is concerned,
8080
# and we are mocking out the actual API, so just send an empty request.
8181
request = {{ method.input.ident }}()
82+
{% if method.client_streaming %}
83+
requests = [request]
84+
{% endif %}
8285

8386
# Mock the actual call within the gRPC stub, and fake the request.
8487
with mock.patch.object(
@@ -98,12 +101,20 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
98101
{%- endfor %}
99102
)
100103
{% endif -%}
104+
{% if method.client_streaming %}
105+
response = client.{{ method.name|snake_case }}(iter(requests))
106+
{% else %}
101107
response = client.{{ method.name|snake_case }}(request)
102-
108+
{% endif %}
109+
103110
# Establish that the underlying gRPC stub method was called.
104111
assert len(call.mock_calls) == 1
105112
_, args, _ = call.mock_calls[0]
113+
{% if method.client_streaming %}
114+
assert next(args[0]) == request
115+
{% else %}
106116
assert args[0] == request
117+
{% endif %}
107118

108119
# Establish that the response is the type that we expect.
109120
{% if method.void -%}

packages/gapic-generator/tests/system/test_grpc_streams.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google import showcase
16+
1517

1618
def test_unary_stream(echo):
1719
content = 'The hail in Wales falls mainly on the snails.'
@@ -24,3 +26,51 @@ def test_unary_stream(echo):
2426
for ground_truth, response in zip(content.split(' '), responses):
2527
assert response.content == ground_truth
2628
assert ground_truth == 'snails.'
29+
30+
# TODO. Check responses.trailing_metadata() content once gapic-showcase
31+
# server returns non-empty trailing metadata.
32+
assert len(responses.trailing_metadata()) == 0
33+
34+
35+
def test_stream_unary(echo):
36+
requests = []
37+
requests.append(showcase.EchoRequest(content="hello"))
38+
requests.append(showcase.EchoRequest(content="world!"))
39+
response = echo.collect(iter(requests))
40+
assert response.content == 'hello world!'
41+
42+
43+
def test_stream_unary_passing_dict(echo):
44+
requests = [{'content': 'hello'}, {'content': 'world!'}]
45+
response = echo.collect(iter(requests))
46+
assert response.content == 'hello world!'
47+
48+
49+
def test_stream_stream(echo):
50+
requests = []
51+
requests.append(showcase.EchoRequest(content="hello"))
52+
requests.append(showcase.EchoRequest(content="world!"))
53+
responses = echo.chat(iter(requests))
54+
55+
contents = []
56+
for response in responses:
57+
contents.append(response.content)
58+
assert contents == ['hello', 'world!']
59+
60+
# TODO. Check responses.trailing_metadata() content once gapic-showcase
61+
# server returns non-empty trailing metadata.
62+
assert len(responses.trailing_metadata()) == 0
63+
64+
65+
def test_stream_stream_passing_dict(echo):
66+
requests = [{'content': 'hello'}, {'content': 'world!'}]
67+
responses = echo.chat(iter(requests))
68+
69+
contents = []
70+
for response in responses:
71+
contents.append(response.content)
72+
assert contents == ['hello', 'world!']
73+
74+
# TODO. Check responses.trailing_metadata() content once gapic-showcase
75+
# server returns non-empty trailing metadata.
76+
assert len(responses.trailing_metadata()) == 0

0 commit comments

Comments
 (0)