From 960c95145843a9c232add836306c89a5697a0d1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20K=C3=B6nig?= Date: Mon, 20 Nov 2023 10:47:55 +0100 Subject: [PATCH] Fix "dim_zero_cat" reduction #2225 Fixes #2225 by changing "reduced = global_state + local_state" to "reduced = torch.cat([global_state, local_state] --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index f574320dc44..c968208faff 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -409,7 +409,7 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: elif reduce_fn == dim_zero_min: reduced = torch.min(global_state, local_state) elif reduce_fn == dim_zero_cat: - reduced = global_state + local_state + reduced = torch.cat([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, Tensor): reduced = torch.stack([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, list):