Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: Davies bouldin score #2071

Merged
merged 16 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `FowlkesMallowsIndex` ([#2066](https://github.com/Lightning-AI/torchmetrics/pull/2066))

- `DaviesBouldinScore` ([#2071](https://github.com/Lightning-AI/torchmetrics/pull/2071))


### Changed

-
Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/davies_bouldin_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Davies Bouldin Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg
:tags: Clustering

.. include:: ../links.rst

####################
Davies Bouldin Score
####################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.DaviesBouldinScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.davies_bouldin_score
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,5 @@
.. _fork of pycocotools: https://github.com/ppwwyyxx/cocoapi
.. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index
.. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index
.. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore
from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore
from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore
from torchmetrics.clustering.dunn_index import DunnIndex
from torchmetrics.clustering.fowlkes_mallows_index import FowlkesMallowsIndex
from torchmetrics.clustering.mutual_info_score import MutualInfoScore
Expand All @@ -22,6 +23,7 @@
__all__ = [
"AdjustedRandScore",
"CalinskiHarabaszScore",
"DaviesBouldinScore",
"DunnIndex",
"FowlkesMallowsIndex",
"MutualInfoScore",
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/clustering/calinski_harabasz_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,20 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> from torchmetrics.clustering import CalinskiHarabaszScore
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> metric = CalinskiHarabaszScore()
>>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> from torchmetrics.clustering import CalinskiHarabaszScore
>>> metric = CalinskiHarabaszScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
... metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

"""
Expand Down
136 changes: 136 additions & 0 deletions src/torchmetrics/clustering/davies_bouldin_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["DaviesBouldinScore.plot"]


class DaviesBouldinScore(Metric):
r"""Compute `Davies-Bouldin Score`_ for clustering algorithms.

Given the following quantities:

..math::
S_i = \left( \frac{1}{T_i} \sum_{j=1}^{T_i} ||X_j - A_i||^2_2 \right)^{1/2}

where :math:`T_i` is the number of samples in cluster :math:`i`, :math:`X_j` is the :math:`j`-th sample in cluster
:math:`i`, and :math:`A_i` is the centroid of cluster :math:`i`. This quantity is the average distance between all
the samples in cluster :math:`i` and its centroid. Let

..math::
M_{i,j} = ||A_i - A_j||_2

e.g. the distance between the centroids of cluster :math:`i` and cluster :math:`j`. Then the Davies-Bouldin score
is defined as:

..math::
DB = \frac{1}{n_{clusters}} \sum_{i=1}^{n_{clusters}} \max_{j \neq i} \left( \frac{S_i + S_j}{M_{i,j}} \right)

This clustering metric is an intrinsic measure, because it does not rely on ground truth labels for the evaluation.
Instead it examines how well the clusters are separated from each other. The score is higher when clusters are dense
and well separated, which relates to a standard concept of a cluster.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``data`` (:class:`~torch.Tensor`): float tensor with shape ``(N,d)`` with the embedded data. ``d`` is the
dimensionality of the embedding space.
- ``labels`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with cluster labels

As output of ``forward`` and ``compute`` the metric returns the following output:

- ``chs`` (:class:`~torch.Tensor`): A tensor with the Calinski Harabasz Score

Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics.clustering import DaviesBouldinScore
>>> _ = torch.manual_seed(42)
>>> data = torch.randn(10, 3)
>>> labels = torch.randint(3, (10,))
>>> metric = DaviesBouldinScore()
>>> metric(data, labels)
tensor(1.2540)

"""
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
data: List[Tensor]
labels: List[Tensor]

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

self.add_state("data", default=[], dist_reduce_fx="cat")
self.add_state("labels", default=[], dist_reduce_fx="cat")

def update(self, data: Tensor, labels: Tensor) -> None:
"""Update metric state with new data and labels."""
self.data.append(data)
self.labels.append(labels)

def compute(self) -> Tensor:
"""Compute the Davies Bouldin Score over all data and labels."""
return davies_bouldin_score(dim_zero_cat(self.data), dim_zero_cat(self.labels))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import DaviesBouldinScore
>>> metric = DaviesBouldinScore()
>>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import DaviesBouldinScore
>>> metric = DaviesBouldinScore()
>>> for _ in range(10):
... metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

"""
return self._plot(val, ax)
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score
from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score
from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score
from torchmetrics.functional.clustering.dunn_index import dunn_index
from torchmetrics.functional.clustering.fowlkes_mallows_index import fowlkes_mallows_index
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
Expand All @@ -22,6 +23,7 @@
__all__ = [
"adjusted_rand_score",
"calinski_harabasz_score",
"davies_bouldin_score",
"dunn_index",
"fowlkes_mallows_index",
"mutual_info_score",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@
import torch
from torch import Tensor


def _calinski_harabasz_score_validate_input(data: Tensor, labels: Tensor) -> None:
"""Validate that the input data and labels have correct shape and type."""
if data.ndim != 2:
raise ValueError(f"Expected 2D data, got {data.ndim}D data instead")
if not data.is_floating_point():
raise ValueError(f"Expected floating point data, got {data.dtype} data instead")
if labels.ndim != 1:
raise ValueError(f"Expected 1D labels, got {labels.ndim}D labels instead")
from torchmetrics.functional.clustering.utils import (
_validate_intrinsic_cluster_data,
_validate_intrinsic_labels_to_samples,
)


def calinski_harabasz_score(data: Tensor, labels: Tensor) -> Tensor:
Expand All @@ -45,19 +40,13 @@ def calinski_harabasz_score(data: Tensor, labels: Tensor) -> Tensor:
tensor(3.4998)

"""
_calinski_harabasz_score_validate_input(data, labels)
_validate_intrinsic_cluster_data(data, labels)

# convert to zero indexed labels
unique_labels, labels = torch.unique(labels, return_inverse=True)
n_labels = len(unique_labels)

n_samples = data.shape[0]

if not 1 < n_labels < n_samples:
raise ValueError(
"Number of detected clusters must be greater than one and less than the number of samples."
f"Got {n_labels} clusters and {n_samples} samples."
)
_validate_intrinsic_labels_to_samples(n_labels, n_samples)

mean = data.mean(dim=0)
between_cluster_dispersion = torch.tensor(0.0, device=data.device)
Expand Down
67 changes: 67 additions & 0 deletions src/torchmetrics/functional/clustering/davies_bouldin_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.functional.clustering.utils import (
_validate_intrinsic_cluster_data,
_validate_intrinsic_labels_to_samples,
)


def davies_bouldin_score(data: Tensor, labels: Tensor) -> Tensor:
"""Compute the Davies bouldin score for clustering algorithms.

Args:
data: float tensor with shape ``(N,d)`` with the embedded data.
labels: single integer tensor with shape ``(N,)`` with cluster labels

Returns:
Scalar tensor with the Davies bouldin score

Example:
>>> import torch
>>> from torchmetrics.functional.clustering import davies_bouldin_score
>>> _ = torch.manual_seed(42)
>>> data = torch.randn(10, 3)
>>> labels = torch.randint(0, 2, (10,))
>>> davies_bouldin_score(data, labels)
tensor(1.3249)

"""
_validate_intrinsic_cluster_data(data, labels)

# convert to zero indexed labels
unique_labels, labels = torch.unique(labels, return_inverse=True)
n_labels = len(unique_labels)
n_samples, dim = data.shape
_validate_intrinsic_labels_to_samples(n_labels, n_samples)

intra_dists = torch.zeros(n_labels, device=data.device)
centroids = torch.zeros((n_labels, dim), device=data.device)
for k in range(n_labels):
cluster_k = data[labels == k, :]
centroids[k] = cluster_k.mean(dim=0)
intra_dists[k] = (cluster_k - centroids[k]).pow(2.0).sum(dim=1).sqrt().mean()
centroid_distances = torch.cdist(centroids, centroids)

cond1 = torch.allclose(intra_dists, torch.zeros_like(intra_dists))
cond2 = torch.allclose(centroid_distances, torch.zeros_like(centroid_distances))
if cond1 or cond2:
return torch.tensor(0.0, device=data.device, dtype=torch.float32)

centroid_distances[centroid_distances == 0] = float("inf")
combined_intra_dists = intra_dists.unsqueeze(0) + intra_dists.unsqueeze(1)
scores = (combined_intra_dists / centroid_distances).max(dim=1).values
return scores.mean()
19 changes: 19 additions & 0 deletions src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None:
raise ValueError(f"Expected real, discrete values for x but received {preds.dtype} and {target.dtype}.")


def _validate_intrinsic_cluster_data(data: Tensor, labels: Tensor) -> None:
"""Validate that the input data and labels have correct shape and type."""
if data.ndim != 2:
raise ValueError(f"Expected 2D data, got {data.ndim}D data instead")
if not data.is_floating_point():
raise ValueError(f"Expected floating point data, got {data.dtype} data instead")
if labels.ndim != 1:
raise ValueError(f"Expected 1D labels, got {labels.ndim}D labels instead")


def _validate_intrinsic_labels_to_samples(n_labels: int, n_samples: int) -> None:
"""Validate that the number of labels are in the correct range."""
if not 1 < n_labels < n_samples:
raise ValueError(
"Number of detected clusters must be greater than one and less than the number of samples."
f"Got {n_labels} clusters and {n_samples} samples."
)


def calcualte_pair_cluster_confusion_matrix(
preds: Optional[Tensor] = None,
target: Optional[Tensor] = None,
Expand Down
Loading
Loading