@@ -264,6 +264,35 @@ def test_retries_failed_request_retries_successfully():
264
264
assert response .headers ['status-code' ] == '200'
265
265
266
266
267
+ def test_retries_failed_request_retries_successfully_at_request_level ():
268
+ mock_transport = mock .Mock (spec = httpx .BaseTransport )
269
+ mock_transport .handle_request .side_effect = (
270
+ _httpx_response (429 ),
271
+ _httpx_response (200 ),
272
+ )
273
+
274
+ client = api_client .BaseApiClient (
275
+ vertexai = True ,
276
+ project = 'test_project' ,
277
+ location = 'global' ,
278
+ http_options = _transport_options (
279
+ transport = mock_transport ,
280
+ ),
281
+ )
282
+
283
+ with _patch_auth_default ():
284
+ response = client .request (
285
+ http_method = 'GET' ,
286
+ path = 'path' ,
287
+ request_dict = {},
288
+ http_options = types .HttpOptions (
289
+ retry_options = _RETRY_OPTIONS
290
+ ), # At request level.
291
+ )
292
+ mock_transport .handle_request .assert_called ()
293
+ assert response .headers ['status-code' ] == '200'
294
+
295
+
267
296
def test_retries_failed_request_retries_unsuccessfully ():
268
297
mock_transport = mock .Mock (spec = httpx .BaseTransport )
269
298
mock_transport .handle_request .side_effect = (
@@ -290,6 +319,36 @@ def test_retries_failed_request_retries_unsuccessfully():
290
319
mock_transport .handle_request .assert_called ()
291
320
292
321
322
+ def test_retries_failed_request_retries_unsuccessfully_at_request_level ():
323
+ mock_transport = mock .Mock (spec = httpx .BaseTransport )
324
+ mock_transport .handle_request .side_effect = (
325
+ _httpx_response (429 ),
326
+ _httpx_response (504 ),
327
+ )
328
+
329
+ client = api_client .BaseApiClient (
330
+ vertexai = True ,
331
+ project = 'test_project' ,
332
+ location = 'global' ,
333
+ http_options = _transport_options (
334
+ transport = mock_transport ,
335
+ ),
336
+ )
337
+
338
+ with _patch_auth_default ():
339
+ try :
340
+ client .request (
341
+ http_method = 'GET' ,
342
+ path = 'path' ,
343
+ request_dict = {},
344
+ http_options = {'retry_options' : _RETRY_OPTIONS }, # At request level.
345
+ )
346
+ assert False , 'Expected APIError to be raised.'
347
+ except errors .APIError as e :
348
+ assert e .code == 504
349
+ mock_transport .handle_request .assert_called ()
350
+
351
+
293
352
# Async httpx
294
353
295
354
@@ -401,6 +460,40 @@ async def run():
401
460
asyncio .run (run ())
402
461
403
462
463
+ def test_async_retries_failed_request_retries_successfully_at_request_level ():
464
+ api_client .has_aiohttp = False
465
+
466
+ async def run ():
467
+ mock_transport = mock .Mock (spec = httpx .AsyncBaseTransport )
468
+ mock_transport .handle_async_request .side_effect = (
469
+ _httpx_response (429 ),
470
+ _httpx_response (200 ),
471
+ )
472
+
473
+ client = api_client .BaseApiClient (
474
+ vertexai = True ,
475
+ project = 'test_project' ,
476
+ location = 'global' ,
477
+ http_options = _transport_options (
478
+ async_transport = mock_transport ,
479
+ ),
480
+ )
481
+
482
+ with _patch_auth_default ():
483
+ response = await client .async_request (
484
+ http_method = 'GET' ,
485
+ path = 'path' ,
486
+ request_dict = {},
487
+ http_options = types .HttpOptions (
488
+ retry_options = _RETRY_OPTIONS
489
+ ), # At request level.
490
+ )
491
+ mock_transport .handle_async_request .assert_called ()
492
+ assert response .headers ['status-code' ] == '200'
493
+
494
+ asyncio .run (run ())
495
+
496
+
404
497
def test_async_retries_failed_request_retries_unsuccessfully ():
405
498
api_client .has_aiohttp = False
406
499
@@ -434,6 +527,41 @@ async def run():
434
527
asyncio .run (run ())
435
528
436
529
530
+ def test_async_retries_failed_request_retries_unsuccessfully_at_request_level ():
531
+ api_client .has_aiohttp = False
532
+
533
+ async def run ():
534
+ mock_transport = mock .Mock (spec = httpx .AsyncBaseTransport )
535
+ mock_transport .handle_async_request .side_effect = (
536
+ _httpx_response (429 ),
537
+ _httpx_response (504 ),
538
+ )
539
+
540
+ client = api_client .BaseApiClient (
541
+ vertexai = True ,
542
+ project = 'test_project' ,
543
+ location = 'global' ,
544
+ http_options = _transport_options (
545
+ async_transport = mock_transport ,
546
+ ),
547
+ )
548
+
549
+ with _patch_auth_default ():
550
+ try :
551
+ await client .async_request (
552
+ http_method = 'GET' ,
553
+ path = 'path' ,
554
+ request_dict = {},
555
+ http_options = types .HttpOptions (retry_options = _RETRY_OPTIONS ),
556
+ )
557
+ assert False , 'Expected APIError to be raised.'
558
+ except errors .APIError as e :
559
+ assert e .code == 504
560
+ mock_transport .handle_async_request .assert_called ()
561
+
562
+ asyncio .run (run ())
563
+
564
+
437
565
# Async aiohttp
438
566
439
567
@@ -553,6 +681,39 @@ async def run():
553
681
asyncio .run (run ())
554
682
555
683
684
+ @mock .patch .object (aiohttp .ClientSession , 'request' , autospec = True )
685
+ def test_aiohttp_retries_failed_request_retries_successfully_at_request_level (
686
+ mock_request ,
687
+ ):
688
+ api_client .has_aiohttp = True
689
+
690
+ async def run ():
691
+ mock_request .side_effect = (
692
+ _aiohttp_async_response (429 ),
693
+ _aiohttp_async_response (200 ),
694
+ )
695
+
696
+ client = api_client .BaseApiClient (
697
+ vertexai = True ,
698
+ project = 'test_project' ,
699
+ location = 'global' ,
700
+ )
701
+
702
+ with _patch_auth_default ():
703
+ response = await client .async_request (
704
+ http_method = 'GET' ,
705
+ path = 'path' ,
706
+ request_dict = {},
707
+ http_options = types .HttpOptions (
708
+ retry_options = _RETRY_OPTIONS
709
+ ), # At request level.
710
+ )
711
+ mock_request .assert_called ()
712
+ assert response .headers ['status-code' ] == '200'
713
+
714
+ asyncio .run (run ())
715
+
716
+
556
717
@mock .patch .object (aiohttp .ClientSession , 'request' , autospec = True )
557
718
def test_aiohttp_retries_failed_request_retries_unsuccessfully (mock_request ):
558
719
api_client .has_aiohttp = True
@@ -583,3 +744,37 @@ async def run():
583
744
mock_request .assert_called ()
584
745
585
746
asyncio .run (run ())
747
+
748
+
749
+ @mock .patch .object (aiohttp .ClientSession , 'request' , autospec = True )
750
+ def test_aiohttp_retries_failed_request_retries_unsuccessfully_at_request_level (
751
+ mock_request ,
752
+ ):
753
+ api_client .has_aiohttp = True
754
+
755
+ async def run ():
756
+ mock_request .side_effect = (
757
+ _aiohttp_async_response (429 ),
758
+ _aiohttp_async_response (504 ),
759
+ )
760
+
761
+ client = api_client .BaseApiClient (
762
+ vertexai = True ,
763
+ project = 'test_project' ,
764
+ location = 'global' ,
765
+ )
766
+
767
+ with _patch_auth_default ():
768
+ try :
769
+ await client .async_request (
770
+ http_method = 'GET' ,
771
+ path = 'path' ,
772
+ request_dict = {},
773
+ http_options = {'retry_options' : _RETRY_OPTIONS }, # At request level.
774
+ )
775
+ assert False , 'Expected APIError to be raised.'
776
+ except errors .APIError as e :
777
+ assert e .code == 504
778
+ mock_request .assert_called ()
779
+
780
+ asyncio .run (run ())
0 commit comments