-
Notifications
You must be signed in to change notification settings - Fork 475
Closed
Lightning-AI/pytorch-lightning
#21507Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededv1.2.x
Description
🐛 Bug
Using a ClasswiseWrapper inside a MetricCollection, if I try to log the MetricCollection using LightningModule.log_dict(mymetric), I get an error.
Is this expected behaviour or am I doing something wrong ? (and am I opening an issue in the right repository ?)
To Reproduce
Here is a code that doesn't work:
from lightning import LightningModule, Trainer
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.wrappers import ClasswiseWrapper
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
import torch
class DummyDataset(Dataset):
def __init__(self, num_classes):
y = torch.randint(0, num_classes, (100,))
self.y = torch.Tensor(y)
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
return {"y": self.y[idx]}
class MyModel(LightningModule):
def __init__(self):
super().__init__()
self.classes = ["nothing", "something"]
self.metrics = MetricCollection(
{
"acc": ClasswiseWrapper(
MulticlassAccuracy(num_classes=len(self.classes), average="none"),
labels=self.classes,
)
}
)
def val_dataloader(self):
return DataLoader(DummyDataset(len(self.classes)), batch_size=1000)
def validation_step(self, batch, batch_idx):
preds = F.softmax(torch.rand((100, len(self.classes))), dim=-1)
target = batch["y"]
self.metrics(preds, target)
self.log_dict(self.metrics, on_step=False, on_epoch=True)
model = MyModel()
trainer = Trainer(max_epochs=1)
trainer.validate(model)Here is the error (full stacktrace in attached file):
File [.../lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:445), in _ResultCollection._get_cache(result_metric, on_step)
ValueError: The `.compute()` return of the metric logged as 'acc' must be a tensor. Found {'multiclassaccuracy_nothing': tensor(0.4000), 'multiclassaccuracy_something': tensor(0.4400)}
If other ClasswiseWrapper are added, only the metrics from the first one appear in the error ("Found {...}").
Expected behavior
All metrics generated by the wrapper should be logged as usual (eg "multiclassaccuracy_nothing", "multiclassaccuracy_something", etc).
Environment
- TorchMetrics fe5f46f, installed with pip as an editable project
- Python 3.9.16, PyTorch v2.0.0, lightning v2.0.2
- OS: ubuntu 20.04
Full stacktrace
Metadata
Metadata
Assignees
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededv1.2.x