Skip to content

Memory Leak in FrechetInceptionDistance if used in training_step #1959

Open
@nicoloesch

Description

@nicoloesch

🐛 Bug

If one uses the FrechetInceptionDistance in the training_step of a LightningModule, one can observe an increase in memory consumption due to the backbone InceptionV3, that is not freed afterwards. In the example below, it is approx. 1.2 GB of memory (or 50% more), which does not get freed.

To Reproduce

Run the following code and monitor the memory consumption with nvidia-smi (not the best monitoring tool but good general direction and is also in accordance with the CUDA OOM Errors encountered).

from torchmetrics.image import FrechetInceptionDistance
from torchmetrics import MetricCollection
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Lambda

class SimpleDataset(Dataset):
    def __init__(self, num_samples) -> None:
        super().__init__()

        self.x = torch.randint(0, 255, (num_samples, 3, 32, 32))
        self.y = torch.randint(0,10, (num_samples,))

        self.transform = Lambda(lambda x: (x / 127.5) - 1)  

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.transform(self.x[idx]), self.y[idx]


class DummyModel(LightningModule):
    def __init__(self,
                    sample_every_n_epoch: int):
        super().__init__()
        self.train_metrics = MetricCollection([FrechetInceptionDistance(feature=2048, normalize=True)])

        self.backbone = nn.Sequential(
            nn.Conv2d(3, 128, 3, 1, 1),
            nn.GroupNorm(32, 128),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.SiLU(),
            nn.GroupNorm(32, 128),
            nn.SiLU(),
            nn.Conv2d(128, 3, 3, 1, 1))
        
        self.criterion = nn.MSELoss()

        self._sample_every_n_epoch = sample_every_n_epoch

    def forward(self, x):
        return self.backbone(x)
    
    def should_sample(self) -> bool:
        return self.trainer is not None and (self.trainer.current_epoch + 1) % self._sample_every_n_epoch == 0 and self.trainer.is_last_batch
    
    def train_dataloader(self):
        dataset = SimpleDataset(num_samples=50000)
        return DataLoader(dataset, batch_size=256, num_workers=0, shuffle=True)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        loss = self.criterion(output, x)
        if self.should_sample():
            self.train_metrics.update(x.add(1).mul_(0.5), real=True)
            self.sample_and_update_metric(x)
            self.log_dict(self.train_metrics, on_epoch=True)
        return loss
    
    def sample_and_update_metric(self, x):
        gen = torch.randn_like(x)
        self.train_metrics.update(gen.add(1).mul_(0.5), real=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(params=self.parameters(), lr=1e-4)
        return {
            "optimizer": optimizer,}

if __name__ == "__main__":
    model = DummyModel(sample_every_n_epoch=2)
    trainer = Trainer(max_epochs=10, logger=False, accelerator="gpu")
    trainer.fit(model)

Expected behavior

No/minimal temporary increase in memory consumption, which is freed as soon as self.inception has finalised its forward step, as the model should be in eval-mode and no activations or any immediate results will be saved but rather the update of the internal state of the FrechetInceptionDistance ({fake,real}_features_cov_sum, etc.). This internal state however is pre-allocated (according to my understanding) in __init__ with add_state and should therefore already have its memory allocated.

Environment

  • TorchMetrics version: 1.0.1
  • Python: 3.10
  • PyTorch: 2.0.1 (with Cuda 11.7)
  • pytorch-lightning: 2.0.5
  • Any other relevant information such as OS (e.g., Linux): Ubuntu 22.04, conda 22.11.1, packages installed with pip

Additional context

I recently changed from sampling in the validation_step to sampling at the end of a training_epoch in the last training_step. Since then, I have observed the increase in memory consumption (CUDA OOM error by using the same batch_size that usually worked).
Furthermore, using torch.no_grad() decorator or context manager and/or manually setting self.inception of the metric to eval() did not change anything.

The absolute weirdest part is that sometimes the memory gets consumed (typically after debugging for a while) and sometimes it does not.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions