Skip to content

Commit 795dccb

Browse files
BanzaiTokyovfdev-5
andauthored
adds available_device to test_multilabel_confusion_matrix #3335 (#3366)
* adds available_device to test_multilabel_confusion_matrix #3335 * moves tensors to cpu --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 79a5d28 commit 795dccb

File tree

1 file changed

+43
-32
lines changed

1 file changed

+43
-32
lines changed

tests/ignite/metrics/test_multilabel_confusion_matrix.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ def get_y_true_y_pred():
8787
return y_true, y_pred
8888

8989

90-
def test_multiclass_images():
90+
def test_multiclass_images(available_device):
9191
num_classes = 3
92-
cm = MultiLabelConfusionMatrix(num_classes=num_classes)
92+
cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=available_device)
93+
assert cm._device == torch.device(available_device)
9394

9495
y_true, y_pred = get_y_true_y_pred()
9596

@@ -107,7 +108,8 @@ def test_multiclass_images():
107108
assert np.all(ignite_CM == sklearn_CM)
108109

109110
# Another test on batch of 2 images
110-
cm = MultiLabelConfusionMatrix(num_classes=num_classes)
111+
cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=available_device)
112+
assert cm._device == torch.device(available_device)
111113

112114
# Create a batch of two images:
113115
th_y_true1 = torch.tensor(y_true)
@@ -208,7 +210,7 @@ def _test_distrib_accumulator_device(device):
208210
), f"{type(cm.confusion_matrix.device)}:{cm._num_correct.device} vs {type(metric_device)}:{metric_device}"
209211

210212

211-
def test_simple_2D_input():
213+
def test_simple_2D_input(available_device):
212214
# Tests for 2D inputs with normalized = True and False
213215

214216
num_iters = 5
@@ -218,19 +220,21 @@ def test_simple_2D_input():
218220
for _ in range(num_iters):
219221
target = torch.randint(0, 2, size=(num_samples, num_classes))
220222
prediction = torch.randint(0, 2, size=(num_samples, num_classes))
221-
sklearn_CM = multilabel_confusion_matrix(target.numpy(), prediction.numpy())
222-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
223+
sklearn_CM = multilabel_confusion_matrix(target.cpu().numpy(), prediction.cpu().numpy())
224+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
225+
assert mlcm._device == torch.device(available_device)
223226
mlcm.update([prediction, target])
224-
ignite_CM = mlcm.compute().numpy()
227+
ignite_CM = mlcm.compute().cpu().numpy()
225228
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
226-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=True)
229+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=True, device=available_device)
230+
assert mlcm._device == torch.device(available_device)
227231
mlcm.update([prediction, target])
228-
ignite_CM_normalized = mlcm.compute().numpy()
232+
ignite_CM_normalized = mlcm.compute().cpu().numpy()
229233
sklearn_CM_normalized = sklearn_CM / sklearn_CM.sum(axis=(1, 2))[:, None, None]
230234
assert np.allclose(sklearn_CM_normalized, ignite_CM_normalized)
231235

232236

233-
def test_simple_ND_input():
237+
def test_simple_ND_input(available_device):
234238
num_iters = 5
235239
num_samples = 100
236240
num_classes = 10
@@ -240,82 +244,88 @@ def test_simple_ND_input():
240244
for _ in range(num_iters): # 3D tests
241245
target = torch.randint(0, 2, size=(num_samples, num_classes, size_3d))
242246
prediction = torch.randint(0, 2, size=(num_samples, num_classes, size_3d))
243-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
247+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
248+
assert mlcm._device == torch.device(available_device)
244249
mlcm.update([prediction, target])
245-
ignite_CM = mlcm.compute().numpy()
250+
ignite_CM = mlcm.compute().cpu().numpy()
246251
target_reshaped = target.permute(0, 2, 1).reshape(size_3d * num_samples, num_classes)
247252
prediction_reshaped = prediction.permute(0, 2, 1).reshape(size_3d * num_samples, num_classes)
248-
sklearn_CM = multilabel_confusion_matrix(target_reshaped.numpy(), prediction_reshaped.numpy())
253+
sklearn_CM = multilabel_confusion_matrix(target_reshaped.cpu().numpy(), prediction_reshaped.cpu().numpy())
249254
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
250255

251256
size_4d = 4
252257
for _ in range(num_iters): # 4D tests
253258
target = torch.randint(0, 2, size=(num_samples, num_classes, size_3d, size_4d))
254259
prediction = torch.randint(0, 2, size=(num_samples, num_classes, size_3d, size_4d))
255-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
260+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
261+
assert mlcm._device == torch.device(available_device)
256262
mlcm.update([prediction, target])
257-
ignite_CM = mlcm.compute().numpy()
263+
ignite_CM = mlcm.compute().cpu().numpy()
258264
target_reshaped = target.permute(0, 2, 3, 1).reshape(size_3d * size_4d * num_samples, num_classes)
259265
prediction_reshaped = prediction.permute(0, 2, 3, 1).reshape(size_3d * size_4d * num_samples, num_classes)
260-
sklearn_CM = multilabel_confusion_matrix(target_reshaped.numpy(), prediction_reshaped.numpy())
266+
sklearn_CM = multilabel_confusion_matrix(target_reshaped.cpu().numpy(), prediction_reshaped.cpu().numpy())
261267
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
262268

263269
size_5d = 4
264270
for _ in range(num_iters): # 5D tests
265271
target = torch.randint(0, 2, size=(num_samples, num_classes, size_3d, size_4d, size_5d))
266272
prediction = torch.randint(0, 2, size=(num_samples, num_classes, size_3d, size_4d, size_5d))
267-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
273+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
274+
assert mlcm._device == torch.device(available_device)
268275
mlcm.update([prediction, target])
269-
ignite_CM = mlcm.compute().numpy()
276+
ignite_CM = mlcm.compute().cpu().numpy()
270277
target_reshaped = target.permute(0, 2, 3, 4, 1).reshape(size_3d * size_4d * size_5d * num_samples, num_classes)
271278
prediction_reshaped = prediction.permute(0, 2, 3, 4, 1).reshape(
272279
size_3d * size_4d * size_5d * num_samples, num_classes
273280
)
274-
sklearn_CM = multilabel_confusion_matrix(target_reshaped.numpy(), prediction_reshaped.numpy())
281+
sklearn_CM = multilabel_confusion_matrix(target_reshaped.cpu().numpy(), prediction_reshaped.cpu().numpy())
275282
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
276283

277284

278-
def test_simple_batched():
285+
def test_simple_batched(available_device):
279286
num_iters = 5
280287
num_samples = 100
281288
num_classes = 10
282289
batch_size = 1
283290
torch.manual_seed(0)
284291

285292
for _ in range(num_iters): # 2D tests
286-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
293+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
294+
assert mlcm._device == torch.device(available_device)
287295
targets = torch.randint(0, 2, size=(int(num_samples / batch_size), batch_size, num_classes))
288296
predictions = torch.randint(0, 2, size=(int(num_samples / batch_size), batch_size, num_classes))
289297
for i in range(int(num_samples / batch_size)):
290298
target_sample = targets[i]
291299
prediction_sample = predictions[i]
292300
mlcm.update([prediction_sample, target_sample])
293301

294-
ignite_CM = mlcm.compute().numpy()
302+
ignite_CM = mlcm.compute().cpu().numpy()
295303
targets_reshaped = targets.reshape(-1, num_classes)
296304
predictions_reshaped = predictions.reshape(-1, num_classes)
297-
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.numpy(), predictions_reshaped.numpy())
305+
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.cpu().numpy(), predictions_reshaped.cpu().numpy())
298306
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
299307

300308
size_3d = 4
301309
for _ in range(num_iters): # 3D tests
302-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
310+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
311+
assert mlcm._device == torch.device(available_device)
303312
targets = torch.randint(0, 2, size=(int(num_samples / batch_size), batch_size, num_classes, size_3d))
304313
predictions = torch.randint(0, 2, size=(int(num_samples / batch_size), batch_size, num_classes, size_3d))
305314
for i in range(int(num_samples / batch_size)):
306315
target_sample = targets[i]
307316
prediction_sample = predictions[i]
308317
mlcm.update([prediction_sample, target_sample])
309318

310-
ignite_CM = mlcm.compute().numpy()
319+
ignite_CM = mlcm.compute().cpu().numpy()
311320
targets_reshaped = targets.permute(0, 1, 3, 2).reshape(-1, num_classes)
312321
predictions_reshaped = predictions.permute(0, 1, 3, 2).reshape(-1, num_classes)
313-
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.numpy(), predictions_reshaped.numpy())
322+
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.cpu().numpy(), predictions_reshaped.cpu().numpy())
314323
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
315324

316325
size_4d = 4
317326
for _ in range(num_iters): # 4D tests
318-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
327+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
328+
assert mlcm._device == torch.device(available_device)
319329
targets = torch.randint(0, 2, size=(int(num_samples / batch_size), batch_size, num_classes, size_3d, size_4d))
320330
predictions = torch.randint(
321331
0, 2, size=(int(num_samples / batch_size), batch_size, num_classes, size_3d, size_4d)
@@ -325,15 +335,16 @@ def test_simple_batched():
325335
prediction_sample = predictions[i]
326336
mlcm.update([prediction_sample, target_sample])
327337

328-
ignite_CM = mlcm.compute().numpy()
338+
ignite_CM = mlcm.compute().cpu().numpy()
329339
targets_reshaped = targets.permute(0, 1, 3, 4, 2).reshape(-1, num_classes)
330340
predictions_reshaped = predictions.permute(0, 1, 3, 4, 2).reshape(-1, num_classes)
331-
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.numpy(), predictions_reshaped.numpy())
341+
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.cpu().numpy(), predictions_reshaped.cpu().numpy())
332342
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
333343

334344
size_5d = 4
335345
for _ in range(num_iters): # 5D tests
336-
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False)
346+
mlcm = MultiLabelConfusionMatrix(num_classes, normalized=False, device=available_device)
347+
assert mlcm._device == torch.device(available_device)
337348
targets = torch.randint(
338349
0, 2, size=(int(num_samples / batch_size), batch_size, num_classes, size_3d, size_4d, size_5d)
339350
)
@@ -345,10 +356,10 @@ def test_simple_batched():
345356
prediction_sample = predictions[i]
346357
mlcm.update([prediction_sample, target_sample])
347358

348-
ignite_CM = mlcm.compute().numpy()
359+
ignite_CM = mlcm.compute().cpu().numpy()
349360
targets_reshaped = targets.permute(0, 1, 3, 4, 5, 2).reshape(-1, num_classes)
350361
predictions_reshaped = predictions.permute(0, 1, 3, 4, 5, 2).reshape(-1, num_classes)
351-
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.numpy(), predictions_reshaped.numpy())
362+
sklearn_CM = multilabel_confusion_matrix(targets_reshaped.cpu().numpy(), predictions_reshaped.cpu().numpy())
352363
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))
353364

354365

0 commit comments

Comments
 (0)