Skip to content

Commit 8ef24c0

Browse files
tushar00jainmeta-codesync[bot]
authored andcommitted
remove device mesh (#289)
Summary: Pull Request resolved: #289 Remove device mesh since we don't really use it. Device mesh is undergoing a lot of changes and using private api's makes the subclass difficult to maintain. We will revisit device mesh integration with public api's. Reviewed By: d4l3k Differential Revision: D86466239 fbshipit-source-id: 386e32ba9e1053fba62bf8d19d05fd0a42ca2853
1 parent e1315dc commit 8ef24c0

File tree

9 files changed

+27
-582
lines changed

9 files changed

+27
-582
lines changed

torchft/_test/diloco_trainer.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import copy
22
import logging
33
import os
4-
from contextlib import ExitStack
54
from datetime import timedelta
6-
from typing import Any, cast, Dict, List
5+
from typing import Any, Dict
76

87
import torch
98
from torch import nn
10-
from torch.distributed.tensor import DTensor
9+
from torch.distributed.tensor import DeviceMesh, DTensor
1110

12-
from torchft.device_mesh import ft_init_device_mesh, ManagedDeviceMesh
1311
from torchft.local_sgd import DiLoCo
1412
from torchft.manager import Manager
1513
from torchft.manager_integ_test import MyModel, Runner
@@ -113,7 +111,7 @@ def __init__(
113111

114112
self.manager: Manager = self.setup_manager()
115113

116-
self.ft_device_mesh: None | ManagedDeviceMesh = None
114+
self.device_mesh: None | DeviceMesh = None
117115
self.setup_distributed()
118116

119117
self.criterion: nn.CrossEntropyLoss = nn.CrossEntropyLoss()
@@ -197,12 +195,9 @@ def setup_distributed(self) -> None:
197195
os.environ["WORLD_SIZE"] = str(self.runner.world_size)
198196
os.environ["RANK"] = str(self.rank)
199197

200-
self.ft_device_mesh = ft_init_device_mesh(
201-
device_type=self.device.type,
202-
mesh_shape=(self.runner.world_size, 1),
203-
mesh_dim_names=("replicate", "none"),
204-
replicate_dim=0,
205-
manager=self.manager,
198+
self.device_mesh = DeviceMesh(
199+
self.device.type,
200+
torch.arange(self.runner.world_size),
206201
)
207202

208203
# Convert model parameters to DTensor
@@ -211,7 +206,7 @@ def setup_distributed(self) -> None:
211206
for param in layer.parameters():
212207
param = DTensor.from_local(
213208
param,
214-
device_mesh=self.ft_device_mesh,
209+
device_mesh=self.device_mesh,
215210
)
216211

217212
def load_state_dict(self, state_dict: Dict[str, Dict[str, object]]) -> None:

torchft/device_mesh.py

Lines changed: 0 additions & 340 deletions
This file was deleted.

0 commit comments

Comments
 (0)