From c6547bbcc0c1f3b4f2ecb0182a26158f15cc4fc9 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Mon, 18 Sep 2023 16:45:52 +0900 Subject: [PATCH] New metric: Adjusted mutual info score (#2058) Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 + .../clustering/adjusted_mutual_info_score.rst | 21 +++ docs/source/links.rst | 1 + src/torchmetrics/clustering/__init__.py | 2 + .../clustering/adjusted_mutual_info_score.py | 127 ++++++++++++++++++ .../functional/clustering/__init__.py | 2 + .../clustering/adjusted_mutual_info_score.py | 121 +++++++++++++++++ .../normalized_mutual_info_score.py | 17 +-- .../functional/clustering/utils.py | 26 +++- .../test_adjusted_mutual_info_score.py | 101 ++++++++++++++ 10 files changed, 408 insertions(+), 12 deletions(-) create mode 100644 docs/source/clustering/adjusted_mutual_info_score.rst create mode 100644 src/torchmetrics/clustering/adjusted_mutual_info_score.py create mode 100644 src/torchmetrics/functional/clustering/adjusted_mutual_info_score.py create mode 100644 tests/unittests/clustering/test_adjusted_mutual_info_score.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2317ac43604..81db60a16ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `FowlkesMallowsIndex` ([#2066](https://github.com/Lightning-AI/torchmetrics/pull/2066)) + - `AdjustedMutualInfoScore` ([#2058](https://github.com/Lightning-AI/torchmetrics/pull/2058)) + - `DaviesBouldinScore` ([#2071](https://github.com/Lightning-AI/torchmetrics/pull/2071)) diff --git a/docs/source/clustering/adjusted_mutual_info_score.rst b/docs/source/clustering/adjusted_mutual_info_score.rst new file mode 100644 index 00000000000..0d0c2e27594 --- /dev/null +++ b/docs/source/clustering/adjusted_mutual_info_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Adjusted Mutual Information Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg + :tags: Clustering + +.. include:: ../links.rst + +################################# +Adjusted Mutual Information Score +################################# + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.AdjustedMutualInfoScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.adjusted_mutual_info_score diff --git a/docs/source/links.rst b/docs/source/links.rst index 9c63d351b43..3685eb8ae1a 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -152,6 +152,7 @@ .. _GIOU: https://arxiv.org/abs/1902.09630 .. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information .. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html +.. _Adjusted Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_mutual_info_score.html#sklearn.metrics.adjusted_mutual_info_score .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools .. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 .. _faster-coco-eval: https://github.com/MiXaiLL76/faster_coco_eval diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index ca7f8ed29af..68886cf9d07 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.clustering.adjusted_mutual_info_score import AdjustedMutualInfoScore from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore @@ -26,6 +27,7 @@ from torchmetrics.clustering.rand_score import RandScore __all__ = [ + "AdjustedMutualInfoScore", "AdjustedRandScore", "CalinskiHarabaszScore", "CompletenessScore", diff --git a/src/torchmetrics/clustering/adjusted_mutual_info_score.py b/src/torchmetrics/clustering/adjusted_mutual_info_score.py new file mode 100644 index 00000000000..eb6d5b2b2e3 --- /dev/null +++ b/src/torchmetrics/clustering/adjusted_mutual_info_score.py @@ -0,0 +1,127 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Literal, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.clustering.mutual_info_score import MutualInfoScore +from torchmetrics.functional.clustering.adjusted_mutual_info_score import ( + _validate_average_method_arg, + adjusted_mutual_info_score, +) +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["AdjustedMutualInfoScore.plot"] + + +class AdjustedMutualInfoScore(MutualInfoScore): + r"""Compute `Adjusted Mutual Information Score`_. + + .. math:: + AMI(U,V) = \frac{MI(U,V) - E(MI(U,V))}{avg(H(U), H(V)) - E(MI(U,V))} + + Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, :math:`M_p(U,V)` is the + generalized mean of order :math:`p` of :math:`U` and :math:`V`, and :math:`MI(U,V)` is the mutual information score + between clusters :math:`U` and :math:`V`. The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields + the same mutual information score. + + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``ami_score`` (:class:`~torch.Tensor`): A tensor with the Adjusted Mutual Information Score + + Args: + average_method: Method used to calculate generalized mean for normalization. Choose between + ``'min'``, ``'geometric'``, ``'arithmetic'``, ``'max'``. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import AdjustedMutualInfoScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> ami_score = AdjustedMutualInfoScore(average_method="arithmetic") + >>> ami_score(preds, target) + tensor(-0.2500) + + """ + + is_differentiable: bool = True + higher_is_better: Optional[bool] = None + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__( + self, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic", **kwargs: Any + ) -> None: + super().__init__(**kwargs) + _validate_average_method_arg(average_method) + self.average_method = average_method + + def compute(self) -> Tensor: + """Compute normalized mutual information over state.""" + return adjusted_mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target), self.average_method) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import AdjustedMutualInfoScore + >>> metric = AdjustedMutualInfoScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import AdjustedMutualInfoScore + >>> metric = AdjustedMutualInfoScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index aa83d386f28..28c92755aa2 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score @@ -26,6 +27,7 @@ from torchmetrics.functional.clustering.rand_score import rand_score __all__ = [ + "adjusted_mutual_info_score", "adjusted_rand_score", "calinski_harabasz_score", "completeness_score", diff --git a/src/torchmetrics/functional/clustering/adjusted_mutual_info_score.py b/src/torchmetrics/functional/clustering/adjusted_mutual_info_score.py new file mode 100644 index 00000000000..b70525c1d76 --- /dev/null +++ b/src/torchmetrics/functional/clustering/adjusted_mutual_info_score.py @@ -0,0 +1,121 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.clustering.mutual_info_score import _mutual_info_score_compute, _mutual_info_score_update +from torchmetrics.functional.clustering.utils import ( + _validate_average_method_arg, + calculate_entropy, + calculate_generalized_mean, +) + + +def adjusted_mutual_info_score( + preds: Tensor, target: Tensor, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic" +) -> Tensor: + """Compute adjusted mutual information between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + average_method: normalizer computation method + + Returns: + Scalar tensor with adjusted mutual info score between 0.0 and 1.0 + + Example: + >>> from torchmetrics.functional.clustering import adjusted_mutual_info_score + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> adjusted_mutual_info_score(preds, target, "arithmetic") + tensor(-0.2500) + + """ + _validate_average_method_arg(average_method) + contingency = _mutual_info_score_update(preds, target) + mutual_info = _mutual_info_score_compute(contingency) + expected_mutual_info = expected_mutual_info_score(contingency, target.numel()) + normalizer = calculate_generalized_mean( + torch.stack([calculate_entropy(preds), calculate_entropy(target)]), average_method + ) + denominator = normalizer - expected_mutual_info + if denominator < 0: + denominator = torch.min(torch.tensor([denominator, -torch.finfo(denominator.dtype).eps])) + else: + denominator = torch.max(torch.tensor([denominator, torch.finfo(denominator.dtype).eps])) + + return (mutual_info - expected_mutual_info) / denominator + + +def expected_mutual_info_score(contingency: Tensor, n_samples: int) -> Tensor: + """Calculated expected mutual information score between two clusterings. + + Implementation taken from sklearn/metrics/cluster/_expected_mutual_info_fast.pyx. + + Args: + contingency: contingency matrix + n_samples: number of samples + + Returns: + expected_mutual_info_score: expected mutual information score + + """ + n_rows, n_cols = contingency.shape + a = torch.ravel(contingency.sum(dim=1)) + b = torch.ravel(contingency.sum(dim=0)) + + # Check if preds or target labels only have one cluster + if a.numel() == 1 or b.numel() == 1: + return tensor(0.0, device=a.device) + + nijs = torch.arange(0, max([a.max().item(), b.max().item()]) + 1, device=a.device) + nijs[0] = 1 + + term1 = nijs / n_samples + log_a = torch.log(a) + log_b = torch.log(b) + + log_nnij = torch.log(torch.tensor(n_samples, device=a.device)) + torch.log(nijs) + + gln_a = torch.lgamma(a + 1) + gln_b = torch.lgamma(b + 1) + gln_na = torch.lgamma(n_samples - a + 1) + gln_nb = torch.lgamma(n_samples - b + 1) + gln_nnij = torch.lgamma(nijs + 1) + torch.lgamma(torch.tensor(n_samples + 1, dtype=a.dtype, device=a.device)) + + emi = tensor(0.0, device=a.device) + for i in range(n_rows): + for j in range(n_cols): + start = int(max(1, a[i].item() - n_samples + b[j].item())) + end = int(min(a[i].item(), b[j].item()) + 1) + + for nij in range(start, end): + term2 = log_nnij[nij] - log_a[i] - log_b[j] + gln = ( + gln_a[i] + + gln_b[j] + + gln_na[i] + + gln_nb[j] + - gln_nnij[nij] + - torch.lgamma(a[i] - nij + 1) + - torch.lgamma(b[j] - nij + 1) + - torch.lgamma(n_samples - a[i] - b[j] + nij + 1) + ) + term3 = torch.exp(gln) + emi += term1[nij] * term2 * term3 + + return emi diff --git a/src/torchmetrics/functional/clustering/normalized_mutual_info_score.py b/src/torchmetrics/functional/clustering/normalized_mutual_info_score.py index dfa9db61c5e..5cdc81b960d 100644 --- a/src/torchmetrics/functional/clustering/normalized_mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/normalized_mutual_info_score.py @@ -17,17 +17,12 @@ from torch import Tensor from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score -from torchmetrics.functional.clustering.utils import calculate_entropy, calculate_generalized_mean, check_cluster_labels - - -def _validate_average_method_arg( - average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic" -) -> None: - if average_method not in ("min", "geometric", "arithmetic", "max"): - raise ValueError( - "Expected argument `average_method` to be one of `min`, `geometric`, `arithmetic`, `max`," - f"but got {average_method}" - ) +from torchmetrics.functional.clustering.utils import ( + _validate_average_method_arg, + calculate_entropy, + calculate_generalized_mean, + check_cluster_labels, +) def normalized_mutual_info_score( diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 6ccca75f569..463bf675c60 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -20,6 +20,30 @@ from torchmetrics.utilities.checks import _check_same_shape +def is_nonnegative(x: Tensor, atol: float = 1e-5) -> Tensor: + """Return True if all elements of tensor are nonnegative within certain tolerance. + + Args: + x: tensor + atol: absolute tolerance + + Returns: + Boolean tensor indicating if all values are nonnegative + + """ + return torch.logical_or(x > 0.0, torch.abs(x) < atol).all() + + +def _validate_average_method_arg( + average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic" +) -> None: + if average_method not in ("min", "geometric", "arithmetic", "max"): + raise ValueError( + "Expected argument `average_method` to be one of `min`, `geometric`, `arithmetic`, `max`," + f"but got {average_method}" + ) + + def calculate_entropy(x: Tensor) -> Tensor: """Calculate entropy for a tensor of labels. @@ -74,7 +98,7 @@ def calculate_generalized_mean(x: Tensor, p: Union[int, Literal["min", "geometri tensor(1.6438) """ - if torch.is_complex(x) or torch.any(x <= 0.0): + if torch.is_complex(x) or not is_nonnegative(x): raise ValueError("`x` must contain positive real numbers") if isinstance(p, str): diff --git a/tests/unittests/clustering/test_adjusted_mutual_info_score.py b/tests/unittests/clustering/test_adjusted_mutual_info_score.py new file mode 100644 index 00000000000..e686ca0837d --- /dev/null +++ b/tests/unittests/clustering/test_adjusted_mutual_info_score.py @@ -0,0 +1,101 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import pytest +import torch +from sklearn.metrics import adjusted_mutual_info_score as sklearn_ami +from torchmetrics.clustering.adjusted_mutual_info_score import AdjustedMutualInfoScore +from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score + +from unittests import BATCH_SIZE, NUM_CLASSES +from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +ATOL = 1e-5 + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_extrinsic1.preds, _single_target_extrinsic1.target), + (_single_target_extrinsic2.preds, _single_target_extrinsic2.target), + ], +) +@pytest.mark.parametrize( + "average_method", + ["min", "arithmetic", "geometric", "max"], +) +class TestAdjustedMutualInfoScore(MetricTester): + """Test class for `AdjustedMutualInfoScore` metric.""" + + atol = ATOL + + @pytest.mark.parametrize("ddp", [True, False]) + def test_adjusted_mutual_info_score(self, preds, target, average_method, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=AdjustedMutualInfoScore, + reference_metric=partial(sklearn_ami, average_method=average_method), + metric_args={"average_method": average_method}, + ) + + def test_adjusted_mutual_info_score_functional(self, preds, target, average_method): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=adjusted_mutual_info_score, + reference_metric=partial(sklearn_ami, average_method=average_method), + average_method=average_method, + ) + + +@pytest.mark.parametrize("average_method", ["min", "geometric", "arithmetic", "max"]) +def test_adjusted_mutual_info_score_functional_single_cluster(average_method): + """Check that for single cluster the metric returns 0.""" + tensor_a = torch.randint(NUM_CLASSES, (BATCH_SIZE,)) + tensor_b = torch.zeros((BATCH_SIZE,), dtype=torch.int) + assert torch.allclose(adjusted_mutual_info_score(tensor_a, tensor_b, average_method), torch.tensor(0.0), atol=ATOL) + assert torch.allclose(adjusted_mutual_info_score(tensor_b, tensor_a, average_method), torch.tensor(0.0), atol=ATOL) + + +@pytest.mark.parametrize("average_method", ["min", "geometric", "arithmetic", "max"]) +def test_adjusted_mutual_info_score_functional_raises_invalid_task(average_method): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs_extrinsic + with pytest.raises(ValueError, match=r"Expected *"): + adjusted_mutual_info_score(preds, target, average_method) + + +@pytest.mark.parametrize( + "average_method", + ["min", "geometric", "arithmetic", "max"], +) +def test_adjusted_mutual_info_score_functional_is_symmetric( + average_method, preds=_single_target_extrinsic1.preds, target=_single_target_extrinsic1.target +): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose( + adjusted_mutual_info_score(p, t, average_method), + adjusted_mutual_info_score(t, p, average_method), + atol=1e-6, + )