Skip to content

Commit a9c70ce

Browse files
authored
Merge branch 'dev' into minor-fixes
2 parents 0ad9279 + 1a018a7 commit a9c70ce

File tree

7 files changed

+292
-4
lines changed

7 files changed

+292
-4
lines changed

docs/source/handlers.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ Peak signal to noise ratio metrics handler
101101
:members:
102102

103103

104+
Metrics reloaded binary handler
105+
-------------------------------
106+
.. autoclass:: MetricsReloadedBinaryHandler
107+
:members:
108+
109+
110+
Metrics reloaded categorical handler
111+
------------------------------------
112+
.. autoclass:: MetricsReloadedCategoricalHandler
113+
:members:
114+
115+
104116
Metric logger
105117
-------------
106118
.. autoclass:: MetricLogger

docs/source/metrics.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ Metrics
145145
.. autoclass:: CumulativeAverage
146146
:members:
147147

148+
`Metrics reloaded binary`
149+
-------------------------
150+
.. autoclass:: MetricsReloadedBinary
151+
:members:
152+
153+
`Metrics reloaded categorical`
154+
------------------------------
155+
.. autoclass:: MetricsReloadedCategorical
156+
:members:
157+
148158
Utilities
149159
---------
150160
.. automodule:: monai.metrics.utils

monai/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .mean_dice import MeanDice
2626
from .mean_iou import MeanIoUHandler
2727
from .metric_logger import MetricLogger, MetricLoggerKeys
28+
from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler
2829
from .metrics_saver import MetricsSaver
2930
from .mlflow_handler import MLFlowHandler
3031
from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections.abc import Callable
15+
16+
from monai.handlers.ignite_metric import IgniteMetric
17+
from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical
18+
from monai.utils.enums import MetricReduction
19+
20+
21+
class MetricsReloadedBinaryHandler(IgniteMetric):
22+
"""
23+
Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded.
24+
"""
25+
26+
def __init__(
27+
self,
28+
metric_name: str,
29+
include_background: bool = True,
30+
reduction: MetricReduction | str = MetricReduction.MEAN,
31+
get_not_nans: bool = False,
32+
output_transform: Callable = lambda x: x,
33+
save_details: bool = True,
34+
) -> None:
35+
"""
36+
37+
Args:
38+
metric_name: Name of a binary metric from the MetricsReloaded package.
39+
include_background: whether to skip computation on the first channel of
40+
the predicted output. Defaults to ``True``.
41+
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
42+
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
43+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
44+
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
45+
Here `not_nans` count the number of not nans for the metric,
46+
thus its shape equals to the shape of the metric.
47+
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
48+
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
49+
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
50+
`engine.state` and `output_transform` inherit from the ignite concept:
51+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
52+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
53+
save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image.
54+
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
55+
56+
See also:
57+
:py:meth:`monai.metrics.wrapper`
58+
"""
59+
metric_fn = MetricsReloadedBinary(
60+
metric_name=metric_name,
61+
include_background=include_background,
62+
reduction=reduction,
63+
get_not_nans=get_not_nans,
64+
)
65+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
66+
67+
68+
class MetricsReloadedCategoricalHandler(IgniteMetric):
69+
"""
70+
Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded.
71+
"""
72+
73+
def __init__(
74+
self,
75+
metric_name: str,
76+
include_background: bool = True,
77+
reduction: MetricReduction | str = MetricReduction.MEAN,
78+
get_not_nans: bool = False,
79+
smooth_dr: float = 1e-5,
80+
output_transform: Callable = lambda x: x,
81+
save_details: bool = True,
82+
) -> None:
83+
"""
84+
85+
Args:
86+
metric_name: Name of a categorical metric from the MetricsReloaded package.
87+
include_background: whether to skip computation on the first channel of
88+
the predicted output. Defaults to ``True``.
89+
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
90+
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
91+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
92+
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
93+
Here `not_nans` count the number of not nans for the metric,
94+
thus its shape equals to the shape of the metric.
95+
smooth_dr: a small constant added to the denominator to avoid nan. OBS: should be greater than zero.
96+
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
97+
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
98+
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
99+
`engine.state` and `output_transform` inherit from the ignite concept:
100+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
101+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
102+
save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image.
103+
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
104+
105+
See also:
106+
:py:meth:`monai.metrics.wrapper`
107+
"""
108+
metric_fn = MetricsReloadedCategorical(
109+
metric_name=metric_name,
110+
include_background=include_background,
111+
reduction=reduction,
112+
get_not_nans=get_not_nans,
113+
smooth_dr=smooth_dr,
114+
)
115+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)

monai/metrics/wrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class MetricsReloadedWrapper(CumulativeIterationMetric):
3030
"""Base class for defining MetricsReloaded metrics as a CumulativeIterationMetric.
3131
3232
Args:
33-
metric_name: Name of a binary metric from the MetricsReloaded package.
34-
include_background: whether to skip Dice computation on the first channel of
33+
metric_name: Name of a metric from the MetricsReloaded package.
34+
include_background: whether to skip computation on the first channel of
3535
the predicted output. Defaults to ``True``.
3636
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
3737
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
@@ -80,7 +80,7 @@ class MetricsReloadedBinary(MetricsReloadedWrapper):
8080
8181
Args:
8282
metric_name: Name of a binary metric from the MetricsReloaded package.
83-
include_background: whether to skip Dice computation on the first channel of
83+
include_background: whether to skip computation on the first channel of
8484
the predicted output. Defaults to ``True``.
8585
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
8686
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
@@ -185,7 +185,7 @@ class MetricsReloadedCategorical(MetricsReloadedWrapper):
185185
186186
Args:
187187
metric_name: Name of a categorical metric from the MetricsReloaded package.
188-
include_background: whether to skip Dice computation on the first channel of
188+
include_background: whether to skip computation on the first channel of
189189
the predicted output. Defaults to ``True``.
190190
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
191191
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def run_testsuit():
6767
"test_global_mutual_information_loss",
6868
"test_grid_patch",
6969
"test_gmm",
70+
"test_handler_metrics_reloaded",
7071
"test_handler_checkpoint_loader",
7172
"test_handler_checkpoint_saver",
7273
"test_handler_classification_saver",
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from ignite.engine import Engine, Events
18+
from parameterized import parameterized
19+
20+
from monai.handlers import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler, from_engine
21+
from monai.utils import optional_import
22+
from tests.utils import assert_allclose
23+
24+
_, has_metrics = optional_import("MetricsReloaded")
25+
26+
TEST_CASE_BIN_1 = [
27+
{"metric_name": "Volume Difference"},
28+
[torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])],
29+
[torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])],
30+
0.3333,
31+
]
32+
33+
TEST_CASE_BIN_2 = [
34+
{"metric_name": "Boundary IoU"},
35+
[torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])],
36+
[torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])],
37+
0.6667,
38+
]
39+
40+
TEST_CASE_BIN_3 = [
41+
{"metric_name": "xTh Percentile Hausdorff Distance"},
42+
[torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])],
43+
[torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])],
44+
0.9,
45+
]
46+
47+
TEST_CASE_CAT_1 = [
48+
{"metric_name": "Weighted Cohens Kappa"},
49+
[
50+
torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),
51+
torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),
52+
],
53+
[
54+
torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),
55+
torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),
56+
],
57+
0.272727,
58+
]
59+
60+
TEST_CASE_CAT_2 = [
61+
{"metric_name": "Matthews Correlation Coefficient"},
62+
[
63+
torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),
64+
torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),
65+
],
66+
[
67+
torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),
68+
torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),
69+
],
70+
0.387298,
71+
]
72+
73+
74+
@unittest.skipIf(not has_metrics, "MetricsReloaded not available.")
75+
class TestHandlerMetricsReloadedBinary(unittest.TestCase):
76+
@parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3])
77+
def test_compute(self, input_params, y_pred, y, expected_value):
78+
input_params["output_transform"] = from_engine(["pred", "label"])
79+
metric = MetricsReloadedBinaryHandler(**input_params)
80+
81+
# set up engine
82+
83+
def _val_func(engine, batch):
84+
pass
85+
86+
engine = Engine(_val_func)
87+
metric.attach(engine=engine, name=input_params["metric_name"])
88+
engine.state.output = {"pred": y_pred, "label": y}
89+
engine.fire_event(Events.ITERATION_COMPLETED)
90+
91+
engine.state.output = {"pred": y_pred, "label": y}
92+
engine.fire_event(Events.ITERATION_COMPLETED)
93+
94+
engine.fire_event(Events.EPOCH_COMPLETED)
95+
assert_allclose(
96+
engine.state.metrics[input_params["metric_name"]], expected_value, atol=1e-4, rtol=1e-4, type_test=False
97+
)
98+
99+
@parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3])
100+
def test_shape_mismatch(self, input_params, _y_pred, _y, _expected_value):
101+
input_params["output_transform"] = from_engine(["pred", "label"])
102+
metric = MetricsReloadedBinaryHandler(**input_params)
103+
with self.assertRaises((AssertionError, ValueError)):
104+
y_pred = torch.Tensor([[0, 1], [1, 0]])
105+
y = torch.ones((2, 3))
106+
metric.update([y_pred, y])
107+
108+
with self.assertRaises((AssertionError, ValueError)):
109+
y_pred = [torch.ones((2, 1, 1)), torch.ones((1, 1, 1))]
110+
y = [torch.ones((2, 1, 1)), torch.ones((1, 1, 1))]
111+
metric.update([y_pred, y])
112+
113+
114+
@unittest.skipIf(not has_metrics, "MetricsReloaded not available.")
115+
class TestMetricsReloadedCategorical(unittest.TestCase):
116+
@parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2])
117+
def test_compute(self, input_params, y_pred, y, expected_value):
118+
input_params["output_transform"] = from_engine(["pred", "label"])
119+
metric = MetricsReloadedCategoricalHandler(**input_params)
120+
121+
# set up engine
122+
123+
def _val_func(engine, batch):
124+
pass
125+
126+
engine = Engine(_val_func)
127+
metric.attach(engine=engine, name=input_params["metric_name"])
128+
engine.state.output = {"pred": y_pred, "label": y}
129+
engine.fire_event(Events.ITERATION_COMPLETED)
130+
131+
engine.state.output = {"pred": y_pred, "label": y}
132+
engine.fire_event(Events.ITERATION_COMPLETED)
133+
134+
engine.fire_event(Events.EPOCH_COMPLETED)
135+
assert_allclose(
136+
engine.state.metrics[input_params["metric_name"]], expected_value, atol=1e-4, rtol=1e-4, type_test=False
137+
)
138+
139+
@parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2])
140+
def test_shape_mismatch(self, input_params, y_pred, y, _expected_value):
141+
input_params["output_transform"] = from_engine(["pred", "label"])
142+
metric = MetricsReloadedCategoricalHandler(**input_params)
143+
with self.assertRaises((AssertionError, ValueError)):
144+
y_pred[0] = torch.zeros([3, 2, 1])
145+
metric.update([y_pred, y])
146+
147+
148+
if __name__ == "__main__":
149+
unittest.main()

0 commit comments

Comments
 (0)