Skip to content

Commit

Permalink
New segmentation metric: Hausdorff Distance (#2122)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Bas Krahmer <baskrahmer@gmail.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka B <j.borovec+github@gmail.com>
  • Loading branch information
6 people authored Oct 14, 2024
1 parent 13f7b94 commit 1335c7b
Show file tree
Hide file tree
Showing 12 changed files with 538 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776))


- Added `HausdorffDistance` to segmentation package ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122))


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
.. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis
21 changes: 21 additions & 0 deletions docs/source/segmentation/hausdorff_distance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Hausdorff Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg
:tags: segmentation

.. include:: ../links.rst

##################
Hausdorff Distance
##################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.HausdorffDistance
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.hausdorff_distance
Empty file added requirements/integrate.txt
Empty file.
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score
from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance
from torchmetrics.functional.segmentation.mean_iou import mean_iou

__all__ = ["generalized_dice_score", "mean_iou"]
__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance"]
114 changes: 114 additions & 0 deletions src/torchmetrics/functional/segmentation/hausdorff_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Literal, Optional, Union

import torch
from torch import Tensor

from torchmetrics.functional.segmentation.utils import (
_ignore_background,
edge_surface_distance,
)
from torchmetrics.utilities.checks import _check_same_shape


def _hausdorff_distance_validate_args(
num_classes: int,
include_background: bool,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
directed: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> None:
"""Validate the arguments of `hausdorff_distance` function."""
if num_classes <= 0:
raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
if not isinstance(include_background, bool):
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
if distance_metric not in ["euclidean", "chessboard", "taxicab"]:
raise ValueError(
f"Arg `distance_metric` must be one of 'euclidean', 'chessboard', 'taxicab', but got {distance_metric}."
)
if spacing is not None and not isinstance(spacing, (list, Tensor)):
raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.")
if not isinstance(directed, bool):
raise ValueError(f"Expected argument `directed` must be a boolean, but got {directed}.")
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")


def hausdorff_distance(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
directed: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Calculate `Hausdorff Distance`_ for semantic segmentation.
Args:
preds: predicted binarized segmentation map
target: target binarized segmentation map
num_classes: number of classes
include_background: whether to include background class in calculation
distance_metric: distance metric to calculate surface distance. Choose one of `"euclidean"`,
`"chessboard"` or `"taxicab"`
spacing: spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1
directed: whether to calculate directed or undirected Hausdorff distance
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
Returns:
Hausdorff Distance for each class and batch element
Example:
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import hausdorff_distance
>>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> hausdorff_distance(preds, target, num_classes=5)
tensor([[2.0000, 1.4142, 2.0000, 2.0000],
[1.4142, 2.0000, 2.0000, 2.0000],
[2.0000, 2.0000, 1.4142, 2.0000],
[2.0000, 2.8284, 2.0000, 2.2361]])
"""
_hausdorff_distance_validate_args(num_classes, include_background, distance_metric, spacing, directed, input_format)
_check_same_shape(preds, target)

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
preds, target = _ignore_background(preds, target)

distances = torch.zeros(preds.shape[0], preds.shape[1], device=preds.device)

# TODO: add support for batched inputs
for b in range(preds.shape[0]):
for c in range(preds.shape[1]):
dist = edge_surface_distance(
preds=preds[b, c],
target=target[b, c],
distance_metric=distance_metric,
spacing=spacing,
symmetric=not directed,
)
distances[b, c] = torch.max(dist) if directed else torch.max(dist[0].max(), dist[1].max()) # type: ignore
return distances
57 changes: 44 additions & 13 deletions src/torchmetrics/functional/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:


def check_if_binarized(x: Tensor) -> None:
"""Check if the input is binarized.
"""Check if tensor is binarized.
Example:
>>> from torchmetrics.functional.segmentation.utils import check_if_binarized
Expand Down Expand Up @@ -200,9 +200,8 @@ def distance_transform(
Args:
x: The binary tensor to calculate the distance transform of.
sampling: Only relevant when distance is calculated using the euclidean distance. The sampling refers to the
pixel spacing in the image, i.e. the distance between two adjacent pixels. If not provided, the pixel
spacing is assumed to be 1.
sampling: The sampling refers to the pixel spacing in the image, i.e. the distance between two adjacent pixels.
If not provided, the pixel spacing is assumed to be 1.
metric: The distance to use for the distance transform. Can be one of ``"euclidean"``, ``"chessboard"``
or ``"taxicab"``.
engine: The engine to use for the distance transform. Can be one of ``["pytorch", "scipy"]``. In general,
Expand Down Expand Up @@ -249,25 +248,25 @@ def distance_transform(
raise ValueError(f"Expected argument `sampling` to have length 2 but got length `{len(sampling)}`.")

if engine == "pytorch":
x = x.float()
# calculate distance from every foreground pixel to every background pixel
i0, j0 = torch.where(x == 0)
i1, j1 = torch.where(x == 1)
dis_row = (i1.unsqueeze(1) - i0.unsqueeze(0)).abs_().mul_(sampling[0])
dis_col = (j1.unsqueeze(1) - j0.unsqueeze(0)).abs_().mul_(sampling[1])
dis_row = (i1.view(-1, 1) - i0.view(1, -1)).abs()
dis_col = (j1.view(-1, 1) - j0.view(1, -1)).abs()

# # calculate distance
h, _ = x.shape
if metric == "euclidean":
dis_row = dis_row.float()
dis_row.pow_(2).add_(dis_col.pow_(2)).sqrt_()
dis = ((sampling[0] * dis_row) ** 2 + (sampling[1] * dis_col) ** 2).sqrt()
if metric == "chessboard":
dis_row = dis_row.max(dis_col)
dis = torch.max(sampling[0] * dis_row, sampling[1] * dis_col).float()
if metric == "taxicab":
dis_row.add_(dis_col)
dis = (sampling[0] * dis_row + sampling[1] * dis_col).float()

# select only the closest distance
mindis, _ = torch.min(dis_row, dim=1)
z = torch.zeros_like(x, dtype=mindis.dtype).view(-1)
mindis, _ = torch.min(dis, dim=1)
z = torch.zeros_like(x).view(-1)
z[i1 * h + j1] = mindis
return z.view(x.shape)

Expand All @@ -279,7 +278,7 @@ def distance_transform(

if metric == "euclidean":
return ndimage.distance_transform_edt(x.cpu().numpy(), sampling)
return ndimage.distance_transform_cdt(x.cpu().numpy(), metric=metric)
return ndimage.distance_transform_cdt(x.cpu().numpy(), sampling, metric=metric)


def mask_edges(
Expand Down Expand Up @@ -390,6 +389,38 @@ def surface_distance(
return dis[preds]


def edge_surface_distance(
preds: Tensor,
target: Tensor,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
symmetric: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Extracts the edges from the input masks and calculates the surface distance between them.
Args:
preds: The predicted binary edge mask.
target: The target binary edge mask.
distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`.
spacing: The spacing between pixels along each spatial dimension.
symmetric: Whether to calculate the symmetric distance between the edges.
Returns:
A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the
distance from the corresponding edge in `preds` to the closest edge in `target`. If `symmetric` is `True`, the
function returns a tuple containing the distances from the predicted edges to the target edges and vice versa.
"""
output = mask_edges(preds, target)
edges_preds, edges_target = output[0].bool(), output[1].bool()
if symmetric:
return (
surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing),
surface_distance(edges_target, edges_preds, distance_metric=distance_metric, spacing=spacing),
)
return surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing)


@functools.lru_cache
def get_neighbour_tables(
spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore
from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance
from torchmetrics.segmentation.mean_iou import MeanIoU

__all__ = ["GeneralizedDiceScore", "MeanIoU"]
__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance"]
Loading

0 comments on commit 1335c7b

Please sign in to comment.