Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect caching of MetricCollection #2571

Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2571](https://github.com/Lightning-AI/torchmetrics/pull/2571))


- Fixed class order of `panoptic_quality(..., return_per_class=True)` output ([#2548](https://github.com/Lightning-AI/torchmetrics/pull/2548))


Expand Down
11 changes: 10 additions & 1 deletion src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def __init__(

self.add_metrics(metrics, *additional_metrics)

@property
def metric_state(self) -> Dict[str, Dict[str, Any]]:
"""Get the current state of the metric."""
return {k: m.metric_state for k, m in self.items(keep_base=False, copy_state=False)}

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Call forward for each metric sequentially.
Expand All @@ -206,6 +211,11 @@ def update(self, *args: Any, **kwargs: Any) -> None:
"""
# Use compute groups if already initialized and checked
if self._groups_checked:
# Delete the cache of all metrics to invalidate the cache and therefore recent compute calls, forcing new
# compute calls to recompute
for k in self.keys(keep_base=True):
mi = getattr(self, str(k))
mi._computed = None
for cg in self._groups.values():
# only update the first member
m0 = getattr(self, cg[0])
Expand Down Expand Up @@ -304,7 +314,6 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
# Determine if we just should set a reference or a full copy
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
mi._computed = deepcopy(m0._computed) if copy else m0._computed
self._state_is_copy = copy

def compute(self) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any:
should_unsync=self._should_unsync,
):
value = _squeeze_if_scalar(compute(*args, **kwargs))
# clone tensor to avoid in-place operations after compute, altering already computed results
value = apply_to_collection(value, Tensor, lambda x: x.clone())

if self.compute_with_cache:
self._computed = value
Expand Down
19 changes: 19 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,29 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target
for key in res_cg:
assert torch.allclose(res_cg[key], res_without_cg[key])

# Check if second compute is the same
res_cg2 = m.compute()
for key in res_cg2:
assert torch.allclose(res_cg[key], res_cg2[key])

if with_reset:
m.reset()
m2.reset()

# Test if a second compute without a reset is the same
m.reset()
m.update(preds, target)
res_cg = m.compute()
# Simulate different preds by simply inversing them
m.update(1 - preds, target)
res_cg2 = m.compute()
# Now check if the results from the first compute are different from the second
for key in res_cg:
# A different shape is okay, therefore skip (this happens for multidim_average="samplewise")
if res_cg[key].shape != res_cg2[key].shape:
continue
assert not torch.all(res_cg[key] == res_cg2[key])

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
"""Check states are copied instead of passed by ref when a single metric in the collection is access."""
Expand Down
Loading