11import copy
22import logging
3+ import os
34import re
45import traceback
56from concurrent .futures import ThreadPoolExecutor , as_completed
1112import torch
1213from parameterized import parameterized
1314from torch import nn , optim
15+ from torch .distributed .tensor import DTensor , Replicate
1416
1517from torchft ._torchft import LighthouseServer
18+ from torchft .device_mesh import ft_init_device_mesh
1619from torchft .local_sgd import DiLoCo , LocalSGD
1720from torchft .manager import Manager
1821from torchft .manager_integ_test import FailureInjector , MyModel , Runner
@@ -64,6 +67,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
6467 stack .callback (lambda : manager .shutdown (wait = False ))
6568
6669 m : nn .Module = MyModel ().to (device )
70+
6771 optimizer : optim .Optimizer = optim .Adam (m .parameters ())
6872 criterion = nn .CrossEntropyLoss ()
6973
@@ -156,6 +160,29 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
156160 ** runner .manager_args ,
157161 )
158162 stack .callback (manager .shutdown )
163+ # initialize default group for device mesh to work
164+ if not torch .distributed .is_initialized ():
165+ torch .distributed .init_process_group (
166+ init_method = f"tcp://localhost:0" ,
167+ rank = rank ,
168+ world_size = runner .world_size ,
169+ )
170+
171+ device_type = device .type
172+ ft_device_mesh = ft_init_device_mesh (
173+ device_type = device_type ,
174+ mesh_shape = (runner .world_size , 1 ),
175+ mesh_dim_names = ("replicate" , "none" ),
176+ replicate_dim = 0 ,
177+ manager = manager ,
178+ )
179+ for layer in m .layers :
180+ if isinstance (layer , nn .Linear ):
181+ for param in layer .parameters ():
182+ param = DTensor .from_local (
183+ param ,
184+ device_mesh = ft_device_mesh ,
185+ )
159186
160187 criterion = nn .CrossEntropyLoss ()
161188 all_state_dicts = {}
@@ -170,13 +197,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
170197 while True :
171198 manager_curr_step = manager .current_step ()
172199 if manager_curr_step not in all_state_dicts :
173- print (
174- f"{ manager_curr_step = } { diloco ._local_step = } { runner .replica_id = } { state_dict ()= } "
175- )
176200 all_state_dicts [manager_curr_step ] = copy .deepcopy (state_dict ())
177201 batch_size = 1
178- inputs = m .get_rand_inputs (batch_size ). to ( device )
179- labels = m .get_rand_labels (batch_size ). to ( device )
202+ inputs = m .get_rand_inputs (batch_size , device = device )
203+ labels = m .get_rand_labels (batch_size , device = device )
180204
181205 out = m (inputs )
182206 loss = criterion (out , labels )
0 commit comments