Skip to content

Commit eaac3e6

Browse files
authored
fix: update paging implementation to handle unconventional pagination (#750)
* fix: update paging implementation to handle unconventional pagination * fix: typing errors, mypy cli update * fix: mypy cli flag * fix: delete __init__.py, remove -p mypy flag * fix: clearing up statements, tests, minor bug in filter usage * fix: wrong generated type hints
1 parent 4077b45 commit eaac3e6

File tree

5 files changed

+189
-19
lines changed

5 files changed

+189
-19
lines changed

gapic/schema/wrappers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -866,13 +866,22 @@ def paged_result_field(self) -> Optional[Field]:
866866
"""Return the response pagination field if the method is paginated."""
867867
# If the request field lacks any of the expected pagination fields,
868868
# then the method is not paginated.
869-
for page_field in ((self.input, int, 'page_size'),
870-
(self.input, str, 'page_token'),
869+
870+
# The request must have page_token and next_page_token as they keep track of pages
871+
for source, source_type, name in ((self.input, str, 'page_token'),
871872
(self.output, str, 'next_page_token')):
872-
field = page_field[0].fields.get(page_field[2], None)
873-
if not field or field.type != page_field[1]:
873+
field = source.fields.get(name, None)
874+
if not field or field.type != source_type:
874875
return None
875876

877+
# The request must have max_results or page_size
878+
page_fields = (self.input.fields.get('max_results', None),
879+
self.input.fields.get('page_size', None))
880+
page_field_size = next(
881+
(field for field in page_fields if field), None)
882+
if not page_field_size or page_field_size.type != int:
883+
return None
884+
876885
# Return the first repeated field.
877886
for field in self.output.fields.values():
878887
if field.repeated:

gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{# This lives within the loop in order to ensure that this template
77
is empty if there are no paged methods.
88
-#}
9-
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple
9+
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional
1010

1111
{% filter sort_lines -%}
1212
{% for method in service.methods.values() | selectattr('paged_result_field') -%}
@@ -68,14 +68,25 @@ class {{ method.name }}Pager:
6868
self._response = self._method(self._request, metadata=self._metadata)
6969
yield self._response
7070

71+
{% if method.paged_result_field.map %}
72+
def __iter__(self) -> Iterable[Tuple[str, {{ method.paged_result_field.type.fields.get('value').ident }}]]:
73+
for page in self.pages:
74+
yield from page.{{ method.paged_result_field.name}}.items()
75+
76+
def get(self, key: str) -> Optional[{{ method.paged_result_field.type.fields.get('value').ident }}]:
77+
return self._response.items.get(key)
78+
{% else %}
7179
def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}:
7280
for page in self.pages:
7381
yield from page.{{ method.paged_result_field.name }}
82+
{% endif %}
7483

7584
def __repr__(self) -> str:
7685
return '{0}<{1!r}>'.format(self.__class__.__name__, self._response)
7786

7887

88+
{# TODO(yon-mg): remove on rest async transport impl #}
89+
{% if 'grpc' in opts.transport %}
7990
class {{ method.name }}AsyncPager:
8091
"""A pager for iterating through ``{{ method.name|snake_case }}`` requests.
8192

@@ -138,5 +149,6 @@ class {{ method.name }}AsyncPager:
138149
def __repr__(self) -> str:
139150
return '{0}<{1!r}>'.format(self.__class__.__name__, self._response)
140151

152+
{% endif %}
141153
{% endfor %}
142154
{% endblock %}

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
184184
# TODO(yon-mg): handle nested fields corerctly rather than using only top level fields
185185
# not required for GCE
186186
query_params = {
187-
{% filter sort_lines -%}
188-
{%- for field in method.query_params %}
187+
{%- for field in method.query_params | sort%}
189188
'{{ field|camel_case }}': request.{{ field }},
190189
{%- endfor %}
191-
{% endfilter -%}
192190
}
193191
# TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here
194192
# discards default values

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def test_{{ method.name|snake_case }}_raw_page_lro():
10201020
assert response.raw_page is response
10211021
{% endif %} {#- method.paged_result_field #}
10221022

1023-
{% endfor -%} {#- method in methods #}
1023+
{% endfor -%} {#- method in methods for grpc #}
10241024

10251025
{% for method in service.methods.values() if 'rest' in opts.transport -%}
10261026
def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type={{ method.input.ident }}):
@@ -1162,7 +1162,126 @@ def test_{{ method.name|snake_case }}_rest_flattened_error():
11621162
)
11631163

11641164

1165-
{% endfor -%}
1165+
{% if method.paged_result_field %}
1166+
def test_{{ method.name|snake_case }}_pager():
1167+
client = {{ service.client_name }}(
1168+
credentials=credentials.AnonymousCredentials(),
1169+
)
1170+
1171+
# Mock the http request call within the method and fake a response.
1172+
with mock.patch.object(Session, 'request') as req:
1173+
# Set the response as a series of pages
1174+
{% if method.paged_result_field.map%}
1175+
response = (
1176+
{{ method.output.ident }}(
1177+
{{ method.paged_result_field.name }}={
1178+
'a':{{ method.paged_result_field.type.fields.get('value').ident }}(),
1179+
'b':{{ method.paged_result_field.type.fields.get('value').ident }}(),
1180+
'c':{{ method.paged_result_field.type.fields.get('value').ident }}(),
1181+
},
1182+
next_page_token='abc',
1183+
),
1184+
{{ method.output.ident }}(
1185+
{{ method.paged_result_field.name }}={},
1186+
next_page_token='def',
1187+
),
1188+
{{ method.output.ident }}(
1189+
{{ method.paged_result_field.name }}={
1190+
'g':{{ method.paged_result_field.type.fields.get('value').ident }}(),
1191+
},
1192+
next_page_token='ghi',
1193+
),
1194+
{{ method.output.ident }}(
1195+
{{ method.paged_result_field.name }}={
1196+
'h':{{ method.paged_result_field.type.fields.get('value').ident }}(),
1197+
'i':{{ method.paged_result_field.type.fields.get('value').ident }}(),
1198+
},
1199+
),
1200+
)
1201+
{% else %}
1202+
response = (
1203+
{{ method.output.ident }}(
1204+
{{ method.paged_result_field.name }}=[
1205+
{{ method.paged_result_field.type.ident }}(),
1206+
{{ method.paged_result_field.type.ident }}(),
1207+
{{ method.paged_result_field.type.ident }}(),
1208+
],
1209+
next_page_token='abc',
1210+
),
1211+
{{ method.output.ident }}(
1212+
{{ method.paged_result_field.name }}=[],
1213+
next_page_token='def',
1214+
),
1215+
{{ method.output.ident }}(
1216+
{{ method.paged_result_field.name }}=[
1217+
{{ method.paged_result_field.type.ident }}(),
1218+
],
1219+
next_page_token='ghi',
1220+
),
1221+
{{ method.output.ident }}(
1222+
{{ method.paged_result_field.name }}=[
1223+
{{ method.paged_result_field.type.ident }}(),
1224+
{{ method.paged_result_field.type.ident }}(),
1225+
],
1226+
),
1227+
)
1228+
{% endif %}
1229+
# Two responses for two calls
1230+
response = response + response
1231+
1232+
# Wrap the values into proper Response objs
1233+
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
1234+
return_values = tuple(Response() for i in response)
1235+
for return_val, response_val in zip(return_values, response):
1236+
return_val._content = response_val.encode('UTF-8')
1237+
return_val.status_code = 200
1238+
req.side_effect = return_values
1239+
1240+
metadata = ()
1241+
{% if method.field_headers -%}
1242+
metadata = tuple(metadata) + (
1243+
gapic_v1.routing_header.to_grpc_metadata((
1244+
{%- for field_header in method.field_headers %}
1245+
{%- if not method.client_streaming %}
1246+
('{{ field_header }}', ''),
1247+
{%- endif %}
1248+
{%- endfor %}
1249+
)),
1250+
)
1251+
{% endif -%}
1252+
pager = client.{{ method.name|snake_case }}(request={})
1253+
1254+
assert pager._metadata == metadata
1255+
1256+
{% if method.paged_result_field.map %}
1257+
assert isinstance(pager.get('a'), {{ method.paged_result_field.type.fields.get('value').ident }})
1258+
assert pager.get('h') is None
1259+
{% endif %}
1260+
1261+
results = list(pager)
1262+
assert len(results) == 6
1263+
{% if method.paged_result_field.map %}
1264+
assert all(
1265+
isinstance(i, tuple)
1266+
for i in results)
1267+
for result in results:
1268+
assert isinstance(result, tuple)
1269+
assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }})
1270+
1271+
assert pager.get('a') is None
1272+
assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }})
1273+
{% else %}
1274+
assert all(isinstance(i, {{ method.paged_result_field.type.ident }})
1275+
for i in results)
1276+
{% endif %}
1277+
1278+
pages = list(client.{{ method.name|snake_case }}(request={}).pages)
1279+
for page_, token in zip(pages, ['abc','def','ghi', '']):
1280+
assert page_.raw_page.next_page_token == token
1281+
1282+
1283+
{% endif %} {# paged methods #}
1284+
{% endfor -%} {#- method in methods for rest #}
11661285
def test_credentials_transport_error():
11671286
# It is an error to provide credentials and a transport instance.
11681287
transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport(

tests/unit/schema/wrappers/test_method.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,38 @@ def test_method_client_output_empty():
6666

6767
def test_method_client_output_paged():
6868
paged = make_field(name='foos', message=make_message('Foo'), repeated=True)
69+
parent = make_field(name='parent', type=9) # str
70+
page_size = make_field(name='page_size', type=5) # int
71+
page_token = make_field(name='page_token', type=9) # str
72+
6973
input_msg = make_message(name='ListFoosRequest', fields=(
70-
make_field(name='parent', type=9), # str
71-
make_field(name='page_size', type=5), # int
72-
make_field(name='page_token', type=9), # str
74+
parent,
75+
page_size,
76+
page_token,
7377
))
7478
output_msg = make_message(name='ListFoosResponse', fields=(
7579
paged,
7680
make_field(name='next_page_token', type=9), # str
7781
))
78-
method = make_method('ListFoos',
79-
input_message=input_msg,
80-
output_message=output_msg,
81-
)
82+
method = make_method(
83+
'ListFoos',
84+
input_message=input_msg,
85+
output_message=output_msg,
86+
)
87+
assert method.paged_result_field == paged
88+
assert method.client_output.ident.name == 'ListFoosPager'
89+
90+
max_results = make_field(name='max_results', type=5) # int
91+
input_msg = make_message(name='ListFoosRequest', fields=(
92+
parent,
93+
max_results,
94+
page_token,
95+
))
96+
method = make_method(
97+
'ListFoos',
98+
input_message=input_msg,
99+
output_message=output_msg,
100+
)
82101
assert method.paged_result_field == paged
83102
assert method.client_output.ident.name == 'ListFoosPager'
84103

@@ -123,6 +142,19 @@ def test_method_paged_result_field_no_page_field():
123142
)
124143
assert method.paged_result_field is None
125144

145+
method = make_method(
146+
name='Foo',
147+
input_message=make_message(
148+
name='FooRequest',
149+
fields=(make_field(name='page_token', type=9),) # str
150+
),
151+
output_message=make_message(
152+
name='FooResponse',
153+
fields=(make_field(name='next_page_token', type=9),) # str
154+
)
155+
)
156+
assert method.paged_result_field is None
157+
126158

127159
def test_method_paged_result_ref_types():
128160
input_msg = make_message(
@@ -139,7 +171,7 @@ def test_method_paged_result_ref_types():
139171
name='ListMolluscsResponse',
140172
fields=(
141173
make_field(name='molluscs', message=mollusc_msg, repeated=True),
142-
make_field(name='next_page_token', type=9)
174+
make_field(name='next_page_token', type=9) # str
143175
),
144176
module='mollusc'
145177
)
@@ -207,7 +239,7 @@ def test_flattened_ref_types():
207239

208240

209241
def test_method_paged_result_primitive():
210-
paged = make_field(name='squids', type=9, repeated=True)
242+
paged = make_field(name='squids', type=9, repeated=True) # str
211243
input_msg = make_message(
212244
name='ListSquidsRequest',
213245
fields=(

0 commit comments

Comments
 (0)