-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathgetDataLoader.py
29 lines (22 loc) · 881 Bytes
/
getDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import print_function, division, absolute_import, unicode_literals
import torch
class BaseDataProvider(object):
def _load_data_and_label(self):
data, label, path = self._next_data()
data, label = self._augment_data(data, label)
data = data.transpose(2, 0, 1).astype(float)
labels = label.transpose(2, 0, 1).astype(float)
nd = data.shape[0]
nw = data.shape[1]
nh = data.shape[2]
return path, data.reshape(1, 1, nd, nw, nh), labels.reshape(1, 1, nd, nw, nh)
def _toTorchFloatTensor(self, img):
img = torch.from_numpy(img.copy())
return img
def __call__(self, n):
path, data, labels = self._load_data_and_label()
P = []
X = self._toTorchFloatTensor(data)
Y = self._toTorchFloatTensor(labels)
P.append(path)
return X, Y, P