Skip to content

Test cm #2618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 18, 2022
Merged

Test cm #2618

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 88 additions & 96 deletions tests/ignite/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,43 +53,41 @@ def test_multiclass_wrong_inputs():
ConfusionMatrix.normalize(None, None)


def test_multiclass_input():
def _test(y_pred, y, num_classes, cm, batch_size):
cm.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
cm.update((y_pred, y))

np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

def get_test_cases():
return [
# Multiclass input data of shape (N, )
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 4, 1),
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 10, 1),
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 2, 1),
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 5, 16),
# Multiclass input data of shape (N, L)
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 4, 1),
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 10, 1),
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 9, 16),
# Multiclass input data of shape (N, H, W, ...)
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 5, 1),
(torch.rand(4, 5, 10, 12, 8), torch.randint(0, 5, size=(4, 10, 12, 8)).long(), 5, 1),
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 3, 16),
]

# check multiple random inputs as random exact occurencies are rare
for _ in range(5):
for y_pred, y, num_classes, batch_size in get_test_cases():
cm = ConfusionMatrix(num_classes=num_classes)
_test(y_pred, y, num_classes, cm, batch_size)
@pytest.fixture(params=[item for item in range(10)])
def test_data(request):
return [
# Multiclass input data of shape (N, )
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 4, 1),
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 10, 1),
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 2, 1),
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 5, 16),
# Multiclass input data of shape (N, L)
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 4, 1),
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 10, 1),
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 9, 16),
# Multiclass input data of shape (N, H, W, ...)
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 5, 1),
(torch.rand(4, 5, 10, 12, 8), torch.randint(0, 5, size=(4, 10, 12, 8)).long(), 5, 1),
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 3, 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_multiclass_input(n_times, test_data):
y_pred, y, num_classes, batch_size = test_data
cm = ConfusionMatrix(num_classes=num_classes)
cm.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
cm.update((y_pred, y))

np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())


def test_ignored_out_of_num_classes_indices():
Expand Down Expand Up @@ -202,43 +200,40 @@ def test_iou_wrong_input():
IoU(cm, ignore_index=11)


def test_iou():
def _test(average=None):
@pytest.mark.parametrize("average", [None, "samples"])
def test_iou(average):

y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)
y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

true_res = [0, 0, 0]
for index in range(3):
bin_y_true = y_true == index
bin_y_pred = y_pred == index
intersection = bin_y_true & bin_y_pred
union = bin_y_true | bin_y_pred
true_res[index] = intersection.sum() / union.sum()

cm = ConfusionMatrix(num_classes=3, average=average)
iou_metric = IoU(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

true_res = [0, 0, 0]
for index in range(3):
bin_y_true = y_true == index
bin_y_pred = y_pred == index
intersection = bin_y_true & bin_y_pred
union = bin_y_true | bin_y_pred
true_res[index] = intersection.sum() / union.sum()
res = iou_metric.compute().numpy()

cm = ConfusionMatrix(num_classes=3, average=average)
iou_metric = IoU(cm)
assert np.all(res == true_res)

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
iou_metric = IoU(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = iou_metric.compute().numpy()

assert np.all(res == true_res)

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
iou_metric = IoU(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = iou_metric.compute().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"

_test()
_test(average="samples")
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"

with pytest.raises(ValueError, match=r"ConfusionMatrix should have average attribute either"):
cm = ConfusionMatrix(num_classes=3, average="precision")
Expand Down Expand Up @@ -543,43 +538,40 @@ def _test_distrib_accumulator_device(device):
), f"{type(cm.confusion_matrix.device)}:{cm._num_correct.device} vs {type(metric_device)}:{metric_device}"


def test_jaccard_index():
def _test(average=None):
@pytest.mark.parametrize("average", [None, "samples"])
def test_jaccard_index(average):

y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

true_res = [0, 0, 0]
for index in range(3):
bin_y_true = y_true == index
bin_y_pred = y_pred == index
intersection = bin_y_true & bin_y_pred
union = bin_y_true | bin_y_pred
true_res[index] = intersection.sum() / union.sum()
y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

cm = ConfusionMatrix(num_classes=3, average=average)
jaccard_index = JaccardIndex(cm)
true_res = [0, 0, 0]
for index in range(3):
bin_y_true = y_true == index
bin_y_pred = y_pred == index
intersection = bin_y_true & bin_y_pred
union = bin_y_true | bin_y_pred
true_res[index] = intersection.sum() / union.sum()

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
cm = ConfusionMatrix(num_classes=3, average=average)
jaccard_index = JaccardIndex(cm)

res = jaccard_index.compute().numpy()
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

assert np.all(res == true_res)
res = jaccard_index.compute().numpy()

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
jaccard_index_metric = JaccardIndex(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = jaccard_index_metric.compute().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"
assert np.all(res == true_res)

_test()
_test(average="samples")
for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
jaccard_index_metric = JaccardIndex(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = jaccard_index_metric.compute().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"

with pytest.raises(ValueError, match=r"ConfusionMatrix should have average attribute either"):
cm = ConfusionMatrix(num_classes=3, average="precision")
Expand Down