Skip to content

Commit 57f75de

Browse files
committed
add eval code
1 parent 504e640 commit 57f75de

File tree

2 files changed

+214
-4
lines changed

2 files changed

+214
-4
lines changed

track/DCFNet.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,66 @@ class TrackerConfig(object):
4343
cos_window = torch.Tensor(np.outer(np.hanning(crop_sz), np.hanning(crop_sz))).cuda()
4444

4545

46-
def DCFNet_init(im, target_pos, target_sz, use_gpu=True):
47-
pass
46+
class DCFNetTraker(object):
47+
def __init__(self, im, init_rect, config=TrackerConfig(), gpu=True):
48+
self.gpu = gpu
49+
self.config = config
50+
self.net = DCFNet(config)
51+
self.net.load_param(config.feature_path)
52+
self.net.eval()
53+
if gpu:
54+
self.net.cuda()
4855

56+
# confine results
57+
target_pos, target_sz = rect1_2_cxy_wh(init_rect)
58+
self.min_sz = np.maximum(config.min_scale_factor * target_sz, 4)
59+
self.max_sz = np.minimum(im.shape[:2], config.max_scale_factor * target_sz)
60+
61+
# crop template
62+
window_sz = target_sz * (1 + config.padding)
63+
bbox = cxy_wh_2_bbox(target_pos, window_sz)
64+
patch = resample(im, bbox, config.net_input_size, [0, 0, 0])
65+
# cv2.imwrite('crop.jpg', np.transpose(patch[::-1,:,:], (1, 2, 0)))
4966

50-
def DCFNet_track(state, im):
51-
pass
67+
target = patch - config.net_average_image
68+
self.net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda())
69+
self.target_pos, self.target_sz = target_pos, target_sz
70+
self.patch_crop = np.zeros((config.num_scale, patch.shape[0], patch.shape[1], patch.shape[2]), np.float32) # buff
71+
72+
def track(self, im):
73+
for i in range(self.config.num_scale): # crop multi-scale search region
74+
window_sz = self.target_sz * (self.config.scale_factor[i] * (1 + self.config.padding))
75+
bbox = cxy_wh_2_bbox(self.target_pos, window_sz)
76+
self.patch_crop[i, :] = resample(im, bbox, self.config.net_input_size, [0, 0, 0])
77+
78+
search = self.patch_crop - self.config.net_average_image
79+
80+
if self.gpu:
81+
response = self.net(torch.Tensor(search).cuda()).cpu()
82+
else:
83+
response = self.net(torch.Tensor(search))
84+
peak, idx = torch.max(response.view(self.config.num_scale, -1), 1)
85+
peak = peak.data.numpy() * self.config.scale_factor
86+
best_scale = np.argmax(peak)
87+
r_max, c_max = np.unravel_index(idx[best_scale], self.config.net_input_size)
88+
89+
if r_max > self.config.net_input_size[0] / 2:
90+
r_max = r_max - self.config.net_input_size[0]
91+
if c_max > self.config.net_input_size[1] / 2:
92+
c_max = c_max - self.config.net_input_size[1]
93+
window_sz = self.target_sz * (self.config.scale_factor[best_scale] * (1 + self.config.padding))
94+
95+
self.target_pos = self.target_pos + np.array([c_max, r_max]) * window_sz / self.config.net_input_size
96+
self.target_sz = np.minimum(np.maximum(window_sz / (1 + self.config.padding), self.min_sz), self.max_sz)
97+
98+
# model update
99+
window_sz = self.target_sz * (1 + self.config.padding)
100+
bbox = cxy_wh_2_bbox(self.target_pos, window_sz)
101+
patch = resample(im, bbox, self.config.net_input_size, [0, 0, 0])
102+
target = patch - self.config.net_average_image
103+
self.net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda(), lr=self.config.interp_factor)
104+
105+
return cxy_wh_2_rect1(self.target_pos, self.target_sz) # 1-index
52106

53107

54108
if __name__ == '__main__':

track/eval_otb.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import sys
2+
import json
3+
import os
4+
import glob
5+
from os.path import join as fullfile
6+
import numpy as np
7+
8+
9+
def overlap_ratio(rect1, rect2):
10+
'''
11+
Compute overlap ratio between two rects
12+
- rect: 1d array of [x,y,w,h] or
13+
2d array of N x [x,y,w,h]
14+
'''
15+
16+
if rect1.ndim==1:
17+
rect1 = rect1[None,:]
18+
if rect2.ndim==1:
19+
rect2 = rect2[None,:]
20+
21+
left = np.maximum(rect1[:,0], rect2[:,0])
22+
right = np.minimum(rect1[:,0]+rect1[:,2], rect2[:,0]+rect2[:,2])
23+
top = np.maximum(rect1[:,1], rect2[:,1])
24+
bottom = np.minimum(rect1[:,1]+rect1[:,3], rect2[:,1]+rect2[:,3])
25+
26+
intersect = np.maximum(0,right - left) * np.maximum(0,bottom - top)
27+
union = rect1[:,2]*rect1[:,3] + rect2[:,2]*rect2[:,3] - intersect
28+
iou = np.clip(intersect / union, 0, 1)
29+
return iou
30+
31+
32+
def compute_success_overlap(gt_bb, result_bb):
33+
thresholds_overlap = np.arange(0, 1.05, 0.05)
34+
n_frame = len(gt_bb)
35+
success = np.zeros(len(thresholds_overlap))
36+
iou = overlap_ratio(gt_bb, result_bb)
37+
for i in range(len(thresholds_overlap)):
38+
success[i] = sum(iou > thresholds_overlap[i]) / float(n_frame)
39+
return success
40+
41+
42+
def compute_success_error(gt_center, result_center):
43+
thresholds_error = np.arange(0, 51, 1)
44+
n_frame = len(gt_center)
45+
success = np.zeros(len(thresholds_error))
46+
dist = np.sqrt(np.sum(np.power(gt_center - result_center, 2), axis=1))
47+
for i in range(len(thresholds_error)):
48+
success[i] = sum(dist <= thresholds_error[i]) / float(n_frame)
49+
return success
50+
51+
52+
def get_result_bb(arch, seq):
53+
result_path = fullfile(arch, seq + '.txt')
54+
temp = np.loadtxt(result_path, delimiter=',').astype(np.float)
55+
return np.array(temp)
56+
57+
58+
def convert_bb_to_center(bboxes):
59+
return np.array([(bboxes[:, 0] + (bboxes[:, 2] - 1) / 2),
60+
(bboxes[:, 1] + (bboxes[:, 3] - 1) / 2)]).T
61+
62+
63+
def eval_auc(dataset='OTB2015', tracker_reg='S*', start=0, end=1e6):
64+
list_path = os.path.join('dataset', dataset + '.json')
65+
annos = json.load(open(list_path, 'r'))
66+
seqs = annos.keys()
67+
68+
OTB2013 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix',
69+
'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek',
70+
'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking',
71+
'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1',
72+
'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2',
73+
'basketball', 'football', 'subway', 'tiger1', 'tiger2']
74+
75+
OTB2015 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix',
76+
'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek',
77+
'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking',
78+
'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1',
79+
'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2',
80+
'basketball', 'football', 'subway', 'tiger1', 'tiger2', 'clifBar', 'biker', 'bird1', 'blurBody',
81+
'blurCar2', 'blurFace', 'blurOwl', 'box', 'car1', 'crowds', 'diving', 'dragonBaby', 'human3', 'human4_2',
82+
'human6', 'human9', 'jump', 'panda', 'redTeam', 'skating2_1', 'skating2_2', 'surfer', 'bird2',
83+
'blurCar1', 'blurCar3', 'blurCar4', 'board', 'bolt2', 'car2', 'car24', 'coupon', 'dancer', 'dancer2',
84+
'dog', 'girl2', 'gym', 'human2', 'human5', 'human7', 'human8', 'kiteSurf', 'man', 'rubik', 'skater',
85+
'skater2', 'toy', 'trans', 'twinnings', 'vase']
86+
87+
trackers = glob.glob(fullfile('result', dataset, tracker_reg))
88+
trackers = trackers[start:min(end, len(trackers))]
89+
90+
n_seq = len(seqs)
91+
thresholds_overlap = np.arange(0, 1.05, 0.05)
92+
thresholds_error = np.arange(0, 51, 1)
93+
94+
success_overlap = np.zeros((n_seq, len(trackers), len(thresholds_overlap)))
95+
success_error = np.zeros((n_seq, len(trackers), len(thresholds_error)))
96+
for i in range(n_seq):
97+
seq = seqs[i]
98+
gt_rect = np.array(annos[seq]['gt_rect']).astype(np.float)
99+
gt_center = convert_bb_to_center(gt_rect)
100+
for j in range(len(trackers)):
101+
tracker = trackers[j]
102+
print('{:d} processing:{} tracker: {}'.format(i, seq, tracker))
103+
bb = get_result_bb(tracker, seq)
104+
center = convert_bb_to_center(bb)
105+
success_overlap[i][j] = compute_success_overlap(gt_rect, bb)
106+
# success_error[i][j] = compute_success_error(gt_center, center)
107+
108+
print('Success Overlap')
109+
110+
if 'OTB2015' == dataset:
111+
OTB2013_id = []
112+
for i in range(n_seq):
113+
if seqs[i] in OTB2013:
114+
OTB2013_id.append(i)
115+
max_auc_OTB2013 = 0.
116+
max_name_OTB2013 = ''
117+
for i in range(len(trackers)):
118+
auc = success_overlap[OTB2013_id, i, :].mean()
119+
if auc > max_auc_OTB2013:
120+
max_auc_OTB2013 = auc
121+
max_name_OTB2013 = trackers[i]
122+
print('%s(%.4f)' % (trackers[i], auc))
123+
124+
max_auc = 0.
125+
max_name = ''
126+
for i in range(len(trackers)):
127+
auc = success_overlap[:, i, :].mean()
128+
if auc > max_auc:
129+
max_auc = auc
130+
max_name = trackers[i]
131+
print('%s(%.4f)' % (trackers[i], auc))
132+
133+
print('\nOTB2013 Best: %s(%.4f)' % (max_name_OTB2013, max_auc_OTB2013))
134+
print('\nOTB2015 Best: %s(%.4f)' % (max_name, max_auc))
135+
elif 'TC128' == dataset:
136+
max_auc = 0.
137+
max_name = ''
138+
for i in range(len(trackers)):
139+
auc = success_overlap[:, i, :].mean()
140+
if auc > max_auc:
141+
max_auc = auc
142+
max_name = trackers[i]
143+
print('%s(%.4f)' % (trackers[i], auc))
144+
145+
print('\nTC128 Best: %s(%.4f)' % (max_name, max_auc))
146+
147+
148+
if __name__ == "__main__":
149+
if len(sys.argv) < 5:
150+
print('python eval_otb.py OTB2015 DCFNet_test* 0 10')
151+
exit()
152+
dataset = sys.argv[1]
153+
tracker_reg = sys.argv[2]
154+
start = int(sys.argv[3])
155+
end = int(sys.argv[4])
156+
eval_auc(dataset, tracker_reg, start, end)

0 commit comments

Comments
 (0)