|
| 1 | +# Copyright The Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from typing import Any, List, Optional, Sequence, Union |
| 15 | + |
| 16 | +from torch import Tensor |
| 17 | + |
| 18 | +from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score |
| 19 | +from torchmetrics.metric import Metric |
| 20 | +from torchmetrics.utilities.data import dim_zero_cat |
| 21 | +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE |
| 22 | +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE |
| 23 | + |
| 24 | +if not _MATPLOTLIB_AVAILABLE: |
| 25 | + __doctest_skip__ = ["DaviesBouldinScore.plot"] |
| 26 | + |
| 27 | + |
| 28 | +class DaviesBouldinScore(Metric): |
| 29 | + r"""Compute `Davies-Bouldin Score`_ for clustering algorithms. |
| 30 | +
|
| 31 | + Given the following quantities: |
| 32 | +
|
| 33 | + ..math:: |
| 34 | + S_i = \left( \frac{1}{T_i} \sum_{j=1}^{T_i} ||X_j - A_i||^2_2 \right)^{1/2} |
| 35 | +
|
| 36 | + where :math:`T_i` is the number of samples in cluster :math:`i`, :math:`X_j` is the :math:`j`-th sample in cluster |
| 37 | + :math:`i`, and :math:`A_i` is the centroid of cluster :math:`i`. This quantity is the average distance between all |
| 38 | + the samples in cluster :math:`i` and its centroid. Let |
| 39 | +
|
| 40 | + ..math:: |
| 41 | + M_{i,j} = ||A_i - A_j||_2 |
| 42 | +
|
| 43 | + e.g. the distance between the centroids of cluster :math:`i` and cluster :math:`j`. Then the Davies-Bouldin score |
| 44 | + is defined as: |
| 45 | +
|
| 46 | + ..math:: |
| 47 | + DB = \frac{1}{n_{clusters}} \sum_{i=1}^{n_{clusters}} \max_{j \neq i} \left( \frac{S_i + S_j}{M_{i,j}} \right) |
| 48 | +
|
| 49 | + This clustering metric is an intrinsic measure, because it does not rely on ground truth labels for the evaluation. |
| 50 | + Instead it examines how well the clusters are separated from each other. The score is higher when clusters are dense |
| 51 | + and well separated, which relates to a standard concept of a cluster. |
| 52 | +
|
| 53 | + As input to ``forward`` and ``update`` the metric accepts the following input: |
| 54 | +
|
| 55 | + - ``data`` (:class:`~torch.Tensor`): float tensor with shape ``(N,d)`` with the embedded data. ``d`` is the |
| 56 | + dimensionality of the embedding space. |
| 57 | + - ``labels`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with cluster labels |
| 58 | +
|
| 59 | + As output of ``forward`` and ``compute`` the metric returns the following output: |
| 60 | +
|
| 61 | + - ``chs`` (:class:`~torch.Tensor`): A tensor with the Calinski Harabasz Score |
| 62 | +
|
| 63 | + Args: |
| 64 | + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
| 65 | +
|
| 66 | + Example: |
| 67 | + >>> import torch |
| 68 | + >>> from torchmetrics.clustering import DaviesBouldinScore |
| 69 | + >>> _ = torch.manual_seed(42) |
| 70 | + >>> data = torch.randn(10, 3) |
| 71 | + >>> labels = torch.randint(3, (10,)) |
| 72 | + >>> metric = DaviesBouldinScore() |
| 73 | + >>> metric(data, labels) |
| 74 | + tensor(1.2540) |
| 75 | +
|
| 76 | + """ |
| 77 | + is_differentiable: bool = True |
| 78 | + higher_is_better: bool = True |
| 79 | + full_state_update: bool = False |
| 80 | + plot_lower_bound: float = 0.0 |
| 81 | + data: List[Tensor] |
| 82 | + labels: List[Tensor] |
| 83 | + |
| 84 | + def __init__(self, **kwargs: Any) -> None: |
| 85 | + super().__init__(**kwargs) |
| 86 | + |
| 87 | + self.add_state("data", default=[], dist_reduce_fx="cat") |
| 88 | + self.add_state("labels", default=[], dist_reduce_fx="cat") |
| 89 | + |
| 90 | + def update(self, data: Tensor, labels: Tensor) -> None: |
| 91 | + """Update metric state with new data and labels.""" |
| 92 | + self.data.append(data) |
| 93 | + self.labels.append(labels) |
| 94 | + |
| 95 | + def compute(self) -> Tensor: |
| 96 | + """Compute the Davies Bouldin Score over all data and labels.""" |
| 97 | + return davies_bouldin_score(dim_zero_cat(self.data), dim_zero_cat(self.labels)) |
| 98 | + |
| 99 | + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: |
| 100 | + """Plot a single or multiple values from the metric. |
| 101 | +
|
| 102 | + Args: |
| 103 | + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. |
| 104 | + If no value is provided, will automatically call `metric.compute` and plot that result. |
| 105 | + ax: An matplotlib axis object. If provided will add plot to that axis |
| 106 | +
|
| 107 | + Returns: |
| 108 | + Figure and Axes object |
| 109 | +
|
| 110 | + Raises: |
| 111 | + ModuleNotFoundError: |
| 112 | + If `matplotlib` is not installed |
| 113 | +
|
| 114 | + .. plot:: |
| 115 | + :scale: 75 |
| 116 | +
|
| 117 | + >>> # Example plotting a single value |
| 118 | + >>> import torch |
| 119 | + >>> from torchmetrics.clustering import DaviesBouldinScore |
| 120 | + >>> metric = DaviesBouldinScore() |
| 121 | + >>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,))) |
| 122 | + >>> fig_, ax_ = metric.plot(metric.compute()) |
| 123 | +
|
| 124 | + .. plot:: |
| 125 | + :scale: 75 |
| 126 | +
|
| 127 | + >>> # Example plotting multiple values |
| 128 | + >>> import torch |
| 129 | + >>> from torchmetrics.clustering import DaviesBouldinScore |
| 130 | + >>> metric = DaviesBouldinScore() |
| 131 | + >>> for _ in range(10): |
| 132 | + ... metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,))) |
| 133 | + >>> fig_, ax_ = metric.plot(metric.compute()) |
| 134 | +
|
| 135 | + """ |
| 136 | + return self._plot(val, ax) |
0 commit comments