1616from torchft .local_sgd import DiLoCo , LocalSGD
1717from torchft .manager import Manager
1818from torchft .manager_integ_test import FailureInjector , MyModel , Runner
19- from torchft .process_group import ProcessGroupGloo , ProcessGroupNCCL
19+ from torchft .process_group import ProcessGroupBabyNCCL , ProcessGroupGloo
2020
2121logger : logging .Logger = logging .getLogger (__name__ )
2222
@@ -41,7 +41,10 @@ def state_dict() -> Dict[str, Dict[str, object]]:
4141
4242 print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
4343
44- pg = ProcessGroupGloo ()
44+ if device .type == "cuda" :
45+ pg = ProcessGroupBabyNCCL ()
46+ else :
47+ pg = ProcessGroupGloo ()
4548 manager = Manager (
4649 pg = pg ,
4750 min_replica_size = 2 ,
@@ -110,7 +113,12 @@ def diloco_train_loop(
110113 # pyre-ignore[53]
111114 def load_state_dict (state_dict : Dict [str , Dict [str , object ]]) -> None :
112115 m .load_state_dict (state_dict ["model" ])
116+ m .to (device )
113117 diloco .original_parameters = state_dict ["original_params" ]
118+ for name in diloco .original_parameters .keys ():
119+ diloco .original_parameters [name ] = diloco .original_parameters [name ].to (
120+ device
121+ )
114122 inner_optimizer .load_state_dict (state_dict ["inner_optim" ])
115123 outer_optimizer .load_state_dict (state_dict ["outer_optim" ])
116124
@@ -124,7 +132,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
124132
125133 print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
126134
127- pg = ProcessGroupGloo ()
135+ if device .type == "cuda" :
136+ pg = ProcessGroupBabyNCCL ()
137+ else :
138+ pg = ProcessGroupGloo ()
128139 manager = Manager (
129140 pg = pg ,
130141 min_replica_size = 2 ,
@@ -138,6 +149,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
138149 world_size = runner .world_size ,
139150 lighthouse_addr = runner .lighthouse_address ,
140151 port = 19530 + runner .replica_id ,
152+ connect_timeout = timedelta (seconds = 10 ),
153+ quorum_timeout = timedelta (seconds = 10 ),
141154 timeout = timedelta (seconds = 10 ),
142155 # pyre-fixme[6]: Incompatible parameter type
143156 ** runner .manager_args ,
@@ -155,6 +168,12 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
155168 sync_every = 2 ,
156169 ) as diloco :
157170 while True :
171+ manager_curr_step = manager .current_step ()
172+ if manager_curr_step not in all_state_dicts :
173+ print (
174+ f"{ manager_curr_step = } { diloco ._local_step = } { runner .replica_id = } { state_dict ()= } "
175+ )
176+ all_state_dicts [manager_curr_step ] = copy .deepcopy (state_dict ())
158177 batch_size = 1
159178 inputs = m .get_rand_inputs (batch_size ).to (device )
160179 labels = m .get_rand_labels (batch_size ).to (device )
@@ -164,7 +183,6 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
164183
165184 inner_optimizer .zero_grad ()
166185 loss .backward ()
167- all_state_dicts [str (manager .current_step ())] = state_dict ()
168186 inner_optimizer .step ()
169187
170188 # after 4 model updates then break
@@ -181,10 +199,15 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
181199class LocalSGDIntegTest (TestCase ):
182200 @parameterized .expand (
183201 [
202+ (True ,),
184203 (False ,),
185204 ]
186205 )
187206 def test_local_sgd_recovery (self , use_cuda : bool ) -> None :
207+ # Skip the test if use_cuda is True and there are not enough GPUs
208+ if use_cuda and torch .cuda .device_count () < 2 :
209+ self .skipTest ("Not enough GPUs for CUDA test" )
210+
188211 lighthouse = LighthouseServer (
189212 bind = "[::]:0" ,
190213 min_replicas = 2 ,
@@ -236,10 +259,15 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
236259
237260 @parameterized .expand (
238261 [
262+ (True ,),
239263 (False ,),
240264 ]
241265 )
242266 def test_diloco_healthy (self , use_cuda : bool ) -> None :
267+ # Skip the test if use_cuda is True and there are not enough GPUs
268+ if use_cuda and torch .cuda .device_count () < 2 :
269+ self .skipTest ("Not enough GPUs for CUDA test" )
270+
243271 lighthouse = LighthouseServer (bind = "[::]:0" , min_replicas = 2 )
244272 num_replicas = 2
245273 futures = []
@@ -289,7 +317,17 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
289317 check_device = False ,
290318 )
291319
292- def test_diloco_recovery (self ) -> None :
320+ @parameterized .expand (
321+ [
322+ (True ,),
323+ (False ,),
324+ ]
325+ )
326+ def test_diloco_recovery (self , use_cuda : bool ) -> None :
327+ # Skip the test if use_cuda is True and there are not enough GPUs
328+ if use_cuda and torch .cuda .device_count () < 2 :
329+ self .skipTest ("Not enough GPUs for CUDA test" )
330+
293331 lighthouse = LighthouseServer (
294332 bind = "[::]:0" ,
295333 min_replicas = 2 ,
0 commit comments