Description
🐛 Bug
The function Metric._reduce_states(...) is supposed to concat states when told to use "cat", but instead tries to add.
To Reproduce
Try to update metric states initialized with dist_reduce_fx='cat'
First define the metric states in the Metric class init function
# base class for metrics requiring an optimal threshold (like EER, BPCER@APCER1%)
class OptimalThresholdMetrics(Metric):
# .............
def __init__(self, iter_max=100, tolerance=0.01) -> None:
super().__init__()
# ................
self.add_state("preds", default=torch.tensor([]), dist_reduce_fx='cat')
self.add_state("targets", default=torch.tensor([]), dist_reduce_fx='cat')
And then, when trying to update the metric with
metric_output = self.train_metrics[0](preds, targets)
I get some sort of dimensionality error (I don't remember the exact message) because Metric._reduce_states(...)
tries to add the states instead of concatenating.
Expected behavior
When passing dist_reduce_fx='cat'
metric states are supposed to be concatenated when updated, not added.
Environment
I encountered the bug in torchmetrics v1.2.0 installed via pip, but it can also be found in the current master branch
Additional context
How to fix
In src/torchmetrics/metric.py
method Meric._reduce_states(...)
, it currently says
elif reduce_fn == dim_zero_cat:
reduced = global_state + local_state
instead it should say something like
elif reduce_fn == dim_zero_cat:
reduced = torch.cat([global_state, local_state])