diff --git a/CHANGELOG.md b/CHANGELOG.md index 196076fdbe7..8c6efbc16ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442)) +- Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786)) + + - Added a new audio metric `NISQA` ([#2792](https://github.com/PyTorchLightning/metrics/pull/2792)) @@ -25,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) + ### Deprecated - Deprecated Dice from classification metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725)) diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py index 9df67a8d802..5c1f5f49276 100644 --- a/src/torchmetrics/clustering/adjusted_rand_score.py +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -64,7 +64,7 @@ class AdjustedRandScore(Metric): is_differentiable = True higher_is_better = None - full_state_update: bool = True + full_state_update: bool = False plot_lower_bound: float = -0.5 plot_upper_bound: float = 1.0 preds: List[Tensor] diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index 9373db1045e..5a85074443d 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -64,7 +64,7 @@ class DunnIndex(Metric): is_differentiable: bool = True higher_is_better: bool = True - full_state_update: bool = True + full_state_update: bool = False plot_lower_bound: float = 0.0 data: List[Tensor] labels: List[Tensor] diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 84f244a1bbe..8ded8b27d0d 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -63,7 +63,7 @@ class RandScore(Metric): is_differentiable = True higher_is_better = None - full_state_update: bool = True + full_state_update: bool = False plot_lower_bound: float = 0.0 preds: List[Tensor] target: List[Tensor] diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 940e393c6d1..dd912aaa28e 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -52,31 +52,34 @@ class Metric(Module, ABC): """Base class for all metrics present in the Metrics API. This class is inherited by all metrics and implements the following functionality: - 1. Handles the transfer of metric states to correct device - 2. Handles the synchronization of metric states across processes - The three core methods of the base class are - * ``add_state()`` - * ``forward()`` - * ``reset()`` - - which should almost never be overwritten by child classes. Instead, the following methods should be overwritten - * ``update()`` - * ``compute()`` + 1. Handles the transfer of metric states to the correct device. + 2. Handles the synchronization of metric states across processes. + 3. Provides properties and methods to control the overall behavior of the metric and its states. + The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost + never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and + ``compute()``. Args: kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info. - - compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states. - - dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False`` - - process_group: The process group on which the synchronization is called. Default is the world. - - dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom - implementation that calls ``torch.distributed.all_gather`` internally. - - distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a - check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``. - - sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True`` - - compute_with_cache: If results from ``compute`` should be cached. Default is ``True`` + - **compute_on_cpu**: + If metric state should be stored on CPU during computations. Only works for list states. + - **dist_sync_on_step**: + If metric state should synchronize on ``forward()``. Default is ``False``. + - **process_group**: + The process group on which the synchronization is called. Default is the world. + - **dist_sync_fn**: + Function that performs the allgather option on the metric state. Default is a custom + implementation that calls ``torch.distributed.all_gather`` internally. + - **distributed_available_fn**: + Function that checks if the distributed backend is available. Defaults to a + check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``. + - **sync_on_compute**: + If metric state should synchronize when ``compute`` is called. Default is ``True``. + - **compute_with_cache**: + If results from ``compute`` should be cached. Default is ``True``. """ @@ -222,7 +225,7 @@ def add_state( persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. Default is ``False``. - Note: + .. note:: Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. However, there won't be any reduction function applied to the synchronized metric state. @@ -236,11 +239,11 @@ def add_state( - If the metric state is a ``list``, the synced value will be a ``list`` containing the combined elements from all processes. - Note: + .. note:: When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow the format discussed in the above note. - Note: + .. note:: The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another @@ -398,6 +401,67 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: return batch_val + def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: + """Merge incoming metric state to the current state of the metric. + + Args: + incoming_state: + either a dict containing a metric state similar to the metric itself or an instance of the + metric class. + + Raises: + ValueError: + If the incoming state is neither a dict nor an instance of the metric class. + RuntimeError: + If the metric has ``full_state_update=True`` or ``dist_sync_on_step=True``. In these cases, the metric + cannot be merged with another metric state in a simple way. The user should overwrite the method in the + metric class to handle the merge operation. + ValueError: + If the incoming state is a metric instance but the class is different from the current metric class. + + Example with a metric instance: + + >>> from torchmetrics.aggregation import SumMetric + >>> metric1 = SumMetric() + >>> metric2 = SumMetric() + >>> metric1.update(1) + >>> metric2.update(2) + >>> metric1.merge_state(metric2) + >>> metric1.compute() + tensor(3.) + + Example with a dict: + + >>> from torchmetrics.aggregation import SumMetric + >>> metric = SumMetric() + >>> metric.update(1) + >>> # SumMetric has one state variable called `sum_value` + >>> metric.merge_state({"sum_value": torch.tensor(2)}) + >>> metric.compute() + tensor(3.) + + """ + if not isinstance(incoming_state, (dict, Metric)): + raise ValueError( + f"Expected incoming state to be a dict or an instance of Metric but got {type(incoming_state)}" + ) + + if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step: + raise RuntimeError( + "``merge_state`` is not supported for metrics with ``full_state_update=True`` or " + "``dist_sync_on_step=True``. Please overwrite the merge_state method in the metric class." + ) + + if isinstance(incoming_state, Metric): + this_class = self.__class__ + if not isinstance(incoming_state, this_class): + raise ValueError( + f"Expected incoming state to be an instance of {this_class.__name__} but got {type(incoming_state)}" + ) + incoming_state = incoming_state.metric_state + + self._reduce_states(incoming_state) + def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: """Add an incoming metric state to the current state of the metric. @@ -407,6 +471,8 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: """ for attr in self._defaults: local_state = getattr(self, attr) + if attr not in incoming_state: + raise ValueError(f"Expected state variable {attr} to be present in incoming state {incoming_state}") global_state = incoming_state[attr] reduce_fn = self._reductions[attr] if reduce_fn == dim_zero_sum: diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 753150478e4..363b2d31a66 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -24,8 +24,11 @@ import torch from torch import Tensor, tensor from torch.nn import Module, Parameter +from torchmetrics.aggregation import MeanMetric, SumMetric from torchmetrics.classification import BinaryAccuracy -from torchmetrics.regression import PearsonCorrCoef +from torchmetrics.clustering import AdjustedRandScore +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.regression import PearsonCorrCoef, R2Score from unittests._helpers import seed_all from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum @@ -609,3 +612,76 @@ def test_dtype_property(): assert metric.dtype == torch.float64 # should not change after initialization metric.set_dtype(torch.float32) assert metric.dtype == torch.float32 + + +def test_merge_state_feature_basic(): + """Check the merge_state method works as expected for a basic metric.""" + metric1 = SumMetric() + metric2 = SumMetric() + metric1.update(1) + metric2.update(2) + metric1.merge_state(metric2) + assert metric1.compute() == 3 + + metric = SumMetric() + metric.update(1) + metric.merge_state({"sum_value": torch.tensor(2)}) + assert metric.compute() == 3 + + +def test_merge_state_feature_raises_errors(): + """Check the merge_state method raises errors when expected.""" + + class TempMetric(SumMetric): + full_state_update = True + + metric = TempMetric() + metric2 = SumMetric() + metric3 = MeanMetric() + + with pytest.raises(ValueError, match="Expected incoming state to be a.*"): + metric.merge_state(2) + + with pytest.raises(RuntimeError, match="``merge_state`` is not supported.*"): + metric.merge_state({"sum_value": torch.tensor(2)}) + + with pytest.raises(ValueError, match="Expected incoming state to be an.*"): + metric2.merge_state(metric3) + + +@pytest.mark.parametrize( + ("metric_class", "preds", "target"), + [ + (BinaryAccuracy, lambda: torch.randint(2, (100,)), lambda: torch.randint(2, (100,))), + (R2Score, lambda: torch.randn(100), lambda: torch.randn(100)), + (StructuralSimilarityIndexMeasure, lambda: torch.randn(1, 3, 25, 25), lambda: torch.randn(1, 3, 25, 25)), + (AdjustedRandScore, lambda: torch.randint(10, (100,)), lambda: torch.randint(10, (100,))), + ], +) +def test_merge_state_feature_for_different_metrics(metric_class, preds, target): + """Check the merge_state method works as expected for different metrics. + + It should work such that the metric is the same as if it had seen the data twice, but in different ways. + + """ + metric1_1 = metric_class() + metric1_2 = metric_class() + metric2 = metric_class() + + preds1, target1 = preds(), target() + preds2, target2 = preds(), target() + + metric1_1.update(preds1, target1) + metric1_2.update(preds2, target2) + metric2.update(preds1, target1) + metric2.update(preds2, target2) + metric1_1.merge_state(metric1_2) + + # should be the same because it has seen the same data twice, but in different ways + res1 = metric1_1.compute() + res2 = metric2.compute() + assert torch.allclose(res1, res2) + + # should not be the same because it has only seen half the data + res3 = metric1_2.compute() + assert not torch.allclose(res3, res2)