Skip to content

Commit a3f7935

Browse files
committed
data augmentation uses less RAM
1 parent 852b303 commit a3f7935

File tree

2 files changed

+27
-34
lines changed

2 files changed

+27
-34
lines changed

nnunetv2/training/dataloading/data_loader_2d.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,19 @@ def generate_train_batch(self):
8888
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1)
8989

9090
if self.transforms is not None:
91-
if torch is not None:
92-
torch_nthreads = torch.get_num_threads()
93-
torch.set_num_threads(1)
94-
with threadpool_limits(limits=1, user_api=None):
95-
data_all = torch.from_numpy(data_all).float()
96-
seg_all = torch.from_numpy(seg_all).to(torch.int16)
97-
images = []
98-
segs = []
99-
for b in range(self.batch_size):
100-
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
101-
images.append(tmp['image'])
102-
segs.append(tmp['segmentation'])
103-
data_all = torch.stack(images)
104-
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
105-
del segs, images
106-
if torch is not None:
107-
torch.set_num_threads(torch_nthreads)
91+
with torch.no_grad():
92+
with threadpool_limits(limits=1, user_api=None):
93+
data_all = torch.from_numpy(data_all).float()
94+
seg_all = torch.from_numpy(seg_all).to(torch.int16)
95+
images = []
96+
segs = []
97+
for b in range(self.batch_size):
98+
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
99+
images.append(tmp['image'])
100+
segs.append(tmp['segmentation'])
101+
data_all = torch.stack(images)
102+
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
103+
del segs, images
108104

109105
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}
110106

nnunetv2/training/dataloading/data_loader_3d.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,20 @@ def generate_train_batch(self):
5151
seg_all[j] = np.pad(seg, padding, 'constant', constant_values=-1)
5252

5353
if self.transforms is not None:
54-
if torch is not None:
55-
torch_nthreads = torch.get_num_threads()
56-
torch.set_num_threads(1)
57-
with threadpool_limits(limits=1, user_api=None):
58-
data_all = torch.from_numpy(data_all).float()
59-
seg_all = torch.from_numpy(seg_all).to(torch.int16)
60-
images = []
61-
segs = []
62-
for b in range(self.batch_size):
63-
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
64-
images.append(tmp['image'])
65-
segs.append(tmp['segmentation'])
66-
data_all = torch.stack(images)
67-
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
68-
del segs, images
69-
if torch is not None:
70-
torch.set_num_threads(torch_nthreads)
54+
with torch.no_grad():
55+
with threadpool_limits(limits=1, user_api=None):
56+
data_all = torch.from_numpy(data_all).float()
57+
seg_all = torch.from_numpy(seg_all).to(torch.int16)
58+
images = []
59+
segs = []
60+
for b in range(self.batch_size):
61+
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
62+
images.append(tmp['image'])
63+
segs.append(tmp['segmentation'])
64+
data_all = torch.stack(images)
65+
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
66+
del segs, images
67+
7168
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}
7269

7370
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}

0 commit comments

Comments
 (0)