Skip to content

Commit 2eaa371

Browse files
committed
nerge
2 parents a3f7935 + ed88855 commit 2eaa371

File tree

4 files changed

+28
-7
lines changed

4 files changed

+28
-7
lines changed

nnunetv2/imageio/base_reader_writer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
class BaseReaderWriter(ABC):
2222
@staticmethod
2323
def _check_all_same(input_list):
24-
# compare all entries to the first
25-
for i in input_list[1:]:
26-
if i != input_list[0]:
27-
return False
28-
return True
24+
if len(input_list) == 1:
25+
return True
26+
else:
27+
# compare all entries to the first
28+
return np.allclose(input_list[0], input_list[1:])
2929

3030
@staticmethod
3131
def _check_all_same_array(input_list):

nnunetv2/training/dataloading/data_loader_2d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def generate_train_batch(self):
9090
if self.transforms is not None:
9191
with torch.no_grad():
9292
with threadpool_limits(limits=1, user_api=None):
93+
9394
data_all = torch.from_numpy(data_all).float()
9495
seg_all = torch.from_numpy(seg_all).to(torch.int16)
9596
images = []
@@ -99,7 +100,10 @@ def generate_train_batch(self):
99100
images.append(tmp['image'])
100101
segs.append(tmp['segmentation'])
101102
data_all = torch.stack(images)
102-
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
103+
if isinstance(segs[0], list):
104+
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
105+
else:
106+
seg_all = torch.stack(segs)
103107
del segs, images
104108

105109
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}

nnunetv2/training/dataloading/data_loader_3d.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def generate_train_batch(self):
6262
images.append(tmp['image'])
6363
segs.append(tmp['segmentation'])
6464
data_all = torch.stack(images)
65-
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
65+
if isinstance(segs[0], list):
66+
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
67+
else:
68+
seg_all = torch.stack(segs)
6669
del segs, images
6770

6871
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}

nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic
5555
self.num_epochs = 250
5656

5757

58+
class nnUNetTrainer_500epochs(nnUNetTrainer):
59+
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
60+
device: torch.device = torch.device('cuda')):
61+
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
62+
self.num_epochs = 500
63+
64+
65+
class nnUNetTrainer_750epochs(nnUNetTrainer):
66+
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
67+
device: torch.device = torch.device('cuda')):
68+
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
69+
self.num_epochs = 750
70+
71+
5872
class nnUNetTrainer_2000epochs(nnUNetTrainer):
5973
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
6074
device: torch.device = torch.device('cuda')):

0 commit comments

Comments
 (0)