Skip to content

Commit

Permalink
Adding new parameter for num_processes of val_dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Mar 29, 2024
1 parent c7f85b7 commit 8df9840
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 79 deletions.
4 changes: 0 additions & 4 deletions nnunetv2/configuration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os

from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA

default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])

ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
# resolution axis must be 3x as large as the next largest spacing)

default_n_proc_DA = get_allowed_n_proc_DA()
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme
from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
Expand Down Expand Up @@ -100,14 +99,10 @@ def static_estimate_VRAM_usage(patch_size: Tuple[int],
"""
Works for PlainConvUNet, ResidualEncoderUNet
"""
a = torch.get_num_threads()
torch.set_num_threads(get_allowed_n_proc_DA())
# print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}')
net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels,
output_channels,
allow_init=False)
ret = net.compute_conv_feature_map_size(patch_size)
torch.set_num_threads(a)
return ret

def determine_resampling(self, *args, **kwargs):
Expand Down
19 changes: 13 additions & 6 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.collate_outputs import collate_outputs
from nnunetv2.utilities.crossval_split import generate_crossval_split
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA, get_allowed_n_proc_DA_val
from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.helpers import empty_cache, dummy_context
Expand Down Expand Up @@ -635,16 +635,23 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)

def init_dataloaders(self, dl_tr, tr_transforms, dl_val, val_transforms):
num_processes_train = get_allowed_n_proc_DA()
if num_processes_train == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
num_processes=allowed_num_processes, num_cached=6, seeds=None,
num_processes=num_processes_train, num_cached=6, seeds=None,
pin_memory=self.device.type == 'cuda', wait_time=0.02)

num_processes_val = get_allowed_n_proc_DA_val()
if num_processes_val == 0:
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
transform=val_transforms, num_processes=num_processes_val,
num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
wait_time=0.02)
return mt_gen_train, mt_gen_val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import torch
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \
GammaTransform
Expand All @@ -21,15 +20,12 @@
ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
DownsampleSegForDSTransform2
from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
LimitedLenWrapper
from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
ConvertSegmentationToRegionsTransform
from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
Convert2DTo3DTransform
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA


class nnUNetTrainerDA5(nnUNetTrainer):
Expand Down Expand Up @@ -338,17 +334,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


def _brightnessadditive_localgamma_transform_scale(x, y):
Expand Down Expand Up @@ -399,17 +385,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter

from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
LimitedLenWrapper
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA


class nnUNetTrainerDAOrd0(nnUNetTrainer):
Expand Down Expand Up @@ -42,17 +37,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


class nnUNetTrainer_DASegOrd0(nnUNetTrainer):
Expand Down Expand Up @@ -91,17 +76,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer):
Expand Down Expand Up @@ -144,14 +119,4 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)
10 changes: 10 additions & 0 deletions nnunetv2/utilities/default_n_proc_DA.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
import os


def get_allowed_n_proc_DA_val():
"""
This function is used to set the number of processes used for the validation data loader. When nnUNet_n_proc_DA_val
is 0, the validation data is loaded sequentially in the main process.
"""
if 'nnUNet_n_proc_DA_val' in os.environ.keys():
return int(os.environ['nnUNet_n_proc_DA_val'])
return get_allowed_n_proc_DA() // 2


def get_allowed_n_proc_DA():
"""
This function is used to set the number of processes used on different Systems. It is specific to our cluster
Expand Down

0 comments on commit 8df9840

Please sign in to comment.