11import copy
22import logging
33import os
4- from contextlib import ExitStack
54from datetime import timedelta
6- from typing import Any , cast , Dict , List
5+ from typing import Any , Dict
76
87import torch
98from torch import nn
10- from torch .distributed .tensor import DTensor
9+ from torch .distributed .tensor import DeviceMesh , DTensor
1110
12- from torchft .device_mesh import ft_init_device_mesh , ManagedDeviceMesh
1311from torchft .local_sgd import DiLoCo
1412from torchft .manager import Manager
1513from torchft .manager_integ_test import MyModel , Runner
@@ -113,7 +111,7 @@ def __init__(
113111
114112 self .manager : Manager = self .setup_manager ()
115113
116- self .ft_device_mesh : None | ManagedDeviceMesh = None
114+ self .device_mesh : None | DeviceMesh = None
117115 self .setup_distributed ()
118116
119117 self .criterion : nn .CrossEntropyLoss = nn .CrossEntropyLoss ()
@@ -197,12 +195,9 @@ def setup_distributed(self) -> None:
197195 os .environ ["WORLD_SIZE" ] = str (self .runner .world_size )
198196 os .environ ["RANK" ] = str (self .rank )
199197
200- self .ft_device_mesh = ft_init_device_mesh (
201- device_type = self .device .type ,
202- mesh_shape = (self .runner .world_size , 1 ),
203- mesh_dim_names = ("replicate" , "none" ),
204- replicate_dim = 0 ,
205- manager = self .manager ,
198+ self .device_mesh = DeviceMesh (
199+ self .device .type ,
200+ torch .arange (self .runner .world_size ),
206201 )
207202
208203 # Convert model parameters to DTensor
@@ -211,7 +206,7 @@ def setup_distributed(self) -> None:
211206 for param in layer .parameters ():
212207 param = DTensor .from_local (
213208 param ,
214- device_mesh = self .ft_device_mesh ,
209+ device_mesh = self .device_mesh ,
215210 )
216211
217212 def load_state_dict (self , state_dict : Dict [str , Dict [str , object ]]) -> None :
0 commit comments