@@ -444,7 +444,8 @@ def test_pg_errored(self, client_mock: MagicMock) -> None:
444
444
manager ._pg .errored .return_value = injected_failure
445
445
446
446
self .assertFalse (manager .should_commit ())
447
- self .assertEqual (manager ._errored , injected_failure )
447
+ assert manager ._errored is not None
448
+ self .assertEqual (manager ._errored .original_exception , injected_failure )
448
449
# pyre-ignore[16]: _pg is mocked
449
450
self .assertEqual (manager ._pg .errored .call_count , 1 )
450
451
@@ -526,7 +527,9 @@ def test_manager_report_error(self, client_mock: MagicMock) -> None:
526
527
self .assertIsNone (manager .errored ())
527
528
e = RuntimeError ("some error" )
528
529
manager .report_error (e )
529
- self .assertIs (manager .errored (), e )
530
+ error = manager .errored ()
531
+ assert error is not None
532
+ self .assertIs (error .original_exception , e )
530
533
531
534
@patch ("torchft.manager.ManagerClient" , autospec = True )
532
535
def test_manager_wrap_future (self , client_mock : MagicMock ) -> None :
@@ -540,7 +543,9 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
540
543
541
544
e = RuntimeError ("injected failure" )
542
545
fut .set_exception (e )
543
- self .assertIs (manager .errored (), e )
546
+ error = manager .errored ()
547
+ assert error is not None
548
+ self .assertIs (error .original_exception , e )
544
549
self .assertEqual (wrapped_fut .value (), 2 )
545
550
546
551
self .assertEqual (manager ._pending_work , [wrapped_fut ])
@@ -555,11 +560,11 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
555
560
wrapped_fut = manager .wrap_future (fut , 2 )
556
561
wrapped_fut .wait ()
557
562
error = manager .errored ()
558
- self . assertIsNotNone ( error )
563
+ assert error is not None
559
564
with self .assertRaisesRegex (
560
565
TimeoutError , "future did not complete within.*0.01"
561
566
):
562
- raise error
567
+ raise error . original_exception
563
568
564
569
@patch ("torchft.manager.ManagerClient" , autospec = True )
565
570
def test_manager_numerics (self , client_mock : MagicMock ) -> None :
@@ -678,19 +683,19 @@ def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
678
683
self .assertFalse (manager .should_commit ())
679
684
680
685
error = manager .errored ()
681
- self . assertIsNotNone ( error )
686
+ assert error is not None
682
687
with self .assertRaisesRegex (RuntimeError , "recv failure" ):
683
- raise error
688
+ raise error . original_exception
684
689
685
690
quorum .recover_dst_ranks = [0 ]
686
691
manager .start_quorum ()
687
692
manager .wait_quorum ()
688
693
self .assertFalse (manager .should_commit ())
689
694
690
695
error = manager .errored ()
691
- self . assertIsNotNone ( error )
696
+ assert error is not None
692
697
with self .assertRaisesRegex (RuntimeError , "send failure" ):
693
- raise error
698
+ raise error . original_exception
694
699
695
700
@patch ("torchft.manager.ManagerClient" , autospec = True )
696
701
def test_quorum_configure_errors (self , client_mock : MagicMock ) -> None :
@@ -718,9 +723,9 @@ def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
718
723
self .assertFalse (manager .should_commit ())
719
724
720
725
error = manager .errored ()
721
- self . assertIsNotNone ( error )
726
+ assert error is not None
722
727
with self .assertRaisesRegex (RuntimeError , "configure failure" ):
723
- raise error
728
+ raise error . original_exception
724
729
725
730
@patch ("torchft.manager.ManagerClient" , autospec = True )
726
731
def test_max_retries (self , client_mock : MagicMock ) -> None :
0 commit comments