Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: merge_state method #2786

Merged
merged 33 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ec741fd
merge state method
SkafteNicki Oct 15, 2024
6eab70a
small change
SkafteNicki Oct 15, 2024
e71b041
refactor: rename merge_states method to merge_state and update relate…
SkafteNicki Oct 18, 2024
293841a
fix: set full_state_update to False in clustering metrics
SkafteNicki Oct 18, 2024
17880a3
tests
SkafteNicki Oct 18, 2024
9954207
changelog
SkafteNicki Oct 18, 2024
2724904
Merge branch 'master' into feature/merge_state
SkafteNicki Oct 18, 2024
858b599
Update src/torchmetrics/metric.py
SkafteNicki Oct 18, 2024
0556b92
Update src/torchmetrics/metric.py
SkafteNicki Oct 18, 2024
94cf072
Merge branch 'master' into feature/merge_state
Borda Oct 18, 2024
0a668f2
Merge branch 'master' into feature/merge_state
SkafteNicki Oct 19, 2024
22acbcc
Merge branch 'master' into feature/merge_state
Borda Oct 21, 2024
fef70c2
Merge branch 'master' into feature/merge_state
Borda Oct 22, 2024
55997dc
Merge branch 'master' into feature/merge_state
SkafteNicki Oct 22, 2024
a0f0e0b
try fixing the docs
SkafteNicki Oct 22, 2024
1edfece
try fixing docs
SkafteNicki Oct 22, 2024
0384984
try fixing docs
SkafteNicki Oct 22, 2024
79c959e
try fixing docs
SkafteNicki Oct 22, 2024
b763ea9
try fixing doc error
SkafteNicki Oct 22, 2024
91e372d
Merge branch 'master' into feature/merge_state
Borda Oct 23, 2024
9ce304b
try fixing docs
SkafteNicki Oct 24, 2024
4dfc3d5
Merge branch 'master' into feature/merge_state
SkafteNicki Oct 24, 2024
9b9178c
fixes
SkafteNicki Oct 24, 2024
253704b
Merge branch 'master' into feature/merge_state
Borda Oct 29, 2024
b8f3d59
Apply suggestions from code review
Borda Oct 29, 2024
b0dc3db
cleaning
Borda Oct 29, 2024
95c804c
lint
Borda Oct 29, 2024
37e1bd6
example
Borda Oct 29, 2024
bf3f787
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
889801a
Merge branch 'master' into feature/merge_state
Borda Oct 29, 2024
093b561
Merge branch 'master' into feature/merge_state
Borda Oct 29, 2024
f4bdd32
Merge branch 'master' into feature/merge_state
mergify[bot] Oct 30, 2024
5fefa58
Merge branch 'master' into feature/merge_state
Borda Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))


- Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786))


- Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/adjusted_rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class AdjustedRandScore(Metric):

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = -0.5
plot_upper_bound: float = 1.0
preds: List[Tensor]
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/dunn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DunnIndex(Metric):

is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
data: List[Tensor]
labels: List[Tensor]
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class RandScore(Metric):

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
Expand Down
110 changes: 88 additions & 22 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,34 @@ class Metric(Module, ABC):
"""Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:
1. Handles the transfer of metric states to correct device
2. Handles the synchronization of metric states across processes

The three core methods of the base class are
* ``add_state()``
* ``forward()``
* ``reset()``

which should almost never be overwritten by child classes. Instead, the following methods should be overwritten
* ``update()``
* ``compute()``
1. Handles the transfer of metric states to the correct device.
2. Handles the synchronization of metric states across processes.
3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
``compute()``.

Args:
kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

- compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
- compute_with_cache: If results from ``compute`` should be cached. Default is ``True``
- **compute_on_cpu**:
If metric state should be stored on CPU during computations. Only works for list states.
- **dist_sync_on_step**:
If metric state should synchronize on ``forward()``. Default is ``False``.
- **process_group**:
The process group on which the synchronization is called. Default is the world.
- **dist_sync_fn**:
Function that performs the allgather option on the metric state. Default is a custom
implementation that calls ``torch.distributed.all_gather`` internally.
- **distributed_available_fn**:
Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- **sync_on_compute**:
If metric state should synchronize when ``compute`` is called. Default is ``True``.
- **compute_with_cache**:
If results from ``compute`` should be cached. Default is ``True``.

"""

Expand Down Expand Up @@ -222,7 +225,7 @@ def add_state(
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.

Note:
.. note::
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
However, there won't be any reduction function applied to the synchronized metric state.

Expand All @@ -236,11 +239,11 @@ def add_state(
- If the metric state is a ``list``, the synced value will be a ``list`` containing the
combined elements from all processes.

Note:
.. note::
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.

Note:
.. note::
The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows
device memory to be automatically reallocated, but may produce unexpected effects when referencing list
states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another
Expand Down Expand Up @@ -398,6 +401,67 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:

return batch_val

def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None:
"""Merge incoming metric state to the current state of the metric.

Args:
incoming_state:
either a dict containing a metric state similar to the metric itself or an instance of the
metric class.

Raises:
ValueError:
If the incoming state is neither a dict nor an instance of the metric class.
RuntimeError:
If the metric has ``full_state_update=True`` or ``dist_sync_on_step=True``. In these cases, the metric
cannot be merged with another metric state in a simple way. The user should overwrite the method in the
metric class to handle the merge operation.
ValueError:
If the incoming state is a metric instance but the class is different from the current metric class.

Example with a metric instance:

>>> from torchmetrics.aggregation import SumMetric
>>> metric1 = SumMetric()
>>> metric2 = SumMetric()
>>> metric1.update(1)
>>> metric2.update(2)
>>> metric1.merge_state(metric2)
>>> metric1.compute()
tensor(3.)

Example with a dict:

>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> # SumMetric has one state variable called `sum_value`
>>> metric.merge_state({"sum_value": torch.tensor(2)})
>>> metric.compute()
tensor(3.)

"""
if not isinstance(incoming_state, (dict, Metric)):
raise ValueError(
f"Expected incoming state to be a dict or an instance of Metric but got {type(incoming_state)}"
)

if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step:
raise RuntimeError(
"``merge_state`` is not supported for metrics with ``full_state_update=True`` or "
"``dist_sync_on_step=True``. Please overwrite the merge_state method in the metric class."
)

if isinstance(incoming_state, Metric):
this_class = self.__class__
if not isinstance(incoming_state, this_class):
raise ValueError(
f"Expected incoming state to be an instance of {this_class.__name__} but got {type(incoming_state)}"
)
incoming_state = incoming_state.metric_state

self._reduce_states(incoming_state)

def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
"""Add an incoming metric state to the current state of the metric.

Expand All @@ -407,6 +471,8 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
"""
for attr in self._defaults:
local_state = getattr(self, attr)
if attr not in incoming_state:
raise ValueError(f"Expected state variable {attr} to be present in incoming state {incoming_state}")
global_state = incoming_state[attr]
reduce_fn = self._reductions[attr]
if reduce_fn == dim_zero_sum:
Expand Down
78 changes: 77 additions & 1 deletion tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import torch
from torch import Tensor, tensor
from torch.nn import Module, Parameter
from torchmetrics.aggregation import MeanMetric, SumMetric
from torchmetrics.classification import BinaryAccuracy
from torchmetrics.regression import PearsonCorrCoef
from torchmetrics.clustering import AdjustedRandScore
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.regression import PearsonCorrCoef, R2Score

from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
Expand Down Expand Up @@ -609,3 +612,76 @@ def test_dtype_property():
assert metric.dtype == torch.float64 # should not change after initialization
metric.set_dtype(torch.float32)
assert metric.dtype == torch.float32


def test_merge_state_feature_basic():
"""Check the merge_state method works as expected for a basic metric."""
metric1 = SumMetric()
metric2 = SumMetric()
metric1.update(1)
metric2.update(2)
metric1.merge_state(metric2)
assert metric1.compute() == 3

metric = SumMetric()
metric.update(1)
metric.merge_state({"sum_value": torch.tensor(2)})
assert metric.compute() == 3


def test_merge_state_feature_raises_errors():
"""Check the merge_state method raises errors when expected."""

class TempMetric(SumMetric):
full_state_update = True

metric = TempMetric()
metric2 = SumMetric()
metric3 = MeanMetric()

with pytest.raises(ValueError, match="Expected incoming state to be a.*"):
metric.merge_state(2)

with pytest.raises(RuntimeError, match="``merge_state`` is not supported.*"):
metric.merge_state({"sum_value": torch.tensor(2)})

with pytest.raises(ValueError, match="Expected incoming state to be an.*"):
metric2.merge_state(metric3)


@pytest.mark.parametrize(
("metric_class", "preds", "target"),
[
(BinaryAccuracy, lambda: torch.randint(2, (100,)), lambda: torch.randint(2, (100,))),
(R2Score, lambda: torch.randn(100), lambda: torch.randn(100)),
(StructuralSimilarityIndexMeasure, lambda: torch.randn(1, 3, 25, 25), lambda: torch.randn(1, 3, 25, 25)),
(AdjustedRandScore, lambda: torch.randint(10, (100,)), lambda: torch.randint(10, (100,))),
],
)
def test_merge_state_feature_for_different_metrics(metric_class, preds, target):
"""Check the merge_state method works as expected for different metrics.

It should work such that the metric is the same as if it had seen the data twice, but in different ways.

"""
metric1_1 = metric_class()
metric1_2 = metric_class()
metric2 = metric_class()

preds1, target1 = preds(), target()
preds2, target2 = preds(), target()

metric1_1.update(preds1, target1)
metric1_2.update(preds2, target2)
metric2.update(preds1, target1)
metric2.update(preds2, target2)
metric1_1.merge_state(metric1_2)

# should be the same because it has seen the same data twice, but in different ways
res1 = metric1_1.compute()
res2 = metric2.compute()
assert torch.allclose(res1, res2)

# should not be the same because it has only seen half the data
res3 = metric1_2.compute()
assert not torch.allclose(res3, res2)
Loading