Skip to content

GPU out of memory when using SupervisedTrainer in a loop #3423

Closed
@tom-osika

Description

@tom-osika

Describe the bug
Constructing and running a SupervisedTrainer in a loop eventually leads to GPU out of memory. See below example

To Reproduce

import segmentation_models_pytorch as smp
import torch
from torch import optim, nn
from monai.engines import SupervisedTrainer
from monai.data import DataLoader, ArrayDataset
import gc


NETWORK_INPUT_SHAPE = (1, 256, 256)
NUM_IMAGES = 50

def get_xy():
    xs = [256 * torch.rand(NETWORK_INPUT_SHAPE) for _ in range(NUM_IMAGES)]
    ys = [torch.rand(NETWORK_INPUT_SHAPE) for _ in range(NUM_IMAGES)]
    return xs, ys


def get_data_loader():
    x, y = get_xy()
    dataset = ArrayDataset(x, seg=y)
    loader = DataLoader(dataset, batch_size=16)
    return loader


def get_model():
    return smp.Unet(
        encoder_weights="imagenet", in_channels=1, classes=2, activation=None
    )

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = get_data_loader()
    model = get_model()

    for i in range(50):
        print(f"On iteration {i}")

        model.to(device)
        optimizer = optim.Adam(model.parameters())

        trainer = SupervisedTrainer(
            device=device,
            max_epochs=10,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=nn.CrossEntropyLoss(),
            prepare_batch=lambda batchdata, device, non_blocking: (
                batchdata[0].to(device),
                batchdata[1].squeeze(1).to(device, dtype=torch.long),
            ),
        )

        trainer.run()
        # gc.collect()

Around the 4th iteration or so, I get RuntimeError: CUDA out of memory. If this doesn't happen for anyone trying out the example, try increasing the NUM_IMAGES variable or the number of iterations of the loop. I know that there are a few common causes for out of memory issues in pytorch, outlined here, but I can't really find where I'm doing any of these things. I've tried calling del trainer and moving the initialization of the model inside the loop and deleting it afterwards. Calling gc.collect() works, which makes me think that there is some kind of circular reference holding up the garbage collection. I'm not convinced that this isn't user error, though.

Environment
ubuntu 18.04, python 3.8

Ensuring you use the relevant python executable, please paste the output of:

python -c 'import monai; monai.config.print_debug_info()'

================================
Printing MONAI config...

MONAI version: 0.8.0
Numpy version: 1.21.2
Pytorch version: 1.10.0+cu102
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 714d00d

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.18.3
Pillow version: 8.4.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.10.1+cu102
tqdm version: 4.62.3
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: 1.3.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...

psutil required for print_system_info

================================
Printing GPU config...

Num GPUs: 1
Has CUDA: True
CUDA version: 10.2
cuDNN enabled: True
cuDNN version: 7605
Current device: 0
Library compiled for CUDA architectures: ['sm_37', 'sm_50', 'sm_60', 'sm_70']
GPU 0 Name: Quadro T2000
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 16
GPU 0 Total memory (GB): 3.8
GPU 0 CUDA capability (maj.min): 7.5

Additional context
Originally used for k-fold cross validation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions