Description
🐛 Bug
If one uses the FrechetInceptionDistance
in the training_step
of a LightningModule
, one can observe an increase in memory consumption due to the backbone InceptionV3
, that is not freed afterwards. In the example below, it is approx. 1.2 GB of memory (or 50% more), which does not get freed.
To Reproduce
Run the following code and monitor the memory consumption with nvidia-smi
(not the best monitoring tool but good general direction and is also in accordance with the CUDA OOM Errors encountered).
from torchmetrics.image import FrechetInceptionDistance
from torchmetrics import MetricCollection
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Lambda
class SimpleDataset(Dataset):
def __init__(self, num_samples) -> None:
super().__init__()
self.x = torch.randint(0, 255, (num_samples, 3, 32, 32))
self.y = torch.randint(0,10, (num_samples,))
self.transform = Lambda(lambda x: (x / 127.5) - 1)
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.transform(self.x[idx]), self.y[idx]
class DummyModel(LightningModule):
def __init__(self,
sample_every_n_epoch: int):
super().__init__()
self.train_metrics = MetricCollection([FrechetInceptionDistance(feature=2048, normalize=True)])
self.backbone = nn.Sequential(
nn.Conv2d(3, 128, 3, 1, 1),
nn.GroupNorm(32, 128),
nn.Conv2d(128, 128, 3, 1, 1),
nn.SiLU(),
nn.GroupNorm(32, 128),
nn.SiLU(),
nn.Conv2d(128, 3, 3, 1, 1))
self.criterion = nn.MSELoss()
self._sample_every_n_epoch = sample_every_n_epoch
def forward(self, x):
return self.backbone(x)
def should_sample(self) -> bool:
return self.trainer is not None and (self.trainer.current_epoch + 1) % self._sample_every_n_epoch == 0 and self.trainer.is_last_batch
def train_dataloader(self):
dataset = SimpleDataset(num_samples=50000)
return DataLoader(dataset, batch_size=256, num_workers=0, shuffle=True)
def training_step(self, batch, batch_idx):
x, y = batch
output = self(x)
loss = self.criterion(output, x)
if self.should_sample():
self.train_metrics.update(x.add(1).mul_(0.5), real=True)
self.sample_and_update_metric(x)
self.log_dict(self.train_metrics, on_epoch=True)
return loss
def sample_and_update_metric(self, x):
gen = torch.randn_like(x)
self.train_metrics.update(gen.add(1).mul_(0.5), real=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(params=self.parameters(), lr=1e-4)
return {
"optimizer": optimizer,}
if __name__ == "__main__":
model = DummyModel(sample_every_n_epoch=2)
trainer = Trainer(max_epochs=10, logger=False, accelerator="gpu")
trainer.fit(model)
Expected behavior
No/minimal temporary increase in memory consumption, which is freed as soon as self.inception
has finalised its forward
step, as the model should be in eval
-mode and no activations or any immediate results will be saved but rather the update of the internal state of the FrechetInceptionDistance
({fake,real}_features_cov_sum
, etc.). This internal state however is pre-allocated (according to my understanding) in __init__
with add_state
and should therefore already have its memory allocated.
Environment
- TorchMetrics version: 1.0.1
- Python: 3.10
- PyTorch: 2.0.1 (with Cuda 11.7)
- pytorch-lightning: 2.0.5
- Any other relevant information such as OS (e.g., Linux): Ubuntu 22.04, conda 22.11.1, packages installed with pip
Additional context
I recently changed from sampling in the validation_step
to sampling at the end of a training_epoch
in the last training_step
. Since then, I have observed the increase in memory consumption (CUDA OOM error by using the same batch_size
that usually worked).
Furthermore, using torch.no_grad()
decorator or context manager and/or manually setting self.inception
of the metric to eval()
did not change anything.
The absolute weirdest part is that sometimes the memory gets consumed (typically after debugging for a while) and sometimes it does not.