Skip to content

Commit 7611b4c

Browse files
authored
New metric: Davies bouldin score (#2071)
* implementation * fix error in other metric * links + init + utils * add tests * changelog * fix inf * changelog * docs
1 parent 77a5317 commit 7611b4c

File tree

11 files changed

+319
-23
lines changed

11 files changed

+319
-23
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727

2828
- `FowlkesMallowsIndex` ([#2066](https://github.com/Lightning-AI/torchmetrics/pull/2066))
2929

30+
- `DaviesBouldinScore` ([#2071](https://github.com/Lightning-AI/torchmetrics/pull/2071))
31+
32+
3033
### Changed
3134

3235
-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
.. customcarditem::
2+
:header: Davies Bouldin Score
3+
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg
4+
:tags: Clustering
5+
6+
.. include:: ../links.rst
7+
8+
####################
9+
Davies Bouldin Score
10+
####################
11+
12+
Module Interface
13+
________________
14+
15+
.. autoclass:: torchmetrics.clustering.DaviesBouldinScore
16+
:exclude-members: update, compute
17+
18+
Functional Interface
19+
____________________
20+
21+
.. autofunction:: torchmetrics.functional.clustering.davies_bouldin_score

docs/source/links.rst

+1
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,5 @@
158158
.. _fork of pycocotools: https://github.com/ppwwyyxx/cocoapi
159159
.. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index
160160
.. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index
161+
.. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
161162
.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score

src/torchmetrics/clustering/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore
1515
from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore
16+
from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore
1617
from torchmetrics.clustering.dunn_index import DunnIndex
1718
from torchmetrics.clustering.fowlkes_mallows_index import FowlkesMallowsIndex
1819
from torchmetrics.clustering.mutual_info_score import MutualInfoScore
@@ -22,6 +23,7 @@
2223
__all__ = [
2324
"AdjustedRandScore",
2425
"CalinskiHarabaszScore",
26+
"DaviesBouldinScore",
2527
"DunnIndex",
2628
"FowlkesMallowsIndex",
2729
"MutualInfoScore",

src/torchmetrics/clustering/calinski_harabasz_score.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,20 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
106106
107107
>>> # Example plotting a single value
108108
>>> import torch
109-
>>> from torchmetrics.clustering import RandScore
110-
>>> metric = RandScore()
111-
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
109+
>>> from torchmetrics.clustering import CalinskiHarabaszScore
110+
>>> metric = CalinskiHarabaszScore()
111+
>>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
112112
>>> fig_, ax_ = metric.plot(metric.compute())
113113
114114
.. plot::
115115
:scale: 75
116116
117117
>>> # Example plotting multiple values
118118
>>> import torch
119-
>>> from torchmetrics.clustering import RandScore
120-
>>> metric = RandScore()
119+
>>> from torchmetrics.clustering import CalinskiHarabaszScore
120+
>>> metric = CalinskiHarabaszScore()
121121
>>> for _ in range(10):
122-
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
122+
... metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
123123
>>> fig_, ax_ = metric.plot(metric.compute())
124124
125125
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, List, Optional, Sequence, Union
15+
16+
from torch import Tensor
17+
18+
from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score
19+
from torchmetrics.metric import Metric
20+
from torchmetrics.utilities.data import dim_zero_cat
21+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
22+
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
23+
24+
if not _MATPLOTLIB_AVAILABLE:
25+
__doctest_skip__ = ["DaviesBouldinScore.plot"]
26+
27+
28+
class DaviesBouldinScore(Metric):
29+
r"""Compute `Davies-Bouldin Score`_ for clustering algorithms.
30+
31+
Given the following quantities:
32+
33+
..math::
34+
S_i = \left( \frac{1}{T_i} \sum_{j=1}^{T_i} ||X_j - A_i||^2_2 \right)^{1/2}
35+
36+
where :math:`T_i` is the number of samples in cluster :math:`i`, :math:`X_j` is the :math:`j`-th sample in cluster
37+
:math:`i`, and :math:`A_i` is the centroid of cluster :math:`i`. This quantity is the average distance between all
38+
the samples in cluster :math:`i` and its centroid. Let
39+
40+
..math::
41+
M_{i,j} = ||A_i - A_j||_2
42+
43+
e.g. the distance between the centroids of cluster :math:`i` and cluster :math:`j`. Then the Davies-Bouldin score
44+
is defined as:
45+
46+
..math::
47+
DB = \frac{1}{n_{clusters}} \sum_{i=1}^{n_{clusters}} \max_{j \neq i} \left( \frac{S_i + S_j}{M_{i,j}} \right)
48+
49+
This clustering metric is an intrinsic measure, because it does not rely on ground truth labels for the evaluation.
50+
Instead it examines how well the clusters are separated from each other. The score is higher when clusters are dense
51+
and well separated, which relates to a standard concept of a cluster.
52+
53+
As input to ``forward`` and ``update`` the metric accepts the following input:
54+
55+
- ``data`` (:class:`~torch.Tensor`): float tensor with shape ``(N,d)`` with the embedded data. ``d`` is the
56+
dimensionality of the embedding space.
57+
- ``labels`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with cluster labels
58+
59+
As output of ``forward`` and ``compute`` the metric returns the following output:
60+
61+
- ``chs`` (:class:`~torch.Tensor`): A tensor with the Calinski Harabasz Score
62+
63+
Args:
64+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
65+
66+
Example:
67+
>>> import torch
68+
>>> from torchmetrics.clustering import DaviesBouldinScore
69+
>>> _ = torch.manual_seed(42)
70+
>>> data = torch.randn(10, 3)
71+
>>> labels = torch.randint(3, (10,))
72+
>>> metric = DaviesBouldinScore()
73+
>>> metric(data, labels)
74+
tensor(1.2540)
75+
76+
"""
77+
is_differentiable: bool = True
78+
higher_is_better: bool = True
79+
full_state_update: bool = False
80+
plot_lower_bound: float = 0.0
81+
data: List[Tensor]
82+
labels: List[Tensor]
83+
84+
def __init__(self, **kwargs: Any) -> None:
85+
super().__init__(**kwargs)
86+
87+
self.add_state("data", default=[], dist_reduce_fx="cat")
88+
self.add_state("labels", default=[], dist_reduce_fx="cat")
89+
90+
def update(self, data: Tensor, labels: Tensor) -> None:
91+
"""Update metric state with new data and labels."""
92+
self.data.append(data)
93+
self.labels.append(labels)
94+
95+
def compute(self) -> Tensor:
96+
"""Compute the Davies Bouldin Score over all data and labels."""
97+
return davies_bouldin_score(dim_zero_cat(self.data), dim_zero_cat(self.labels))
98+
99+
def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
100+
"""Plot a single or multiple values from the metric.
101+
102+
Args:
103+
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
104+
If no value is provided, will automatically call `metric.compute` and plot that result.
105+
ax: An matplotlib axis object. If provided will add plot to that axis
106+
107+
Returns:
108+
Figure and Axes object
109+
110+
Raises:
111+
ModuleNotFoundError:
112+
If `matplotlib` is not installed
113+
114+
.. plot::
115+
:scale: 75
116+
117+
>>> # Example plotting a single value
118+
>>> import torch
119+
>>> from torchmetrics.clustering import DaviesBouldinScore
120+
>>> metric = DaviesBouldinScore()
121+
>>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
122+
>>> fig_, ax_ = metric.plot(metric.compute())
123+
124+
.. plot::
125+
:scale: 75
126+
127+
>>> # Example plotting multiple values
128+
>>> import torch
129+
>>> from torchmetrics.clustering import DaviesBouldinScore
130+
>>> metric = DaviesBouldinScore()
131+
>>> for _ in range(10):
132+
... metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
133+
>>> fig_, ax_ = metric.plot(metric.compute())
134+
135+
"""
136+
return self._plot(val, ax)

src/torchmetrics/functional/clustering/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score
1515
from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score
16+
from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score
1617
from torchmetrics.functional.clustering.dunn_index import dunn_index
1718
from torchmetrics.functional.clustering.fowlkes_mallows_index import fowlkes_mallows_index
1819
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
@@ -22,6 +23,7 @@
2223
__all__ = [
2324
"adjusted_rand_score",
2425
"calinski_harabasz_score",
26+
"davies_bouldin_score",
2527
"dunn_index",
2628
"fowlkes_mallows_index",
2729
"mutual_info_score",

src/torchmetrics/functional/clustering/calinski_harabasz_score.py

+6-17
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,10 @@
1414
import torch
1515
from torch import Tensor
1616

17-
18-
def _calinski_harabasz_score_validate_input(data: Tensor, labels: Tensor) -> None:
19-
"""Validate that the input data and labels have correct shape and type."""
20-
if data.ndim != 2:
21-
raise ValueError(f"Expected 2D data, got {data.ndim}D data instead")
22-
if not data.is_floating_point():
23-
raise ValueError(f"Expected floating point data, got {data.dtype} data instead")
24-
if labels.ndim != 1:
25-
raise ValueError(f"Expected 1D labels, got {labels.ndim}D labels instead")
17+
from torchmetrics.functional.clustering.utils import (
18+
_validate_intrinsic_cluster_data,
19+
_validate_intrinsic_labels_to_samples,
20+
)
2621

2722

2823
def calinski_harabasz_score(data: Tensor, labels: Tensor) -> Tensor:
@@ -45,19 +40,13 @@ def calinski_harabasz_score(data: Tensor, labels: Tensor) -> Tensor:
4540
tensor(3.4998)
4641
4742
"""
48-
_calinski_harabasz_score_validate_input(data, labels)
43+
_validate_intrinsic_cluster_data(data, labels)
4944

5045
# convert to zero indexed labels
5146
unique_labels, labels = torch.unique(labels, return_inverse=True)
5247
n_labels = len(unique_labels)
53-
5448
n_samples = data.shape[0]
55-
56-
if not 1 < n_labels < n_samples:
57-
raise ValueError(
58-
"Number of detected clusters must be greater than one and less than the number of samples."
59-
f"Got {n_labels} clusters and {n_samples} samples."
60-
)
49+
_validate_intrinsic_labels_to_samples(n_labels, n_samples)
6150

6251
mean = data.mean(dim=0)
6352
between_cluster_dispersion = torch.tensor(0.0, device=data.device)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
from torch import Tensor
16+
17+
from torchmetrics.functional.clustering.utils import (
18+
_validate_intrinsic_cluster_data,
19+
_validate_intrinsic_labels_to_samples,
20+
)
21+
22+
23+
def davies_bouldin_score(data: Tensor, labels: Tensor) -> Tensor:
24+
"""Compute the Davies bouldin score for clustering algorithms.
25+
26+
Args:
27+
data: float tensor with shape ``(N,d)`` with the embedded data.
28+
labels: single integer tensor with shape ``(N,)`` with cluster labels
29+
30+
Returns:
31+
Scalar tensor with the Davies bouldin score
32+
33+
Example:
34+
>>> import torch
35+
>>> from torchmetrics.functional.clustering import davies_bouldin_score
36+
>>> _ = torch.manual_seed(42)
37+
>>> data = torch.randn(10, 3)
38+
>>> labels = torch.randint(0, 2, (10,))
39+
>>> davies_bouldin_score(data, labels)
40+
tensor(1.3249)
41+
42+
"""
43+
_validate_intrinsic_cluster_data(data, labels)
44+
45+
# convert to zero indexed labels
46+
unique_labels, labels = torch.unique(labels, return_inverse=True)
47+
n_labels = len(unique_labels)
48+
n_samples, dim = data.shape
49+
_validate_intrinsic_labels_to_samples(n_labels, n_samples)
50+
51+
intra_dists = torch.zeros(n_labels, device=data.device)
52+
centroids = torch.zeros((n_labels, dim), device=data.device)
53+
for k in range(n_labels):
54+
cluster_k = data[labels == k, :]
55+
centroids[k] = cluster_k.mean(dim=0)
56+
intra_dists[k] = (cluster_k - centroids[k]).pow(2.0).sum(dim=1).sqrt().mean()
57+
centroid_distances = torch.cdist(centroids, centroids)
58+
59+
cond1 = torch.allclose(intra_dists, torch.zeros_like(intra_dists))
60+
cond2 = torch.allclose(centroid_distances, torch.zeros_like(centroid_distances))
61+
if cond1 or cond2:
62+
return torch.tensor(0.0, device=data.device, dtype=torch.float32)
63+
64+
centroid_distances[centroid_distances == 0] = float("inf")
65+
combined_intra_dists = intra_dists.unsqueeze(0) + intra_dists.unsqueeze(1)
66+
scores = (combined_intra_dists / centroid_distances).max(dim=1).values
67+
return scores.mean()

src/torchmetrics/functional/clustering/utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,25 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None:
171171
raise ValueError(f"Expected real, discrete values for x but received {preds.dtype} and {target.dtype}.")
172172

173173

174+
def _validate_intrinsic_cluster_data(data: Tensor, labels: Tensor) -> None:
175+
"""Validate that the input data and labels have correct shape and type."""
176+
if data.ndim != 2:
177+
raise ValueError(f"Expected 2D data, got {data.ndim}D data instead")
178+
if not data.is_floating_point():
179+
raise ValueError(f"Expected floating point data, got {data.dtype} data instead")
180+
if labels.ndim != 1:
181+
raise ValueError(f"Expected 1D labels, got {labels.ndim}D labels instead")
182+
183+
184+
def _validate_intrinsic_labels_to_samples(n_labels: int, n_samples: int) -> None:
185+
"""Validate that the number of labels are in the correct range."""
186+
if not 1 < n_labels < n_samples:
187+
raise ValueError(
188+
"Number of detected clusters must be greater than one and less than the number of samples."
189+
f"Got {n_labels} clusters and {n_samples} samples."
190+
)
191+
192+
174193
def calcualte_pair_cluster_confusion_matrix(
175194
preds: Optional[Tensor] = None,
176195
target: Optional[Tensor] = None,

0 commit comments

Comments
 (0)