-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathcustom.py
64 lines (52 loc) · 2.34 KB
/
custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from skimage import io
import numpy as np
import pandas as pd
from skimage import io
import torch
class MyDataset(Dataset):
def __init__(self, csv_path, image_ids, image_folder, label_folder, nb_dates, patch_size):
# Read the csv file
self.data_info = pd.read_csv(csv_path)
self.patch_size = patch_size
self.nb_dates = nb_dates
self.all_imgs = []
for nd in self.nb_dates:
imgs_i = []
for id in image_ids:
imgs_i.append(np.load(image_folder + id + '/' + id + '_{}.npy'.format(str(nd))))
self.all_imgs.append(imgs_i)
self.all_labels = []
for id in image_ids:
label = io.imread(label_folder + id + '/cm/' + id + '-cm.tif')
label[label==1]=0
label[label==2]=1
self.all_labels.append(label)
# Calculate len
self.data_len = self.data_info.shape[0]-1
def __getitem__(self, index):
x = int(self.data_info.iloc[:,0][index])
y = int(self.data_info.iloc[:,1][index])
image_id = int(self.data_info.iloc[:,2][index])
transformation_id = int(self.data_info.iloc[:,3][index])
def transform_date(patch, tr_id):
if tr_id == 0:
patch = patch
elif tr_id == 1:
patch = np.rot90(patch, k=1)
elif tr_id == 2:
patch = np.rot90(patch, k=2)
elif tr_id == 3:
patch = np.rot90(patch, k=3)
return patch
image_patch = []
for nd in self.nb_dates:
find_patch = self.all_imgs[self.nb_dates.index(nd)][image_id] [x:x + self.patch_size, y:y + self.patch_size, :]
find_patch = np.concatenate( (find_patch[:,:,1:4], np.reshape(find_patch[:,:,7], (find_patch.shape[0],find_patch.shape[1],1))), 2) #take the 4 highest resolution channels
image_patch.append(np.transpose(transform_date(find_patch, transformation_id), (2,0,1)))
find_labels = self.all_labels[image_id] [x:x + self.patch_size, y:y + self.patch_size]
label_patch = transform_date(find_labels, transformation_id)
return np.ascontiguousarray(image_patch), np.ascontiguousarray(label_patch)
def __len__(self):
return self.data_len