-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
76 lines (60 loc) · 2.21 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import sys
import cv2
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
print("Current path",curr_path)
def image_loader(image_path):
try:
image = cv2.imread(image_path)
print(image)
if len(image.shape) == 2:
image = np.stack([image]*3, 2)
return image
except IOError:
print('fail to load image:' + image_path)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])
class AgeDB30(Dataset):
def __init__(self, root, file_list):
self.root = root
self.file_list = file_list
self.nameLs = []
self.nameRs = []
self.folds = []
self.labels = []
with open(file_list) as f:
pairs = f.read().splitlines()
for i, p in enumerate(pairs):
p = p.split(' ')
nameL = p[0]
nameR = p[1]
fold = i // 600
label = int(p[2])
self.nameLs.append(nameL)
self.nameRs.append(nameR)
self.folds.append(fold)
self.labels.append(label)
def __getitem__(self, index):
img_l = image_loader(os.path.join(self.root, self.nameLs[index]))
img_r = image_loader(os.path.join(self.root, self.nameRs[index]))
image_list = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)]
for i in range(len(image_list)):
image_list[i] = transform(image_list[i])
return image_list
def __len__(self):
return len(self.nameLs)
if __name__ == '__main__':
dataset_path = '../../../../../../media/data1/masked_dataset/casiawebface/casiawebface_without_masked'
file_list = '../../../../../../media/data1/masked_dataset/Test/agedb_30.txt'
agedb_dataset = AgeDB30(dataset_path, file_list)
agedb_dataloader = DataLoader(agedb_dataset, batch_size=32, shuffle=False, num_workers=4, drop_last=False)
print(len(agedb_dataset))
print(len(agedb_dataloader))
for data in agedb_dataloader:
print(len(data))