@@ -219,6 +219,9 @@ def __init__(
219
219
torch .cuda .Stream () if torch .cuda .is_available () else None
220
220
)
221
221
222
+ # Used to synchronize recovery operation
223
+ self ._recovery_event : Optional [torch .cuda .Event ] = None
224
+
222
225
if self ._group_rank == 0 :
223
226
if port is None :
224
227
port = int (os .environ .get (MANAGER_PORT_ENV , 0 ))
@@ -323,6 +326,7 @@ def allreduce(
323
326
return fut
324
327
325
328
self .wait_quorum ()
329
+ num_participants : int = self .num_participants ()
326
330
327
331
if not self .is_participating ():
328
332
tensor .zero_ ()
@@ -337,6 +341,7 @@ def allreduce(
337
341
)
338
342
else :
339
343
work = self ._pg .allreduce ([tensor ], ReduceOp .SUM )
344
+ work .wait ()
340
345
fut = work .get_future ()
341
346
342
347
stream : Optional [torch .cuda .Stream ] = (
@@ -349,13 +354,13 @@ def allreduce(
349
354
def callback (
350
355
fut : torch .futures .Future [List [torch .Tensor ]],
351
356
) -> torch .Tensor :
352
- nonlocal tensor , stream
357
+ nonlocal tensor , stream , num_participants
353
358
354
359
# change the stream to avoid making the callback stream
355
360
# dependent on process group stream running the allreduce
356
361
with torch .cuda .stream (stream ) if stream is not None else nullcontext ():
357
362
fut .value ()
358
- tensor /= self . num_participants ()
363
+ tensor /= num_participants
359
364
360
365
return tensor
361
366
@@ -644,7 +649,12 @@ def _async_quorum(
644
649
except Exception as e :
645
650
self ._logger .exception (f"got exception in recovery: { e } " )
646
651
self .report_error (e )
647
- return
652
+
653
+ self ._recovery_event = (
654
+ torch .cuda .current_stream ().record_event ()
655
+ if recovery_stream is not None
656
+ else None
657
+ )
648
658
649
659
def _apply_pending_state_dict (self ) -> None :
650
660
assert self ._healing , "must be in healing state"
@@ -704,8 +714,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
704
714
with torch .profiler .record_function (
705
715
"torchft::manager::should_commmit::recovery_stream::synchronize"
706
716
):
707
- if self ._recovery_stream is not None :
708
- self ._recovery_stream .synchronize ()
717
+ if self ._recovery_event is not None :
718
+ self ._recovery_event .synchronize ()
719
+ self ._recovery_event = None
709
720
710
721
with torch .profiler .record_function (
711
722
"torchft::manager::should_commit::current_stream::synchronize"
0 commit comments