|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import warnings |
| 13 | +from typing import List, Union |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import torch |
| 17 | + |
| 18 | +from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background |
| 19 | +from monai.utils import MetricReduction, convert_data_type |
| 20 | + |
| 21 | +from .metric import CumulativeIterationMetric |
| 22 | + |
| 23 | + |
| 24 | +class SurfaceDiceMetric(CumulativeIterationMetric): |
| 25 | + """ |
| 26 | + Computes the Normalized Surface Distance (NSD) for each batch sample and class of |
| 27 | + predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`. |
| 28 | + This implementation supports 2D images. For 3D images, please refer to DeepMind's implementation |
| 29 | + https://github.com/deepmind/surface-distance. |
| 30 | +
|
| 31 | + The class- and batch sample-wise NSD values can be aggregated with the function `aggregate`. |
| 32 | +
|
| 33 | + Args: |
| 34 | + class_thresholds: List of class-specific thresholds. |
| 35 | + The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels. |
| 36 | + Each threshold needs to be a finite, non-negative number. |
| 37 | + include_background: Whether to skip NSD computation on the first channel of the predicted output. |
| 38 | + Defaults to ``False``. |
| 39 | + distance_metric: The metric used to compute surface distances. |
| 40 | + One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]. |
| 41 | + Defaults to ``"euclidean"``. |
| 42 | + reduction: The mode to aggregate metrics. |
| 43 | + One of [``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``, |
| 44 | + ``"none"``]. |
| 45 | + Defaults to ``"mean"``. |
| 46 | + If ``"none"`` is chosen, no aggregation will be performed. |
| 47 | + The aggregation will ignore nan values. |
| 48 | + get_not_nans: whether to return the `not_nans` count. |
| 49 | + Defaults to ``False``. |
| 50 | + `not_nans` is the number of batch samples for which not all class-specific NSD values were nan values. |
| 51 | + If set to ``True``, the function `aggregate` will return both the aggregated NSD and the `not_nans` count. |
| 52 | + If set to ``False``, `aggregate` will only return the aggregated NSD. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + class_thresholds: List[float], |
| 58 | + include_background: bool = False, |
| 59 | + distance_metric: str = "euclidean", |
| 60 | + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, |
| 61 | + get_not_nans: bool = False, |
| 62 | + ) -> None: |
| 63 | + super().__init__() |
| 64 | + self.class_thresholds = class_thresholds |
| 65 | + self.include_background = include_background |
| 66 | + self.distance_metric = distance_metric |
| 67 | + self.reduction = reduction |
| 68 | + self.get_not_nans = get_not_nans |
| 69 | + |
| 70 | + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore |
| 71 | + r""" |
| 72 | + Args: |
| 73 | + y_pred: Predicted segmentation, typically segmentation model output. |
| 74 | + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. |
| 75 | + y: Reference segmentation. |
| 76 | + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. |
| 77 | +
|
| 78 | + Returns: |
| 79 | + Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch |
| 80 | + index :math:`b` and class :math:`c`. |
| 81 | + """ |
| 82 | + return compute_surface_dice( |
| 83 | + y_pred=y_pred, |
| 84 | + y=y, |
| 85 | + class_thresholds=self.class_thresholds, |
| 86 | + include_background=self.include_background, |
| 87 | + distance_metric=self.distance_metric, |
| 88 | + ) |
| 89 | + |
| 90 | + def aggregate(self): |
| 91 | + r""" |
| 92 | + Aggregates the output of `_compute_tensor`. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + If `get_not_nans` is set to ``True``, this function returns the aggregated NSD and the `not_nans` count. |
| 96 | + If `get_not_nans` is set to ``False``, this function returns only the aggregated NSD. |
| 97 | + """ |
| 98 | + data = self.get_buffer() |
| 99 | + if not isinstance(data, torch.Tensor): |
| 100 | + raise ValueError("the data to aggregate must be PyTorch Tensor.") |
| 101 | + |
| 102 | + # do metric reduction |
| 103 | + f, not_nans = do_metric_reduction(data, self.reduction) |
| 104 | + return (f, not_nans) if self.get_not_nans else f |
| 105 | + |
| 106 | + |
| 107 | +def compute_surface_dice( |
| 108 | + y_pred: torch.Tensor, |
| 109 | + y: torch.Tensor, |
| 110 | + class_thresholds: List[float], |
| 111 | + include_background: bool = False, |
| 112 | + distance_metric: str = "euclidean", |
| 113 | +): |
| 114 | + r""" |
| 115 | + This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as |
| 116 | + :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation |
| 117 | + boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the |
| 118 | + reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in |
| 119 | + pixels. The NSD is bounded between 0 and 1. |
| 120 | +
|
| 121 | + This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`. |
| 122 | + The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function: |
| 123 | +
|
| 124 | + .. math:: |
| 125 | + \operatorname {NSD}_{b,c} \left(Y_{b,c}, \hat{Y}_{b,c}\right) = \frac{\left|\mathcal{D}_{Y_{b,c}}^{'}\right| + |
| 126 | + \left| \mathcal{D}_{\hat{Y}_{b,c}}^{'} \right|}{\left|\mathcal{D}_{Y_{b,c}}\right| + |
| 127 | + \left|\mathcal{D}_{\hat{Y}_{b,c}}\right|} |
| 128 | + :label: nsd |
| 129 | +
|
| 130 | + with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor |
| 131 | + distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation |
| 132 | + boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and |
| 133 | + :math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the |
| 134 | + acceptable distance :math:`\tau_c`: |
| 135 | +
|
| 136 | + .. math:: |
| 137 | + \mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}. |
| 138 | +
|
| 139 | +
|
| 140 | + In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value |
| 141 | + will be returned for this class. In the case of a class being present in only one of predicted segmentation or |
| 142 | + reference segmentation, the class NSD will be 0. |
| 143 | +
|
| 144 | + This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images. |
| 145 | + Be aware that the computation of boundaries is different from DeepMind's implementation |
| 146 | + https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is |
| 147 | + interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary |
| 148 | + depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). |
| 149 | +
|
| 150 | + Args: |
| 151 | + y_pred: Predicted segmentation, typically segmentation model output. |
| 152 | + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. |
| 153 | + y: Reference segmentation. |
| 154 | + It must be a one-hot encoded, batch-first tensor [B,C,H,W]. |
| 155 | + class_thresholds: List of class-specific thresholds. |
| 156 | + The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels. |
| 157 | + Each threshold needs to be a finite, non-negative number. |
| 158 | + include_background: Whether to skip the surface dice computation on the first channel of |
| 159 | + the predicted output. Defaults to ``False``. |
| 160 | + distance_metric: The metric used to compute surface distances. |
| 161 | + One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]. |
| 162 | + Defaults to ``"euclidean"``. |
| 163 | +
|
| 164 | + Raises: |
| 165 | + ValueError: If `y_pred` and/or `y` are not PyTorch tensors. |
| 166 | + ValueError: If `y_pred` and/or `y` do not have four dimensions. |
| 167 | + ValueError: If `y_pred` and/or `y` have different shapes. |
| 168 | + ValueError: If `y_pred` and/or `y` are not one-hot encoded |
| 169 | + ValueError: If the number of channels of `y_pred` and/or `y` is different from the number of class thresholds. |
| 170 | + ValueError: If any class threshold is not finite. |
| 171 | + ValueError: If any class threshold is negative. |
| 172 | +
|
| 173 | + Returns: |
| 174 | + Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index |
| 175 | + :math:`b` and class :math:`c`. |
| 176 | + """ |
| 177 | + |
| 178 | + if not include_background: |
| 179 | + y_pred, y = ignore_background(y_pred=y_pred, y=y) |
| 180 | + |
| 181 | + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): |
| 182 | + raise ValueError("y_pred and y must be PyTorch Tensor.") |
| 183 | + |
| 184 | + if y_pred.ndimension() != 4 or y.ndimension() != 4: |
| 185 | + raise ValueError("y_pred and y should have four dimensions: [B,C,H,W].") |
| 186 | + |
| 187 | + if y_pred.shape != y.shape: |
| 188 | + raise ValueError( |
| 189 | + f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." |
| 190 | + ) |
| 191 | + |
| 192 | + if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y): |
| 193 | + raise ValueError("y_pred and y should be binarized tensors (e.g. torch.int64).") |
| 194 | + if torch.any(y_pred > 1) or torch.any(y > 1): |
| 195 | + raise ValueError("y_pred and y should be one-hot encoded.") |
| 196 | + |
| 197 | + y = y.float() |
| 198 | + y_pred = y_pred.float() |
| 199 | + |
| 200 | + batch_size, n_class = y_pred.shape[:2] |
| 201 | + |
| 202 | + if n_class != len(class_thresholds): |
| 203 | + raise ValueError( |
| 204 | + f"number of classes ({n_class}) does not match number of class thresholds ({len(class_thresholds)})." |
| 205 | + ) |
| 206 | + |
| 207 | + if any(~np.isfinite(class_thresholds)): |
| 208 | + raise ValueError("All class thresholds need to be finite.") |
| 209 | + |
| 210 | + if any(np.array(class_thresholds) < 0): |
| 211 | + raise ValueError("All class thresholds need to be >= 0.") |
| 212 | + |
| 213 | + nsd = np.empty((batch_size, n_class)) |
| 214 | + |
| 215 | + for b, c in np.ndindex(batch_size, n_class): |
| 216 | + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False) |
| 217 | + if not np.any(edges_gt): |
| 218 | + warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") |
| 219 | + if not np.any(edges_pred): |
| 220 | + warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") |
| 221 | + |
| 222 | + distances_pred_gt = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) |
| 223 | + distances_gt_pred = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) |
| 224 | + |
| 225 | + boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) |
| 226 | + boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( |
| 227 | + distances_gt_pred <= class_thresholds[c] |
| 228 | + ) |
| 229 | + |
| 230 | + if boundary_complete == 0: |
| 231 | + # the class is neither present in the prediction, nor in the reference segmentation |
| 232 | + nsd[b, c] = np.nan |
| 233 | + else: |
| 234 | + nsd[b, c] = boundary_correct / boundary_complete |
| 235 | + |
| 236 | + return convert_data_type(nsd, torch.Tensor)[0] |
0 commit comments