Skip to content

Commit 1f93550

Browse files
committed
avoid stream synchronization in manager
Summary: - use a recovery event to synchronize on instead of the recovery stream - fix calling `work.wait()` in manager - avoid calling `quorum.wait` inside of a callback
1 parent 2546bbe commit 1f93550

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

torchft/manager.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ def __init__(
219219
torch.cuda.Stream() if torch.cuda.is_available() else None
220220
)
221221

222+
# Used to synchronize recovery operation
223+
self._recovery_event: Optional[torch.cuda.Event] = None
224+
222225
if self._group_rank == 0:
223226
if port is None:
224227
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
@@ -323,6 +326,7 @@ def allreduce(
323326
return fut
324327

325328
self.wait_quorum()
329+
num_participants: int = self.num_participants()
326330

327331
if not self.is_participating():
328332
tensor.zero_()
@@ -337,6 +341,7 @@ def allreduce(
337341
)
338342
else:
339343
work = self._pg.allreduce([tensor], ReduceOp.SUM)
344+
work.wait()
340345
fut = work.get_future()
341346

342347
stream: Optional[torch.cuda.Stream] = (
@@ -349,13 +354,13 @@ def allreduce(
349354
def callback(
350355
fut: torch.futures.Future[List[torch.Tensor]],
351356
) -> torch.Tensor:
352-
nonlocal tensor, stream
357+
nonlocal tensor, stream, num_participants
353358

354359
# change the stream to avoid making the callback stream
355360
# dependent on process group stream running the allreduce
356361
with torch.cuda.stream(stream) if stream is not None else nullcontext():
357362
fut.value()
358-
tensor /= self.num_participants()
363+
tensor /= num_participants
359364

360365
return tensor
361366

@@ -644,7 +649,12 @@ def _async_quorum(
644649
except Exception as e:
645650
self._logger.exception(f"got exception in recovery: {e}")
646651
self.report_error(e)
647-
return
652+
653+
self._recovery_event = (
654+
torch.cuda.current_stream().record_event()
655+
if recovery_stream is not None
656+
else None
657+
)
648658

649659
def _apply_pending_state_dict(self) -> None:
650660
assert self._healing, "must be in healing state"
@@ -704,8 +714,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
704714
with torch.profiler.record_function(
705715
"torchft::manager::should_commmit::recovery_stream::synchronize"
706716
):
707-
if self._recovery_stream is not None:
708-
self._recovery_stream.synchronize()
717+
if self._recovery_event is not None:
718+
self._recovery_event.synchronize()
719+
self._recovery_event = None
709720

710721
with torch.profiler.record_function(
711722
"torchft::manager::should_commit::current_stream::synchronize"

0 commit comments

Comments
 (0)