Skip to content

Commit

Permalink
5919 fix generalized dice issue (Project-MONAI#5929)
Browse files Browse the repository at this point in the history
Signed-off-by: Yiheng Wang <vennw@nvidia.com>

Fixes Project-MONAI#5919 .

### Description

This PR is used to fix the device issue of function
`compute_generalized_dice`, and cuda tensor input will not raise errors.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <vennw@nvidia.com>
  • Loading branch information
yiheng-wang-nv authored and wyli committed Feb 2, 2023
1 parent c1f3c73 commit 34c8f0d
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 25 deletions.
4 changes: 3 additions & 1 deletion monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def compute_generalized_dice(
y_pred_o = y_pred_o.sum(dim=-1)
denom_zeros = denom == 0
generalized_dice_score[denom_zeros] = torch.where(
(y_pred_o == 0)[denom_zeros], torch.tensor(1.0), torch.tensor(0.0)
(y_pred_o == 0)[denom_zeros],
torch.tensor(1.0, device=generalized_dice_score.device),
torch.tensor(0.0, device=generalized_dice_score.device),
)

return generalized_dice_score
2 changes: 2 additions & 0 deletions tests/test_compute_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import unittest
from typing import Any

import numpy as np
import torch
from parameterized import parameterized

Expand Down Expand Up @@ -220,6 +221,7 @@ def test_value(self, input_data, expected_value):
input_data["include_background"] = False
result = get_confusion_matrix(**input_data)
assert_allclose(result, expected_value[:, 1:, :], atol=1e-4, rtol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

@parameterized.expand(TEST_CASES_COMPUTE_SAMPLE)
def test_compute_sample(self, input_data, expected_value):
Expand Down
18 changes: 11 additions & 7 deletions tests/test_compute_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@

import unittest

import numpy
import numpy as np
import torch

from monai.metrics import FBetaScore
from tests.utils import assert_allclose

_device = "cuda:0" if torch.cuda.is_available() else "cpu"


class TestFBetaScore(unittest.TestCase):
def test_expecting_success(self):
def test_expecting_success_and_device(self):
metric = FBetaScore()
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
)
assert_allclose(metric.aggregate()[0], torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
y_pred = torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], device=_device)
y = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], device=_device)
metric(y_pred=y_pred, y=y)
result = metric.aggregate()[0]
assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
np.testing.assert_equal(result.device, y_pred.device)

def test_expecting_success2(self):
metric = FBetaScore(beta=0.5)
Expand Down Expand Up @@ -58,7 +62,7 @@ def test_with_nan_values(self):
metric = FBetaScore(get_not_nans=True)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]),
y=torch.Tensor([[1, 0, 1], [numpy.NaN, numpy.NaN, numpy.NaN], [1, 0, 1]]),
y=torch.Tensor([[1, 0, 1], [np.NaN, np.NaN, np.NaN], [1, 0, 1]]),
)
assert_allclose(metric.aggregate()[0][0], torch.Tensor([0.727273]), atol=1e-6, rtol=1e-6)

Expand Down
13 changes: 10 additions & 3 deletions tests/test_compute_generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

from monai.metrics import GeneralizedDiceScore, compute_generalized_dice

_device = "cuda:0" if torch.cuda.is_available() else "cpu"

# keep background
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1)
{
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
"include_background": True,
},
[0.8],
Expand Down Expand Up @@ -116,7 +118,12 @@
TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]]


class TestComputeMeanDice(unittest.TestCase):
class TestComputeGeneralizedDiceScore(unittest.TestCase):
@parameterized.expand([TEST_CASE_1])
def test_device(self, input_data, _expected_value):
result = compute_generalized_dice(**input_data)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

# Functional part tests
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9])
def test_value(self, input_data, expected_value):
Expand Down
1 change: 1 addition & 0 deletions tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class TestComputeMeanDice(unittest.TestCase):
def test_value(self, input_data, expected_value):
result = compute_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

@parameterized.expand([TEST_CASE_3])
def test_nans(self, input_data, expected_value):
Expand Down
1 change: 1 addition & 0 deletions tests/test_compute_meaniou.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class TestComputeMeanIoU(unittest.TestCase):
def test_value(self, input_data, expected_value):
result = compute_meaniou(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

@parameterized.expand([TEST_CASE_3])
def test_nans(self, input_data, expected_value):
Expand Down
1 change: 1 addition & 0 deletions tests/test_compute_panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class TestPanopticQualityMetric(unittest.TestCase):
def test_value(self, input_params, expected_value):
result = compute_panoptic_quality(**input_params)
np.testing.assert_allclose(result.cpu().detach().item(), expected_value, atol=1e-4)
np.testing.assert_equal(result.device, input_params["pred"].device)

@parameterized.expand([TEST_CLS_CASE_1, TEST_CLS_CASE_2, TEST_CLS_CASE_3, TEST_CLS_CASE_4, TEST_CLS_CASE_5])
def test_value_class(self, input_params, y_pred, y_gt, expected_value):
Expand Down
1 change: 1 addition & 0 deletions tests/test_compute_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class TestComputeVariance(unittest.TestCase):
def test_value(self, input_data, expected_value):
result = compute_variance(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

@parameterized.expand([TEST_CASE_5, TEST_CASE_6])
def test_spatial_case(self, input_data, expected_value):
Expand Down
9 changes: 6 additions & 3 deletions tests/test_hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from monai.metrics import HausdorffDistanceMetric

_device = "cuda:0" if torch.cuda.is_available() else "cpu"


def create_spherical_seg_3d(
radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99)
Expand Down Expand Up @@ -116,8 +118,8 @@ def test_value(self, input_data, expected_value):
else:
[seg_1, seg_2] = input_data
ct = 0
seg_1 = torch.tensor(seg_1)
seg_2 = torch.tensor(seg_2)
seg_1 = torch.tensor(seg_1, device=_device)
seg_2 = torch.tensor(seg_2, device=_device)
for metric in ["euclidean", "chessboard", "taxicab"]:
for directed in [True, False]:
hd_metric = HausdorffDistanceMetric(
Expand All @@ -130,7 +132,8 @@ def test_value(self, input_data, expected_value):
hd_metric(batch_seg_1, batch_seg_2)
result = hd_metric.aggregate(reduction="mean")
expected_value_curr = expected_value[ct]
np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7)
np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-7)
np.testing.assert_equal(result.device, seg_1.device)
ct += 1

@parameterized.expand(TEST_CASES_NANS)
Expand Down
1 change: 1 addition & 0 deletions tests/test_label_quality_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class TestLabelQualityScore(unittest.TestCase):
def test_value(self, input_data, expected_value):
result = label_quality_score(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

@parameterized.expand([TEST_CASE_6, TEST_CASE_7])
def test_spatial_case(self, input_data, expected_value):
Expand Down
20 changes: 12 additions & 8 deletions tests/test_surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

from monai.metrics.surface_dice import SurfaceDiceMetric

_device = "cuda:0" if torch.cuda.is_available() else "cpu"


class TestAllSurfaceDiceMetrics(unittest.TestCase):
def test_tolerance_euclidean_distance(self):
batch_size = 2
n_class = 2
predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64)
labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64)
predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device)
labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device)
predictions[0, :, 50:] = 1
labels[0, :, 60:] = 1 # 10 px shift
predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2)
Expand All @@ -38,8 +40,10 @@ def test_tolerance_euclidean_distance(self):
res0_nans = sd0_nans(predictions_hot, labels_hot)
agg0_nans, not_nans = sd0_nans.aggregate()

np.testing.assert_array_equal(res0, res0_nans)
np.testing.assert_array_equal(agg0, agg0_nans)
np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu())
np.testing.assert_equal(res0.device, predictions.device)
np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu())
np.testing.assert_equal(agg0.device, predictions.device)

res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot)
res9 = SurfaceDiceMetric(class_thresholds=[9, 9], include_background=True)(predictions_hot, labels_hot)
Expand All @@ -51,17 +55,17 @@ def test_tolerance_euclidean_distance(self):

assert res0[0, 0] < res1[0, 0] < res9[0, 0] < res10[0, 0]
assert res0[0, 1] < res1[0, 1] < res9[0, 1] < res10[0, 1]
np.testing.assert_array_equal(res10, res11)
np.testing.assert_array_equal(res10.cpu(), res11.cpu())

expected_res0 = np.zeros((batch_size, n_class))
expected_res0[0, 1] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 588 * 2 + 578 * 2)
expected_res0[0, 0] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 48 * 2 + 58 * 2)
expected_res0[1, 0] = 1
expected_res0[1, 1] = np.nan
for b, c in np.ndindex(batch_size, n_class):
np.testing.assert_allclose(expected_res0[b, c], res0[b, c])
np.testing.assert_array_equal(agg0, np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))
np.testing.assert_equal(not_nans, torch.tensor(2))
np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu())
np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))
np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))

def test_tolerance_all_distances(self):
batch_size = 1
Expand Down
9 changes: 6 additions & 3 deletions tests/test_surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from monai.metrics import SurfaceDistanceMetric

_device = "cuda:0" if torch.cuda.is_available() else "cpu"


def create_spherical_seg_3d(
radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99)
Expand Down Expand Up @@ -111,8 +113,8 @@ def test_value(self, input_data, expected_value):
[seg_1, seg_2] = input_data
metric = "euclidean"
ct = 0
seg_1 = torch.tensor(seg_1)
seg_2 = torch.tensor(seg_2)
seg_1 = torch.tensor(seg_1, device=_device)
seg_2 = torch.tensor(seg_2, device=_device)
for symmetric in [True, False]:
sur_metric = SurfaceDistanceMetric(include_background=False, symmetric=symmetric, distance_metric=metric)
# shape of seg_1, seg_2 are: HWD, converts to BNHWD
Expand All @@ -122,7 +124,8 @@ def test_value(self, input_data, expected_value):
sur_metric(batch_seg_1, batch_seg_2)
result = sur_metric.aggregate()
expected_value_curr = expected_value[ct]
np.testing.assert_allclose(expected_value_curr, result, rtol=1e-5)
np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-5)
np.testing.assert_equal(result.device, seg_1.device)
ct += 1

@parameterized.expand(TEST_CASES_NANS)
Expand Down

0 comments on commit 34c8f0d

Please sign in to comment.