Skip to content

Commit 6602303

Browse files
added 2D (normalized) surface dice metric (#4050)
* added 2D (normalized) surface dice metric Signed-off-by: Silvia Seidlitz <s.seidlitz@dkfz-heidelberg.de> * exclude from min tests Signed-off-by: Wenqi Li <wenqil@nvidia.com> * more detailled docstring Signed-off-by: Silvia Seidlitz <s.seidlitz@dkfz-heidelberg.de>
1 parent 17529e7 commit 6602303

File tree

5 files changed

+537
-0
lines changed

5 files changed

+537
-0
lines changed

docs/source/metrics.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ Metrics
6969
.. autoclass:: SurfaceDistanceMetric
7070
:members:
7171

72+
`Surface dice`
73+
--------------
74+
.. autofunction:: compute_surface_dice
75+
76+
.. autoclass:: SurfaceDiceMetric
77+
:members:
78+
7279
`Mean squared error`
7380
--------------------
7481
.. autoclass:: MSEMetric

monai/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
1818
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
1919
from .rocauc import ROCAUCMetric, compute_roc_auc
20+
from .surface_dice import SurfaceDiceMetric, compute_surface_dice
2021
from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance
2122
from .utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background

monai/metrics/surface_dice.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import warnings
13+
from typing import List, Union
14+
15+
import numpy as np
16+
import torch
17+
18+
from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background
19+
from monai.utils import MetricReduction, convert_data_type
20+
21+
from .metric import CumulativeIterationMetric
22+
23+
24+
class SurfaceDiceMetric(CumulativeIterationMetric):
25+
"""
26+
Computes the Normalized Surface Distance (NSD) for each batch sample and class of
27+
predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`.
28+
This implementation supports 2D images. For 3D images, please refer to DeepMind's implementation
29+
https://github.com/deepmind/surface-distance.
30+
31+
The class- and batch sample-wise NSD values can be aggregated with the function `aggregate`.
32+
33+
Args:
34+
class_thresholds: List of class-specific thresholds.
35+
The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels.
36+
Each threshold needs to be a finite, non-negative number.
37+
include_background: Whether to skip NSD computation on the first channel of the predicted output.
38+
Defaults to ``False``.
39+
distance_metric: The metric used to compute surface distances.
40+
One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``].
41+
Defaults to ``"euclidean"``.
42+
reduction: The mode to aggregate metrics.
43+
One of [``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``,
44+
``"none"``].
45+
Defaults to ``"mean"``.
46+
If ``"none"`` is chosen, no aggregation will be performed.
47+
The aggregation will ignore nan values.
48+
get_not_nans: whether to return the `not_nans` count.
49+
Defaults to ``False``.
50+
`not_nans` is the number of batch samples for which not all class-specific NSD values were nan values.
51+
If set to ``True``, the function `aggregate` will return both the aggregated NSD and the `not_nans` count.
52+
If set to ``False``, `aggregate` will only return the aggregated NSD.
53+
"""
54+
55+
def __init__(
56+
self,
57+
class_thresholds: List[float],
58+
include_background: bool = False,
59+
distance_metric: str = "euclidean",
60+
reduction: Union[MetricReduction, str] = MetricReduction.MEAN,
61+
get_not_nans: bool = False,
62+
) -> None:
63+
super().__init__()
64+
self.class_thresholds = class_thresholds
65+
self.include_background = include_background
66+
self.distance_metric = distance_metric
67+
self.reduction = reduction
68+
self.get_not_nans = get_not_nans
69+
70+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore
71+
r"""
72+
Args:
73+
y_pred: Predicted segmentation, typically segmentation model output.
74+
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
75+
y: Reference segmentation.
76+
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
77+
78+
Returns:
79+
Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch
80+
index :math:`b` and class :math:`c`.
81+
"""
82+
return compute_surface_dice(
83+
y_pred=y_pred,
84+
y=y,
85+
class_thresholds=self.class_thresholds,
86+
include_background=self.include_background,
87+
distance_metric=self.distance_metric,
88+
)
89+
90+
def aggregate(self):
91+
r"""
92+
Aggregates the output of `_compute_tensor`.
93+
94+
Returns:
95+
If `get_not_nans` is set to ``True``, this function returns the aggregated NSD and the `not_nans` count.
96+
If `get_not_nans` is set to ``False``, this function returns only the aggregated NSD.
97+
"""
98+
data = self.get_buffer()
99+
if not isinstance(data, torch.Tensor):
100+
raise ValueError("the data to aggregate must be PyTorch Tensor.")
101+
102+
# do metric reduction
103+
f, not_nans = do_metric_reduction(data, self.reduction)
104+
return (f, not_nans) if self.get_not_nans else f
105+
106+
107+
def compute_surface_dice(
108+
y_pred: torch.Tensor,
109+
y: torch.Tensor,
110+
class_thresholds: List[float],
111+
include_background: bool = False,
112+
distance_metric: str = "euclidean",
113+
):
114+
r"""
115+
This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as
116+
:math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation
117+
boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the
118+
reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in
119+
pixels. The NSD is bounded between 0 and 1.
120+
121+
This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`.
122+
The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function:
123+
124+
.. math::
125+
\operatorname {NSD}_{b,c} \left(Y_{b,c}, \hat{Y}_{b,c}\right) = \frac{\left|\mathcal{D}_{Y_{b,c}}^{'}\right| +
126+
\left| \mathcal{D}_{\hat{Y}_{b,c}}^{'} \right|}{\left|\mathcal{D}_{Y_{b,c}}\right| +
127+
\left|\mathcal{D}_{\hat{Y}_{b,c}}\right|}
128+
:label: nsd
129+
130+
with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor
131+
distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation
132+
boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
133+
:math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the
134+
acceptable distance :math:`\tau_c`:
135+
136+
.. math::
137+
\mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}.
138+
139+
140+
In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value
141+
will be returned for this class. In the case of a class being present in only one of predicted segmentation or
142+
reference segmentation, the class NSD will be 0.
143+
144+
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images.
145+
Be aware that the computation of boundaries is different from DeepMind's implementation
146+
https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is
147+
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
148+
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
149+
150+
Args:
151+
y_pred: Predicted segmentation, typically segmentation model output.
152+
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
153+
y: Reference segmentation.
154+
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
155+
class_thresholds: List of class-specific thresholds.
156+
The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels.
157+
Each threshold needs to be a finite, non-negative number.
158+
include_background: Whether to skip the surface dice computation on the first channel of
159+
the predicted output. Defaults to ``False``.
160+
distance_metric: The metric used to compute surface distances.
161+
One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``].
162+
Defaults to ``"euclidean"``.
163+
164+
Raises:
165+
ValueError: If `y_pred` and/or `y` are not PyTorch tensors.
166+
ValueError: If `y_pred` and/or `y` do not have four dimensions.
167+
ValueError: If `y_pred` and/or `y` have different shapes.
168+
ValueError: If `y_pred` and/or `y` are not one-hot encoded
169+
ValueError: If the number of channels of `y_pred` and/or `y` is different from the number of class thresholds.
170+
ValueError: If any class threshold is not finite.
171+
ValueError: If any class threshold is negative.
172+
173+
Returns:
174+
Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index
175+
:math:`b` and class :math:`c`.
176+
"""
177+
178+
if not include_background:
179+
y_pred, y = ignore_background(y_pred=y_pred, y=y)
180+
181+
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
182+
raise ValueError("y_pred and y must be PyTorch Tensor.")
183+
184+
if y_pred.ndimension() != 4 or y.ndimension() != 4:
185+
raise ValueError("y_pred and y should have four dimensions: [B,C,H,W].")
186+
187+
if y_pred.shape != y.shape:
188+
raise ValueError(
189+
f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)."
190+
)
191+
192+
if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y):
193+
raise ValueError("y_pred and y should be binarized tensors (e.g. torch.int64).")
194+
if torch.any(y_pred > 1) or torch.any(y > 1):
195+
raise ValueError("y_pred and y should be one-hot encoded.")
196+
197+
y = y.float()
198+
y_pred = y_pred.float()
199+
200+
batch_size, n_class = y_pred.shape[:2]
201+
202+
if n_class != len(class_thresholds):
203+
raise ValueError(
204+
f"number of classes ({n_class}) does not match number of class thresholds ({len(class_thresholds)})."
205+
)
206+
207+
if any(~np.isfinite(class_thresholds)):
208+
raise ValueError("All class thresholds need to be finite.")
209+
210+
if any(np.array(class_thresholds) < 0):
211+
raise ValueError("All class thresholds need to be >= 0.")
212+
213+
nsd = np.empty((batch_size, n_class))
214+
215+
for b, c in np.ndindex(batch_size, n_class):
216+
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False)
217+
if not np.any(edges_gt):
218+
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
219+
if not np.any(edges_pred):
220+
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
221+
222+
distances_pred_gt = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric)
223+
distances_gt_pred = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric)
224+
225+
boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
226+
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
227+
distances_gt_pred <= class_thresholds[c]
228+
)
229+
230+
if boundary_complete == 0:
231+
# the class is neither present in the prediction, nor in the reference segmentation
232+
nsd[b, c] = np.nan
233+
else:
234+
nsd[b, c] = boundary_correct / boundary_complete
235+
236+
return convert_data_type(nsd, torch.Tensor)[0]

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def run_testsuit():
145145
"test_spacingd",
146146
"test_splitdimd",
147147
"test_surface_distance",
148+
"test_surface_dice",
148149
"test_testtimeaugmentation",
149150
"test_torchvision",
150151
"test_torchvisiond",

0 commit comments

Comments
 (0)