Skip to content

Commit 40bbe29

Browse files
author
Lukas Arts
committed
Fixed validation for 1d signals and added customizable dice calculation
1 parent c4ce1eb commit 40bbe29

File tree

13 files changed

+175
-32
lines changed

13 files changed

+175
-32
lines changed

nnUNet_results

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/media/lukas/f476eab7-6c09-4db6-bfcd-2922b3c3502b/UU/ASRA/Segmentation/nnUNet_results

nnunetv2/evaluation/evaluate_predictions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def compute_metrics(reference_file: str, prediction_file: str, image_reader_writ
9393
seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file)
9494
seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file)
9595

96+
print(reference_file, prediction_file)
97+
9698
ignore_mask = seg_ref == ignore_label if ignore_label is not None else None
9799

98100
results = {}

nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(self, dataset_name_or_id: Union[str, int],
2727
# much as possible
2828
self.UNet_reference_val_3d = 680000000
2929
self.UNet_reference_val_2d = 135000000
30+
self.UNet_reference_val_1d = 135000000
3031
self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)
3132
self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
3233

@@ -36,7 +37,7 @@ def generate_data_identifier(self, configuration_name: str) -> str:
3637
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
3738
config but also the plans it originates from
3839
"""
39-
if configuration_name == '2d' or configuration_name == '3d_fullres':
40+
if configuration_name == '1d' or configuration_name == '2d' or configuration_name == '3d_fullres':
4041
# we do not deviate from ExperimentPlanner so we can reuse its data
4142
return 'nnUNetPlans' + '_' + configuration_name
4243
else:
@@ -76,6 +77,9 @@ def _keygen(patch_size, strides):
7677
initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]
7778
elif len(spacing) == 2:
7879
initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]
80+
elif len(spacing) == 1:
81+
#initial patch size for 1d signals is the entire signal
82+
initial_patch_size = [round(median_shape[0])]
7983
else:
8084
raise RuntimeError()
8185

@@ -129,7 +133,7 @@ def _keygen(patch_size, strides):
129133

130134
# how large is the reference for us here (batch size etc)?
131135
# adapt for our vram target
132-
reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \
136+
reference = (self.UNet_reference_val_1d if len(spacing) == 1 else (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d)) * \
133137
(self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)
134138

135139
while estimate > reference:
@@ -183,7 +187,7 @@ def _keygen(patch_size, strides):
183187

184188
# alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was
185189
# executed. If not, additional vram headroom is used to increase batch size
186-
ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d
190+
ref_bs = self.UNet_reference_val_corresp_bs_1d if len(spacing) == 1 else (self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d)
187191
batch_size = round((reference / estimate) * ref_bs)
188192

189193
# we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot
@@ -241,6 +245,7 @@ def __init__(self, dataset_name_or_id: Union[str, int],
241245
# this is supposed to give the same GPU memory requirement as the default nnU-Net
242246
self.UNet_reference_val_3d = 680000000
243247
self.UNet_reference_val_2d = 135000000
248+
self.UNet_reference_val_1d = 135000000
244249
self.max_dataset_covered = 1
245250

246251

nnunetv2/imageio/numpy_reader_writer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@ class NumpyIO(BaseReaderWriter):
2727
'.npy'
2828
]
2929

30-
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
30+
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]], annotations=False) -> Tuple[np.ndarray, dict]:
3131
images = []
3232
for f in image_fnames:
3333
npy_img = np.load(f)
3434
assert npy_img.ndim == 1 or npy_img.ndim == 2, "Only 1D timeseries with one or more channels supported"
3535
if npy_img.ndim == 2:
3636
# channel to front, add additional dim so that we have shape (c, 1, 1, X)
37-
images.append(npy_img.transpose((1, 0))[:, None, None])
37+
if annotations:
38+
images.append(npy_img[None, None, :])
39+
else:
40+
images.append(npy_img.transpose((1, 0))[:, None, None])
3841
elif npy_img.ndim == 1:
3942
# grayscale image
4043
images.append(npy_img[None, None, None])
@@ -49,7 +52,7 @@ def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[
4952
return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': (999, 999, 1)}
5053

5154
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
52-
return self.read_images((seg_fname, ))
55+
return self.read_images((seg_fname, ), annotations=True)
5356

5457
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
5558
np.save(output_fname, seg[0].astype(np.uint8, copy=False))

nnunetv2/inference/predict_from_raw_data.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,23 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
502502

503503
def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
504504
slicers = []
505-
if len(self.configuration_manager.patch_size) < len(image_size):
505+
dim = len(self.configuration_manager.patch_size)
506+
507+
if dim == 1:
508+
steps = compute_steps_for_sliding_window(image_size[2:], self.configuration_manager.patch_size,
509+
self.tile_step_size)
510+
511+
if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is'
512+
f' {image_size}, tile_size {self.configuration_manager.patch_size}, '
513+
f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}')
514+
515+
for d in range(image_size[0]):
516+
for sx in steps[0]:
517+
slicers.append(
518+
tuple([slice(None), d, 0, slice(sx, sx + self.configuration_manager.patch_size[0])]))
519+
520+
elif dim == 2:
521+
#if len(self.configuration_manager.patch_size) < len(image_size):
506522
assert len(self.configuration_manager.patch_size) == len(
507523
image_size) - 1, 'if tile_size has less entries than image_size, ' \
508524
'len(tile_size) ' \
@@ -520,7 +536,7 @@ def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
520536
slicers.append(
521537
tuple([slice(None), d, *[slice(si, si + ti) for si, ti in
522538
zip((sx, sy), self.configuration_manager.patch_size)]]))
523-
else:
539+
elif dim == 3:
524540
steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size,
525541
self.tile_step_size)
526542
if self.verbose: print(
@@ -532,6 +548,10 @@ def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
532548
slicers.append(
533549
tuple([slice(None), *[slice(si, si + ti) for si, ti in
534550
zip((sx, sy, sz), self.configuration_manager.patch_size)]]))
551+
552+
else:
553+
raise NotImplementedError('This function only supports 1D, 2D and 3D images')
554+
535555
return slicers
536556

537557
def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:

nnunetv2/paths.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP
1919
"""
2020

21-
nnUNet_raw = "/Users/lukasarts/Dropbox/UU/ASRA/nnUNet/nnUNet_raw"
22-
nnUNet_preprocessed = "/Users/lukasarts/Dropbox/UU/ASRA/nnUNet/nnUNet_preprocessed"
23-
nnUNet_results = "/Users/lukasarts/Dropbox/UU/ASRA/nnUNet/nnUNet_results"
21+
nnUNet_raw = "/home/lukas/UU/ASRA/Analysis/nnUNet/nnUNet_raw"
22+
nnUNet_preprocessed = "/home/lukas/UU/ASRA/Analysis/nnUNet/nnUNet_preprocessed"
23+
nnUNet_results = "/home/lukas/UU/ASRA/Analysis/nnUNet/nnUNet_results"
2424

2525
if nnUNet_raw is None:
2626
print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files "

nnunetv2/run/run_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,4 +284,4 @@ def run_training_entry():
284284
# multiprocessing.set_start_method("spawn")
285285
#run_training_entry()
286286

287-
run_training('11', '1d', 0, 'nnUNetTrainer', 'nnUNetPlans', None, 1, False, False, False, False, False, False, device=torch.device('cpu'))
287+
run_training('12', '1d', 1, 'nnUNetTrainer', 'nnUNetPlans', None, 1, False, False, False, False, False, False, device=torch.device('cuda'))

nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic
147147
### Some hyperparameters for you to fiddle with
148148
self.initial_lr = 1e-2
149149
self.weight_decay = 3e-5
150-
self.oversample_foreground_percent = 0.33
151-
self.num_iterations_per_epoch = 250
152-
self.num_val_iterations_per_epoch = 50
153-
self.num_epochs = 1000
150+
self.oversample_foreground_percent = 0
151+
self.num_iterations_per_epoch = 100
152+
self.num_val_iterations_per_epoch = 20
153+
self.num_epochs = 64
154154
self.current_epoch = 0
155155
self.enable_deep_supervision = True
156156

@@ -203,6 +203,9 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic
203203
"#######################################################################\n",
204204
also_print_to_console=True, add_timestamp=False)
205205

206+
def count_trainable_params(self, model) -> int:
207+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
208+
206209
def initialize(self):
207210
if not self.was_initialized:
208211
self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
@@ -216,10 +219,14 @@ def initialize(self):
216219
self.label_manager.num_segmentation_heads,
217220
self.enable_deep_supervision
218221
).to(self.device)
222+
223+
print("NUM PARAMS:", self.count_trainable_params(self.network))
224+
219225
# compile network for free speedup
220-
if self._do_i_compile():
221-
self.print_to_log_file('Using torch.compile...')
222-
self.network = torch.compile(self.network)
226+
# gives errors when compiling network
227+
# if self._do_i_compile():
228+
# self.print_to_log_file('Using torch.compile...')
229+
# self.network = torch.compile(self.network)
223230

224231
self.optimizer, self.lr_scheduler = self.configure_optimizers()
225232
# if ddp, wrap in DDP wrapper
@@ -959,8 +966,8 @@ def on_train_start(self):
959966

960967
self._save_debug_information()
961968

962-
# print(f"batch size: {self.batch_size}")
963-
# print(f"oversample: {self.oversample_foreground_percent}")
969+
print(f"batch size: {self.batch_size}")
970+
print(f"oversample: {self.oversample_foreground_percent}")
964971

965972
def on_train_end(self):
966973
# dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards.
@@ -1028,6 +1035,7 @@ def train_step(self, batch: dict) -> dict:
10281035
l.backward()
10291036
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
10301037
self.optimizer.step()
1038+
10311039
return {'loss': l.detach().cpu().numpy()}
10321040

10331041
def on_train_epoch_end(self, train_outputs: List[dict]):
@@ -1059,6 +1067,7 @@ def validation_step(self, batch: dict) -> dict:
10591067
# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
10601068
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
10611069
# So autocast will only be active if we have a cuda device.
1070+
10621071
with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
10631072
output = self.network(data)
10641073
del data
@@ -1101,14 +1110,17 @@ def validation_step(self, batch: dict) -> dict:
11011110
tp_hard = tp.detach().cpu().numpy()
11021111
fp_hard = fp.detach().cpu().numpy()
11031112
fn_hard = fn.detach().cpu().numpy()
1104-
if not self.label_manager.has_regions:
1105-
# if we train with regions all segmentation heads predict some kind of foreground. In conventional
1106-
# (softmax training) there needs tobe one output for the background. We are not interested in the
1107-
# background Dice
1108-
# [1:] in order to remove background
1109-
tp_hard = tp_hard[1:]
1110-
fp_hard = fp_hard[1:]
1111-
fn_hard = fn_hard[1:]
1113+
1114+
# we now handle the removal of the background dice in the labelmanager
1115+
1116+
# if not self.label_manager.has_regions:
1117+
# # if we train with regions all segmentation heads predict some kind of foreground. In conventional
1118+
# # (softmax training) there needs tobe one output for the background. We are not interested in the
1119+
# # background Dice
1120+
# # [1:] in order to remove background
1121+
# tp_hard = tp_hard[1:]
1122+
# fp_hard = fp_hard[1:]
1123+
# fn_hard = fn_hard[1:]
11121124

11131125
return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
11141126

@@ -1118,6 +1130,10 @@ def on_validation_epoch_end(self, val_outputs: List[dict]):
11181130
fp = np.sum(outputs_collated['fp_hard'], 0)
11191131
fn = np.sum(outputs_collated['fn_hard'], 0)
11201132

1133+
tp = tp[self.label_manager._get_indices_to_calc_dice()]
1134+
fp = fp[self.label_manager._get_indices_to_calc_dice()]
1135+
fn = fn[self.label_manager._get_indices_to_calc_dice()]
1136+
11211137
if self.is_ddp:
11221138
world_size = dist.get_world_size()
11231139

nnunetv2/utilities/generate_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def simulate_ecg(length, fs):
3333

3434
def generate_dataset(n_files=10, length=2048, fs=200, folder=""):
3535
dataset = {}
36-
basedir = '/Users/lukasarts/Dropbox/UU/ASRA/nnUNet/nnUNet_raw/'
36+
basedir = 'nnUNet_raw/'
3737
if folder != "" and not os.path.exists(os.path.join(basedir, folder)):
3838
os.makedirs(os.path.join(basedir, folder))
3939

@@ -52,7 +52,7 @@ def generate_dataset(n_files=10, length=2048, fs=200, folder=""):
5252
n_files = 25
5353
length = 10
5454
fs = 200
55-
folder = 'Dataset0011_test'
55+
folder = 'Dataset011_test'
5656
generate_dataset(n_files, length, fs, folder=folder)
5757

5858

nnunetv2/utilities/label_handling/label_handling.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919

2020

2121
class LabelManager(object):
22-
def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], force_use_labels: bool = False,
22+
def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], use_for_validation: Union[dict, None], force_use_labels: bool = False,
2323
inference_nonlin=None):
2424
self._sanity_check(label_dict)
2525
self.label_dict = label_dict
26+
self.use_for_validation = use_for_validation
2627
self.regions_class_order = regions_class_order
2728
self._force_use_labels = force_use_labels
2829

@@ -74,6 +75,13 @@ def _get_all_labels(self) -> List[int]:
7475
all_labels.sort()
7576
return all_labels
7677

78+
def _get_indices_to_calc_dice(self) -> List[int]:
79+
indices = []
80+
for l, b in self.use_for_validation.items():
81+
if b:
82+
indices.append(self.label_dict[l])
83+
return indices
84+
7785
def _get_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]:
7886
if not self._has_regions or self._force_use_labels:
7987
return None

0 commit comments

Comments
 (0)