Skip to content

Commit 6e6c0cc

Browse files
committed
More fixes for hvd
1 parent c769298 commit 6e6c0cc

File tree

3 files changed

+51
-28
lines changed

3 files changed

+51
-28
lines changed

ignite/distributed/comp_models/horovod.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import warnings
23
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
34

@@ -23,6 +24,9 @@
2324
if has_hvd_support:
2425
HOROVOD = "horovod"
2526

27+
# Enables dynamic process sets: new_group methods and passing group into collective ops
28+
os.environ["HOROVOD_DYNAMIC_PROCESS_SETS"] = "1"
29+
2630
class _HorovodDistModel(ComputationModel):
2731
"""Private class for `Horovod <https://horovod.readthedocs.io/en/stable/>`_ distributed computation model."""
2832

@@ -155,6 +159,15 @@ def spawn(
155159
**kwargs,
156160
)
157161

162+
def _setup_group(self, group: Any) -> hvd.ProcessSet:
163+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
164+
group = self._do_new_group(group)
165+
if not isinstance(group, hvd.ProcessSet):
166+
raise ValueError(
167+
f"Argument group should be list of int or hvd.ProcessSet, got {type(group)}, group={group}"
168+
)
169+
return group
170+
158171
_reduce_op_map = {
159172
"SUM": hvd.mpi_ops.Sum,
160173
"AVERAGE": hvd.mpi_ops.Average,
@@ -186,21 +199,25 @@ def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:
186199
return reduced_res[0]
187200

188201
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
202+
if group is not None:
203+
group = self._setup_group(group)
204+
if self._rank_not_in_group(group):
205+
return tensor
189206
if tensor.ndimension() == 0:
190207
tensor = tensor.unsqueeze(0)
191-
if group is None:
192-
return hvd.allgather(tensor)
193-
else:
208+
if group is not None:
194209
return hvd.allgather(tensor, process_set=group)
210+
else:
211+
return hvd.allgather(tensor)
195212

196213
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
197214
if group is not None:
198215
raise NotImplementedError("all_gather with group for horovod is not implemented")
199216

200217
return hvd.allgather_object(tensor)
201218

202-
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
203-
return hvd.ProcessSet(ranks)
219+
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> hvd.ProcessSet:
220+
return hvd.add_process_set(ranks)
204221

205222
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
206223
return hvd.broadcast(tensor, root_rank=src)
@@ -210,5 +227,7 @@ def barrier(self) -> None:
210227
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
211228
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
212229

213-
def _rank_not_in_group(self, group: Any) -> bool:
230+
def _rank_not_in_group(self, group: Optional[Any]) -> bool:
231+
if group is None:
232+
return False
214233
return not group.included()

ignite/distributed/comp_models/native.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def spawn(
408408
**spawn_kwargs,
409409
)
410410

411-
def _setup_group(self, group: Optional[Any]) -> dist.ProcessGroup:
411+
def _setup_group(self, group: Any) -> dist.ProcessGroup:
412412
if isinstance(group, list) and all(isinstance(item, int) for item in group):
413413
group = self._do_new_group(group)
414414
if not (isinstance(group, dist.ProcessGroup) or group == dist.GroupMember.NON_GROUP_MEMBER):
@@ -442,7 +442,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
442442
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
443443
if group is not None:
444444
group = self._setup_group(group)
445-
if group == dist.GroupMember.NON_GROUP_MEMBER:
445+
if self._rank_not_in_group(group):
446446
return tensor
447447
if group is None:
448448
group_size = self.get_world_size()
@@ -466,7 +466,7 @@ def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Lis
466466
)
467467
if group is not None:
468468
group = self._setup_group(group)
469-
if group == dist.GroupMember.NON_GROUP_MEMBER:
469+
if self._rank_not_in_group(group):
470470
return tensor
471471
if group is None:
472472
group_size = self.get_world_size()
@@ -491,7 +491,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
491491
def barrier(self) -> None:
492492
dist.barrier()
493493

494-
def _rank_not_in_group(self, group: Any) -> bool:
494+
def _rank_not_in_group(self, group: Optional[Any]) -> bool:
495495
return dist._rank_not_in_group(group)
496496

497497
def _expand_hostlist(nodelist: str) -> List[str]:

tests/ignite/distributed/utils/__init__.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ def _test_distrib_all_gather_group(device):
238238
assert res == t
239239

240240
t = torch.tensor([rank], device=device)
241-
res = idist.all_gather(t, group=ranks)
241+
if bnd == "horovod":
242+
res = idist.all_gather(t, group=group)
243+
else:
244+
res = idist.all_gather(t, group=ranks)
242245
if rank in ranks:
243246
assert torch.equal(res, torch.tensor(sorted(ranks), device=device))
244247
else:
@@ -252,6 +255,9 @@ def _test_distrib_all_gather_group(device):
252255
if bnd in ("xla-tpu"):
253256
with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"):
254257
res = idist.all_gather(t, group=ranks)
258+
elif bnd in ("horovod"):
259+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
260+
res = idist.all_gather(t, group=group)
255261
else:
256262
res = idist.all_gather(t, group=ranks)
257263
if rank in ranks:
@@ -273,15 +279,13 @@ def _test_distrib_all_gather_group(device):
273279
else:
274280
assert res == t
275281

282+
t = torch.tensor([rank], device=device)
276283
if bnd in ("nccl", "gloo", "mpi", "horovod"):
277-
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
284+
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
278285
res = idist.all_gather(t, group="abc")
279286
elif bnd in ("xla-tpu"):
280287
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
281288
res = idist.all_gather(t, group="abc")
282-
elif bnd in ("horovod"):
283-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
284-
res = idist.all_gather(t, group="abc")
285289

286290

287291
def _test_idist_all_gather_tensors_with_shapes(device):
@@ -309,9 +313,8 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
309313
torch.manual_seed(41)
310314

311315
rank = idist.get_rank()
312-
ranks = list(range(1, idist.get_world_size()))
316+
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [1, 2, 3]
313317
ws = idist.get_world_size()
314-
bnd = idist.backend()
315318
if rank in ranks:
316319
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
317320
rank_tensor = reference[
@@ -321,17 +324,18 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
321324
]
322325
else:
323326
rank_tensor = torch.tensor([rank], device=device)
324-
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
325-
if rank in ranks:
326-
for r in ranks:
327-
r_tensor = reference[
328-
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
329-
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
330-
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
331-
]
332-
assert (r_tensor == tensors[r - 1]).all()
333-
else:
334-
assert [rank_tensor] == tensors
327+
328+
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
329+
if rank in ranks:
330+
for r in ranks:
331+
r_tensor = reference[
332+
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
333+
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
334+
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
335+
]
336+
assert torch.equal(r_tensor, tensors[r - 1])
337+
else:
338+
assert [rank_tensor] == tensors
335339

336340

337341
def _test_distrib_broadcast(device):

0 commit comments

Comments
 (0)