Skip to content

Bug in "dim_zero_cat" reduction "Metric._reduce_states(...)" #2225

Closed
@jankng

Description

@jankng

🐛 Bug

The function Metric._reduce_states(...) is supposed to concat states when told to use "cat", but instead tries to add.

To Reproduce

Try to update metric states initialized with dist_reduce_fx='cat'

First define the metric states in the Metric class init function

# base class for metrics requiring an optimal threshold (like EER, BPCER@APCER1%)
class OptimalThresholdMetrics(Metric):
    # .............

    def __init__(self, iter_max=100, tolerance=0.01) -> None:
        super().__init__()

        # ................

        self.add_state("preds", default=torch.tensor([]), dist_reduce_fx='cat')
        self.add_state("targets", default=torch.tensor([]), dist_reduce_fx='cat')

And then, when trying to update the metric with

metric_output = self.train_metrics[0](preds, targets)

I get some sort of dimensionality error (I don't remember the exact message) because Metric._reduce_states(...) tries to add the states instead of concatenating.

Expected behavior

When passing dist_reduce_fx='cat' metric states are supposed to be concatenated when updated, not added.

Environment

I encountered the bug in torchmetrics v1.2.0 installed via pip, but it can also be found in the current master branch

Additional context

How to fix

In src/torchmetrics/metric.py method Meric._reduce_states(...), it currently says

elif reduce_fn == dim_zero_cat:
    reduced = global_state + local_state

instead it should say something like

elif reduce_fn == dim_zero_cat:
    reduced = torch.cat([global_state, local_state])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions