@@ -292,7 +292,7 @@ def _test_idist_all_gather_tensors_with_shapes(device):
292
292
torch .manual_seed (41 )
293
293
rank = idist .get_rank ()
294
294
ws = idist .get_world_size ()
295
- reference = torch .randn (ws * ( ws + 1 ) // 2 , ws * ( ws + 3 ) // 2 , ws * ( ws + 5 ) // 2 , device = device )
295
+ reference = torch .randn (ws * 5 , ws * 5 , ws * 5 , device = device )
296
296
rank_tensor = reference [
297
297
rank * (rank + 1 ) // 2 : rank * (rank + 1 ) // 2 + rank + 1 ,
298
298
rank * (rank + 3 ) // 2 : rank * (rank + 3 ) // 2 + rank + 2 ,
@@ -305,7 +305,7 @@ def _test_idist_all_gather_tensors_with_shapes(device):
305
305
r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
306
306
r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
307
307
]
308
- assert ( r_tensor == tensors [r ]). all ( )
308
+ assert r_tensor . allclose ( tensors [r ])
309
309
310
310
311
311
def _test_idist_all_gather_tensors_with_shapes_group (device ):
@@ -316,7 +316,7 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
316
316
ranks = sorted (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [1, 2, 3]
317
317
ws = idist .get_world_size ()
318
318
if rank in ranks :
319
- reference = torch .randn (ws * ( ws + 1 ) // 2 , ws * ( ws + 3 ) // 2 , ws * ( ws + 5 ) // 2 , device = device )
319
+ reference = torch .randn (ws * 5 , ws * 5 , ws * 5 , device = device )
320
320
rank_tensor = reference [
321
321
rank * (rank + 1 ) // 2 : rank * (rank + 1 ) // 2 + rank + 1 ,
322
322
rank * (rank + 3 ) // 2 : rank * (rank + 3 ) // 2 + rank + 2 ,
@@ -327,13 +327,13 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
327
327
328
328
tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
329
329
if rank in ranks :
330
- for r in ranks :
330
+ for i , r in enumerate ( ranks ) :
331
331
r_tensor = reference [
332
332
r * (r + 1 ) // 2 : r * (r + 1 ) // 2 + r + 1 ,
333
333
r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
334
334
r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
335
335
]
336
- assert torch . equal ( r_tensor , tensors [r - 1 ])
336
+ assert r_tensor . allclose ( tensors [i ])
337
337
else :
338
338
assert [rank_tensor ] == tensors
339
339
@@ -403,7 +403,7 @@ def _test_distrib_barrier(device):
403
403
404
404
405
405
def _test_distrib_group (device ):
406
- ranks = [0 , 1 ]
406
+ ranks = sorted ( range ( idist . get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [1, 2, 3 ]
407
407
if idist .get_world_size () > 1 and idist .backend () is not None :
408
408
bnd = idist .backend ()
409
409
rank = idist .get_rank ()
0 commit comments