Skip to content

ClasswiseWrapper & MetricCollection error with LightningModule.log_dict() #2091

@FrenchKrab

Description

@FrenchKrab

🐛 Bug

Using a ClasswiseWrapper inside a MetricCollection, if I try to log the MetricCollection using LightningModule.log_dict(mymetric), I get an error.
Is this expected behaviour or am I doing something wrong ? (and am I opening an issue in the right repository ?)

To Reproduce

Here is a code that doesn't work:

from lightning import LightningModule, Trainer
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.wrappers import ClasswiseWrapper
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
import torch

class DummyDataset(Dataset):
    def __init__(self, num_classes):
        y = torch.randint(0, num_classes, (100,))
        self.y = torch.Tensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {"y": self.y[idx]}


class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.classes = ["nothing", "something"]
        self.metrics = MetricCollection(
            {
                "acc": ClasswiseWrapper(
                    MulticlassAccuracy(num_classes=len(self.classes), average="none"),
                    labels=self.classes,
                )
            }
        )

    def val_dataloader(self):
        return DataLoader(DummyDataset(len(self.classes)), batch_size=1000)

    def validation_step(self, batch, batch_idx):
        preds = F.softmax(torch.rand((100, len(self.classes))), dim=-1)
        target = batch["y"]
        self.metrics(preds, target)
        self.log_dict(self.metrics, on_step=False, on_epoch=True)

model = MyModel()
trainer = Trainer(max_epochs=1)
trainer.validate(model)

Here is the error (full stacktrace in attached file):

File [.../lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:445), in _ResultCollection._get_cache(result_metric, on_step)

ValueError: The `.compute()` return of the metric logged as 'acc' must be a tensor. Found {'multiclassaccuracy_nothing': tensor(0.4000), 'multiclassaccuracy_something': tensor(0.4400)}

If other ClasswiseWrapper are added, only the metrics from the first one appear in the error ("Found {...}").

Expected behavior

All metrics generated by the wrapper should be logged as usual (eg "multiclassaccuracy_nothing", "multiclassaccuracy_something", etc).

Environment

  • TorchMetrics fe5f46f, installed with pip as an editable project
  • Python 3.9.16, PyTorch v2.0.0, lightning v2.0.2
  • OS: ubuntu 20.04

Full stacktrace

torchmetrics_classwisewrapper_error.txt

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions