-
Notifications
You must be signed in to change notification settings - Fork 61
/
cuhk03_dataset.py
93 lines (84 loc) · 3.29 KB
/
cuhk03_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import h5py
import os
import cv2
import random
import sys
def prepare_data(path):
f = h5py.File('%s/cuhk-03.mat' % path)
labeled = [f['labeled'][0][i] for i in xrange(len(f['labeled'][0]))]
labeled = [f[labeled[0]][i] for i in xrange(len(f[labeled[0]]))]
detected = [f['detected'][0][i] for i in xrange(len(f['detected'][0]))]
detected = [f[detected[0]][i] for i in xrange(len(f[detected[0]]))]
datasets = [['labeled', labeled], ['detected', detected]]
prev_id = 0
for dataset in datasets:
if not os.path.exists('%s/%s/train' % (path, dataset[0])):
os.makedirs('%s/%s/train' % (path, dataset[0]))
if not os.path.exists('%s/%s/val' % (path, dataset[0])):
os.makedirs('%s/%s/val' % (path, dataset[0]))
for i in xrange(0, len(dataset[1])):
for j in xrange(len(dataset[1][0])):
try:
image = np.array(f[dataset[1][i][j]]).transpose((2, 1, 0))
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = cv2.imencode('.jpg', image)[1].tostring()
if len(dataset[1][0]) - j <= 100:
filepath = '%s/%s/val/%04d_%02d.jpg' % (path, dataset[0], j - prev_id - 1, i)
else:
filepath = '%s/%s/train/%04d_%02d.jpg' % (path, dataset[0], j, i)
prev_id = j
with open(filepath, 'wb') as image_file:
image_file.write(image)
except Exception as e:
continue
def get_pair(path, set, num_id, positive):
pair = []
if positive:
value = int(random.random() * num_id)
id = [value, value]
else:
while True:
id = [int(random.random() * num_id), int(random.random() * num_id)]
if id[0] != id[1]:
break
for i in xrange(2):
filepath = ''
while True:
index = int(random.random() * 10)
filepath = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, id[i], index)
if not os.path.exists(filepath):
continue
break
pair.append(filepath)
return pair
def get_num_id(path, set):
files = os.listdir('%s/labeled/%s' % (path, set))
files.sort()
return int(files[-1].split('_')[0]) - int(files[0].split('_')[0]) + 1
def read_data(path, set, num_id, image_width, image_height, batch_size):
batch_images = []
labels = []
for i in xrange(batch_size // 2):
pairs = [get_pair(path, set, num_id, True), get_pair(path, set, num_id, False)]
for pair in pairs:
images = []
for p in pair:
image = cv2.imread(p)
image = cv2.resize(image, (image_width, image_height))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
images.append(image)
batch_images.append(images)
labels.append([1., 0.])
labels.append([0., 1.])
'''
for pair in batch_images:
for p in pair:
cv2.imshow('img', p)
key = cv2.waitKey(0)
if key == 1048603:
exit()
'''
return np.transpose(batch_images, (1, 0, 2, 3, 4)), np.array(labels)
if __name__ == '__main__':
prepare_data(sys.argv[1])