Skip to content

Add traceback to manager reported error #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchft/checkpointing/transport_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from torchft.checkpointing.transport import CheckpointTransport

TIMEOUT_REGEX = r"(Timed out|timed out|timeout|time out)"
TIMEOUT_REGEX = r".*(Timed out|timed out|timeout|time out).*"


def assertStateDictEqual(
Expand Down
14 changes: 11 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import logging
import os
import socket
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
Expand Down Expand Up @@ -71,6 +72,13 @@ class WorldSizeMode(Enum):
FIXED_WITH_SPARES = 1


class ExceptionWithTraceback(Exception):
def __init__(self, e: Exception) -> None:
self.original_exception = e
self.stack_trace: str = traceback.format_exc()
super().__init__(f"{e}\n{self.stack_trace}")


class Manager:
"""
Manager manages the full fault tolerant training loop.
Expand Down Expand Up @@ -235,7 +243,7 @@ def __init__(

self._step = 0
self._quorum_id = -1
self._errored: Optional[Exception] = None
self._errored: Optional[ExceptionWithTraceback] = None
self._healing = False
self._pending_work: List[torch.futures.Future[object]] = []
self._batches_committed = 0
Expand Down Expand Up @@ -332,9 +340,9 @@ def report_error(self, e: Exception) -> None:
This should be called when an error occurs that leads to a corrupted
gradient that needs to be discarded.
"""
self._errored = e
self._errored = ExceptionWithTraceback(e)

def errored(self) -> Optional[Exception]:
def errored(self) -> Optional[ExceptionWithTraceback]:
"""
Get whether an error has occurred.

Expand Down
27 changes: 16 additions & 11 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ def test_pg_errored(self, client_mock: MagicMock) -> None:
manager._pg.errored.return_value = injected_failure

self.assertFalse(manager.should_commit())
self.assertEqual(manager._errored, injected_failure)
assert manager._errored is not None
self.assertEqual(manager._errored.original_exception, injected_failure)
# pyre-ignore[16]: _pg is mocked
self.assertEqual(manager._pg.errored.call_count, 1)

Expand Down Expand Up @@ -526,7 +527,9 @@ def test_manager_report_error(self, client_mock: MagicMock) -> None:
self.assertIsNone(manager.errored())
e = RuntimeError("some error")
manager.report_error(e)
self.assertIs(manager.errored(), e)
error = manager.errored()
assert error is not None
self.assertIs(error.original_exception, e)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
Expand All @@ -540,7 +543,9 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:

e = RuntimeError("injected failure")
fut.set_exception(e)
self.assertIs(manager.errored(), e)
error = manager.errored()
assert error is not None
self.assertIs(error.original_exception, e)
self.assertEqual(wrapped_fut.value(), 2)

self.assertEqual(manager._pending_work, [wrapped_fut])
Expand All @@ -555,11 +560,11 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
wrapped_fut = manager.wrap_future(fut, 2)
wrapped_fut.wait()
error = manager.errored()
self.assertIsNotNone(error)
assert error is not None
with self.assertRaisesRegex(
TimeoutError, "future did not complete within.*0.01"
):
raise error
raise error.original_exception

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_numerics(self, client_mock: MagicMock) -> None:
Expand Down Expand Up @@ -678,19 +683,19 @@ def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
self.assertFalse(manager.should_commit())

error = manager.errored()
self.assertIsNotNone(error)
assert error is not None
with self.assertRaisesRegex(RuntimeError, "recv failure"):
raise error
raise error.original_exception

quorum.recover_dst_ranks = [0]
manager.start_quorum()
manager.wait_quorum()
self.assertFalse(manager.should_commit())

error = manager.errored()
self.assertIsNotNone(error)
assert error is not None
with self.assertRaisesRegex(RuntimeError, "send failure"):
raise error
raise error.original_exception

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
Expand Down Expand Up @@ -718,9 +723,9 @@ def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
self.assertFalse(manager.should_commit())

error = manager.errored()
self.assertIsNotNone(error)
assert error is not None
with self.assertRaisesRegex(RuntimeError, "configure failure"):
raise error
raise error.original_exception

@patch("torchft.manager.ManagerClient", autospec=True)
def test_max_retries(self, client_mock: MagicMock) -> None:
Expand Down