-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
State Changes from List to Tensor After Calling Compute() #2005
Comments
Hi! thanks for your contribution!, great first issue! |
@donglihe-hub, thanks for raising this issue. from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
class MyMetric(Metric):
...
def compute(self):
state_1 = dim_zero_cat(self.state_1) # the returned value will always be a tensor which secures that everything is a tensor before doing any computation. cc: @justusschock opinions on this issue? is it something we should change? |
Thank you for your answer. I think maybe it should be mentioned in IMPLEMENTING A METRIC? Otherwise users have no way to know that before they actually running their codes and met the "bug". |
@donglihe-hub you are completely right. |
🐛 Bug
I was logging a custom Metric result using the logger provided by LightningModule. The codes runs on 2 GPUs with ddp. While a list state whose
dist_reduce_fx
is "cat" was synchronized by callingcompute()
, I realized the behavior intraining_step()
andvalidation_step()
were different.The state in
training_step()
where on_step=True and on_epoch=False (default setting for training) after callingcompute()
was always a list. However, invalidation_step()
where on_step=False and on_epoch=True (default setting for validation), the state became a Tensor. Such behavior is not explained in the doc.To Reproduce
Metric and Logging Results
Codes
Log Outputs
Training:
WARNING:self.state_1.size(): 40
Validation:
WARNING:self.state_1.size(): torch.Size([80])
Expected behavior
The state is always what it is defined in add_state(), i.e., a list.
Environment
conda
,pip
, build from source):0.10.3 (conda-forge)
3.10.12
pytorch: 1.13.1
pytorch-lightning: 1.9.4
Linux: 5.15.0-78-generic
Additional context
The text was updated successfully, but these errors were encountered: