diff --git a/torchft/process_group.py b/torchft/process_group.py index e689288..d1d2cbe 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -19,7 +19,6 @@ import logging import queue import threading -from abc import ABC from datetime import timedelta from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union @@ -507,6 +506,10 @@ def __init__(self, manager: "Manager") -> None: self._manager = manager def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + # Ensure we have a valid quorum and are configured before trying to do + # any work. + self._manager.wait_quorum() + if self._manager.errored() is not None: return _DummyWork(tensors) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 417c32e..d24f838 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -368,6 +368,7 @@ def test_managed_process_group(self) -> None: self.assertEqual(manager.report_error.call_count, 0) self.assertEqual(manager.wrap_future.call_count, 1) + self.assertEqual(manager.wait_quorum.call_count, 1) class DeviceMeshTest(TestCase):