Skip to content

Commit

Permalink
add predict_from_files_sequential to nnUNetPredictor
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Dec 3, 2024
1 parent ac79a61 commit 43349fa
Showing 1 changed file with 106 additions and 34 deletions.
140 changes: 106 additions & 34 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,92 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
return predicted_logits

def predict_from_files_sequential(self,
list_of_lists_or_source_folder: Union[str, List[List[str]]],
output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
save_probabilities: bool = False,
overwrite: bool = True,
folder_with_segs_from_prev_stage: str = None):
"""
Just like predict_from_files but doesn't use any multiprocessing. Slow, but sometimes necessary
"""
if isinstance(output_folder_or_list_of_truncated_output_files, str):
output_folder = output_folder_or_list_of_truncated_output_files
elif isinstance(output_folder_or_list_of_truncated_output_files, list):
output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])
else:
output_folder = None

########################
# let's store the input arguments so that its clear what was used to generate the prediction
if output_folder is not None:
my_init_kwargs = {}
for k in inspect.signature(self.predict_from_files_sequential).parameters.keys():
my_init_kwargs[k] = locals()[k]
my_init_kwargs = deepcopy(
my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a
recursive_fix_for_json_export(my_init_kwargs)
maybe_mkdir_p(output_folder)
save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))

# we need these two if we want to do things with the predictions like for example apply postprocessing
save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
#######################

# check if we need a prediction from the previous stage
if self.configuration_manager.previous_stage_name is not None:
assert folder_with_segs_from_prev_stage is not None, \
f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \
f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \
f' they are located via folder_with_segs_from_prev_stage'

# sort out input and output filenames
list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
self._manage_input_and_output_lists(list_of_lists_or_source_folder,
output_folder_or_list_of_truncated_output_files,
folder_with_segs_from_prev_stage, overwrite, 0, 1,
save_probabilities)
if len(list_of_lists_or_source_folder) == 0:
return

label_manager = self.plans_manager.get_label_manager(self.dataset_json)
preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose)

if output_filename_truncated is None:
output_filename_truncated = [None] * len(list_of_lists_or_source_folder)
if seg_from_prev_stage_files is None:
seg_from_prev_stage_files = [None] * len(seg_from_prev_stage_files)

ret = []
for li, of, sps in zip(list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files):
data, seg, data_properties = preprocessor.run_case(
li,
sps,
self.plans_manager,
self.configuration_manager,
self.dataset_json
)

print(f'perform_everything_on_device: {self.perform_everything_on_device}')

prediction = self.predict_logits_from_preprocessed_data(torch.from_numpy(data)).cpu()

if of is not None:
export_prediction_from_logits(prediction, data_properties, self.configuration_manager, self.plans_manager,
self.dataset_json, of, save_probabilities)
else:
ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(prediction, self.plans_manager,
self.configuration_manager, self.label_manager,
data_properties,
save_probabilities))

# clear lru cache
compute_gaussian.cache_clear()
# clear device cache
empty_cache(self.device)
return ret


def predict_entry_point_modelfolder():
import argparse
Expand Down Expand Up @@ -891,7 +977,7 @@ def predict_entry_point():


if __name__ == '__main__':
# predict a bunch of files
########################## predict a bunch of files
from nnunetv2.paths import nnUNet_results, nnUNet_raw

predictor = nnUNetPredictor(
Expand All @@ -905,42 +991,28 @@ def predict_entry_point():
allow_tqdm=True
)
predictor.initialize_from_trained_model_folder(
join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
join(nnUNet_results, 'Dataset004_Hippocampus/nnUNetTrainer_5epochs__nnUNetPlans__3d_fullres'),
use_folds=(0,),
checkpoint_name='checkpoint_final.pth',
)
predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
save_probabilities=False, overwrite=False,
num_processes_preprocessing=2, num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)

# predict a numpy array
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO

img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
ret = predictor.predict_single_npy_array(img, props, None, None, False)

iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1)
ret = predictor.predict_from_data_iterator(iterator, False, 1)

# predictor = nnUNetPredictor(
# tile_step_size=0.5,
# use_gaussian=True,
# use_mirroring=True,
# perform_everything_on_device=True,
# device=torch.device('cuda', 0),
# verbose=False,
# allow_tqdm=True
# )
# predictor.initialize_from_trained_model_folder(
# join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_cascade_fullres'),
# use_folds=(0,),
# checkpoint_name='checkpoint_final.pth',
# )
# predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
# join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predCascade'),
# join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
# save_probabilities=False, overwrite=False,
# num_processes_preprocessing=2, num_processes_segmentation_export=2,
# folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres',
# num_parts=1, part_id=0)
# folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
#
# # predict a numpy array
# from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
#
# img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
# ret = predictor.predict_single_npy_array(img, props, None, None, False)
#
# iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1)
# ret = predictor.predict_from_data_iterator(iterator, False, 1)

ret = predictor.predict_from_files_sequential(
[['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_002_0000.nii.gz'], ['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_005_0000.nii.gz']],
'/home/isensee/temp/tmp', False, True, None
)


0 comments on commit 43349fa

Please sign in to comment.