Skip to content

Commit 5cb28a3

Browse files
committed
manager: gracefully handle errors from configure+checkpoint
1 parent abc660b commit 5cb28a3

File tree

2 files changed

+133
-55
lines changed

2 files changed

+133
-55
lines changed

torchft/manager.py

Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,16 @@ def _async_quorum(
508508

509509
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
510510
# We use the replica rank and world as we want all replicas in the PG.
511-
# TODO: handle configure errors
512-
with torch.profiler.record_function("torchft::manager::_pg.configure"):
513-
self._pg.configure(
514-
store_prefixed_addr, replica_rank, replica_world_size
515-
)
516-
self._quorum_id = quorum_id
511+
try:
512+
with torch.profiler.record_function("torchft::manager::_pg.configure"):
513+
self._pg.configure(
514+
store_prefixed_addr, replica_rank, replica_world_size
515+
)
516+
self._quorum_id = quorum_id
517+
except Exception as e:
518+
self._logger.exception(f"got exception in pg configure: {e}")
519+
self.report_error(e)
520+
return
517521

518522
if allow_heal:
519523
# run recovery on the recovery stream if available
@@ -523,62 +527,67 @@ def _async_quorum(
523527
if recovery_stream is not None
524528
else nullcontext()
525529
):
526-
if quorum.recover_dst_ranks:
527-
self._logger.info(
528-
f"peers need recovery from us {quorum.recover_dst_ranks}"
529-
)
530-
with torch.profiler.record_function(
531-
"torchft::manager::_checkpoint_transport::send_checkpoint"
532-
):
533-
self._checkpoint_transport.send_checkpoint(
534-
dst_ranks=quorum.recover_dst_ranks,
535-
step=max_step,
536-
state_dict=self._manager_state_dict(),
537-
timeout=self._timeout,
530+
try:
531+
if quorum.recover_dst_ranks:
532+
self._logger.info(
533+
f"peers need recovery from us {quorum.recover_dst_ranks}"
538534
)
539-
540-
# See manager.rs for healing conditions
541-
if heal:
542-
self._healing = True
543-
self._logger.info(
544-
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
545-
)
546-
primary_client = ManagerClient(
547-
recover_src_manager_address,
548-
connect_timeout=self._connect_timeout,
549-
)
550-
checkpoint_metadata = primary_client._checkpoint_metadata(
551-
self._rank, timeout=self._timeout
552-
)
553-
recover_src_rank = quorum.recover_src_rank
554-
assert (
555-
recover_src_rank is not None
556-
), "must have a recover rank when healing"
557-
558-
self._logger.info(
559-
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
560-
)
561-
562-
# we apply the user state dict only when safe from the main thread
563-
# save it for now
564-
with torch.profiler.record_function(
565-
"torchft::manager::_checkpoint_transport::recv_checkpoint"
566-
):
567-
self._pending_state_dict = (
568-
self._checkpoint_transport.recv_checkpoint(
569-
src_rank=recover_src_rank,
570-
metadata=checkpoint_metadata,
535+
with torch.profiler.record_function(
536+
"torchft::manager::_checkpoint_transport::send_checkpoint"
537+
):
538+
self._checkpoint_transport.send_checkpoint(
539+
dst_ranks=quorum.recover_dst_ranks,
571540
step=max_step,
541+
state_dict=self._manager_state_dict(),
572542
timeout=self._timeout,
573543
)
544+
545+
# See manager.rs for healing conditions
546+
if heal:
547+
self._healing = True
548+
self._logger.info(
549+
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
550+
)
551+
primary_client = ManagerClient(
552+
recover_src_manager_address,
553+
connect_timeout=self._connect_timeout,
574554
)
555+
checkpoint_metadata = primary_client._checkpoint_metadata(
556+
self._rank, timeout=self._timeout
557+
)
558+
recover_src_rank = quorum.recover_src_rank
559+
assert (
560+
recover_src_rank is not None
561+
), "must have a recover rank when healing"
562+
563+
self._logger.info(
564+
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
565+
)
566+
567+
# we apply the user state dict only when safe from the main thread
568+
# save it for now
569+
with torch.profiler.record_function(
570+
"torchft::manager::_checkpoint_transport::recv_checkpoint"
571+
):
572+
self._pending_state_dict = (
573+
self._checkpoint_transport.recv_checkpoint(
574+
src_rank=recover_src_rank,
575+
metadata=checkpoint_metadata,
576+
step=max_step,
577+
timeout=self._timeout,
578+
)
579+
)
575580

576-
# pyre-fixme[6]: got object
577-
self.load_state_dict(self._pending_state_dict["torchft"])
581+
# pyre-fixme[6]: got object
582+
self.load_state_dict(self._pending_state_dict["torchft"])
578583

579-
# This isn't strictly needed as loading the state_dict above should
580-
# restore the correct step but it makes writing tests simpler.
581-
self._step = max_step
584+
# This isn't strictly needed as loading the state_dict above should
585+
# restore the correct step but it makes writing tests simpler.
586+
self._step = max_step
587+
except Exception as e:
588+
self._logger.exception(f"got exception in recovery: {e}")
589+
self.report_error(e)
590+
return
582591

583592
def _apply_pending_state_dict(self) -> None:
584593
assert self._healing, "must be in healing state"

torchft/manager_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.distributed import TCPStore
1515

1616
from torchft._torchft import QuorumResult
17+
from torchft.checkpointing.transport import CheckpointTransport
1718
from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
1819
from torchft.process_group import ProcessGroup, _DummyWork
1920

@@ -648,6 +649,74 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
648649
manager.start_quorum()
649650
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)
650651

652+
@patch("torchft.manager.ManagerClient", autospec=True)
653+
def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
654+
manager = self._create_manager(use_async_quorum=True)
655+
656+
transport = MagicMock(spec=CheckpointTransport)
657+
transport.send_checkpoint.side_effect = RuntimeError("send failure")
658+
transport.recv_checkpoint.side_effect = RuntimeError("recv failure")
659+
manager._checkpoint_transport = transport
660+
661+
quorum = QuorumResult()
662+
quorum.quorum_id = 123
663+
quorum.replica_rank = 1
664+
quorum.replica_world_size = 2
665+
quorum.recover_src_manager_address = "manager address"
666+
quorum.recover_src_rank = 0
667+
quorum.store_address = f"localhost:{self.store.port}"
668+
quorum.max_step = 20
669+
quorum.max_rank = None
670+
quorum.max_world_size = 2
671+
quorum.heal = True
672+
673+
client_mock()._quorum.return_value = quorum
674+
675+
manager.start_quorum()
676+
manager.wait_quorum()
677+
678+
error = manager.errored()
679+
self.assertIsNotNone(error)
680+
with self.assertRaisesRegex(RuntimeError, "recv failure"):
681+
raise error
682+
683+
quorum.recover_dst_ranks = [0]
684+
manager.start_quorum()
685+
manager.wait_quorum()
686+
687+
error = manager.errored()
688+
self.assertIsNotNone(error)
689+
with self.assertRaisesRegex(RuntimeError, "send failure"):
690+
raise error
691+
692+
@patch("torchft.manager.ManagerClient", autospec=True)
693+
def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
694+
manager = self._create_manager(use_async_quorum=True)
695+
696+
# pyre-ignore[16]: mock
697+
manager._pg.configure.side_effect = RuntimeError("configure failure")
698+
699+
quorum = QuorumResult()
700+
quorum.quorum_id = 123
701+
quorum.replica_rank = 1
702+
quorum.replica_world_size = 2
703+
quorum.recover_src_manager_address = "manager address"
704+
quorum.recover_src_rank = 0
705+
quorum.store_address = f"localhost:{self.store.port}"
706+
quorum.max_step = 20
707+
quorum.max_rank = None
708+
quorum.max_world_size = 2
709+
710+
client_mock()._quorum.return_value = quorum
711+
712+
manager.start_quorum()
713+
manager.wait_quorum()
714+
715+
error = manager.errored()
716+
self.assertIsNotNone(error)
717+
with self.assertRaisesRegex(RuntimeError, "configure failure"):
718+
raise error
719+
651720
@patch("torchft.manager.ManagerClient", autospec=True)
652721
def test_max_retries(self, client_mock: MagicMock) -> None:
653722
# Create a manager with max_retries=2

0 commit comments

Comments
 (0)