Skip to content

Commit 417e655

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support retry configuration at request level
PiperOrigin-RevId: 787236850
1 parent ae2d790 commit 417e655

File tree

3 files changed

+231
-5
lines changed

3 files changed

+231
-5
lines changed

google/genai/_api_client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -965,8 +965,21 @@ def _request_once(
965965
def _request(
966966
self,
967967
http_request: HttpRequest,
968+
http_options: Optional[HttpOptionsOrDict] = None,
968969
stream: bool = False,
969970
) -> HttpResponse:
971+
if http_options:
972+
parameter_model = (
973+
HttpOptions(**http_options)
974+
if isinstance(http_options, dict)
975+
else http_options
976+
)
977+
# Support per request retry options.
978+
if parameter_model.retry_options:
979+
retry_kwargs = _retry_args(parameter_model.retry_options)
980+
retry = tenacity.Retrying(**retry_kwargs)
981+
return retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
982+
970983
return self._retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
971984

972985
async def _async_request_once(
@@ -1111,8 +1124,20 @@ async def _async_request_once(
11111124
async def _async_request(
11121125
self,
11131126
http_request: HttpRequest,
1127+
http_options: Optional[HttpOptionsOrDict] = None,
11141128
stream: bool = False,
11151129
) -> HttpResponse:
1130+
if http_options:
1131+
parameter_model = (
1132+
HttpOptions(**http_options)
1133+
if isinstance(http_options, dict)
1134+
else http_options
1135+
)
1136+
# Support per request retry options.
1137+
if parameter_model.retry_options:
1138+
retry_kwargs = _retry_args(parameter_model.retry_options)
1139+
retry = tenacity.AsyncRetrying(**retry_kwargs)
1140+
return await retry(self._async_request_once, http_request, stream) # type: ignore[no-any-return]
11161141
return await self._async_retry( # type: ignore[no-any-return]
11171142
self._async_request_once, http_request, stream
11181143
)
@@ -1134,7 +1159,7 @@ def request(
11341159
http_request = self._build_request(
11351160
http_method, path, request_dict, http_options
11361161
)
1137-
response = self._request(http_request, stream=False)
1162+
response = self._request(http_request, http_options, stream=False)
11381163
response_body = (
11391164
response.response_stream[0] if response.response_stream else ''
11401165
)
@@ -1151,7 +1176,7 @@ def request_streamed(
11511176
http_method, path, request_dict, http_options
11521177
)
11531178

1154-
session_response = self._request(http_request, stream=True)
1179+
session_response = self._request(http_request, http_options, stream=True)
11551180
for chunk in session_response.segments():
11561181
yield SdkHttpResponse(
11571182
headers=session_response.headers, body=json.dumps(chunk)
@@ -1168,7 +1193,9 @@ async def async_request(
11681193
http_method, path, request_dict, http_options
11691194
)
11701195

1171-
result = await self._async_request(http_request=http_request, stream=False)
1196+
result = await self._async_request(
1197+
http_request=http_request, http_options=http_options, stream=False
1198+
)
11721199
response_body = result.response_stream[0] if result.response_stream else ''
11731200
return SdkHttpResponse(headers=result.headers, body=response_body)
11741201

google/genai/_replay_api_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,13 +479,14 @@ def _verify_response(self, response_model: BaseModel) -> None:
479479
def _request(
480480
self,
481481
http_request: HttpRequest,
482+
http_options: Optional[HttpOptionsOrDict] = None,
482483
stream: bool = False,
483484
) -> HttpResponse:
484485
self._initialize_replay_session_if_not_loaded()
485486
if self._should_call_api():
486487
_debug_print('api mode request: %s' % http_request)
487488
try:
488-
result = super()._request(http_request, stream)
489+
result = super()._request(http_request, http_options, stream)
489490
except errors.APIError as e:
490491
self._record_interaction(http_request, e)
491492
raise e
@@ -507,13 +508,16 @@ def _request(
507508
async def _async_request(
508509
self,
509510
http_request: HttpRequest,
511+
http_options: Optional[HttpOptionsOrDict] = None,
510512
stream: bool = False,
511513
) -> HttpResponse:
512514
self._initialize_replay_session_if_not_loaded()
513515
if self._should_call_api():
514516
_debug_print('api mode request: %s' % http_request)
515517
try:
516-
result = await super()._async_request(http_request, stream)
518+
result = await super()._async_request(
519+
http_request, http_options, stream
520+
)
517521
except errors.APIError as e:
518522
self._record_interaction(http_request, e)
519523
raise e

google/genai/tests/client/test_retries.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,35 @@ def test_retries_failed_request_retries_successfully():
264264
assert response.headers['status-code'] == '200'
265265

266266

267+
def test_retries_failed_request_retries_successfully_at_request_level():
268+
mock_transport = mock.Mock(spec=httpx.BaseTransport)
269+
mock_transport.handle_request.side_effect = (
270+
_httpx_response(429),
271+
_httpx_response(200),
272+
)
273+
274+
client = api_client.BaseApiClient(
275+
vertexai=True,
276+
project='test_project',
277+
location='global',
278+
http_options=_transport_options(
279+
transport=mock_transport,
280+
),
281+
)
282+
283+
with _patch_auth_default():
284+
response = client.request(
285+
http_method='GET',
286+
path='path',
287+
request_dict={},
288+
http_options=types.HttpOptions(
289+
retry_options=_RETRY_OPTIONS
290+
), # At request level.
291+
)
292+
mock_transport.handle_request.assert_called()
293+
assert response.headers['status-code'] == '200'
294+
295+
267296
def test_retries_failed_request_retries_unsuccessfully():
268297
mock_transport = mock.Mock(spec=httpx.BaseTransport)
269298
mock_transport.handle_request.side_effect = (
@@ -290,6 +319,36 @@ def test_retries_failed_request_retries_unsuccessfully():
290319
mock_transport.handle_request.assert_called()
291320

292321

322+
def test_retries_failed_request_retries_unsuccessfully_at_request_level():
323+
mock_transport = mock.Mock(spec=httpx.BaseTransport)
324+
mock_transport.handle_request.side_effect = (
325+
_httpx_response(429),
326+
_httpx_response(504),
327+
)
328+
329+
client = api_client.BaseApiClient(
330+
vertexai=True,
331+
project='test_project',
332+
location='global',
333+
http_options=_transport_options(
334+
transport=mock_transport,
335+
),
336+
)
337+
338+
with _patch_auth_default():
339+
try:
340+
client.request(
341+
http_method='GET',
342+
path='path',
343+
request_dict={},
344+
http_options={'retry_options': _RETRY_OPTIONS}, # At request level.
345+
)
346+
assert False, 'Expected APIError to be raised.'
347+
except errors.APIError as e:
348+
assert e.code == 504
349+
mock_transport.handle_request.assert_called()
350+
351+
293352
# Async httpx
294353

295354

@@ -401,6 +460,40 @@ async def run():
401460
asyncio.run(run())
402461

403462

463+
def test_async_retries_failed_request_retries_successfully_at_request_level():
464+
api_client.has_aiohttp = False
465+
466+
async def run():
467+
mock_transport = mock.Mock(spec=httpx.AsyncBaseTransport)
468+
mock_transport.handle_async_request.side_effect = (
469+
_httpx_response(429),
470+
_httpx_response(200),
471+
)
472+
473+
client = api_client.BaseApiClient(
474+
vertexai=True,
475+
project='test_project',
476+
location='global',
477+
http_options=_transport_options(
478+
async_transport=mock_transport,
479+
),
480+
)
481+
482+
with _patch_auth_default():
483+
response = await client.async_request(
484+
http_method='GET',
485+
path='path',
486+
request_dict={},
487+
http_options=types.HttpOptions(
488+
retry_options=_RETRY_OPTIONS
489+
), # At request level.
490+
)
491+
mock_transport.handle_async_request.assert_called()
492+
assert response.headers['status-code'] == '200'
493+
494+
asyncio.run(run())
495+
496+
404497
def test_async_retries_failed_request_retries_unsuccessfully():
405498
api_client.has_aiohttp = False
406499

@@ -434,6 +527,41 @@ async def run():
434527
asyncio.run(run())
435528

436529

530+
def test_async_retries_failed_request_retries_unsuccessfully_at_request_level():
531+
api_client.has_aiohttp = False
532+
533+
async def run():
534+
mock_transport = mock.Mock(spec=httpx.AsyncBaseTransport)
535+
mock_transport.handle_async_request.side_effect = (
536+
_httpx_response(429),
537+
_httpx_response(504),
538+
)
539+
540+
client = api_client.BaseApiClient(
541+
vertexai=True,
542+
project='test_project',
543+
location='global',
544+
http_options=_transport_options(
545+
async_transport=mock_transport,
546+
),
547+
)
548+
549+
with _patch_auth_default():
550+
try:
551+
await client.async_request(
552+
http_method='GET',
553+
path='path',
554+
request_dict={},
555+
http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS),
556+
)
557+
assert False, 'Expected APIError to be raised.'
558+
except errors.APIError as e:
559+
assert e.code == 504
560+
mock_transport.handle_async_request.assert_called()
561+
562+
asyncio.run(run())
563+
564+
437565
# Async aiohttp
438566

439567

@@ -553,6 +681,39 @@ async def run():
553681
asyncio.run(run())
554682

555683

684+
@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True)
685+
def test_aiohttp_retries_failed_request_retries_successfully_at_request_level(
686+
mock_request,
687+
):
688+
api_client.has_aiohttp = True
689+
690+
async def run():
691+
mock_request.side_effect = (
692+
_aiohttp_async_response(429),
693+
_aiohttp_async_response(200),
694+
)
695+
696+
client = api_client.BaseApiClient(
697+
vertexai=True,
698+
project='test_project',
699+
location='global',
700+
)
701+
702+
with _patch_auth_default():
703+
response = await client.async_request(
704+
http_method='GET',
705+
path='path',
706+
request_dict={},
707+
http_options=types.HttpOptions(
708+
retry_options=_RETRY_OPTIONS
709+
), # At request level.
710+
)
711+
mock_request.assert_called()
712+
assert response.headers['status-code'] == '200'
713+
714+
asyncio.run(run())
715+
716+
556717
@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True)
557718
def test_aiohttp_retries_failed_request_retries_unsuccessfully(mock_request):
558719
api_client.has_aiohttp = True
@@ -583,3 +744,37 @@ async def run():
583744
mock_request.assert_called()
584745

585746
asyncio.run(run())
747+
748+
749+
@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True)
750+
def test_aiohttp_retries_failed_request_retries_unsuccessfully_at_request_level(
751+
mock_request,
752+
):
753+
api_client.has_aiohttp = True
754+
755+
async def run():
756+
mock_request.side_effect = (
757+
_aiohttp_async_response(429),
758+
_aiohttp_async_response(504),
759+
)
760+
761+
client = api_client.BaseApiClient(
762+
vertexai=True,
763+
project='test_project',
764+
location='global',
765+
)
766+
767+
with _patch_auth_default():
768+
try:
769+
await client.async_request(
770+
http_method='GET',
771+
path='path',
772+
request_dict={},
773+
http_options={'retry_options': _RETRY_OPTIONS}, # At request level.
774+
)
775+
assert False, 'Expected APIError to be raised.'
776+
except errors.APIError as e:
777+
assert e.code == 504
778+
mock_request.assert_called()
779+
780+
asyncio.run(run())

0 commit comments

Comments
 (0)