Skip to content

Commit a5396a8

Browse files
committed
test allreduce failures for diloco
Summary: - test when allreduce fails but no new nodes join - added another event of type `AllreduceFailure` - This new event required modifying some manager code to inject the failure
1 parent 17c7cb9 commit a5396a8

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

torchft/local_sgd_integ_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
210210
# pyre-fixme[6]: Incompatible parameter type
211211
**runner.manager_args,
212212
)
213+
runner.event_injector.set_manager(manager)
213214
stack.callback(manager.shutdown)
214215
# initialize default group for device mesh to work
215216
if not torch.distributed.is_initialized():
@@ -669,3 +670,79 @@ def test_streaming_diloco_upscale(
669670

670671
for event_injector in event_injectors:
671672
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1)
673+
674+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
675+
@skipIf(sys.platform == "darwin", "not reliable on mac")
676+
@parameterized.expand(CONFIG)
677+
def test_streaming_diloco_commit_failure(
678+
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int
679+
) -> None:
680+
# Skip the test if use_cuda is True and there are not enough GPUs
681+
if use_cuda and torch.cuda.device_count() < 2:
682+
self.skipTest("Not enough GPUs for CUDA test")
683+
684+
lighthouse = LighthouseServer(
685+
bind="[::]:0",
686+
min_replicas=2,
687+
)
688+
num_replicas = 2
689+
futures = []
690+
executors = []
691+
692+
event_injectors = [
693+
EventInjector().fail_allreduce_at(0, 1),
694+
EventInjector().fail_allreduce_at(0, 1),
695+
]
696+
697+
torch.manual_seed(42)
698+
# Initialize the model so we can pass in the state_dict
699+
m: nn.Module = MultiMyModel(2, 3, n_fragments)
700+
701+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
702+
executor = ThreadPoolExecutor(max_workers=1)
703+
executors.append(executor)
704+
runner = Runner(
705+
replica_id=replica_id,
706+
num_replicas=num_replicas,
707+
lighthouse_address=lighthouse.address(),
708+
event_injector=event_injector,
709+
train_loop=diloco_train_loop,
710+
train_loop_args={
711+
"model_state_dict": m.state_dict(),
712+
"n_fragments": n_fragments,
713+
"diloco_args": {
714+
"fragment_sync_delay": fragment_sync_delay,
715+
"sync_every": 4,
716+
},
717+
},
718+
)
719+
futures.append(executor.submit(runner.run_replica))
720+
721+
state_dicts = []
722+
723+
for fut in as_completed(futures):
724+
continue
725+
726+
for fut in futures:
727+
try:
728+
state_dicts.append(fut.result()[0])
729+
except Exception as e:
730+
print(e)
731+
raise
732+
733+
lighthouse.shutdown()
734+
735+
rep0, rep1 = state_dicts
736+
737+
assert_equal_global_state(rep0, rep1)
738+
739+
for step in rep0.keys():
740+
print(step, rep0[step]["user"]["local_step"])
741+
self.assertEqual(
742+
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
743+
)
744+
745+
for event_injector in event_injectors:
746+
self.assertEqual(
747+
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
748+
)

torchft/manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,15 @@ def __init__(
268268
self._participating_replica_rank: Optional[int] = None
269269
self._participating_replica_world_size: int = 0
270270

271+
# used to artificially fail the next allreduce by tests
272+
self._TEST_should_fail_allreduce = False
273+
274+
def TEST_fail_allreduce(self) -> None:
275+
"""
276+
Fails the next allreduce. This is used for testing.
277+
"""
278+
self._TEST_should_fail_allreduce = True
279+
271280
def register_state_dict_fn(
272281
self,
273282
key: str,
@@ -356,6 +365,10 @@ def callback(
356365
) -> torch.Tensor:
357366
nonlocal tensor, stream, num_participants
358367

368+
if self._TEST_should_fail_allreduce:
369+
self._TEST_should_fail_allreduce = False
370+
raise
371+
359372
# change the stream to avoid making the callback stream
360373
# dependent on process group stream running the allreduce
361374
with torch.cuda.stream(stream) if stream is not None else nullcontext():

torchft/manager_integ_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,13 @@ class InjectedFailure(Exception):
7676

7777

7878
class EventInjectorEvent(Enum):
79+
# Crashes a rank
7980
Failure = auto()
8081
# Used to wait for a rank to reach a certain step before continuing.
8182
# Users need to make sure the size of the barrier is appropriately set.
8283
Barrier = auto()
84+
# Fails the allreduce call made by a rank
85+
AllreduceFailure = auto()
8386

8487

8588
class EventInjectorInfo:
@@ -90,10 +93,15 @@ def __init__(self, event: EventInjectorEvent, data: object) -> None:
9093

9194
class EventInjector:
9295
def __init__(self) -> None:
96+
self._manager: Optional[Manager] = None
9397
self._lock = threading.Lock()
9498
self._events: Dict[Tuple[int, int], EventInjectorInfo] = {}
9599
self.count: dict[EventInjectorEvent, int] = defaultdict(int)
96100

101+
def set_manager(self, manager: Manager) -> None:
102+
with self._lock:
103+
self._manager = manager
104+
97105
def fail_at(self, rank: int, step: int) -> "EventInjector":
98106
with self._lock:
99107
assert (rank, step) not in self._events
@@ -102,6 +110,14 @@ def fail_at(self, rank: int, step: int) -> "EventInjector":
102110
)
103111
return self
104112

113+
def fail_allreduce_at(self, rank: int, step: int) -> "EventInjector":
114+
with self._lock:
115+
assert (rank, step) not in self._events
116+
self._events[(rank, step)] = EventInjectorInfo(
117+
EventInjectorEvent.AllreduceFailure, None
118+
)
119+
return self
120+
105121
def barrier_at(
106122
self, rank: int, step: int, barrier: threading.Barrier
107123
) -> "EventInjector":
@@ -124,6 +140,12 @@ def check(self, rank: int, step: int) -> None:
124140
print(f"injecting failure {rank=} {step=}")
125141
raise InjectedFailure(f"injected failure {rank=} {step=}")
126142

143+
if event_info.event == EventInjectorEvent.AllreduceFailure:
144+
print(f"injecting allreduce failure {rank=} {step=}")
145+
assert self._manager is not None
146+
self._manager.TEST_fail_allreduce()
147+
return
148+
127149
if event_info.event == EventInjectorEvent.Barrier:
128150
print(f"waiting for barrier {rank=} {step=}")
129151
cast(threading.Barrier, event_info.data).wait()

0 commit comments

Comments
 (0)