@@ -405,6 +405,56 @@ def test_iteration_w_raw_raising_unavailable_after_token(self):
405405 self .assertEqual (request .resume_token , RESUME_TOKEN )
406406 self .assertNoSpans ()
407407
408+ def test_iteration_w_raw_raising_unavailable_during_restart (self ):
409+ from google .api_core .exceptions import ServiceUnavailable
410+
411+ FIRST = (self ._make_item (0 ), self ._make_item (1 , resume_token = RESUME_TOKEN ))
412+ LAST = (self ._make_item (2 ),)
413+ before = _MockIterator (
414+ * FIRST , fail_after = True , error = ServiceUnavailable ("testing" )
415+ )
416+ after = _MockIterator (* LAST )
417+ request = mock .Mock (test = "test" , spec = ["test" , "resume_token" ])
418+ # The second call (the first retry) raises ServiceUnavailable immediately.
419+ # The third call (the second retry) succeeds.
420+ restart = mock .Mock (
421+ spec = [],
422+ side_effect = [before , ServiceUnavailable ("retry failed" ), after ],
423+ )
424+ database = _Database ()
425+ database .spanner_api = build_spanner_api ()
426+ session = _Session (database )
427+ derived = _build_snapshot_derived (session )
428+ resumable = self ._call_fut (derived , restart , request , session = session )
429+ self .assertEqual (list (resumable ), list (FIRST + LAST ))
430+ self .assertEqual (len (restart .mock_calls ), 3 )
431+ self .assertEqual (request .resume_token , RESUME_TOKEN )
432+ self .assertNoSpans ()
433+
434+ def test_iteration_w_raw_raising_resumable_internal_error_during_restart (self ):
435+ FIRST = (self ._make_item (0 ), self ._make_item (1 , resume_token = RESUME_TOKEN ))
436+ LAST = (self ._make_item (2 ),)
437+ before = _MockIterator (
438+ * FIRST ,
439+ fail_after = True ,
440+ error = INTERNAL_SERVER_ERROR_UNEXPECTED_EOS ,
441+ )
442+ after = _MockIterator (* LAST )
443+ request = mock .Mock (test = "test" , spec = ["test" , "resume_token" ])
444+ restart = mock .Mock (
445+ spec = [],
446+ side_effect = [before , INTERNAL_SERVER_ERROR_UNEXPECTED_EOS , after ],
447+ )
448+ database = _Database ()
449+ database .spanner_api = build_spanner_api ()
450+ session = _Session (database )
451+ derived = _build_snapshot_derived (session )
452+ resumable = self ._call_fut (derived , restart , request , session = session )
453+ self .assertEqual (list (resumable ), list (FIRST + LAST ))
454+ self .assertEqual (len (restart .mock_calls ), 3 )
455+ self .assertEqual (request .resume_token , RESUME_TOKEN )
456+ self .assertNoSpans ()
457+
408458 def test_iteration_w_raw_w_multiuse (self ):
409459 from google .cloud .spanner_v1 import (
410460 ReadRequest ,
0 commit comments