22
22
from torchft .device_mesh import ft_init_device_mesh
23
23
from torchft .local_sgd import DiLoCo , LocalSGD
24
24
from torchft .manager import Manager
25
- from torchft .manager_integ_test import BarrierInjector , FailureInjector , MyModel , Runner
25
+ from torchft .manager_integ_test import (
26
+ EventInjector ,
27
+ EventInjectorEvent ,
28
+ MyModel ,
29
+ Runner ,
30
+ )
26
31
from torchft .process_group import ProcessGroupBabyNCCL , ProcessGroupGloo
27
32
28
33
logger : logging .Logger = logging .getLogger (__name__ )
@@ -119,7 +124,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
119
124
if manager .current_step () >= 4 :
120
125
break
121
126
122
- runner .failure_injector .check (rank , manager .current_step ())
127
+ runner .event_injector .check (rank , manager .current_step ())
123
128
124
129
# return state_dict so we can check consistency
125
130
return state_dict ()
@@ -252,15 +257,14 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
252
257
** diloco_args ,
253
258
) as diloco :
254
259
while True :
260
+ runner .event_injector .check (rank , manager .current_step ())
261
+
255
262
manager_curr_step = manager .current_step ()
256
263
if manager_curr_step not in all_state_dicts :
257
264
all_state_dicts [manager_curr_step ] = copy .deepcopy (
258
265
manager ._manager_state_dict ()
259
266
)
260
267
261
- if runner .barrier_injector is not None :
262
- runner .barrier_injector .check (manager_curr_step )
263
-
264
268
batch_size = 1
265
269
inputs = m .get_rand_inputs (batch_size , device = device )
266
270
labels = m .get_rand_labels (batch_size , device = device )
@@ -276,8 +280,6 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
276
280
if manager .current_step () >= 4 :
277
281
break
278
282
279
- runner .failure_injector .check (rank , manager .current_step ())
280
-
281
283
# return state_dict so we can check consistency
282
284
return all_state_dicts
283
285
return {}
@@ -324,20 +326,18 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
324
326
num_replicas = 2
325
327
futures = []
326
328
327
- failure_injectors = [
328
- FailureInjector (),
329
- FailureInjector ().fail_at (0 , 2 ),
329
+ event_injectors = [
330
+ EventInjector (),
331
+ EventInjector ().fail_at (0 , 2 ),
330
332
]
331
333
332
334
with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
333
- for replica_id , failure_injector in zip (
334
- range (num_replicas ), failure_injectors
335
- ):
335
+ for replica_id , event_injector in zip (range (num_replicas ), event_injectors ):
336
336
runner = Runner (
337
337
replica_id = replica_id ,
338
338
num_replicas = num_replicas ,
339
339
lighthouse_address = lighthouse .address (),
340
- failure_injector = failure_injector ,
340
+ event_injector = event_injector ,
341
341
train_loop = local_sgd_train_loop ,
342
342
use_cuda = use_cuda ,
343
343
manager_args = {
@@ -364,7 +364,7 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
364
364
state_dict [0 ]["model" ], state_dicts [0 ][0 ]["model" ], check_device = False
365
365
)
366
366
367
- self .assertEqual (failure_injectors [1 ].count , 1 )
367
+ self .assertEqual (event_injectors [1 ].count [ EventInjectorEvent . Failure ] , 1 )
368
368
369
369
@parameterized .expand (
370
370
[
@@ -387,12 +387,12 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
387
387
388
388
with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
389
389
for replica_id in range (num_replicas ):
390
- failure_injector = FailureInjector ()
390
+ event_injector = EventInjector ()
391
391
runner = Runner (
392
392
replica_id = replica_id ,
393
393
num_replicas = num_replicas ,
394
394
lighthouse_address = lighthouse .address (),
395
- failure_injector = failure_injector ,
395
+ event_injector = event_injector ,
396
396
train_loop = diloco_train_loop ,
397
397
use_cuda = use_cuda ,
398
398
train_loop_args = {
@@ -446,24 +446,22 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
446
446
num_replicas = 2
447
447
futures = []
448
448
449
- failure_injectors = [
450
- FailureInjector (),
451
- FailureInjector ().fail_at (0 , 2 ),
449
+ event_injectors = [
450
+ EventInjector (),
451
+ EventInjector ().fail_at (0 , 2 ),
452
452
]
453
453
454
454
torch .manual_seed (42 )
455
455
# Initialize the model so we can pass in the state_dict
456
456
m : nn .Module = MultiMyModel (2 , 3 , 1 )
457
457
458
458
with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
459
- for replica_id , failure_injector in zip (
460
- range (num_replicas ), failure_injectors
461
- ):
459
+ for replica_id , event_injector in zip (range (num_replicas ), event_injectors ):
462
460
runner = Runner (
463
461
replica_id = replica_id ,
464
462
num_replicas = num_replicas ,
465
463
lighthouse_address = lighthouse .address (),
466
- failure_injector = failure_injector ,
464
+ event_injector = event_injector ,
467
465
train_loop = diloco_train_loop ,
468
466
train_loop_args = {
469
467
"model_state_dict" : m .state_dict (),
@@ -504,7 +502,7 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
504
502
# Outer optimizer and global model should be the same
505
503
assert_equal_global_state (rep1 , rep0 )
506
504
507
- self .assertEqual (failure_injectors [1 ].count , 1 )
505
+ self .assertEqual (event_injectors [1 ].count [ EventInjectorEvent . Failure ] , 1 )
508
506
509
507
# pyre-fixme[56]: Pyre was not able to infer the type of argument
510
508
@skipIf (sys .platform == "darwin" , "not reliable on mac" )
@@ -526,24 +524,22 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
526
524
num_replicas = 2
527
525
futures = []
528
526
529
- failure_injectors = [
530
- FailureInjector (),
531
- FailureInjector ().fail_at (0 , 2 ),
527
+ event_injectors = [
528
+ EventInjector (),
529
+ EventInjector ().fail_at (0 , 2 ),
532
530
]
533
531
534
532
torch .manual_seed (42 )
535
533
# Initialize the model so we can pass in the state_dict
536
534
m : nn .Module = MultiMyModel (2 , 3 , 2 )
537
535
538
536
with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
539
- for replica_id , failure_injector in zip (
540
- range (num_replicas ), failure_injectors
541
- ):
537
+ for replica_id , event_injector in zip (range (num_replicas ), event_injectors ):
542
538
runner = Runner (
543
539
replica_id = replica_id ,
544
540
num_replicas = num_replicas ,
545
541
lighthouse_address = lighthouse .address (),
546
- failure_injector = failure_injector ,
542
+ event_injector = event_injector ,
547
543
train_loop = diloco_train_loop ,
548
544
train_loop_args = {
549
545
"model_state_dict" : m .state_dict (),
@@ -584,11 +580,11 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
584
580
rep0 [step ]["user" ]["local_step" ], rep1 [step ]["user" ]["local_step" ]
585
581
)
586
582
587
- self .assertEqual (failure_injectors [1 ].count , 1 )
583
+ self .assertEqual (event_injectors [1 ].count [ EventInjectorEvent . Failure ] , 1 )
588
584
589
585
CONFIG : list [tuple [bool , int , int ]] = [
590
586
(use_cuda , n_fragments , fragment_sync_delay )
591
- for use_cuda in [True , False ]
587
+ for use_cuda in [False ]
592
588
for n_fragments in [1 , 2 ]
593
589
for fragment_sync_delay in [0 , 1 ]
594
590
]
@@ -613,26 +609,25 @@ def test_streaming_diloco_upscale(
613
609
614
610
barrier = threading .Barrier (num_replicas )
615
611
616
- barrier_injectors = [
612
+ event_injectors = [
617
613
# Make this replica join after other replicas have made 2 steps
618
- BarrierInjector ().barrier_at (0 , barrier ),
619
- BarrierInjector ().barrier_at (2 , barrier ),
620
- BarrierInjector ().barrier_at (2 , barrier ),
614
+ EventInjector ().barrier_at (0 , 0 , barrier ),
615
+ EventInjector ().barrier_at (0 , 2 , barrier ),
616
+ EventInjector ().barrier_at (0 , 2 , barrier ),
621
617
]
622
618
623
619
torch .manual_seed (42 )
624
620
# Initialize the model so we can pass in the state_dict
625
621
m : nn .Module = MultiMyModel (2 , 3 , n_fragments )
626
622
627
- for replica_id , barrier_injector in zip (range (num_replicas ), barrier_injectors ):
623
+ for replica_id , event_injector in zip (range (num_replicas ), event_injectors ):
628
624
executor = ThreadPoolExecutor (max_workers = 1 )
629
625
executors .append (executor )
630
626
runner = Runner (
631
627
replica_id = replica_id ,
632
628
num_replicas = num_replicas ,
633
629
lighthouse_address = lighthouse .address (),
634
- failure_injector = FailureInjector (),
635
- barrier_injector = barrier_injector ,
630
+ event_injector = event_injector ,
636
631
train_loop = diloco_train_loop ,
637
632
train_loop_args = {
638
633
"model_state_dict" : m .state_dict (),
@@ -672,5 +667,5 @@ def test_streaming_diloco_upscale(
672
667
rep1 [step ]["user" ]["local_step" ], rep2 [step ]["user" ]["local_step" ]
673
668
)
674
669
675
- for barrier_injector in barrier_injectors :
676
- self .assertEqual (barrier_injector .count , 1 )
670
+ for event_injector in event_injectors :
671
+ self .assertEqual (event_injectors [ 1 ] .count [ EventInjectorEvent . Barrier ] , 1 )
0 commit comments