diff --git a/CHANGELOG.md b/CHANGELOG.md index 45d79362517..34d6cf354e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2571](https://github.com/Lightning-AI/torchmetrics/pull/2571)) + + - Fixed class order of `panoptic_quality(..., return_per_class=True)` output ([#2548](https://github.com/Lightning-AI/torchmetrics/pull/2548)) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index e4b0dbafd2a..fbcf6a6ac51 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -187,6 +187,11 @@ def __init__( self.add_metrics(metrics, *additional_metrics) + @property + def metric_state(self) -> Dict[str, Dict[str, Any]]: + """Get the current state of the metric.""" + return {k: m.metric_state for k, m in self.items(keep_base=False, copy_state=False)} + @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: """Call forward for each metric sequentially. @@ -206,6 +211,11 @@ def update(self, *args: Any, **kwargs: Any) -> None: """ # Use compute groups if already initialized and checked if self._groups_checked: + # Delete the cache of all metrics to invalidate the cache and therefore recent compute calls, forcing new + # compute calls to recompute + for k in self.keys(keep_base=True): + mi = getattr(self, str(k)) + mi._computed = None for cg in self._groups.values(): # only update the first member m0 = getattr(self, cg[0]) @@ -304,7 +314,6 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None: # Determine if we just should set a reference or a full copy setattr(mi, state, deepcopy(m0_state) if copy else m0_state) mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count - mi._computed = deepcopy(m0._computed) if copy else m0._computed self._state_is_copy = copy def compute(self) -> Dict[str, Any]: diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index a210995eeaf..3b4dfcdc9ed 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -630,6 +630,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any: should_unsync=self._should_unsync, ): value = _squeeze_if_scalar(compute(*args, **kwargs)) + # clone tensor to avoid in-place operations after compute, altering already computed results + value = apply_to_collection(value, Tensor, lambda x: x.clone()) if self.compute_with_cache: self._computed = value diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index a677c92ddb1..55062ccbe29 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -452,10 +452,29 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target for key in res_cg: assert torch.allclose(res_cg[key], res_without_cg[key]) + # Check if second compute is the same + res_cg2 = m.compute() + for key in res_cg2: + assert torch.allclose(res_cg[key], res_cg2[key]) + if with_reset: m.reset() m2.reset() + # Test if a second compute without a reset is the same + m.reset() + m.update(preds, target) + res_cg = m.compute() + # Simulate different preds by simply inversing them + m.update(1 - preds, target) + res_cg2 = m.compute() + # Now check if the results from the first compute are different from the second + for key in res_cg: + # A different shape is okay, therefore skip (this happens for multidim_average="samplewise") + if res_cg[key].shape != res_cg2[key].shape: + continue + assert not torch.all(res_cg[key] == res_cg2[key]) + @pytest.mark.parametrize("method", ["items", "values", "keys"]) def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method): """Check states are copied instead of passed by ref when a single metric in the collection is access."""