Skip to content

Commit 85243f5

Browse files
authored
Support return dice for each class in DiceMetric (#7163)
Fixes #7162 Fixes #7164 ### Description Add `return_with_label`, if True or a list, will return the metrics with the corresponding label name, only works when reduction="mean_batch". https://github.com/pytorch/ignite/blob/47b95d087a0f8713a9d24bcfe3a539b08101ba7a/ignite/metrics/metric.py#L424 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent cc20c9b commit 85243f5

File tree

4 files changed

+100
-3
lines changed

4 files changed

+100
-3
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def generate_apidocs(*args):
126126
{"name": "Twitter", "url": "https://twitter.com/projectmonai", "icon": "fab fa-twitter-square"},
127127
],
128128
"collapse_navigation": True,
129+
"navigation_with_keys": True,
129130
"navigation_depth": 1,
130131
"show_toc_level": 1,
131132
"footer_start": ["copyright"],

monai/handlers/mean_dice.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
num_classes: int | None = None,
3131
output_transform: Callable = lambda x: x,
3232
save_details: bool = True,
33+
return_with_label: bool | list[str] = False,
3334
) -> None:
3435
"""
3536
@@ -50,9 +51,18 @@ def __init__(
5051
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
5152
save_details: whether to save metric computation details per image, for example: mean dice of every image.
5253
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
54+
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
55+
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
56+
the index begins at "0", otherwise at "1". It can also take a list of label names.
57+
The outcome will then be returned as a dictionary.
5358
5459
See also:
5560
:py:meth:`monai.metrics.meandice.compute_dice`
5661
"""
57-
metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes)
62+
metric_fn = DiceMetric(
63+
include_background=include_background,
64+
reduction=reduction,
65+
num_classes=num_classes,
66+
return_with_label=return_with_label,
67+
)
5868
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)

monai/metrics/meandice.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class DiceMetric(CumulativeIterationMetric):
5050
num_classes: number of input channels (always including the background). When this is None,
5151
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
5252
single-channel class indices and the number of classes is not automatically inferred from data.
53+
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
54+
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
55+
the index begins at "0", otherwise at "1". It can also take a list of label names.
56+
The outcome will then be returned as a dictionary.
5357
5458
"""
5559

@@ -60,13 +64,15 @@ def __init__(
6064
get_not_nans: bool = False,
6165
ignore_empty: bool = True,
6266
num_classes: int | None = None,
67+
return_with_label: bool | list[str] = False,
6368
) -> None:
6469
super().__init__()
6570
self.include_background = include_background
6671
self.reduction = reduction
6772
self.get_not_nans = get_not_nans
6873
self.ignore_empty = ignore_empty
6974
self.num_classes = num_classes
75+
self.return_with_label = return_with_label
7076
self.dice_helper = DiceHelper(
7177
include_background=self.include_background,
7278
reduction=MetricReduction.NONE,
@@ -112,6 +118,16 @@ def aggregate(
112118

113119
# do metric reduction
114120
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
121+
if self.reduction == MetricReduction.MEAN_BATCH and self.return_with_label:
122+
_f = {}
123+
if isinstance(self.return_with_label, bool):
124+
for i, v in enumerate(f):
125+
_label_key = f"label_{i+1}" if not self.include_background else f"label_{i}"
126+
_f[_label_key] = round(v.item(), 4)
127+
else:
128+
for key, v in zip(self.return_with_label, f):
129+
_f[key] = round(v.item(), 4)
130+
f = _f
115131
return (f, not_nans) if self.get_not_nans else f
116132

117133

tests/test_compute_meandice.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,71 @@
185185
[[0.0000, 0.0000], [0.0000, 0.0000]],
186186
]
187187

188+
# test return_with_label
189+
TEST_CASE_13 = [
190+
{
191+
"include_background": True,
192+
"reduction": "mean_batch",
193+
"get_not_nans": True,
194+
"return_with_label": ["bg", "fg0", "fg1"],
195+
},
196+
{
197+
"y_pred": torch.tensor(
198+
[
199+
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
200+
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
201+
]
202+
),
203+
"y": torch.tensor(
204+
[
205+
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
206+
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
207+
]
208+
),
209+
},
210+
{"bg": 0.6786, "fg0": 0.4000, "fg1": 0.6667},
211+
]
212+
213+
# test return_with_label, include_background
214+
TEST_CASE_14 = [
215+
{"include_background": True, "reduction": "mean_batch", "get_not_nans": True, "return_with_label": True},
216+
{
217+
"y_pred": torch.tensor(
218+
[
219+
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
220+
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
221+
]
222+
),
223+
"y": torch.tensor(
224+
[
225+
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
226+
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
227+
]
228+
),
229+
},
230+
{"label_0": 0.6786, "label_1": 0.4000, "label_2": 0.6667},
231+
]
232+
233+
# test return_with_label, not include_background
234+
TEST_CASE_15 = [
235+
{"include_background": False, "reduction": "mean_batch", "get_not_nans": True, "return_with_label": True},
236+
{
237+
"y_pred": torch.tensor(
238+
[
239+
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
240+
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
241+
]
242+
),
243+
"y": torch.tensor(
244+
[
245+
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
246+
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
247+
]
248+
),
249+
},
250+
{"label_1": 0.4000, "label_2": 0.6667},
251+
]
252+
188253

189254
class TestComputeMeanDice(unittest.TestCase):
190255
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
@@ -223,12 +288,17 @@ def test_value_class(self, input_data, expected_value):
223288
result = dice_metric.aggregate(reduction="none")
224289
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
225290

226-
@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8])
291+
@parameterized.expand(
292+
[TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15]
293+
)
227294
def test_nans_class(self, params, input_data, expected_value):
228295
dice_metric = DiceMetric(**params)
229296
dice_metric(**input_data)
230297
result, _ = dice_metric.aggregate()
231-
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
298+
if isinstance(result, dict):
299+
self.assertEqual(result, expected_value)
300+
else:
301+
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
232302

233303

234304
if __name__ == "__main__":

0 commit comments

Comments
 (0)