Skip to content

Commit 947ef84

Browse files
committed
few more hvd fixes
1 parent 7f1eaa1 commit 947ef84

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

tests/ignite/distributed/utils/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _test_idist_all_gather_tensors_with_shapes(device):
292292
torch.manual_seed(41)
293293
rank = idist.get_rank()
294294
ws = idist.get_world_size()
295-
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
295+
reference = torch.randn(ws * 5, ws * 5, ws * 5, device=device)
296296
rank_tensor = reference[
297297
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
298298
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
@@ -305,7 +305,7 @@ def _test_idist_all_gather_tensors_with_shapes(device):
305305
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
306306
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
307307
]
308-
assert (r_tensor == tensors[r]).all()
308+
assert r_tensor.allclose(tensors[r])
309309

310310

311311
def _test_idist_all_gather_tensors_with_shapes_group(device):
@@ -316,7 +316,7 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
316316
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [1, 2, 3]
317317
ws = idist.get_world_size()
318318
if rank in ranks:
319-
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
319+
reference = torch.randn(ws * 5, ws * 5, ws * 5, device=device)
320320
rank_tensor = reference[
321321
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
322322
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
@@ -327,13 +327,13 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
327327

328328
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
329329
if rank in ranks:
330-
for r in ranks:
330+
for i, r in enumerate(ranks):
331331
r_tensor = reference[
332332
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
333333
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
334334
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
335335
]
336-
assert torch.equal(r_tensor, tensors[r - 1])
336+
assert r_tensor.allclose(tensors[i])
337337
else:
338338
assert [rank_tensor] == tensors
339339

@@ -403,7 +403,7 @@ def _test_distrib_barrier(device):
403403

404404

405405
def _test_distrib_group(device):
406-
ranks = [0, 1]
406+
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [1, 2, 3]
407407
if idist.get_world_size() > 1 and idist.backend() is not None:
408408
bnd = idist.backend()
409409
rank = idist.get_rank()

tests/ignite/distributed/utils/test_horovod.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def test_idist_methods_overhead_hvd(gloo_hvd_executor):
242242
sync_model = False
243243
gloo_hvd_executor(_test_idist_methods_overhead, (ok_factor, sync_model), np=np, do_init=True)
244244

245-
ok_factor = 2.5
245+
ok_factor = 3.0
246246
sync_model = True
247247
gloo_hvd_executor(_test_idist_methods_overhead, (ok_factor, sync_model), np=np, do_init=True)
248248

0 commit comments

Comments
 (0)