Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in "dim_zero_cat" reduction "Metric._reduce_states(...)" #2225

Closed
jankng opened this issue Nov 20, 2023 · 2 comments · Fixed by #2226
Closed

Bug in "dim_zero_cat" reduction "Metric._reduce_states(...)" #2225

jankng opened this issue Nov 20, 2023 · 2 comments · Fixed by #2226
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x

Comments

@jankng
Copy link
Contributor

jankng commented Nov 20, 2023

🐛 Bug

The function Metric._reduce_states(...) is supposed to concat states when told to use "cat", but instead tries to add.

To Reproduce

Try to update metric states initialized with dist_reduce_fx='cat'

First define the metric states in the Metric class init function

# base class for metrics requiring an optimal threshold (like EER, BPCER@APCER1%)
class OptimalThresholdMetrics(Metric):
    # .............

    def __init__(self, iter_max=100, tolerance=0.01) -> None:
        super().__init__()

        # ................

        self.add_state("preds", default=torch.tensor([]), dist_reduce_fx='cat')
        self.add_state("targets", default=torch.tensor([]), dist_reduce_fx='cat')

And then, when trying to update the metric with

metric_output = self.train_metrics[0](preds, targets)

I get some sort of dimensionality error (I don't remember the exact message) because Metric._reduce_states(...) tries to add the states instead of concatenating.

Expected behavior

When passing dist_reduce_fx='cat' metric states are supposed to be concatenated when updated, not added.

Environment

I encountered the bug in torchmetrics v1.2.0 installed via pip, but it can also be found in the current master branch

Additional context

How to fix

In src/torchmetrics/metric.py method Meric._reduce_states(...), it currently says

elif reduce_fn == dim_zero_cat:
    reduced = global_state + local_state

instead it should say something like

elif reduce_fn == dim_zero_cat:
    reduced = torch.cat([global_state, local_state])
@jankng jankng added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 20, 2023
Copy link

Hi! thanks for your contribution!, great first issue!

jankng added a commit to jankng/torchmetrics that referenced this issue Nov 20, 2023
Fixes Lightning-AI#2225 by changing "reduced = global_state + local_state" to "reduced = torch.cat([global_state, local_state]
@jankng jankng mentioned this issue Nov 20, 2023
4 tasks
@jankng
Copy link
Contributor Author

jankng commented Nov 20, 2023

I created a pull request with my fix. I don't know if I followed the guidelines correctly though (i requested a pull into the master branch if that's okay).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants