Skip to content

Commit

Permalink
inference entry points, train entry point now uses nolightning version
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Aug 5, 2022
1 parent 277c0ff commit 9fc338f
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 9 deletions.
142 changes: 139 additions & 3 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.transforms.utility_transforms import NumpyToTensor
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir
from nnunetv2.configuration import default_num_processes
from nnunetv2.imageio.reader_writer_registry import recursive_find_reader_writer_by_name
from nnunetv2.inference.export_prediction import export_prediction
from nnunetv2.inference.sliding_window_prediction import predict_sliding_window_return_logits, compute_gaussian
from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor
from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name
from nnunetv2.preprocessing.utils import get_preprocessor_class_from_plans
from nnunetv2.utilities.file_path_utilities import get_output_folder
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.helpers import softmax_helper_dim0
from nnunetv2.utilities.label_handling import determine_num_input_channels, LabelManager, convert_labelmap_to_one_hot
Expand Down Expand Up @@ -247,6 +248,143 @@ def predict_from_raw_data(list_of_lists_or_source_folder: Union[str, List[List[s
export_pool.join()


def predict_entry_point_modelfolder():
import argparse
parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
'you want to manually specify a folder containing a trained nnU-Net '
'model. This is useful when the nnunet environment variables '
'(nnUNet_results) are not set.')
parser.add_argument('-i', type=str, required=True,
help='input folder. Remember to use the correct suffixes for your files (_0000 etc). '
'File endings must be the same as the training dataset!')
parser.add_argument('-o', type=str, required=True,
help='Output folder. If it does not exist it will be created. Predicted segmentations will '
'have the same name as their source images.')
parser.add_argument('-m', type=str, required=True,
help='Folder in which the trained model is. Must have subfolders fold_X for the different '
'folds you trained')
parser.add_argument('-f', nargs='+', type=int, required=False, default=(0, 1, 2, 3, 4),
help='Specify the folds of the trained model that should be used for prediction. '
'Default: (0, 1, 2, 3, 4)')
parser.add_argument('-step_size', type=float, required=False, default=0.5,
help='Step size for sliding window prediction. The larger it is the faster but less accurate '
'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
'but less accurate inference. Not recommended.')
# todo all in gpu as default
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
"to be a good listener/reader.")
parser.add_argument('--save_probabilities', action='store_true',
help='Set this to export predicted class "probabilities". Required if you want to ensemble '
'multiple configurations.')
parser.add_argument('--continue_prediction', '--c', action='store_true',
help='Continue an aborted previous prediction (will not overwrite existing files)')
parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
parser.add_argument('-npp', type=int, required=False, default=3,
help='Number of processes used for preprocessing. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-nps', type=int, required=False, default=3,
help='Number of processes used for segmentation export. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
help='Folder containing the predictions of the previous stage. Required for cascaded models.')
args = parser.parse_args()

if not isdir(args.o):
maybe_mkdir_p(args.o)

predict_from_raw_data(args.i,
args.o,
args.m,
args.f,
args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_gpu=False,
verbose=args.verbose,
save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
checkpoint_name=args.chk,
num_processes_preprocessing=args.npp,
num_processes_segmentation_export=args.nps,
folder_with_segs_from_prev_stage=args.prev_stage_predictions)


def predict_entry_point():
import argparse
parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
'you want to manually specify a folder containing a trained nnU-Net '
'model. This is useful when the nnunet environment variables '
'(nnUNet_results) are not set.')
parser.add_argument('-i', type=str, required=True,
help='input folder. Remember to use the correct suffixes for your files (_0000 etc). '
'File endings must be the same as the training dataset!')
parser.add_argument('-o', type=str, required=True,
help='Output folder. If it does not exist it will be created. Predicted segmentations will '
'have the same name as their source images.')
parser.add_argument('-d', type=str, required=True,
help='Dataset with which you would like to predict. You can specify either dataset name or id')
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
help='Plans identifier. Specify the plans in which the desired configuration is located. '
'Default: nnUNetPlans')
parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer')
parser.add_argument('-c', type=str, required=True,
help='nnU-Net configuration that should be used for prediction. Config must be located '
'in the plans specified with -p')
parser.add_argument('-f', nargs='+', type=int, required=False, default=(0, 1, 2, 3, 4),
help='Specify the folds of the trained model that should be used for prediction. '
'Default: (0, 1, 2, 3, 4)')
parser.add_argument('-step_size', type=float, required=False, default=0.5,
help='Step size for sliding window prediction. The larger it is the faster but less accurate '
'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
'but less accurate inference. Not recommended.')
# todo all in gpu as default
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
"to be a good listener/reader.")
parser.add_argument('--save_probabilities', action='store_true',
help='Set this to export predicted class "probabilities". Required if you want to ensemble '
'multiple configurations.')
parser.add_argument('--continue_prediction', action='store_true',
help='Continue an aborted previous prediction (will not overwrite existing files)')
parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
parser.add_argument('-npp', type=int, required=False, default=3,
help='Number of processes used for preprocessing. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-nps', type=int, required=False, default=3,
help='Number of processes used for segmentation export. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
help='Folder containing the predictions of the previous stage. Required for cascaded models.')
args = parser.parse_args()

model_folder = get_output_folder(args.d, args.tr, args.p, args.c)

if not isdir(args.o):
maybe_mkdir_p(args.o)

predict_from_raw_data(args.i,
args.o,
model_folder,
args.f,
args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_gpu=False,
verbose=args.verbose,
save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
checkpoint_name=args.chk,
num_processes_preprocessing=args.npp,
num_processes_segmentation_export=args.nps,
folder_with_segs_from_prev_stage=args.prev_stage_predictions)


if __name__ == '__main__':
predict_from_raw_data('/media/fabian/data/nnUNet_raw/Dataset003_Liver/imagesTs',
'/media/fabian/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres',
Expand Down Expand Up @@ -278,5 +416,3 @@ def predict_from_raw_data(list_of_lists_or_source_folder: Union[str, List[List[s
num_processes_preprocessing=2,
num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage='/media/fabian/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres')


5 changes: 3 additions & 2 deletions nnunetv2/preprocessing/preprocessors/default_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def run_case(self, image_files: List[str], seg_file: Union[str, None], plans: Un

# when using the ignore label we want to sample only from annotated regions. Therefore we also need to
# collect samples uniformly from all classes (incl background)
collect_for_this.append(label_manager.all_labels)
if label_manager.has_ignore_label:
collect_for_this.append(label_manager.all_labels)

# no need to filter background in regions because it is already filtered in handle_labels
# print(all_labels, regions)
Expand All @@ -137,7 +138,7 @@ def run_case_save(self, output_filename_truncated: str, image_files: List[str],
@staticmethod
def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[tuple[int, ...]]],
seed: int = 1234):
num_samples = 25000
num_samples = 10000
min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too
# sparse
rndst = np.random.RandomState(seed)
Expand Down
4 changes: 2 additions & 2 deletions nnunetv2/utilities/file_path_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name


def get_output_folder(dataset_name_or_id: Union[str, int], trainer_module: str = 'nnUNetModule',
def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres',
fold: Union[str, int] = None) -> str:
tmp = join(nnUNet_results, maybe_convert_to_dataset_name(dataset_name_or_id),
f'{trainer_module}__{plans_identifier}__{configuration}')
f'{trainer_name}__{plans_identifier}__{configuration}')
if fold is not None:
tmp = join(tmp, f'fold_{fold}')
return tmp
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
version='2',
description='nnU-Net. Framework for out-of-the box biomedical image segmentation.',
url='https://github.com/MIC-DKFZ/nnUNet',
author='HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center',
author='Helmholtz Imaging Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center',
author_email='f.isensee@dkfz-heidelberg.de',
license='Apache License Version 2.0, January 2004',
install_requires=[
Expand All @@ -32,7 +32,9 @@
'nnUNetv2_extract_fingerprint = nnunetv2.experiment_planning.plan_and_preprocess:extract_fingerprint',
'nnUNetv2_plan_experiment = nnunetv2.experiment_planning.plan_and_preprocess:plan_experiment',
'nnUNetv2_preprocess = nnunetv2.experiment_planning.plan_and_preprocess:preprocess',
'nnUNetv2_train = nnunetv2.run.train:nnUNet_train_from_args'
'nnUNetv2_train = nnunetv2.run.train_nolightning:nnUNet_train_from_args',
'nnUNetv2_predict_from_modelfolder = nnunetv2.inference.predict_from_raw_data:predict_entry_point_modelfolder',
'nnUNetv2_predict = nnunetv2.inference.predict_from_raw_data:predict_entry_point',
],
},
keywords=['deep learning', 'image segmentation', 'medical image analysis',
Expand Down

0 comments on commit 9fc338f

Please sign in to comment.