Skip to content

Commit

Permalink
Fix dim_zero_cat reduction (#2226)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
(cherry picked from commit 9652899)
  • Loading branch information
jankng authored and Borda committed Dec 1, 2023
1 parent b76911d commit c665da1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234))


- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226))


## [1.2.0] - 2023-09-22

### Added
Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
elif reduce_fn == dim_zero_min:
reduced = torch.min(global_state, local_state)
elif reduce_fn == dim_zero_cat:
reduced = global_state + local_state
if isinstance(global_state, Tensor):
reduced = torch.cat([global_state, local_state])
else:
reduced = global_state + local_state
elif reduce_fn is None and isinstance(global_state, Tensor):
reduced = torch.stack([global_state, local_state])
elif reduce_fn is None and isinstance(global_state, list):
Expand Down

0 comments on commit c665da1

Please sign in to comment.