@@ -215,11 +215,9 @@ def get_test_cases():
215
215
def _test_distrib_integration_binary_input (device ):
216
216
217
217
rank = idist .get_rank ()
218
- torch .manual_seed (12 )
219
218
n_iters = 80
220
- s = 16
219
+ batch_size = 16
221
220
n_classes = 2
222
- offset = n_iters * s
223
221
224
222
def _test (y_preds , y_true , n_epochs , metric_device , update_fn ):
225
223
metric_device = torch .device (metric_device )
@@ -232,6 +230,9 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
232
230
data = list (range (n_iters ))
233
231
engine .run (data = data , max_epochs = n_epochs )
234
232
233
+ y_true = idist .all_gather (y_true )
234
+ y_preds = idist .all_gather (y_preds )
235
+
235
236
assert "ap" in engine .state .metrics
236
237
237
238
res = engine .state .metrics ["ap" ]
@@ -240,24 +241,25 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
240
241
assert pytest .approx (res ) == true_res
241
242
242
243
def get_tests (is_N ):
244
+ torch .manual_seed (12 + rank )
243
245
if is_N :
244
- y_true = torch .randint (0 , n_classes , size = (offset * idist . get_world_size () ,)).to (device )
245
- y_preds = torch .rand (offset * idist . get_world_size () ).to (device )
246
+ y_true = torch .randint (0 , n_classes , size = (n_iters * batch_size ,)).to (device )
247
+ y_preds = torch .rand (n_iters * batch_size ).to (device )
246
248
247
249
def update_fn (engine , i ):
248
250
return (
249
- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset ],
250
- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset ],
251
+ y_preds [i * batch_size : (i + 1 ) * batch_size ],
252
+ y_true [i * batch_size : (i + 1 ) * batch_size ],
251
253
)
252
254
253
255
else :
254
- y_true = torch .randint (0 , n_classes , size = (offset * idist . get_world_size () , 10 )).to (device )
255
- y_preds = torch .randint (0 , n_classes , size = (offset * idist . get_world_size () , 10 )).to (device )
256
+ y_true = torch .randint (0 , n_classes , size = (n_iters * batch_size , 10 )).to (device )
257
+ y_preds = torch .randint (0 , n_classes , size = (n_iters * batch_size , 10 )).to (device )
256
258
257
259
def update_fn (engine , i ):
258
260
return (
259
- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset , :],
260
- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset , :],
261
+ y_preds [i * batch_size : (i + 1 ) * batch_size , :],
262
+ y_true [i * batch_size : (i + 1 ) * batch_size , :],
261
263
)
262
264
263
265
return y_preds , y_true , update_fn
0 commit comments