Skip to content

Commit 4cb7ac6

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
handle exception waiting for work (#287)
Summary: work.wait() can throw so wrap that in a try/catch to handle it gracefully by reporting error to the manager, leading the should_commit to fail Reviewed By: d4l3k Differential Revision: D84880993
1 parent e4d99b5 commit 4cb7ac6

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

torchft/device_mesh.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
7070
self.parent = parent
7171
self.flatten_meshes: Dict[str, DeviceMesh] = {}
72+
self._flatten_mapping: Dict[str, "DeviceMesh"] = {}
7273
self._device_type: str
7374
if mesh is not None:
7475
self._device_type = mesh.device_type

torchft/manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,14 +1253,19 @@ def _assert_same_stream(self) -> None:
12531253
def wait(self, timeout: Optional[timedelta] = None) -> bool:
12541254
self._assert_same_stream()
12551255

1256-
with get_stream_context(self._stream):
1257-
self._work.wait()
1258-
self._set_future_callback()
1256+
try:
1257+
with get_stream_context(self._stream):
1258+
self._work.wait()
1259+
self._set_future_callback()
12591260

1260-
with get_stream_context(self._stream):
1261-
self._managed_fut_tail.wait()
1261+
with get_stream_context(self._stream):
1262+
self._managed_fut_tail.wait()
12621263

1263-
return True
1264+
return True
1265+
except Exception as e:
1266+
self._manager._logger.exception(f"got exception waiting for work {e}")
1267+
self._manager.report_error(e)
1268+
return False
12641269

12651270
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
12661271
self._assert_same_stream()

0 commit comments

Comments
 (0)