Skip to content

Commit

Permalink
Fix "dim_zero_cat" reduction Lightning-AI#2225
Browse files Browse the repository at this point in the history
Fixes Lightning-AI#2225 by changing "reduced = global_state + local_state" to "reduced = torch.cat([global_state, local_state]
  • Loading branch information
jankng authored Nov 20, 2023
1 parent 894de4c commit 960c951
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
elif reduce_fn == dim_zero_min:
reduced = torch.min(global_state, local_state)
elif reduce_fn == dim_zero_cat:
reduced = global_state + local_state
reduced = torch.cat([global_state, local_state])
elif reduce_fn is None and isinstance(global_state, Tensor):
reduced = torch.stack([global_state, local_state])
elif reduce_fn is None and isinstance(global_state, list):
Expand Down

0 comments on commit 960c951

Please sign in to comment.