11import copy
22import logging
3+ import os
34import re
45import traceback
56from concurrent .futures import ThreadPoolExecutor , as_completed
1112import torch
1213from parameterized import parameterized
1314from torch import nn , optim
15+ from torch .distributed .tensor import DTensor , Replicate
1416
1517from torchft ._torchft import LighthouseServer
18+ from torchft .device_mesh import ft_init_device_mesh
1619from torchft .local_sgd import DiLoCo , LocalSGD
1720from torchft .manager import Manager
1821from torchft .manager_integ_test import FailureInjector , MyModel , Runner
@@ -64,6 +67,29 @@ def state_dict() -> Dict[str, Dict[str, object]]:
6467 stack .callback (lambda : manager .shutdown (wait = False ))
6568
6669 m : nn .Module = MyModel ().to (device )
70+
71+ # # Apply FSDP
72+ # mesh = init_device_mesh("cuda", (runner.world_size,), mesh_dim_names=("dp",))
73+ # for module in m.modules():
74+ # if isinstance(module, nn.Linear):
75+ # fully_shard(module, mesh=mesh)
76+ # fully_shard(m, mesh=mesh)
77+
78+ # LOCALSGD
79+ print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
80+
81+ import fbvscode
82+
83+ device_type = device .type
84+ ft_device_mesh = ft_init_device_mesh (
85+ device_type = device_type ,
86+ mesh_shape = (1 ,),
87+ mesh_dim_names = ("none" ,),
88+ replicate_dim = runner .world_size ,
89+ manager = manager ,
90+ )
91+ print (f"{ ft_device_mesh = } " )
92+
6793 optimizer : optim .Optimizer = optim .Adam (m .parameters ())
6894 criterion = nn .CrossEntropyLoss ()
6995
@@ -156,6 +182,27 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
156182 ** runner .manager_args ,
157183 )
158184 stack .callback (manager .shutdown )
185+ # initialize default group for device mesh
186+ if torch .distributed .is_initialized ():
187+ torch .distributed .init_process_group (
188+ init_method = f"tcp://localhost:0" , rank = rank , world_size = runner .world_size
189+ )
190+
191+ device_type = device .type
192+ ft_device_mesh = ft_init_device_mesh (
193+ device_type = device_type ,
194+ mesh_shape = (runner .world_size , 1 ),
195+ mesh_dim_names = ("replicate" , "none" ),
196+ replicate_dim = 0 ,
197+ manager = manager ,
198+ )
199+ for layer in m .layers :
200+ if isinstance (layer , nn .Linear ):
201+ for param in layer .parameters ():
202+ param = DTensor .from_local (
203+ param ,
204+ device_mesh = ft_device_mesh ,
205+ )
159206
160207 criterion = nn .CrossEntropyLoss ()
161208 all_state_dicts = {}
@@ -170,15 +217,16 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
170217 while True :
171218 manager_curr_step = manager .current_step ()
172219 if manager_curr_step not in all_state_dicts :
173- print (
174- f"{ manager_curr_step = } { diloco ._local_step = } { runner .replica_id = } { state_dict ()= } "
175- )
220+ # print(
221+ # f"{manager_curr_step=} {diloco._local_step=} {runner.replica_id=} {state_dict()=}"
222+ # )
176223 all_state_dicts [manager_curr_step ] = copy .deepcopy (state_dict ())
177224 batch_size = 1
178- inputs = m .get_rand_inputs (batch_size ). to ( device )
179- labels = m .get_rand_labels (batch_size ). to ( device )
225+ inputs = m .get_rand_inputs (batch_size , device = device )
226+ labels = m .get_rand_labels (batch_size , device = device )
180227
181228 out = m (inputs )
229+ # print(f"{device=} {inputs=} {out=} {labels=}]")
182230 loss = criterion (out , labels )
183231
184232 inner_optimizer .zero_grad ()
@@ -261,7 +309,7 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
261309
262310 @parameterized .expand (
263311 [
264- # (True,),
312+ (True ,),
265313 (False ,),
266314 ]
267315 )
0 commit comments