Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Newmetric: ProcrustesDisparity #2723

Merged
merged 27 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
35598bc
some docs
SkafteNicki Sep 3, 2024
955613e
some code
SkafteNicki Sep 3, 2024
7e058b8
docs
SkafteNicki Sep 4, 2024
d146ee0
nearly working code
SkafteNicki Sep 4, 2024
71bb5e4
improve tests
SkafteNicki Sep 7, 2024
f8d87da
fix src
SkafteNicki Sep 7, 2024
e4004f1
changelog
SkafteNicki Sep 7, 2024
8a27c6f
fix bare except
SkafteNicki Sep 7, 2024
65a9039
Merge branch 'master' into newmetric/procrustes
SkafteNicki Sep 7, 2024
31e313c
Merge branch 'master' into newmetric/procrustes
SkafteNicki Sep 9, 2024
edf19ad
fix doctest
SkafteNicki Sep 9, 2024
4c1235c
fix mypy
SkafteNicki Sep 9, 2024
8871678
Update src/torchmetrics/functional/shape/procrustes.py
SkafteNicki Sep 14, 2024
abd5ffb
Update src/torchmetrics/functional/shape/procrustes.py
SkafteNicki Sep 14, 2024
8059b5f
Update src/torchmetrics/functional/shape/procrustes.py
SkafteNicki Sep 14, 2024
5e771da
Update src/torchmetrics/shape/procrustes.py
SkafteNicki Sep 14, 2024
aa1f3d7
Merge branch 'master' into newmetric/procrustes
Borda Sep 16, 2024
ce3665f
Merge branch 'master' into newmetric/procrustes
Borda Sep 17, 2024
5f11a27
Merge branch 'master' into newmetric/procrustes
Borda Sep 23, 2024
fe201df
Merge branch 'master' into newmetric/procrustes
SkafteNicki Oct 9, 2024
3623efb
Merge branch 'master' into newmetric/procrustes
SkafteNicki Oct 10, 2024
8c16836
Apply suggestions from code review
Borda Oct 10, 2024
f110ffa
Merge branch 'master' into newmetric/procrustes
SkafteNicki Oct 10, 2024
1d5f20c
Merge branch 'master' into newmetric/procrustes
Borda Oct 10, 2024
9d6fee2
Merge branch 'master' into newmetric/procrustes
Borda Oct 10, 2024
86d4c3b
Merge branch 'master' into newmetric/procrustes
mergify[bot] Oct 11, 2024
5069f15
rename input variables
SkafteNicki Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605))


- Added new metric `ProcrustesDistance` to new domain Shape ([#2723](https://github.com/Lightning-AI/torchmetrics/pull/2723)


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ Or directly from conda

segmentation/*

.. toctree::
:maxdepth: 2
:name: shape
:caption: Shape
:glob:

shape/*

.. toctree::
:maxdepth: 2
:name: text
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,4 @@
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis
2 changes: 1 addition & 1 deletion docs/source/segmentation/mean_iou.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. customcarditem::
:header: Mean Intersection over Union (mIoU)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg
:tags: segmentation
:tags: Segmentation

###################################
Mean Intersection over Union (mIoU)
Expand Down
22 changes: 22 additions & 0 deletions docs/source/shape/procrustes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Procrustes Disparity
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: shape

.. include:: ../links.rst

####################
Procrustes Disparity
####################

Module Interface
________________

.. autoclass:: torchmetrics.shape.ProcrustesDisparity
:exclude-members: update, compute


Functional Interface
____________________

.. autofunction:: torchmetrics.functional.shape.procrustes_disparity
16 changes: 16 additions & 0 deletions src/torchmetrics/functional/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.shape.procrustes import procrustes_disparity

__all__ = ["procrustes_disparity"]
64 changes: 64 additions & 0 deletions src/torchmetrics/functional/shape/procrustes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union

import torch
from torch import Tensor, linalg

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.prints import rank_zero_warn


def procrustes_disparity(
dataset1: Tensor, dataset2: Tensor, return_all: bool = False
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
"""Runs procrustrus analysis on a batch of data points.

Works similar ``scipy.spatial.procrustes`` but for batches of data points.

Args:
dataset1: The first set of data points
dataset2: The second set of data points
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return_all: If True, returns the scale and rotation matrices along with the disparity

"""
_check_same_shape(dataset1, dataset2)
if dataset1.ndim != 3:
raise ValueError(
"Expected both datasets to be 3D tensors of shape (N, M, D), where N is the batch size, M is the number of"
f" data points and D is the dimensionality of the data points, but got {dataset1.ndim} dimensions."
)

dataset1 = dataset1 - dataset1.mean(dim=1, keepdim=True)
dataset2 = dataset2 - dataset2.mean(dim=1, keepdim=True)
dataset1 /= linalg.norm(dataset1, dim=[1, 2], keepdim=True)
dataset2 /= linalg.norm(dataset2, dim=[1, 2], keepdim=True)

try:
u, w, v = linalg.svd(torch.matmul(dataset2.transpose(1, 2), dataset1).transpose(1, 2), full_matrices=False)
except Exception as ex:
rank_zero_warn(
f"SVD calculation in procrustes_disparity failed with exception {ex}. Returning 0 disparity and identity"
" scale/rotation.",
UserWarning,
)
return torch.tensor(0.0), torch.ones(dataset1.shape[0]), torch.eye(dataset1.shape[2])

rotation = torch.matmul(u, v)
scale = w.sum(1, keepdim=True)
dataset2 = scale[:, None] * torch.matmul(dataset2, rotation.transpose(1, 2))
disparity = (dataset1 - dataset2).square().sum(dim=[1, 2])
if return_all:
return disparity, scale, rotation
return disparity
16 changes: 16 additions & 0 deletions src/torchmetrics/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.shape.procrustes import ProcrustesDisparity

__all__ = ["ProcrustesDisparity"]
137 changes: 137 additions & 0 deletions src/torchmetrics/shape/procrustes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics import Metric
from torchmetrics.functional.shape.procrustes import procrustes_disparity
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ProcrustesDisparity.plot"]


class ProcrustesDisparity(Metric):
r"""Compute the `Procrustes Disparity`_.

The Procrustes Disparity is defined as the sum of the squared differences between two datasets after
applying a Procrustes transformation. The Procrustes Disparity is useful to compare two datasets
that are similar but not aligned.

The metric works similar to ``scipy.spatial.procrustes`` but for batches of data points. The disparity is
aggregated over the batch, thus to get the individual disparities please use the functional version of this
metric: ``torchmetrics.functional.shape.procrustes.procrustes_disparity``.
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``dataset1`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size,
``M`` the number of data points and ``D`` the dimensionality of the data points.
- ``dataset2`` (torch.Tensor): A tensor of shape ``(N, M, D)`` with ``N`` being the batch size,
``M`` the number of data points and ``D`` the dimensionality of the data points.


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``gds`` (:class:`~torch.Tensor`): A scalar tensor with the Procrustes Disparity.

Args:
reduction: Determines whether to return the mean disparity or the sum of the disparities.
Can be one of ``"mean"`` or ``"sum"``.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError: If ``average`` is not one of ``"mean"`` or ``"sum"``.

Example:
>>> from torch import randn
>>> from torchmetrics.shape import ProcrustesDisparity
>>> metric = ProcrustesDisparity()
>>> dataset1 = randn(10, 50, 2)
>>> dataset2 = randn(10, 50, 2)
>>> metric(dataset1, dataset2)
tensor(0.9770)

"""

disparity: Tensor
total: Tensor
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(self, reduction: Literal["mean", "sum"] = "mean", **kwargs: Any) -> None:
super().__init__(**kwargs)
if reduction not in ("mean", "sum"):
raise ValueError(f"Argument `reduction` must be one of ['mean', 'sum'], got {reduction}")
self.reduction = reduction
self.add_state("disparity", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, dataset1: torch.Tensor, dataset2: torch.Tensor) -> None:
"""Update the Procrustes Disparity with the given datasets."""
disparity: Tensor = procrustes_disparity(dataset1, dataset2) # type: ignore[assignment]
self.disparity += disparity.sum()
self.total += disparity.numel()

def compute(self) -> torch.Tensor:
"""Computes the Procrustes Disparity."""
if self.reduction == "mean":
return self.disparity / self.total
return self.disparity

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.shape import ProcrustesDisparity
>>> metric = ProcrustesDisparity()
>>> metric.update(torch.randn(10, 50, 2), torch.randn(10, 50, 2))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.shape import ProcrustesDisparity
>>> metric = ProcrustesDisparity()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randn(10, 50, 2), torch.randn(10, 50, 2)))
>>> fig_, ax_ = metric.plot(values)

"""
return self._plot(val, ax)
13 changes: 13 additions & 0 deletions tests/unittests/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading