Skip to content

Commit 17c7cb9

Browse files
authored
simplify injecting events during training (#224)
Summary: Create 1 injector to inject events of various types. This simplifies the code and removes some duplication. Test Plan: ``` pytest -vs ./torchft/local_sgd_integ_test.py ```
1 parent 2a32b98 commit 17c7cb9

File tree

2 files changed

+116
-106
lines changed

2 files changed

+116
-106
lines changed

torchft/local_sgd_integ_test.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from torchft.device_mesh import ft_init_device_mesh
2323
from torchft.local_sgd import DiLoCo, LocalSGD
2424
from torchft.manager import Manager
25-
from torchft.manager_integ_test import BarrierInjector, FailureInjector, MyModel, Runner
25+
from torchft.manager_integ_test import (
26+
EventInjector,
27+
EventInjectorEvent,
28+
MyModel,
29+
Runner,
30+
)
2631
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
2732

2833
logger: logging.Logger = logging.getLogger(__name__)
@@ -119,7 +124,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
119124
if manager.current_step() >= 4:
120125
break
121126

122-
runner.failure_injector.check(rank, manager.current_step())
127+
runner.event_injector.check(rank, manager.current_step())
123128

124129
# return state_dict so we can check consistency
125130
return state_dict()
@@ -252,15 +257,14 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
252257
**diloco_args,
253258
) as diloco:
254259
while True:
260+
runner.event_injector.check(rank, manager.current_step())
261+
255262
manager_curr_step = manager.current_step()
256263
if manager_curr_step not in all_state_dicts:
257264
all_state_dicts[manager_curr_step] = copy.deepcopy(
258265
manager._manager_state_dict()
259266
)
260267

261-
if runner.barrier_injector is not None:
262-
runner.barrier_injector.check(manager_curr_step)
263-
264268
batch_size = 1
265269
inputs = m.get_rand_inputs(batch_size, device=device)
266270
labels = m.get_rand_labels(batch_size, device=device)
@@ -276,8 +280,6 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
276280
if manager.current_step() >= 4:
277281
break
278282

279-
runner.failure_injector.check(rank, manager.current_step())
280-
281283
# return state_dict so we can check consistency
282284
return all_state_dicts
283285
return {}
@@ -324,20 +326,18 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
324326
num_replicas = 2
325327
futures = []
326328

327-
failure_injectors = [
328-
FailureInjector(),
329-
FailureInjector().fail_at(0, 2),
329+
event_injectors = [
330+
EventInjector(),
331+
EventInjector().fail_at(0, 2),
330332
]
331333

332334
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
333-
for replica_id, failure_injector in zip(
334-
range(num_replicas), failure_injectors
335-
):
335+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
336336
runner = Runner(
337337
replica_id=replica_id,
338338
num_replicas=num_replicas,
339339
lighthouse_address=lighthouse.address(),
340-
failure_injector=failure_injector,
340+
event_injector=event_injector,
341341
train_loop=local_sgd_train_loop,
342342
use_cuda=use_cuda,
343343
manager_args={
@@ -364,7 +364,7 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
364364
state_dict[0]["model"], state_dicts[0][0]["model"], check_device=False
365365
)
366366

367-
self.assertEqual(failure_injectors[1].count, 1)
367+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
368368

369369
@parameterized.expand(
370370
[
@@ -387,12 +387,12 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
387387

388388
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
389389
for replica_id in range(num_replicas):
390-
failure_injector = FailureInjector()
390+
event_injector = EventInjector()
391391
runner = Runner(
392392
replica_id=replica_id,
393393
num_replicas=num_replicas,
394394
lighthouse_address=lighthouse.address(),
395-
failure_injector=failure_injector,
395+
event_injector=event_injector,
396396
train_loop=diloco_train_loop,
397397
use_cuda=use_cuda,
398398
train_loop_args={
@@ -446,24 +446,22 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
446446
num_replicas = 2
447447
futures = []
448448

449-
failure_injectors = [
450-
FailureInjector(),
451-
FailureInjector().fail_at(0, 2),
449+
event_injectors = [
450+
EventInjector(),
451+
EventInjector().fail_at(0, 2),
452452
]
453453

454454
torch.manual_seed(42)
455455
# Initialize the model so we can pass in the state_dict
456456
m: nn.Module = MultiMyModel(2, 3, 1)
457457

458458
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
459-
for replica_id, failure_injector in zip(
460-
range(num_replicas), failure_injectors
461-
):
459+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
462460
runner = Runner(
463461
replica_id=replica_id,
464462
num_replicas=num_replicas,
465463
lighthouse_address=lighthouse.address(),
466-
failure_injector=failure_injector,
464+
event_injector=event_injector,
467465
train_loop=diloco_train_loop,
468466
train_loop_args={
469467
"model_state_dict": m.state_dict(),
@@ -504,7 +502,7 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
504502
# Outer optimizer and global model should be the same
505503
assert_equal_global_state(rep1, rep0)
506504

507-
self.assertEqual(failure_injectors[1].count, 1)
505+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
508506

509507
# pyre-fixme[56]: Pyre was not able to infer the type of argument
510508
@skipIf(sys.platform == "darwin", "not reliable on mac")
@@ -526,24 +524,22 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
526524
num_replicas = 2
527525
futures = []
528526

529-
failure_injectors = [
530-
FailureInjector(),
531-
FailureInjector().fail_at(0, 2),
527+
event_injectors = [
528+
EventInjector(),
529+
EventInjector().fail_at(0, 2),
532530
]
533531

534532
torch.manual_seed(42)
535533
# Initialize the model so we can pass in the state_dict
536534
m: nn.Module = MultiMyModel(2, 3, 2)
537535

538536
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
539-
for replica_id, failure_injector in zip(
540-
range(num_replicas), failure_injectors
541-
):
537+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
542538
runner = Runner(
543539
replica_id=replica_id,
544540
num_replicas=num_replicas,
545541
lighthouse_address=lighthouse.address(),
546-
failure_injector=failure_injector,
542+
event_injector=event_injector,
547543
train_loop=diloco_train_loop,
548544
train_loop_args={
549545
"model_state_dict": m.state_dict(),
@@ -584,11 +580,11 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
584580
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
585581
)
586582

587-
self.assertEqual(failure_injectors[1].count, 1)
583+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
588584

589585
CONFIG: list[tuple[bool, int, int]] = [
590586
(use_cuda, n_fragments, fragment_sync_delay)
591-
for use_cuda in [True, False]
587+
for use_cuda in [False]
592588
for n_fragments in [1, 2]
593589
for fragment_sync_delay in [0, 1]
594590
]
@@ -613,26 +609,25 @@ def test_streaming_diloco_upscale(
613609

614610
barrier = threading.Barrier(num_replicas)
615611

616-
barrier_injectors = [
612+
event_injectors = [
617613
# Make this replica join after other replicas have made 2 steps
618-
BarrierInjector().barrier_at(0, barrier),
619-
BarrierInjector().barrier_at(2, barrier),
620-
BarrierInjector().barrier_at(2, barrier),
614+
EventInjector().barrier_at(0, 0, barrier),
615+
EventInjector().barrier_at(0, 2, barrier),
616+
EventInjector().barrier_at(0, 2, barrier),
621617
]
622618

623619
torch.manual_seed(42)
624620
# Initialize the model so we can pass in the state_dict
625621
m: nn.Module = MultiMyModel(2, 3, n_fragments)
626622

627-
for replica_id, barrier_injector in zip(range(num_replicas), barrier_injectors):
623+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
628624
executor = ThreadPoolExecutor(max_workers=1)
629625
executors.append(executor)
630626
runner = Runner(
631627
replica_id=replica_id,
632628
num_replicas=num_replicas,
633629
lighthouse_address=lighthouse.address(),
634-
failure_injector=FailureInjector(),
635-
barrier_injector=barrier_injector,
630+
event_injector=event_injector,
636631
train_loop=diloco_train_loop,
637632
train_loop_args={
638633
"model_state_dict": m.state_dict(),
@@ -672,5 +667,5 @@ def test_streaming_diloco_upscale(
672667
rep1[step]["user"]["local_step"], rep2[step]["user"]["local_step"]
673668
)
674669

675-
for barrier_injector in barrier_injectors:
676-
self.assertEqual(barrier_injector.count, 1)
670+
for event_injector in event_injectors:
671+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1)

0 commit comments

Comments
 (0)