Skip to content

Commit

Permalink
New metric: Adjusted mutual info score (#2058)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 18, 2023
1 parent 8da841c commit c6547bb
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `FowlkesMallowsIndex` ([#2066](https://github.com/Lightning-AI/torchmetrics/pull/2066))

- `AdjustedMutualInfoScore` ([#2058](https://github.com/Lightning-AI/torchmetrics/pull/2058))

- `DaviesBouldinScore` ([#2071](https://github.com/Lightning-AI/torchmetrics/pull/2071))


Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/adjusted_mutual_info_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Adjusted Mutual Information Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg
:tags: Clustering

.. include:: ../links.rst

#################################
Adjusted Mutual Information Score
#################################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.AdjustedMutualInfoScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.adjusted_mutual_info_score
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
.. _GIOU: https://arxiv.org/abs/1902.09630
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
.. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
.. _Adjusted Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_mutual_info_score.html#sklearn.metrics.adjusted_mutual_info_score
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
.. _faster-coco-eval: https://github.com/MiXaiLL76/faster_coco_eval
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.adjusted_mutual_info_score import AdjustedMutualInfoScore
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore
from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore
from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore
Expand All @@ -26,6 +27,7 @@
from torchmetrics.clustering.rand_score import RandScore

__all__ = [
"AdjustedMutualInfoScore",
"AdjustedRandScore",
"CalinskiHarabaszScore",
"CompletenessScore",
Expand Down
127 changes: 127 additions & 0 deletions src/torchmetrics/clustering/adjusted_mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Literal, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.clustering.mutual_info_score import MutualInfoScore
from torchmetrics.functional.clustering.adjusted_mutual_info_score import (
_validate_average_method_arg,
adjusted_mutual_info_score,
)
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["AdjustedMutualInfoScore.plot"]


class AdjustedMutualInfoScore(MutualInfoScore):
r"""Compute `Adjusted Mutual Information Score`_.
.. math::
AMI(U,V) = \frac{MI(U,V) - E(MI(U,V))}{avg(H(U), H(V)) - E(MI(U,V))}
Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, :math:`M_p(U,V)` is the
generalized mean of order :math:`p` of :math:`U` and :math:`V`, and :math:`MI(U,V)` is the mutual information score
between clusters :math:`U` and :math:`V`. The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields
the same mutual information score.
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``ami_score`` (:class:`~torch.Tensor`): A tensor with the Adjusted Mutual Information Score
Args:
average_method: Method used to calculate generalized mean for normalization. Choose between
``'min'``, ``'geometric'``, ``'arithmetic'``, ``'max'``.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torchmetrics.clustering import AdjustedMutualInfoScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> ami_score = AdjustedMutualInfoScore(average_method="arithmetic")
>>> ami_score(preds, target)
tensor(-0.2500)
"""

is_differentiable: bool = True
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(
self, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic", **kwargs: Any
) -> None:
super().__init__(**kwargs)
_validate_average_method_arg(average_method)
self.average_method = average_method

def compute(self) -> Tensor:
"""Compute normalized mutual information over state."""
return adjusted_mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target), self.average_method)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import AdjustedMutualInfoScore
>>> metric = AdjustedMutualInfoScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import AdjustedMutualInfoScore
>>> metric = AdjustedMutualInfoScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
"""
return self._plot(val, ax)
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score
from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score
from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score
Expand All @@ -26,6 +27,7 @@
from torchmetrics.functional.clustering.rand_score import rand_score

__all__ = [
"adjusted_mutual_info_score",
"adjusted_rand_score",
"calinski_harabasz_score",
"completeness_score",
Expand Down
121 changes: 121 additions & 0 deletions src/torchmetrics/functional/clustering/adjusted_mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal

import torch
from torch import Tensor, tensor

from torchmetrics.functional.clustering.mutual_info_score import _mutual_info_score_compute, _mutual_info_score_update
from torchmetrics.functional.clustering.utils import (
_validate_average_method_arg,
calculate_entropy,
calculate_generalized_mean,
)


def adjusted_mutual_info_score(
preds: Tensor, target: Tensor, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> Tensor:
"""Compute adjusted mutual information between two clusterings.
Args:
preds: predicted cluster labels
target: ground truth cluster labels
average_method: normalizer computation method
Returns:
Scalar tensor with adjusted mutual info score between 0.0 and 1.0
Example:
>>> from torchmetrics.functional.clustering import adjusted_mutual_info_score
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> adjusted_mutual_info_score(preds, target, "arithmetic")
tensor(-0.2500)
"""
_validate_average_method_arg(average_method)
contingency = _mutual_info_score_update(preds, target)
mutual_info = _mutual_info_score_compute(contingency)
expected_mutual_info = expected_mutual_info_score(contingency, target.numel())
normalizer = calculate_generalized_mean(
torch.stack([calculate_entropy(preds), calculate_entropy(target)]), average_method
)
denominator = normalizer - expected_mutual_info
if denominator < 0:
denominator = torch.min(torch.tensor([denominator, -torch.finfo(denominator.dtype).eps]))
else:
denominator = torch.max(torch.tensor([denominator, torch.finfo(denominator.dtype).eps]))

return (mutual_info - expected_mutual_info) / denominator


def expected_mutual_info_score(contingency: Tensor, n_samples: int) -> Tensor:
"""Calculated expected mutual information score between two clusterings.
Implementation taken from sklearn/metrics/cluster/_expected_mutual_info_fast.pyx.
Args:
contingency: contingency matrix
n_samples: number of samples
Returns:
expected_mutual_info_score: expected mutual information score
"""
n_rows, n_cols = contingency.shape
a = torch.ravel(contingency.sum(dim=1))
b = torch.ravel(contingency.sum(dim=0))

# Check if preds or target labels only have one cluster
if a.numel() == 1 or b.numel() == 1:
return tensor(0.0, device=a.device)

nijs = torch.arange(0, max([a.max().item(), b.max().item()]) + 1, device=a.device)
nijs[0] = 1

term1 = nijs / n_samples
log_a = torch.log(a)
log_b = torch.log(b)

log_nnij = torch.log(torch.tensor(n_samples, device=a.device)) + torch.log(nijs)

gln_a = torch.lgamma(a + 1)
gln_b = torch.lgamma(b + 1)
gln_na = torch.lgamma(n_samples - a + 1)
gln_nb = torch.lgamma(n_samples - b + 1)
gln_nnij = torch.lgamma(nijs + 1) + torch.lgamma(torch.tensor(n_samples + 1, dtype=a.dtype, device=a.device))

emi = tensor(0.0, device=a.device)
for i in range(n_rows):
for j in range(n_cols):
start = int(max(1, a[i].item() - n_samples + b[j].item()))
end = int(min(a[i].item(), b[j].item()) + 1)

for nij in range(start, end):
term2 = log_nnij[nij] - log_a[i] - log_b[j]
gln = (
gln_a[i]
+ gln_b[j]
+ gln_na[i]
+ gln_nb[j]
- gln_nnij[nij]
- torch.lgamma(a[i] - nij + 1)
- torch.lgamma(b[j] - nij + 1)
- torch.lgamma(n_samples - a[i] - b[j] + nij + 1)
)
term3 = torch.exp(gln)
emi += term1[nij] * term2 * term3

return emi
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@
from torch import Tensor

from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.utils import calculate_entropy, calculate_generalized_mean, check_cluster_labels


def _validate_average_method_arg(
average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> None:
if average_method not in ("min", "geometric", "arithmetic", "max"):
raise ValueError(
"Expected argument `average_method` to be one of `min`, `geometric`, `arithmetic`, `max`,"
f"but got {average_method}"
)
from torchmetrics.functional.clustering.utils import (
_validate_average_method_arg,
calculate_entropy,
calculate_generalized_mean,
check_cluster_labels,
)


def normalized_mutual_info_score(
Expand Down
26 changes: 25 additions & 1 deletion src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,30 @@
from torchmetrics.utilities.checks import _check_same_shape


def is_nonnegative(x: Tensor, atol: float = 1e-5) -> Tensor:
"""Return True if all elements of tensor are nonnegative within certain tolerance.
Args:
x: tensor
atol: absolute tolerance
Returns:
Boolean tensor indicating if all values are nonnegative
"""
return torch.logical_or(x > 0.0, torch.abs(x) < atol).all()


def _validate_average_method_arg(
average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> None:
if average_method not in ("min", "geometric", "arithmetic", "max"):
raise ValueError(
"Expected argument `average_method` to be one of `min`, `geometric`, `arithmetic`, `max`,"
f"but got {average_method}"
)


def calculate_entropy(x: Tensor) -> Tensor:
"""Calculate entropy for a tensor of labels.
Expand Down Expand Up @@ -74,7 +98,7 @@ def calculate_generalized_mean(x: Tensor, p: Union[int, Literal["min", "geometri
tensor(1.6438)
"""
if torch.is_complex(x) or torch.any(x <= 0.0):
if torch.is_complex(x) or not is_nonnegative(x):
raise ValueError("`x` must contain positive real numbers")

if isinstance(p, str):
Expand Down
Loading

0 comments on commit c6547bb

Please sign in to comment.