diff --git a/CHANGELOG.md b/CHANGELOG.md index 39f70bf82b6..584b59be6f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605)) +- Added new metric `ProcrustesDistance` to new domain Shape ([#2723](https://github.com/Lightning-AI/torchmetrics/pull/2723) + + ### Changed - Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) diff --git a/docs/source/index.rst b/docs/source/index.rst index 58dac2c6fb1..46670de00e4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -231,6 +231,14 @@ Or directly from conda segmentation/* +.. toctree:: + :maxdepth: 2 + :name: shape + :caption: Shape + :glob: + + shape/* + .. toctree:: :maxdepth: 2 :name: text diff --git a/docs/source/links.rst b/docs/source/links.rst index 14597ae3f37..035cbbec8b7 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -172,3 +172,4 @@ .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 .. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 +.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst index 7fddd9f316d..9e5544db349 100644 --- a/docs/source/segmentation/mean_iou.rst +++ b/docs/source/segmentation/mean_iou.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: Mean Intersection over Union (mIoU) :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg - :tags: segmentation + :tags: Segmentation ################################### Mean Intersection over Union (mIoU) diff --git a/docs/source/shape/procrustes.rst b/docs/source/shape/procrustes.rst new file mode 100644 index 00000000000..e69357c6473 --- /dev/null +++ b/docs/source/shape/procrustes.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Procrustes Disparity + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: shape + +.. include:: ../links.rst + +#################### +Procrustes Disparity +#################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.shape.ProcrustesDisparity + :exclude-members: update, compute + + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.shape.procrustes_disparity diff --git a/src/torchmetrics/functional/shape/__init__.py b/src/torchmetrics/functional/shape/__init__.py new file mode 100644 index 00000000000..7cf4118b053 --- /dev/null +++ b/src/torchmetrics/functional/shape/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.functional.shape.procrustes import procrustes_disparity + +__all__ = ["procrustes_disparity"] diff --git a/src/torchmetrics/functional/shape/procrustes.py b/src/torchmetrics/functional/shape/procrustes.py new file mode 100644 index 00000000000..08068fd2454 --- /dev/null +++ b/src/torchmetrics/functional/shape/procrustes.py @@ -0,0 +1,66 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from torch import Tensor, linalg + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.prints import rank_zero_warn + + +def procrustes_disparity( + point_cloud1: Tensor, point_cloud2: Tensor, return_all: bool = False +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: + """Runs procrustrus analysis on a batch of data points. + + Works similar ``scipy.spatial.procrustes`` but for batches of data points. + + Args: + point_cloud1: The first set of data points + point_cloud2: The second set of data points + return_all: If True, returns the scale and rotation matrices along with the disparity + + """ + _check_same_shape(point_cloud1, point_cloud2) + if point_cloud1.ndim != 3: + raise ValueError( + "Expected both datasets to be 3D tensors of shape (N, M, D), where N is the batch size, M is the number of" + f" data points and D is the dimensionality of the data points, but got {point_cloud1.ndim} dimensions." + ) + + point_cloud1 = point_cloud1 - point_cloud1.mean(dim=1, keepdim=True) + point_cloud2 = point_cloud2 - point_cloud2.mean(dim=1, keepdim=True) + point_cloud1 /= linalg.norm(point_cloud1, dim=[1, 2], keepdim=True) + point_cloud2 /= linalg.norm(point_cloud2, dim=[1, 2], keepdim=True) + + try: + u, w, v = linalg.svd( + torch.matmul(point_cloud2.transpose(1, 2), point_cloud1).transpose(1, 2), full_matrices=False + ) + except Exception as ex: + rank_zero_warn( + f"SVD calculation in procrustes_disparity failed with exception {ex}. Returning 0 disparity and identity" + " scale/rotation.", + UserWarning, + ) + return torch.tensor(0.0), torch.ones(point_cloud1.shape[0]), torch.eye(point_cloud1.shape[2]) + + rotation = torch.matmul(u, v) + scale = w.sum(1, keepdim=True) + point_cloud2 = scale[:, None] * torch.matmul(point_cloud2, rotation.transpose(1, 2)) + disparity = (point_cloud1 - point_cloud2).square().sum(dim=[1, 2]) + if return_all: + return disparity, scale, rotation + return disparity diff --git a/src/torchmetrics/shape/__init__.py b/src/torchmetrics/shape/__init__.py new file mode 100644 index 00000000000..263a1e395a2 --- /dev/null +++ b/src/torchmetrics/shape/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.shape.procrustes import ProcrustesDisparity + +__all__ = ["ProcrustesDisparity"] diff --git a/src/torchmetrics/shape/procrustes.py b/src/torchmetrics/shape/procrustes.py new file mode 100644 index 00000000000..a924fb48a4a --- /dev/null +++ b/src/torchmetrics/shape/procrustes.py @@ -0,0 +1,137 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics import Metric +from torchmetrics.functional.shape.procrustes import procrustes_disparity +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ProcrustesDisparity.plot"] + + +class ProcrustesDisparity(Metric): + r"""Compute the `Procrustes Disparity`_. + + The Procrustes Disparity is defined as the sum of the squared differences between two datasets after + applying a Procrustes transformation. The Procrustes Disparity is useful to compare two datasets + that are similar but not aligned. + + The metric works similar to ``scipy.spatial.procrustes`` but for batches of data points. The disparity is + aggregated over the batch, thus to get the individual disparities please use the functional version of this + metric: ``torchmetrics.functional.shape.procrustes.procrustes_disparity``. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``point_cloud1`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size, + ``M`` the number of data points and ``D`` the dimensionality of the data points. + - ``point_cloud2`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size, + ``M`` the number of data points and ``D`` the dimensionality of the data points. + + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``gds`` (:class:`~torch.Tensor`): A scalar tensor with the Procrustes Disparity. + + Args: + reduction: Determines whether to return the mean disparity or the sum of the disparities. + Can be one of ``"mean"`` or ``"sum"``. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: If ``average`` is not one of ``"mean"`` or ``"sum"``. + + Example: + >>> from torch import randn + >>> from torchmetrics.shape import ProcrustesDisparity + >>> metric = ProcrustesDisparity() + >>> point_cloud1 = randn(10, 50, 2) + >>> point_cloud2 = randn(10, 50, 2) + >>> metric(point_cloud1, point_cloud2) + tensor(0.9770) + + """ + + disparity: Tensor + total: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__(self, reduction: Literal["mean", "sum"] = "mean", **kwargs: Any) -> None: + super().__init__(**kwargs) + if reduction not in ("mean", "sum"): + raise ValueError(f"Argument `reduction` must be one of ['mean', 'sum'], got {reduction}") + self.reduction = reduction + self.add_state("disparity", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, point_cloud1: torch.Tensor, point_cloud2: torch.Tensor) -> None: + """Update the Procrustes Disparity with the given datasets.""" + disparity: Tensor = procrustes_disparity(point_cloud1, point_cloud2) # type: ignore[assignment] + self.disparity += disparity.sum() + self.total += disparity.numel() + + def compute(self) -> torch.Tensor: + """Computes the Procrustes Disparity.""" + if self.reduction == "mean": + return self.disparity / self.total + return self.disparity + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.shape import ProcrustesDisparity + >>> metric = ProcrustesDisparity() + >>> metric.update(torch.randn(10, 50, 2), torch.randn(10, 50, 2)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.shape import ProcrustesDisparity + >>> metric = ProcrustesDisparity() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randn(10, 50, 2), torch.randn(10, 50, 2))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/shape/__init__.py b/tests/unittests/shape/__init__.py new file mode 100644 index 00000000000..94f1dec4a9f --- /dev/null +++ b/tests/unittests/shape/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/shape/test_procrustes.py b/tests/unittests/shape/test_procrustes.py new file mode 100644 index 00000000000..a3b89e13eb7 --- /dev/null +++ b/tests/unittests/shape/test_procrustes.py @@ -0,0 +1,95 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.spatial import procrustes as scipy_procrustes +from torchmetrics.functional.shape.procrustes import procrustes_disparity +from torchmetrics.shape.procrustes import ProcrustesDisparity + +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) + +NUM_TARGETS = 5 + + +_inputs = _Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 50, EXTRA_DIM), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 50, EXTRA_DIM), +) + + +def _reference_procrustes(point_cloud1, point_cloud2, reduction=None): + point_cloud1 = point_cloud1.numpy() + point_cloud2 = point_cloud2.numpy() + + if reduction is None: + return np.array([scipy_procrustes(d1, d2)[2] for d1, d2 in zip(point_cloud1, point_cloud2)]) + + disparity = 0 + for d1, d2 in zip(point_cloud1, point_cloud2): + disparity += scipy_procrustes(d1, d2)[2] + if reduction == "mean": + return disparity / len(point_cloud1) + return disparity + + +@pytest.mark.parametrize("point_cloud1, point_cloud2", [(_inputs.preds, _inputs.target)]) +class TestProcrustesDisparity(MetricTester): + """Test class for `ProcrustesDisparity` metric.""" + + @pytest.mark.parametrize("reduction", ["sum", "mean"]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_procrustes_disparity(self, reduction, point_cloud1, point_cloud2, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + point_cloud1, + point_cloud2, + ProcrustesDisparity, + partial(_reference_procrustes, reduction=reduction), + metric_args={"reduction": reduction}, + ) + + def test_procrustes_disparity_functional(self, point_cloud1, point_cloud2): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + point_cloud1, + point_cloud2, + procrustes_disparity, + _reference_procrustes, + ) + + +def test_error_on_different_shape(): + """Test that error is raised on different shapes of input.""" + metric = ProcrustesDisparity() + with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): + metric(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) + with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): + procrustes_disparity(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) + + +def test_error_on_non_3d_input(): + """Test that error is raised if input is not 3-dimensional.""" + metric = ProcrustesDisparity() + with pytest.raises(ValueError, match="Expected both datasets to be 3D tensors of shape"): + metric(torch.randn(100), torch.randn(100)) + with pytest.raises(ValueError, match="Expected both datasets to be 3D tensors of shape"): + procrustes_disparity(torch.randn(100), torch.randn(100)) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 465ed2d55e5..aacbf050dac 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -150,6 +150,7 @@ RetrievalRecallAtFixedPrecision, RetrievalRPrecision, ) +from torchmetrics.shape import ProcrustesDisparity from torchmetrics.text import ( BERTScore, BLEUScore, @@ -609,6 +610,12 @@ pytest.param(CalinskiHarabaszScore, lambda: torch.randn(100, 3), _nominal_input, id="calinski harabasz score"), pytest.param(NormalizedMutualInfoScore, _nominal_input, _nominal_input, id="normalized mutual info score"), pytest.param(DunnIndex, lambda: torch.randn(100, 3), _nominal_input, id="dunn index"), + pytest.param( + ProcrustesDisparity, + lambda: torch.randn(1, 100, 3), + lambda: torch.randn(1, 100, 3), + id="procrustes disparity", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 3])