Skip to content

manager: gracefully handle errors from configure+checkpoint #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 76 additions & 63 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,16 @@ def _async_quorum(

self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
# We use the replica rank and world as we want all replicas in the PG.
# TODO: handle configure errors
with torch.profiler.record_function("torchft::manager::_pg.configure"):
self._pg.configure(
store_prefixed_addr, replica_rank, replica_world_size
)
self._quorum_id = quorum_id
try:
with torch.profiler.record_function("torchft::manager::_pg.configure"):
self._pg.configure(
store_prefixed_addr, replica_rank, replica_world_size
)
self._quorum_id = quorum_id
except Exception as e:
self._logger.exception(f"got exception in pg configure: {e}")
self.report_error(e)
return

if allow_heal:
# run recovery on the recovery stream if available
Expand All @@ -523,62 +527,67 @@ def _async_quorum(
if recovery_stream is not None
else nullcontext()
):
if quorum.recover_dst_ranks:
self._logger.info(
f"peers need recovery from us {quorum.recover_dst_ranks}"
)
with torch.profiler.record_function(
"torchft::manager::_checkpoint_transport::send_checkpoint"
):
self._checkpoint_transport.send_checkpoint(
dst_ranks=quorum.recover_dst_ranks,
step=max_step,
state_dict=self._manager_state_dict(),
timeout=self._timeout,
try:
if quorum.recover_dst_ranks:
self._logger.info(
f"peers need recovery from us {quorum.recover_dst_ranks}"
)

# See manager.rs for healing conditions
if heal:
self._healing = True
self._logger.info(
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
)
primary_client = ManagerClient(
recover_src_manager_address,
connect_timeout=self._connect_timeout,
)
checkpoint_metadata = primary_client._checkpoint_metadata(
self._rank, timeout=self._timeout
)
recover_src_rank = quorum.recover_src_rank
assert (
recover_src_rank is not None
), "must have a recover rank when healing"

self._logger.info(
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
)

# we apply the user state dict only when safe from the main thread
# save it for now
with torch.profiler.record_function(
"torchft::manager::_checkpoint_transport::recv_checkpoint"
):
self._pending_state_dict = (
self._checkpoint_transport.recv_checkpoint(
src_rank=recover_src_rank,
metadata=checkpoint_metadata,
with torch.profiler.record_function(
"torchft::manager::_checkpoint_transport::send_checkpoint"
):
self._checkpoint_transport.send_checkpoint(
dst_ranks=quorum.recover_dst_ranks,
step=max_step,
state_dict=self._manager_state_dict(),
timeout=self._timeout,
)

# See manager.rs for healing conditions
if heal:
self._healing = True
self._logger.info(
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
)
primary_client = ManagerClient(
recover_src_manager_address,
connect_timeout=self._connect_timeout,
)
checkpoint_metadata = primary_client._checkpoint_metadata(
self._rank, timeout=self._timeout
)
recover_src_rank = quorum.recover_src_rank
assert (
recover_src_rank is not None
), "must have a recover rank when healing"

# pyre-fixme[6]: got object
self.load_state_dict(self._pending_state_dict["torchft"])
self._logger.info(
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
)

# This isn't strictly needed as loading the state_dict above should
# restore the correct step but it makes writing tests simpler.
self._step = max_step
# we apply the user state dict only when safe from the main thread
# save it for now
with torch.profiler.record_function(
"torchft::manager::_checkpoint_transport::recv_checkpoint"
):
self._pending_state_dict = (
self._checkpoint_transport.recv_checkpoint(
src_rank=recover_src_rank,
metadata=checkpoint_metadata,
step=max_step,
timeout=self._timeout,
)
)

# pyre-fixme[6]: got object
self.load_state_dict(self._pending_state_dict["torchft"])

# This isn't strictly needed as loading the state_dict above should
# restore the correct step but it makes writing tests simpler.
self._step = max_step
except Exception as e:
self._logger.exception(f"got exception in recovery: {e}")
self.report_error(e)
return

def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"
Expand All @@ -587,15 +596,19 @@ def _apply_pending_state_dict(self) -> None:
assert self._quorum_future is not None, "must call step before should_commit"
self._quorum_future.result()

self._logger.info("applying pending state dict")
pending_state_dict = self._pending_state_dict

assert self._pending_state_dict is not None, "checkpoint was not staged"
assert (
self._load_state_dict is not None
), "user load_state_dict is not initialized."
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None
self._logger.info("Loaded state dict.")
if pending_state_dict is None:
assert self.errored(), "checkpoint was not staged and no error occured"
else:
self._logger.info("applying pending state dict")

assert (
self._load_state_dict is not None
), "user load_state_dict is not initialized."
self._load_state_dict(pending_state_dict["user"])
self._pending_state_dict = None
self._logger.info("Loaded state dict.")

@torch.profiler.record_function("torchft::manager::should_commit")
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
Expand Down
74 changes: 74 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.distributed import TCPStore

from torchft._torchft import QuorumResult
from torchft.checkpointing.transport import CheckpointTransport
from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
from torchft.process_group import ProcessGroup, _DummyWork

Expand Down Expand Up @@ -648,6 +649,79 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
manager.start_quorum()
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
manager = self._create_manager(use_async_quorum=True)
client_mock().should_commit = MagicMock(return_value=False)

transport = MagicMock(spec=CheckpointTransport)
transport.send_checkpoint.side_effect = RuntimeError("send failure")
transport.recv_checkpoint.side_effect = RuntimeError("recv failure")
manager._checkpoint_transport = transport

quorum = QuorumResult()
quorum.quorum_id = 123
quorum.replica_rank = 1
quorum.replica_world_size = 2
quorum.recover_src_manager_address = "manager address"
quorum.recover_src_rank = 0
quorum.store_address = f"localhost:{self.store.port}"
quorum.max_step = 20
quorum.max_rank = None
quorum.max_world_size = 2
quorum.heal = True

client_mock()._quorum.return_value = quorum

manager.start_quorum()
manager.wait_quorum()
self.assertFalse(manager.should_commit())

error = manager.errored()
self.assertIsNotNone(error)
with self.assertRaisesRegex(RuntimeError, "recv failure"):
raise error

quorum.recover_dst_ranks = [0]
manager.start_quorum()
manager.wait_quorum()
self.assertFalse(manager.should_commit())

error = manager.errored()
self.assertIsNotNone(error)
with self.assertRaisesRegex(RuntimeError, "send failure"):
raise error

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
manager = self._create_manager(use_async_quorum=True)
client_mock().should_commit = MagicMock(return_value=False)

# pyre-ignore[16]: mock
manager._pg.configure.side_effect = RuntimeError("configure failure")

quorum = QuorumResult()
quorum.quorum_id = 123
quorum.replica_rank = 1
quorum.replica_world_size = 2
quorum.recover_src_manager_address = "manager address"
quorum.recover_src_rank = 0
quorum.store_address = f"localhost:{self.store.port}"
quorum.max_step = 20
quorum.max_rank = None
quorum.max_world_size = 2

client_mock()._quorum.return_value = quorum

manager.start_quorum()
manager.wait_quorum()
self.assertFalse(manager.should_commit())

error = manager.errored()
self.assertIsNotNone(error)
with self.assertRaisesRegex(RuntimeError, "configure failure"):
raise error

@patch("torchft.manager.ManagerClient", autospec=True)
def test_max_retries(self, client_mock: MagicMock) -> None:
# Create a manager with max_retries=2
Expand Down