Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State Changes from List to Tensor After Calling Compute() #2005

Closed
donglihe-hub opened this issue Aug 17, 2023 · 4 comments · Fixed by #2061
Closed

State Changes from List to Tensor After Calling Compute() #2005

donglihe-hub opened this issue Aug 17, 2023 · 4 comments · Fixed by #2061
Assignees
Labels
bug / fix Something isn't working documentation Improvements or additions to documentation help wanted Extra attention is needed v0.10.x
Milestone

Comments

@donglihe-hub
Copy link

donglihe-hub commented Aug 17, 2023

🐛 Bug

I was logging a custom Metric result using the logger provided by LightningModule. The codes runs on 2 GPUs with ddp. While a list state whose dist_reduce_fx is "cat" was synchronized by calling compute() , I realized the behavior in training_step() and validation_step() were different.

The state in training_step() where on_step=True and on_epoch=False (default setting for training) after calling compute() was always a list. However, in validation_step() where on_step=False and on_epoch=True (default setting for validation), the state became a Tensor. Such behavior is not explained in the doc.

To Reproduce

Metric and Logging Results

Codes

import logging

import torch
from torchmetrics import Metric
from torchmetrics.functional.retrieval import retrieval_normalized_dcg


class MyMetric(Metric):
    def __init__(self, top_k):
        super().__init__()
        self.top_k = top_k
        self.add_state("state_1", default=[], dist_reduce_fx="cat")

    def update(self, preds: Tensor, target: Tensor):
        assert preds.shape == target.shape
        self.state_1 += [self._metric(p, t) for p, t in zip(preds, target)]

    def compute(self):
        if isinstance(self.state_1, list):
            logging.warning(f"self.state_1.size(): {len(self.state_1)}")
            return torch.stack(self.state_1).mean()
        logging.warning(f"self.state_1.size(): {self.state_1.size()}")
        return self.state_1.mean()

    def _metric(self, preds: Tensor, target: Tensor):
        return retrieval_normalized_dcg(preds, target, k=self.top_k).float()

Log Outputs
Training:
WARNING:self.state_1.size(): 40

Validation:
WARNING:self.state_1.size(): torch.Size([80])

Expected behavior

The state is always what it is defined in add_state(), i.e., a list.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source):
    0.10.3 (conda-forge)
  • Python & PyTorch Version (e.g., 1.0):
    3.10.12
  • Any other relevant information such as OS (e.g., Linux):
    pytorch: 1.13.1
    pytorch-lightning: 1.9.4
    Linux: 5.15.0-78-generic

Additional context

@donglihe-hub donglihe-hub added bug / fix Something isn't working help wanted Extra attention is needed labels Aug 17, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@donglihe-hub donglihe-hub changed the title State Change from List to Tensor During Synchronization State Change from List to Tensor After Calling Compute() Aug 17, 2023
@donglihe-hub donglihe-hub changed the title State Change from List to Tensor After Calling Compute() State Changes from List to Tensor After Calling Compute() Aug 17, 2023
@SkafteNicki
Copy link
Member

@donglihe-hub, thanks for raising this issue.
I definitely see the problem that you have pointed out, the core problem with changing the behaviour being backwards compatibility. This behaviour has been present since the start of torchmetrics and we have internally really not considered it a problem. The solution in our own metrics are just to call the dim_zero_cat function in compute:

from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
class MyMetric(Metric):
     ...
     def compute(self):
         state_1 = dim_zero_cat(self.state_1)  # the returned value will always be a tensor

which secures that everything is a tensor before doing any computation.

cc: @justusschock opinions on this issue? is it something we should change?

@donglihe-hub
Copy link
Author

donglihe-hub commented Aug 24, 2023

@donglihe-hub, thanks for raising this issue. I definitely see the problem that you have pointed out, the core problem with changing the behaviour being backwards compatibility. This behaviour has been present since the start of torchmetrics and we have internally really not considered it a problem. The solution in our own metrics are just to call the dim_zero_cat function in compute:

from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
class MyMetric(Metric):
     ...
     def compute(self):
         state_1 = dim_zero_cat(self.state_1)  # the returned value will always be a tensor

which secures that everything is a tensor before doing any computation.

cc: @justusschock opinions on this issue? is it something we should change?

Thank you for your answer. I think maybe it should be mentioned in IMPLEMENTING A METRIC? Otherwise users have no way to know that before they actually running their codes and met the "bug".

@Borda Borda added the v0.10.x label Aug 25, 2023
@SkafteNicki
Copy link
Member

@donglihe-hub you are completely right.
That page have not been updated in a long time a definite need an overhaul. Let me try to figure out some more examples and hopefully we can also figure out the issue you have reported in #2022.

@SkafteNicki SkafteNicki added the documentation Improvements or additions to documentation label Aug 26, 2023
@SkafteNicki SkafteNicki added this to the v1.2.0 milestone Aug 26, 2023
@SkafteNicki SkafteNicki self-assigned this Aug 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working documentation Improvements or additions to documentation help wanted Extra attention is needed v0.10.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants