Skip to content

Commit 23afd25

Browse files
committed
merge
2 parents 24b5e48 + 8b6adc2 commit 23afd25

File tree

8 files changed

+206
-70
lines changed

8 files changed

+206
-70
lines changed

nnunetv2/evaluation/find_best_configuration.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from copy import deepcopy
44
from typing import Union, List, Tuple
55

6-
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json
7-
6+
from batchgenerators.utilities.file_and_folder_operations import (
7+
load_json, join, isdir, listdir, save_json
8+
)
89
from nnunetv2.configuration import default_num_processes
910
from nnunetv2.ensembling.ensemble import ensemble_crossvalidations
1011
from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
@@ -320,6 +321,11 @@ def accumulate_crossval_results_entry_point():
320321
merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}')
321322
else:
322323
merged_output_folder = args.o
324+
if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0:
325+
raise FileExistsError(
326+
f"Output folder {merged_output_folder} exists and is not empty. "
327+
f"To avoid data loss, nnUNet requires an empty output folder."
328+
)
323329

324330
accumulate_cv_results(trained_model_folder, merged_output_folder, args.f)
325331

nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py

Whitespace-only changes.
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from typing import Union, List, Tuple
2+
3+
from nnunetv2.configuration import ANISO_THRESHOLD
4+
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
5+
from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \
6+
nnUNetPlannerResEncL
7+
from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet
8+
9+
10+
class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):
11+
def __init__(self, dataset_name_or_id: Union[str, int],
12+
gpu_memory_target_in_gb: float = 24,
13+
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',
14+
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
15+
suppress_transpose: bool = False):
16+
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
17+
overwrite_target_spacing, suppress_transpose)
18+
19+
def generate_data_identifier(self, configuration_name: str) -> str:
20+
"""
21+
configurations are unique within each plans file but different plans file can have configurations with the
22+
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
23+
config but also the plans it originates from
24+
"""
25+
return self.plans_identifier + '_' + configuration_name
26+
27+
def determine_resampling(self, *args, **kwargs):
28+
"""
29+
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
30+
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
31+
32+
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
33+
configuration
34+
"""
35+
resampling_data = resample_torch_fornnunet
36+
resampling_data_kwargs = {
37+
"is_seg": False,
38+
'force_separate_z': False,
39+
'memefficient_seg_resampling': False
40+
}
41+
resampling_seg = resample_torch_fornnunet
42+
resampling_seg_kwargs = {
43+
"is_seg": True,
44+
'force_separate_z': False,
45+
'memefficient_seg_resampling': False
46+
}
47+
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
48+
49+
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
50+
"""
51+
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
52+
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
53+
54+
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
55+
functions for each configuration
56+
57+
"""
58+
resampling_fn = resample_torch_fornnunet
59+
resampling_fn_kwargs = {
60+
"is_seg": False,
61+
'force_separate_z': False,
62+
'memefficient_seg_resampling': False
63+
}
64+
return resampling_fn, resampling_fn_kwargs
65+
66+
67+
class nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL):
68+
def __init__(self, dataset_name_or_id: Union[str, int],
69+
gpu_memory_target_in_gb: float = 24,
70+
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz',
71+
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
72+
suppress_transpose: bool = False):
73+
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
74+
overwrite_target_spacing, suppress_transpose)
75+
76+
def generate_data_identifier(self, configuration_name: str) -> str:
77+
"""
78+
configurations are unique within each plans file but different plans file can have configurations with the
79+
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
80+
config but also the plans it originates from
81+
"""
82+
return self.plans_identifier + '_' + configuration_name
83+
84+
def determine_resampling(self, *args, **kwargs):
85+
"""
86+
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
87+
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
88+
89+
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
90+
configuration
91+
"""
92+
resampling_data = resample_torch_fornnunet
93+
resampling_data_kwargs = {
94+
"is_seg": False,
95+
'force_separate_z': None,
96+
'memefficient_seg_resampling': False,
97+
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
98+
}
99+
resampling_seg = resample_torch_fornnunet
100+
resampling_seg_kwargs = {
101+
"is_seg": True,
102+
'force_separate_z': None,
103+
'memefficient_seg_resampling': False,
104+
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
105+
}
106+
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
107+
108+
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
109+
"""
110+
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
111+
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
112+
113+
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
114+
functions for each configuration
115+
116+
"""
117+
resampling_fn = resample_torch_fornnunet
118+
resampling_fn_kwargs = {
119+
"is_seg": False,
120+
'force_separate_z': None,
121+
'memefficient_seg_resampling': False,
122+
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
123+
}
124+
return resampling_fn, resampling_fn_kwargs
125+
126+
127+
class nnUNetPlanner_torchres(ExperimentPlanner):
128+
def __init__(self, dataset_name_or_id: Union[str, int],
129+
gpu_memory_target_in_gb: float = 8,
130+
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres',
131+
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
132+
suppress_transpose: bool = False):
133+
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
134+
overwrite_target_spacing, suppress_transpose)
135+
136+
def generate_data_identifier(self, configuration_name: str) -> str:
137+
"""
138+
configurations are unique within each plans file but different plans file can have configurations with the
139+
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
140+
config but also the plans it originates from
141+
"""
142+
return self.plans_identifier + '_' + configuration_name
143+
144+
def determine_resampling(self, *args, **kwargs):
145+
"""
146+
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
147+
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
148+
149+
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
150+
configuration
151+
"""
152+
resampling_data = resample_torch_fornnunet
153+
resampling_data_kwargs = {
154+
"is_seg": False,
155+
'force_separate_z': False,
156+
'memefficient_seg_resampling': False
157+
}
158+
resampling_seg = resample_torch_fornnunet
159+
resampling_seg_kwargs = {
160+
"is_seg": True,
161+
'force_separate_z': False,
162+
'memefficient_seg_resampling': False
163+
}
164+
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
165+
166+
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
167+
"""
168+
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
169+
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
170+
171+
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
172+
functions for each configuration
173+
174+
"""
175+
resampling_fn = resample_torch_fornnunet
176+
resampling_fn_kwargs = {
177+
"is_seg": False,
178+
'force_separate_z': False,
179+
'memefficient_seg_resampling': False
180+
}
181+
return resampling_fn, resampling_fn_kwargs

nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -294,63 +294,6 @@ def __init__(self, dataset_name_or_id: Union[str, int],
294294
self.max_dataset_covered = 1
295295

296296

297-
class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):
298-
def __init__(self, dataset_name_or_id: Union[str, int],
299-
gpu_memory_target_in_gb: float = 24,
300-
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',
301-
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
302-
suppress_transpose: bool = False):
303-
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
304-
overwrite_target_spacing, suppress_transpose)
305-
306-
def generate_data_identifier(self, configuration_name: str) -> str:
307-
"""
308-
configurations are unique within each plans file but different plans file can have configurations with the
309-
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
310-
config but also the plans it originates from
311-
"""
312-
return self.plans_identifier + '_' + configuration_name
313-
314-
def determine_resampling(self, *args, **kwargs):
315-
"""
316-
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
317-
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
318-
319-
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
320-
configuration
321-
"""
322-
resampling_data = resample_torch_fornnunet
323-
resampling_data_kwargs = {
324-
"is_seg": False,
325-
'force_separate_z': False,
326-
'memefficient_seg_resampling': False
327-
}
328-
resampling_seg = resample_torch_fornnunet
329-
resampling_seg_kwargs = {
330-
"is_seg": True,
331-
'force_separate_z': False,
332-
'memefficient_seg_resampling': False
333-
}
334-
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
335-
336-
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
337-
"""
338-
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
339-
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
340-
341-
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
342-
functions for each configuration
343-
344-
"""
345-
resampling_fn = resample_torch_fornnunet
346-
resampling_fn_kwargs = {
347-
"is_seg": False,
348-
'force_separate_z': False,
349-
'memefficient_seg_resampling': False
350-
}
351-
return resampling_fn, resampling_fn_kwargs
352-
353-
354297
if __name__ == '__main__':
355298
# we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively
356299
net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),

nnunetv2/imageio/nibabel_reader_writer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ class NibabelIO(BaseReaderWriter):
3131
supported_file_endings = [
3232
'.nii',
3333
'.nii.gz',
34-
'.nrrd',
35-
'.mha'
3634
]
3735

3836
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
@@ -110,8 +108,6 @@ class NibabelIOWithReorient(BaseReaderWriter):
110108
supported_file_endings = [
111109
'.nii',
112110
'.nii.gz',
113-
'.nrrd',
114-
'.mha'
115111
]
116112

117113
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:

nnunetv2/preprocessing/preprocessors/default_preprocessor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,17 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan
230230
# multiprocessing magic.
231231
r = []
232232
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
233+
remaining = list(range(len(dataset)))
234+
# p is pretty nifti. If we kill workers they just respawn but don't do any work.
235+
# So we need to store the original pool of workers.
236+
workers = [j for j in p._pool]
237+
233238
for k in dataset.keys():
234239
r.append(p.starmap_async(self.run_case_save,
235240
((join(output_directory, k), dataset[k]['images'], dataset[k]['label'],
236241
plans_manager, configuration_manager,
237242
dataset_json),)))
238-
remaining = list(range(len(dataset)))
239-
# p is pretty nifti. If we kill workers they just respawn but don't do any work.
240-
# So we need to store the original pool of workers.
241-
workers = [j for j in p._pool]
243+
242244
with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar:
243245
while len(remaining) > 0:
244246
all_alive = all([j.is_alive() for j in workers])
@@ -251,6 +253,8 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan
251253
'an error message, out of RAM is likely the problem. In that case '
252254
'reducing the number of workers might help')
253255
done = [i for i in remaining if r[i].ready()]
256+
# get done so that errors can be raised
257+
_ = [r[i].get() for i in done]
254258
for _ in done:
255259
r[_].get() # allows triggering errors
256260
pbar.update()

nnunetv2/preprocessing/resampling/default_resampling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
3131
return new_shape
3232

3333

34-
def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
35-
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, int]:
34+
35+
def determine_do_sep_z_and_axis(
36+
force_separate_z: bool,
37+
current_spacing,
38+
new_spacing,
39+
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:
3640
if force_separate_z is not None:
3741
do_separate_z = force_separate_z
3842
if force_separate_z:

nnunetv2/utilities/plans_handling/plans_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __init__(self, configuration_dict: dict):
5454
conv_op = convert_dim_to_conv_op(dim)
5555
instnorm = get_matching_instancenorm(dimension=dim)
5656

57+
convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage"
58+
5759
arch_dict = {
5860
'network_class_name': network_class_name,
5961
'arch_kwargs': {
@@ -64,7 +66,7 @@ def __init__(self, configuration_dict: dict):
6466
"conv_op": conv_op.__module__ + '.' + conv_op.__name__,
6567
"kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]),
6668
"strides": deepcopy(self.configuration["pool_op_kernel_sizes"]),
67-
"n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]),
69+
convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]),
6870
"n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]),
6971
"conv_bias": True,
7072
"norm_op": instnorm.__module__ + '.' + instnorm.__name__,

0 commit comments

Comments
 (0)