Skip to content

Commit 9200e74

Browse files
authored
ManagedDeviceMesh: support TP (#155)
1 parent 2c3383f commit 9200e74

File tree

3 files changed

+115
-11
lines changed

3 files changed

+115
-11
lines changed

torchft/device_mesh_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,40 @@ def _test_init_device_mesh(world_size: int, rank: int) -> None:
7575
torch.load(buffer, weights_only=False)
7676

7777
def test_init_device_mesh(self) -> None:
78+
if dist.is_initialized():
79+
dist.destroy_process_group()
80+
7881
with ProcessPoolExecutor(max_workers=4) as executor:
7982
futures = []
8083
for i in range(4):
8184
future = executor.submit(self._test_init_device_mesh, 4, i)
8285
futures.append(future)
8386
for f in futures:
8487
f.result()
88+
89+
def test_repr_hash(self) -> None:
90+
if dist.is_initialized():
91+
dist.destroy_process_group()
92+
93+
os.environ["MASTER_ADDR"] = "127.0.0.1"
94+
os.environ["MASTER_PORT"] = str(12346)
95+
os.environ["RANK"] = str(0)
96+
os.environ["WORLD_SIZE"] = str(1)
97+
98+
manager = Mock(spec=Manager)
99+
manager._pg = ProcessGroupGloo()
100+
101+
device_mesh = ft_init_device_mesh(
102+
device_type="cpu",
103+
mesh_shape=(1, 1),
104+
mesh_dim_names=("dp_replicate", "dp_shard"),
105+
replicate_dim=0,
106+
manager=manager,
107+
)
108+
109+
self.assertIsInstance(repr(device_mesh), str)
110+
self.assertIsInstance(str(device_mesh), str)
111+
self.assertEqual(hash(device_mesh), hash(device_mesh))
112+
self.assertIsInstance(hash(device_mesh), int)
113+
114+
dist.destroy_process_group()

torchft/fsdp_test.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,34 @@
3030
)
3131
from torch.distributed._composable.fsdp import fully_shard
3232
from torch.distributed.device_mesh import init_device_mesh
33+
from torch.distributed.tensor.parallel import (
34+
ColwiseParallel,
35+
PrepareModuleInput,
36+
RowwiseParallel,
37+
SequenceParallel,
38+
parallelize_module,
39+
)
3340

3441
from torchft.manager import Manager
35-
from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh
42+
from torchft.process_group import (
43+
ManagedProcessGroup,
44+
ProcessGroupGloo,
45+
ft_init_device_mesh,
46+
)
3647

3748

3849
class FSDPTest(unittest.TestCase):
3950
@staticmethod
40-
def _test_fsdp(world_size: int, rank: int) -> None:
51+
def _test_fsdp(
52+
world_size: int,
53+
rank: int,
54+
dp_replicate: int = 2,
55+
dp_shard: int = 2,
56+
tp: int = 1,
57+
) -> None:
4158
torch.cuda.set_device(rank)
4259

43-
group_size = world_size // 2
60+
group_size = world_size // dp_replicate
4461
group = rank // group_size
4562
group_rank = rank % group_size
4663

@@ -50,17 +67,28 @@ def _test_fsdp(world_size: int, rank: int) -> None:
5067
os.environ["WORLD_SIZE"] = str(group_size)
5168

5269
manager = Mock(spec=Manager)
70+
manager._pg = ProcessGroupGloo()
5371
device_mesh = ft_init_device_mesh(
5472
device_type="cuda",
55-
mesh_shape=(2, 2),
56-
mesh_dim_names=("dp_replicate", "dp_shard"),
73+
mesh_shape=(dp_replicate, dp_shard, tp),
74+
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
5775
replicate_dim=0,
5876
manager=manager,
5977
)
6078
manager.num_participants.return_value = 1
6179
model = nn.Linear(128, 128).cuda()
6280
batch = torch.randn(4, 128).cuda()
63-
shard_model = fully_shard(model, mesh=device_mesh)
81+
82+
fsdp_mesh = device_mesh["dp_replicate", "dp_shard"]
83+
84+
if tp > 1:
85+
tp_mesh = device_mesh["tp"]
86+
model = parallelize_module(
87+
model,
88+
tp_mesh,
89+
ColwiseParallel(),
90+
)
91+
shard_model = fully_shard(model, mesh=fsdp_mesh)
6492
shard_model(batch).mean().backward()
6593

6694
# pyre-ignore[56]: Pyre was not able to infer the type of argument
@@ -72,3 +100,21 @@ def test_fsdp(self) -> None:
72100
for i in range(4):
73101
future = executor.submit(self._test_fsdp, 4, i)
74102
futures.append(future)
103+
104+
for fut in futures:
105+
fut.result()
106+
107+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
108+
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
109+
def test_fsdp_tp(self) -> None:
110+
context = multiprocessing.get_context("spawn")
111+
with ProcessPoolExecutor(max_workers=4, mp_context=context) as executor:
112+
futures = []
113+
for i in range(4):
114+
future = executor.submit(
115+
self._test_fsdp, 4, i, dp_replicate=1, dp_shard=2, tp=2
116+
)
117+
futures.append(future)
118+
119+
for fut in futures:
120+
fut.result()

torchft/process_group.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
ReduceScatterOptions,
6565
Work,
6666
)
67+
from torch.distributed.tensor.device_mesh import _mesh_resources
6768
from torch.futures import Future
6869
from torch.utils._pytree import tree_any
6970

@@ -1790,6 +1791,7 @@ def __init__(
17901791
self.device_type = parent.device_type
17911792
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
17921793
self._thread_id: Optional[int] = None
1794+
self._hash: Optional[int] = None
17931795

17941796
def __getstate__(self) -> Dict[str, Any]:
17951797
state = self.__dict__.copy()
@@ -1804,36 +1806,43 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
18041806
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
18051807
if isinstance(mesh_dim_names, str):
18061808
if mesh_dim_names == self.replicate_dim_name:
1807-
return ManagedDeviceMesh(
1809+
res_submesh = ManagedDeviceMesh(
18081810
mesh=None,
18091811
mesh_dim_names=(mesh_dim_names,),
18101812
replicate_pg=self.replicate_pg,
18111813
replicate_dim=0,
18121814
parent=self,
18131815
)
18141816
elif mesh_dim_names in self.flatten_meshes:
1815-
return self.flatten_meshes[mesh_dim_names]
1817+
res_submesh = self.flatten_meshes[mesh_dim_names]
18161818
else:
18171819
assert self.mesh is not None
1818-
return self.mesh[mesh_dim_names]
1820+
res_submesh = self.mesh[mesh_dim_names]
18191821
else:
18201822
assert isinstance(mesh_dim_names, tuple)
18211823
if self.replicate_dim_name not in mesh_dim_names:
18221824
assert self.mesh is not None
1823-
return self.mesh[mesh_dim_names]
1825+
res_submesh = self.mesh[mesh_dim_names]
18241826
else:
18251827
mesh_dim_names_wo_replicate = tuple(
18261828
n for n in mesh_dim_names if n != self.replicate_dim_name
18271829
)
18281830
assert self.mesh is not None
1829-
return ManagedDeviceMesh(
1831+
res_submesh = ManagedDeviceMesh(
18301832
self.mesh[mesh_dim_names_wo_replicate],
18311833
mesh_dim_names,
18321834
self.replicate_pg,
18331835
mesh_dim_names.index(self.replicate_dim_name),
18341836
parent=self,
18351837
)
18361838

1839+
# TODO: find a better way to do this that doesn't depend on device mesh
1840+
# internals
1841+
root = _mesh_resources.get_root_mesh(self)
1842+
_mesh_resources.child_to_root_mapping[res_submesh] = root
1843+
1844+
return res_submesh
1845+
18371846
def _real_mesh_dim(self, mesh_dim: int) -> int:
18381847
return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim
18391848

@@ -1937,6 +1946,25 @@ def get_coordinate(self) -> Optional[List[int]]:
19371946
def get_all_groups(self) -> List[BaseProcessGroup]:
19381947
raise NotImplementedError
19391948

1949+
def __repr__(self) -> str:
1950+
return f"ManagedDeviceMesh(mesh={self.mesh})"
1951+
1952+
def __hash__(self) -> int:
1953+
# lazily compute hash
1954+
if not self._hash:
1955+
self._hash = hash(
1956+
(
1957+
self.mesh,
1958+
self.mesh_dim_names,
1959+
self.replicate_pg,
1960+
self.replicate_dim,
1961+
self.replicate_dim_name,
1962+
self.parent,
1963+
self.device_type,
1964+
)
1965+
)
1966+
return self._hash
1967+
19401968

19411969
class _FlattenDeviceMesh(DeviceMesh):
19421970
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:

0 commit comments

Comments
 (0)