@@ -508,12 +508,16 @@ def _async_quorum(
508
508
509
509
self ._logger .info (f"reconfiguring for { quorum_id = } { store_prefixed_addr = } " )
510
510
# We use the replica rank and world as we want all replicas in the PG.
511
- # TODO: handle configure errors
512
- with torch .profiler .record_function ("torchft::manager::_pg.configure" ):
513
- self ._pg .configure (
514
- store_prefixed_addr , replica_rank , replica_world_size
515
- )
516
- self ._quorum_id = quorum_id
511
+ try :
512
+ with torch .profiler .record_function ("torchft::manager::_pg.configure" ):
513
+ self ._pg .configure (
514
+ store_prefixed_addr , replica_rank , replica_world_size
515
+ )
516
+ self ._quorum_id = quorum_id
517
+ except Exception as e :
518
+ self ._logger .exception (f"got exception in pg configure: { e } " )
519
+ self .report_error (e )
520
+ return
517
521
518
522
if allow_heal :
519
523
# run recovery on the recovery stream if available
@@ -523,62 +527,67 @@ def _async_quorum(
523
527
if recovery_stream is not None
524
528
else nullcontext ()
525
529
):
526
- if quorum .recover_dst_ranks :
527
- self ._logger .info (
528
- f"peers need recovery from us { quorum .recover_dst_ranks } "
529
- )
530
- with torch .profiler .record_function (
531
- "torchft::manager::_checkpoint_transport::send_checkpoint"
532
- ):
533
- self ._checkpoint_transport .send_checkpoint (
534
- dst_ranks = quorum .recover_dst_ranks ,
535
- step = max_step ,
536
- state_dict = self ._manager_state_dict (),
537
- timeout = self ._timeout ,
530
+ try :
531
+ if quorum .recover_dst_ranks :
532
+ self ._logger .info (
533
+ f"peers need recovery from us { quorum .recover_dst_ranks } "
538
534
)
539
-
540
- # See manager.rs for healing conditions
541
- if heal :
542
- self ._healing = True
543
- self ._logger .info (
544
- f"healing required, fetching checkpoint metadata from { recover_src_manager_address = } { max_step = } "
545
- )
546
- primary_client = ManagerClient (
547
- recover_src_manager_address ,
548
- connect_timeout = self ._connect_timeout ,
549
- )
550
- checkpoint_metadata = primary_client ._checkpoint_metadata (
551
- self ._rank , timeout = self ._timeout
552
- )
553
- recover_src_rank = quorum .recover_src_rank
554
- assert (
555
- recover_src_rank is not None
556
- ), "must have a recover rank when healing"
557
-
558
- self ._logger .info (
559
- f"fetching checkpoint from { recover_src_rank = } with { checkpoint_metadata = } "
560
- )
561
-
562
- # we apply the user state dict only when safe from the main thread
563
- # save it for now
564
- with torch .profiler .record_function (
565
- "torchft::manager::_checkpoint_transport::recv_checkpoint"
566
- ):
567
- self ._pending_state_dict = (
568
- self ._checkpoint_transport .recv_checkpoint (
569
- src_rank = recover_src_rank ,
570
- metadata = checkpoint_metadata ,
535
+ with torch .profiler .record_function (
536
+ "torchft::manager::_checkpoint_transport::send_checkpoint"
537
+ ):
538
+ self ._checkpoint_transport .send_checkpoint (
539
+ dst_ranks = quorum .recover_dst_ranks ,
571
540
step = max_step ,
541
+ state_dict = self ._manager_state_dict (),
572
542
timeout = self ._timeout ,
573
543
)
544
+
545
+ # See manager.rs for healing conditions
546
+ if heal :
547
+ self ._healing = True
548
+ self ._logger .info (
549
+ f"healing required, fetching checkpoint metadata from { recover_src_manager_address = } { max_step = } "
550
+ )
551
+ primary_client = ManagerClient (
552
+ recover_src_manager_address ,
553
+ connect_timeout = self ._connect_timeout ,
574
554
)
555
+ checkpoint_metadata = primary_client ._checkpoint_metadata (
556
+ self ._rank , timeout = self ._timeout
557
+ )
558
+ recover_src_rank = quorum .recover_src_rank
559
+ assert (
560
+ recover_src_rank is not None
561
+ ), "must have a recover rank when healing"
562
+
563
+ self ._logger .info (
564
+ f"fetching checkpoint from { recover_src_rank = } with { checkpoint_metadata = } "
565
+ )
566
+
567
+ # we apply the user state dict only when safe from the main thread
568
+ # save it for now
569
+ with torch .profiler .record_function (
570
+ "torchft::manager::_checkpoint_transport::recv_checkpoint"
571
+ ):
572
+ self ._pending_state_dict = (
573
+ self ._checkpoint_transport .recv_checkpoint (
574
+ src_rank = recover_src_rank ,
575
+ metadata = checkpoint_metadata ,
576
+ step = max_step ,
577
+ timeout = self ._timeout ,
578
+ )
579
+ )
575
580
576
- # pyre-fixme[6]: got object
577
- self .load_state_dict (self ._pending_state_dict ["torchft" ])
581
+ # pyre-fixme[6]: got object
582
+ self .load_state_dict (self ._pending_state_dict ["torchft" ])
578
583
579
- # This isn't strictly needed as loading the state_dict above should
580
- # restore the correct step but it makes writing tests simpler.
581
- self ._step = max_step
584
+ # This isn't strictly needed as loading the state_dict above should
585
+ # restore the correct step but it makes writing tests simpler.
586
+ self ._step = max_step
587
+ except Exception as e :
588
+ self ._logger .exception (f"got exception in recovery: { e } " )
589
+ self .report_error (e )
590
+ return
582
591
583
592
def _apply_pending_state_dict (self ) -> None :
584
593
assert self ._healing , "must be in healing state"
0 commit comments