Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6409 support class indices y_pred DiceHelper #6412

Merged
merged 3 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
self,
include_background: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
num_classes: int | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
) -> None:
Expand All @@ -38,6 +39,9 @@ def __init__(
reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
Expand All @@ -50,5 +54,5 @@ def __init__(
See also:
:py:meth:`monai.metrics.meandice.compute_dice`
"""
metric_fn = DiceMetric(include_background=include_background, reduction=reduction)
metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
44 changes: 34 additions & 10 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class DiceMetric(CumulativeIterationMetric):
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.

"""

Expand All @@ -56,18 +59,21 @@ def __init__(
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
ignore_empty: bool = True,
num_classes: int | None = None,
) -> None:
super().__init__()
self.include_background = include_background
self.reduction = reduction
self.get_not_nans = get_not_nans
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
softmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -110,20 +116,26 @@ def aggregate(


def compute_dice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
) -> torch.Tensor:
"""Computes Dice score metric for a batch of predictions.

Args:
y_pred: input data to compute, typical segmentation model output.
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
should be binarized.
`y_pred` can be single-channel class indices or in the one-hot format.
y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
include_background: whether to skip Dice computation on the first channel of
the predicted output. Defaults to True.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.

Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -135,13 +147,14 @@ def compute_dice(
get_not_nans=False,
softmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
)(y_pred=y_pred, y=y)


class DiceHelper:
"""
Compute Dice score between two tensors `y_pred` and `y`.
`y_pred` must have N channels, `y` can be single-channel class indices or in the one-hot format.
`y_pred` and `y` can be single-channel class indices or in the one-hot format.

Example:

Expand Down Expand Up @@ -170,6 +183,7 @@ def __init__(
get_not_nans: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
ignore_empty: bool = True,
num_classes: int | None = None,
) -> None:
"""

Expand All @@ -186,6 +200,9 @@ def __init__(
reduction: define mode of reduction to the metrics
ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
"""
self.sigmoid = sigmoid
self.reduction = reduction
Expand All @@ -194,6 +211,7 @@ def __init__(
self.softmax = not sigmoid if softmax is None else softmax
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
""""""
Expand All @@ -211,17 +229,23 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
"""

Args:
y_pred: input predictions with shape (batch_size, num_classes, spatial_dims...).
the number of channels is inferred from ``y_pred.shape[1]``.
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
"""
n_pred_ch = y_pred.shape[1]

if self.softmax:
_softmax, _sigmoid = self.softmax, self.sigmoid
if self.num_classes is None:
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
else:
n_pred_ch = self.num_classes
if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices
_softmax = _sigmoid = False

if _softmax:
if n_pred_ch > 1:
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)

elif self.sigmoid:
elif _sigmoid:
if self.activate:
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5
Expand Down
8 changes: 6 additions & 2 deletions tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,16 @@ def test_helper(self, input_data, _unused):
result = DiceHelper(softmax=True, get_not_nans=False)(**vals)
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)

num_classes = vals["y_pred"].shape[1]
vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True)
result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals)
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)

# DiceMetric class tests
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10])
def test_value_class(self, input_data, expected_value):
# same test as for compute_dice
vals = {}
vals["y_pred"] = input_data.pop("y_pred")
vals = {"y_pred": input_data.pop("y_pred")}
vals["y"] = input_data.pop("y")
dice_metric = DiceMetric(**input_data)
dice_metric(**vals)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_handler_mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class TestHandlerMeanDice(unittest.TestCase):
def test_compute(self, input_params, expected_avg, details_shape):
dice_metric = MeanDice(**input_params)

# set up engine

def _val_func(engine, batch):
pass

Expand Down Expand Up @@ -71,6 +69,30 @@ def test_shape_mismatch(self, input_params, _expected_avg, _details_shape):
y = torch.ones((3, 2))
dice_metric.update([y_pred, y])

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_compute_n_class(self, input_params, expected_avg, details_shape):
dice_metric = MeanDice(num_classes=2, **input_params)

def _val_func(engine, batch):
pass

engine = Engine(_val_func)
dice_metric.attach(engine=engine, name="mean_dice")
# test input a list of channel-first tensor
y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])]
y = torch.Tensor([[[0], [1]], [[0], [1]]])
engine.state.output = {"pred": y_pred, "label": y}
engine.fire_event(Events.ITERATION_COMPLETED)

y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] # class indices y_pred
y = torch.Tensor([[[1]], [[0]]]) # class indices y
engine.state.output = {"pred": y_pred, "label": y}
engine.fire_event(Events.ITERATION_COMPLETED)

engine.fire_event(Events.EPOCH_COMPLETED)
assert_allclose(engine.state.metrics["mean_dice"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)
self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape)


if __name__ == "__main__":
unittest.main()