diff --git a/CHANGELOG.md b/CHANGELOG.md index b248b4e3d49..74ab706d6a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226)) +- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2211](https://github.com/Lightning-AI/torchmetrics/pull/2211)) + ## [1.2.0] - 2023-09-22 ### Added diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 3a5650b0cae..06b9b7b4c4e 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -304,6 +304,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None: # Determine if we just should set a reference or a full copy setattr(mi, state, deepcopy(m0_state) if copy else m0_state) mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count + mi._computed = deepcopy(m0._computed) if copy else m0._computed self._state_is_copy = copy def compute(self) -> Dict[str, Any]: diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index ce4bb1ba8c8..caea6091c28 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -411,7 +411,8 @@ class TestComputeGroups: ("prefix_", "_postfix"), ], ) - def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix): + @pytest.mark.parametrize("with_reset", [True, False]) + def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix, with_reset): """Check that compute groups are formed after initialization and that metrics are correctly computed.""" if isinstance(metrics, MetricCollection): prefix, postfix = None, None # disable for nested collections @@ -445,8 +446,9 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target for key in res_cg: assert torch.allclose(res_cg[key], res_without_cg[key]) - m.reset() - m2.reset() + if with_reset: + m.reset() + m2.reset() @pytest.mark.parametrize("method", ["items", "values", "keys"]) def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):