Skip to content

Commit

Permalink
allow more than 127 classes
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Mar 16, 2023
1 parent 0f70690 commit 619d0ab
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
16 changes: 5 additions & 11 deletions nnunetv2/preprocessing/preprocessors/default_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def __init__(self, verbose: bool = True):
self.verbose = verbose
"""
Everything we need is in the plans. Those are given when run() is called
CAREFUL! WE USE INT8 FOR SAVING SEGMENTATIONS (NOT UINT8) SO 127 IS THE MAXIMUM LABEL!
"""

def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager,
Expand Down Expand Up @@ -119,8 +117,11 @@ def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_man
data_properites['class_locations'] = self._sample_foreground_locations(seg, collect_for_this,
verbose=self.verbose)
seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager)

return data, seg.astype(np.int8), data_properites
if np.max(seg) > 127:
seg = seg.astype(np.int16)
else:
seg = seg.astype(np.int8)
return data, seg, data_properites

def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str,
plans_manager: PlansManager, configuration_manager: ConfigurationManager,
Expand Down Expand Up @@ -197,13 +198,6 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan
dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json')
dataset_json = load_json(dataset_json_file)

label_manager = plans_manager.get_label_manager(dataset_json)
classes = label_manager.all_labels

if max(classes) > 127:
raise RuntimeError('WE USE INT8 FOR SAVING SEGMENTATIONS (NOT UINT8) SO 127 IS THE MAXIMUM LABEL! '
'Your labels go larger than that')

identifiers = get_identifiers_from_splitted_dataset_folder(join(nnUNet_raw, dataset_name, 'imagesTr'),
dataset_json['file_ending'])
output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier)
Expand Down
3 changes: 1 addition & 2 deletions nnunetv2/training/dataloading/data_loader_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
# again, no support for more than 127 labels! We gotta save memory, yo
seg_all = np.zeros(self.seg_shape, dtype=np.int8)
seg_all = np.zeros(self.seg_shape, dtype=np.int16)
case_properties = []

for j, current_key in enumerate(selected_keys):
Expand Down
3 changes: 1 addition & 2 deletions nnunetv2/training/dataloading/data_loader_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(self.data_shape, dtype=np.float32)
# again, no support for more than 127 labels! We gotta save memory, yo
seg_all = np.zeros(self.seg_shape, dtype=np.int8)
seg_all = np.zeros(self.seg_shape, dtype=np.int16)
case_properties = []

for j, i in enumerate(selected_keys):
Expand Down

0 comments on commit 619d0ab

Please sign in to comment.