Skip to content

Commit

Permalink
Feature: merge_state method (#2786)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: jirka <jirka.borovec@seznam.cz>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 30, 2024
1 parent 76c502b commit 70dabf1
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 26 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))


- Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786))


- Added a new audio metric `NISQA` ([#2792](https://github.com/PyTorchLightning/metrics/pull/2792))


Expand All @@ -25,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))


### Deprecated

- Deprecated Dice from classification metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/adjusted_rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class AdjustedRandScore(Metric):

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = -0.5
plot_upper_bound: float = 1.0
preds: List[Tensor]
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/dunn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DunnIndex(Metric):

is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
data: List[Tensor]
labels: List[Tensor]
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class RandScore(Metric):

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
Expand Down
110 changes: 88 additions & 22 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,34 @@ class Metric(Module, ABC):
"""Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
1. Handles the transfer of metric states to correct device
2. Handles the synchronization of metric states across processes
The three core methods of the base class are
* ``add_state()``
* ``forward()``
* ``reset()``
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten
* ``update()``
* ``compute()``
1. Handles the transfer of metric states to the correct device.
2. Handles the synchronization of metric states across processes.
3. Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
``compute()``.
Args:
kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
- compute_with_cache: If results from ``compute`` should be cached. Default is ``True``
- **compute_on_cpu**:
If metric state should be stored on CPU during computations. Only works for list states.
- **dist_sync_on_step**:
If metric state should synchronize on ``forward()``. Default is ``False``.
- **process_group**:
The process group on which the synchronization is called. Default is the world.
- **dist_sync_fn**:
Function that performs the allgather option on the metric state. Default is a custom
implementation that calls ``torch.distributed.all_gather`` internally.
- **distributed_available_fn**:
Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- **sync_on_compute**:
If metric state should synchronize when ``compute`` is called. Default is ``True``.
- **compute_with_cache**:
If results from ``compute`` should be cached. Default is ``True``.
"""

Expand Down Expand Up @@ -222,7 +225,7 @@ def add_state(
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.
Note:
.. note::
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
However, there won't be any reduction function applied to the synchronized metric state.
Expand All @@ -236,11 +239,11 @@ def add_state(
- If the metric state is a ``list``, the synced value will be a ``list`` containing the
combined elements from all processes.
Note:
.. note::
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.
Note:
.. note::
The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows
device memory to be automatically reallocated, but may produce unexpected effects when referencing list
states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another
Expand Down Expand Up @@ -398,6 +401,67 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:

return batch_val

def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None:
"""Merge incoming metric state to the current state of the metric.
Args:
incoming_state:
either a dict containing a metric state similar to the metric itself or an instance of the
metric class.
Raises:
ValueError:
If the incoming state is neither a dict nor an instance of the metric class.
RuntimeError:
If the metric has ``full_state_update=True`` or ``dist_sync_on_step=True``. In these cases, the metric
cannot be merged with another metric state in a simple way. The user should overwrite the method in the
metric class to handle the merge operation.
ValueError:
If the incoming state is a metric instance but the class is different from the current metric class.
Example with a metric instance:
>>> from torchmetrics.aggregation import SumMetric
>>> metric1 = SumMetric()
>>> metric2 = SumMetric()
>>> metric1.update(1)
>>> metric2.update(2)
>>> metric1.merge_state(metric2)
>>> metric1.compute()
tensor(3.)
Example with a dict:
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> # SumMetric has one state variable called `sum_value`
>>> metric.merge_state({"sum_value": torch.tensor(2)})
>>> metric.compute()
tensor(3.)
"""
if not isinstance(incoming_state, (dict, Metric)):
raise ValueError(
f"Expected incoming state to be a dict or an instance of Metric but got {type(incoming_state)}"
)

if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step:
raise RuntimeError(
"``merge_state`` is not supported for metrics with ``full_state_update=True`` or "
"``dist_sync_on_step=True``. Please overwrite the merge_state method in the metric class."
)

if isinstance(incoming_state, Metric):
this_class = self.__class__
if not isinstance(incoming_state, this_class):
raise ValueError(
f"Expected incoming state to be an instance of {this_class.__name__} but got {type(incoming_state)}"
)
incoming_state = incoming_state.metric_state

self._reduce_states(incoming_state)

def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
"""Add an incoming metric state to the current state of the metric.
Expand All @@ -407,6 +471,8 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
"""
for attr in self._defaults:
local_state = getattr(self, attr)
if attr not in incoming_state:
raise ValueError(f"Expected state variable {attr} to be present in incoming state {incoming_state}")
global_state = incoming_state[attr]
reduce_fn = self._reductions[attr]
if reduce_fn == dim_zero_sum:
Expand Down
78 changes: 77 additions & 1 deletion tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import torch
from torch import Tensor, tensor
from torch.nn import Module, Parameter
from torchmetrics.aggregation import MeanMetric, SumMetric
from torchmetrics.classification import BinaryAccuracy
from torchmetrics.regression import PearsonCorrCoef
from torchmetrics.clustering import AdjustedRandScore
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.regression import PearsonCorrCoef, R2Score

from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
Expand Down Expand Up @@ -609,3 +612,76 @@ def test_dtype_property():
assert metric.dtype == torch.float64 # should not change after initialization
metric.set_dtype(torch.float32)
assert metric.dtype == torch.float32


def test_merge_state_feature_basic():
"""Check the merge_state method works as expected for a basic metric."""
metric1 = SumMetric()
metric2 = SumMetric()
metric1.update(1)
metric2.update(2)
metric1.merge_state(metric2)
assert metric1.compute() == 3

metric = SumMetric()
metric.update(1)
metric.merge_state({"sum_value": torch.tensor(2)})
assert metric.compute() == 3


def test_merge_state_feature_raises_errors():
"""Check the merge_state method raises errors when expected."""

class TempMetric(SumMetric):
full_state_update = True

metric = TempMetric()
metric2 = SumMetric()
metric3 = MeanMetric()

with pytest.raises(ValueError, match="Expected incoming state to be a.*"):
metric.merge_state(2)

with pytest.raises(RuntimeError, match="``merge_state`` is not supported.*"):
metric.merge_state({"sum_value": torch.tensor(2)})

with pytest.raises(ValueError, match="Expected incoming state to be an.*"):
metric2.merge_state(metric3)


@pytest.mark.parametrize(
("metric_class", "preds", "target"),
[
(BinaryAccuracy, lambda: torch.randint(2, (100,)), lambda: torch.randint(2, (100,))),
(R2Score, lambda: torch.randn(100), lambda: torch.randn(100)),
(StructuralSimilarityIndexMeasure, lambda: torch.randn(1, 3, 25, 25), lambda: torch.randn(1, 3, 25, 25)),
(AdjustedRandScore, lambda: torch.randint(10, (100,)), lambda: torch.randint(10, (100,))),
],
)
def test_merge_state_feature_for_different_metrics(metric_class, preds, target):
"""Check the merge_state method works as expected for different metrics.
It should work such that the metric is the same as if it had seen the data twice, but in different ways.
"""
metric1_1 = metric_class()
metric1_2 = metric_class()
metric2 = metric_class()

preds1, target1 = preds(), target()
preds2, target2 = preds(), target()

metric1_1.update(preds1, target1)
metric1_2.update(preds2, target2)
metric2.update(preds1, target1)
metric2.update(preds2, target2)
metric1_1.merge_state(metric1_2)

# should be the same because it has seen the same data twice, but in different ways
res1 = metric1_1.compute()
res2 = metric2.compute()
assert torch.allclose(res1, res2)

# should not be the same because it has only seen half the data
res3 = metric1_2.compute()
assert not torch.allclose(res3, res2)

0 comments on commit 70dabf1

Please sign in to comment.