Skip to content

Commit c0f7f54

Browse files
committed
Add traceback to reported error
1 parent 93c230b commit c0f7f54

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

torchft/checkpointing/transport_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from torchft.checkpointing.transport import CheckpointTransport
1313

14-
TIMEOUT_REGEX = r"(Timed out|timed out|timeout|time out)"
14+
TIMEOUT_REGEX = r".*(Timed out|timed out|timeout|time out).*"
1515

1616

1717
def assertStateDictEqual(

torchft/manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import logging
3030
import os
3131
import socket
32+
import traceback
3233
import uuid
3334
from concurrent.futures import ThreadPoolExecutor
3435
from contextlib import nullcontext
@@ -71,6 +72,13 @@ class WorldSizeMode(Enum):
7172
FIXED_WITH_SPARES = 1
7273

7374

75+
class ExceptionWithTraceback(Exception):
76+
def __init__(self, e: Exception) -> None:
77+
self.original_exception = e
78+
self.stack_trace: str = traceback.format_exc()
79+
super().__init__(f"{e}\n{self.stack_trace}")
80+
81+
7482
class Manager:
7583
"""
7684
Manager manages the full fault tolerant training loop.
@@ -235,7 +243,7 @@ def __init__(
235243

236244
self._step = 0
237245
self._quorum_id = -1
238-
self._errored: Optional[Exception] = None
246+
self._errored: Optional[ExceptionWithTraceback] = None
239247
self._healing = False
240248
self._pending_work: List[torch.futures.Future[object]] = []
241249
self._batches_committed = 0
@@ -332,9 +340,9 @@ def report_error(self, e: Exception) -> None:
332340
This should be called when an error occurs that leads to a corrupted
333341
gradient that needs to be discarded.
334342
"""
335-
self._errored = e
343+
self._errored = ExceptionWithTraceback(e)
336344

337-
def errored(self) -> Optional[Exception]:
345+
def errored(self) -> Optional[ExceptionWithTraceback]:
338346
"""
339347
Get whether an error has occurred.
340348

torchft/manager_test.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,8 @@ def test_pg_errored(self, client_mock: MagicMock) -> None:
444444
manager._pg.errored.return_value = injected_failure
445445

446446
self.assertFalse(manager.should_commit())
447-
self.assertEqual(manager._errored, injected_failure)
447+
assert manager._errored is not None
448+
self.assertEqual(manager._errored.original_exception, injected_failure)
448449
# pyre-ignore[16]: _pg is mocked
449450
self.assertEqual(manager._pg.errored.call_count, 1)
450451

@@ -526,7 +527,9 @@ def test_manager_report_error(self, client_mock: MagicMock) -> None:
526527
self.assertIsNone(manager.errored())
527528
e = RuntimeError("some error")
528529
manager.report_error(e)
529-
self.assertIs(manager.errored(), e)
530+
error = manager.errored()
531+
assert error is not None
532+
self.assertIs(error.original_exception, e)
530533

531534
@patch("torchft.manager.ManagerClient", autospec=True)
532535
def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
@@ -540,7 +543,9 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
540543

541544
e = RuntimeError("injected failure")
542545
fut.set_exception(e)
543-
self.assertIs(manager.errored(), e)
546+
error = manager.errored()
547+
assert error is not None
548+
self.assertIs(error.original_exception, e)
544549
self.assertEqual(wrapped_fut.value(), 2)
545550

546551
self.assertEqual(manager._pending_work, [wrapped_fut])
@@ -555,11 +560,11 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
555560
wrapped_fut = manager.wrap_future(fut, 2)
556561
wrapped_fut.wait()
557562
error = manager.errored()
558-
self.assertIsNotNone(error)
563+
assert error is not None
559564
with self.assertRaisesRegex(
560565
TimeoutError, "future did not complete within.*0.01"
561566
):
562-
raise error
567+
raise error.original_exception
563568

564569
@patch("torchft.manager.ManagerClient", autospec=True)
565570
def test_manager_numerics(self, client_mock: MagicMock) -> None:
@@ -678,19 +683,19 @@ def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
678683
self.assertFalse(manager.should_commit())
679684

680685
error = manager.errored()
681-
self.assertIsNotNone(error)
686+
assert error is not None
682687
with self.assertRaisesRegex(RuntimeError, "recv failure"):
683-
raise error
688+
raise error.original_exception
684689

685690
quorum.recover_dst_ranks = [0]
686691
manager.start_quorum()
687692
manager.wait_quorum()
688693
self.assertFalse(manager.should_commit())
689694

690695
error = manager.errored()
691-
self.assertIsNotNone(error)
696+
assert error is not None
692697
with self.assertRaisesRegex(RuntimeError, "send failure"):
693-
raise error
698+
raise error.original_exception
694699

695700
@patch("torchft.manager.ManagerClient", autospec=True)
696701
def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
@@ -718,9 +723,9 @@ def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
718723
self.assertFalse(manager.should_commit())
719724

720725
error = manager.errored()
721-
self.assertIsNotNone(error)
726+
assert error is not None
722727
with self.assertRaisesRegex(RuntimeError, "configure failure"):
723-
raise error
728+
raise error.original_exception
724729

725730
@patch("torchft.manager.ManagerClient", autospec=True)
726731
def test_max_retries(self, client_mock: MagicMock) -> None:

0 commit comments

Comments
 (0)