@@ -238,7 +238,10 @@ def _test_distrib_all_gather_group(device):
238
238
assert res == t
239
239
240
240
t = torch .tensor ([rank ], device = device )
241
- res = idist .all_gather (t , group = ranks )
241
+ if bnd == "horovod" :
242
+ res = idist .all_gather (t , group = group )
243
+ else :
244
+ res = idist .all_gather (t , group = ranks )
242
245
if rank in ranks :
243
246
assert torch .equal (res , torch .tensor (sorted (ranks ), device = device ))
244
247
else :
@@ -252,6 +255,9 @@ def _test_distrib_all_gather_group(device):
252
255
if bnd in ("xla-tpu" ):
253
256
with pytest .raises (NotImplementedError , match = r"all_gather on object is not implemented for xla" ):
254
257
res = idist .all_gather (t , group = ranks )
258
+ elif bnd in ("horovod" ):
259
+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
260
+ res = idist .all_gather (t , group = group )
255
261
else :
256
262
res = idist .all_gather (t , group = ranks )
257
263
if rank in ranks :
@@ -273,15 +279,13 @@ def _test_distrib_all_gather_group(device):
273
279
else :
274
280
assert res == t
275
281
282
+ t = torch .tensor ([rank ], device = device )
276
283
if bnd in ("nccl" , "gloo" , "mpi" , "horovod" ):
277
- with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup " ):
284
+ with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
278
285
res = idist .all_gather (t , group = "abc" )
279
286
elif bnd in ("xla-tpu" ):
280
287
with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
281
288
res = idist .all_gather (t , group = "abc" )
282
- elif bnd in ("horovod" ):
283
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
284
- res = idist .all_gather (t , group = "abc" )
285
289
286
290
287
291
def _test_idist_all_gather_tensors_with_shapes (device ):
@@ -309,9 +313,8 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
309
313
torch .manual_seed (41 )
310
314
311
315
rank = idist .get_rank ()
312
- ranks = list (range (1 , idist .get_world_size ()))
316
+ ranks = sorted (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [1, 2, 3]
313
317
ws = idist .get_world_size ()
314
- bnd = idist .backend ()
315
318
if rank in ranks :
316
319
reference = torch .randn (ws * (ws + 1 ) // 2 , ws * (ws + 3 ) // 2 , ws * (ws + 5 ) // 2 , device = device )
317
320
rank_tensor = reference [
@@ -321,17 +324,18 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
321
324
]
322
325
else :
323
326
rank_tensor = torch .tensor ([rank ], device = device )
324
- tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
325
- if rank in ranks :
326
- for r in ranks :
327
- r_tensor = reference [
328
- r * (r + 1 ) // 2 : r * (r + 1 ) // 2 + r + 1 ,
329
- r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
330
- r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
331
- ]
332
- assert (r_tensor == tensors [r - 1 ]).all ()
333
- else :
334
- assert [rank_tensor ] == tensors
327
+
328
+ tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
329
+ if rank in ranks :
330
+ for r in ranks :
331
+ r_tensor = reference [
332
+ r * (r + 1 ) // 2 : r * (r + 1 ) // 2 + r + 1 ,
333
+ r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
334
+ r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
335
+ ]
336
+ assert torch .equal (r_tensor , tensors [r - 1 ])
337
+ else :
338
+ assert [rank_tensor ] == tensors
335
339
336
340
337
341
def _test_distrib_broadcast (device ):
0 commit comments