Skip to content

Commit

Permalink
Fix MetricCollection with repeated compute calls (#2211)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
(cherry picked from commit 58ffb01)
  • Loading branch information
SkafteNicki authored and Borda committed Nov 30, 2023
1 parent 12c2344 commit 52183c3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226))


- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2211](https://github.com/Lightning-AI/torchmetrics/pull/2211))

## [1.2.0] - 2023-09-22

### Added
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
# Determine if we just should set a reference or a full copy
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
mi._computed = deepcopy(m0._computed) if copy else m0._computed
self._state_is_copy = copy

def compute(self) -> Dict[str, Any]:
Expand Down
8 changes: 5 additions & 3 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ class TestComputeGroups:
("prefix_", "_postfix"),
],
)
def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix):
@pytest.mark.parametrize("with_reset", [True, False])
def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix, with_reset):
"""Check that compute groups are formed after initialization and that metrics are correctly computed."""
if isinstance(metrics, MetricCollection):
prefix, postfix = None, None # disable for nested collections
Expand Down Expand Up @@ -445,8 +446,9 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target
for key in res_cg:
assert torch.allclose(res_cg[key], res_without_cg[key])

m.reset()
m2.reset()
if with_reset:
m.reset()
m2.reset()

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
Expand Down

0 comments on commit 52183c3

Please sign in to comment.