diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index f574320dc44..c968208faff 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -409,7 +409,7 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: elif reduce_fn == dim_zero_min: reduced = torch.min(global_state, local_state) elif reduce_fn == dim_zero_cat: - reduced = global_state + local_state + reduced = torch.cat([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, Tensor): reduced = torch.stack([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, list):