Skip to content

Commit 1450e20

Browse files
committed
tune dcfnet
1 parent 0eb001f commit 1450e20

File tree

2 files changed

+112
-8
lines changed

2 files changed

+112
-8
lines changed

track/DCFNet.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def track(self, im):
8383
else:
8484
response = self.net(torch.Tensor(search))
8585
peak, idx = torch.max(response.view(self.config.num_scale, -1), 1)
86-
peak = peak.data.numpy() * self.config.scale_factor
86+
peak = peak.data.cpu().numpy() * self.config.scale_penalties
8787
best_scale = np.argmax(peak)
8888
r_max, c_max = np.unravel_index(idx[best_scale], self.config.net_input_size)
8989

@@ -108,9 +108,6 @@ def track(self, im):
108108

109109
if __name__ == '__main__':
110110
# base dataset path and setting
111-
raw_data_path = '/media/sensetime/memo/OTB2015'
112-
if not isdir(raw_data_path):
113-
raw_data_path = '/data1/qwang/OTB100'
114111
dataset = 'OTB2015'
115112
base_path = join('dataset', dataset)
116113
json_path = join('dataset', dataset + '.json')
@@ -131,7 +128,7 @@ def track(self, im):
131128
for video_id, video in enumerate(videos): # run without resetting
132129
video_path_name = annos[video]['name']
133130
init_rect = np.array(annos[video]['init_rect']).astype(np.float)
134-
image_files = [join(raw_data_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']]
131+
image_files = [join(base_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']]
135132
n_images = len(image_files)
136133

137134
target_pos, target_sz = rect1_2_cxy_wh(init_rect) # OTB label is 1-indexed
@@ -168,9 +165,9 @@ def track(self, im):
168165
# cv2.waitKey(0)
169166

170167
search = patch_crop - config.net_average_image
171-
response = net(torch.Tensor(search).cuda())
168+
response = net(torch.Tensor(search).cuda()).cpu()
172169
peak, idx = torch.max(response.view(config.num_scale, -1), 1)
173-
peak = peak.cpu().data.numpy() * config.scale_penalties
170+
peak = peak.data.numpy() * config.scale_penalties
174171
best_scale = np.argmax(peak)
175172
r_max, c_max = np.unravel_index(idx[best_scale], config.net_input_size)
176173

@@ -212,4 +209,4 @@ def track(self, im):
212209
for x in res:
213210
f.write(','.join(['{:.2f}'.format(i) for i in x]) + '\n')
214211

215-
eval_auc('OTB2015', 'DCFNet_test', 0, 1)
212+
eval_auc('OTB2015', 'DCFNet_test', 0, 1)

track/tune_otb.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import argparse
2+
import cv2
3+
import numpy as np
4+
from os import makedirs
5+
from os.path import isfile, isdir, join
6+
from util import cxy_wh_2_rect1
7+
import torch
8+
import json
9+
from DCFNet import *
10+
11+
parser = argparse.ArgumentParser(description='Tune parameters for DCFNet tracker on OTB2015')
12+
parser.add_argument('-v', '--visualization', dest='visualization', action='store_true',
13+
help='whether visualize result')
14+
15+
args = parser.parse_args()
16+
17+
18+
def tune_otb(param):
19+
regions = [] # result and states[1 init / 2 lost / 0 skip]
20+
# save result
21+
benchmark_result_path = join('result', param['dataset'])
22+
tracker_path = join(benchmark_result_path, (param['network_name'] +
23+
'_scale_step_{:.3f}'.format(param['config'].scale_step) +
24+
'_scale_penalty_{:.3f}'.format(param['config'].scale_penalty) +
25+
'_interp_factor_{:.3f}'.format(param['config'].interp_factor)))
26+
result_path = join(tracker_path, '{:s}.txt'.format(param['video']))
27+
if isfile(result_path):
28+
return
29+
if not isdir(tracker_path): makedirs(tracker_path)
30+
with open(result_path, 'w') as f: # Occupation
31+
for x in regions:
32+
f.write('')
33+
34+
ims = param['ims']
35+
toc = 0
36+
for f, im in enumerate(ims):
37+
tic = cv2.getTickCount()
38+
if f == 0: # init
39+
init_rect = p['init_rect']
40+
tracker = DCFNetTraker(ims[f], init_rect, config=param['config'])
41+
regions.append(init_rect)
42+
else: # tracking
43+
rect = tracker.track(ims[f])
44+
regions.append(rect)
45+
toc += cv2.getTickCount() - tic
46+
47+
if args.visualization: # visualization (skip lost frame)
48+
if f == 0: cv2.destroyAllWindows()
49+
location = [int(l) for l in location] # int
50+
cv2.rectangle(im, (location[0], location[1]), (location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3)
51+
cv2.putText(im, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
52+
53+
cv2.imshow(video, im)
54+
cv2.waitKey(1)
55+
toc /= cv2.getTickFrequency()
56+
print('{:2d} Video: {:12s} Time: {:2.1f}s Speed: {:3.1f}fps'.format(v, video, toc, f / toc))
57+
regions = np.array(regions)
58+
regions[:,:2] += 1 # 1-index
59+
with open(result_path, 'w') as f:
60+
for x in regions:
61+
f.write(','.join(['{:.2f}'.format(i) for i in x]) + '\n')
62+
63+
64+
params = {'dataset':['OTB2015'], 'network':['param.pth'],
65+
'scale_step':np.arange(1.01, 1.05, 0.005, np.float32),
66+
'scale_penalty':np.arange(0.98, 1.0, 0.025, np.float32),
67+
'interp_factor':np.arange(0.001, 0.015, 0.001, np.float32)}
68+
69+
p = dict()
70+
p['config'] = TrackerConfig()
71+
for network in params['network']:
72+
p['network_name'] = network
73+
np.random.shuffle(params['dataset'])
74+
for dataset in params['dataset']:
75+
base_path = join('dataset', dataset)
76+
json_path = join('dataset', dataset+'.json')
77+
annos = json.load(open(json_path, 'r'))
78+
videos = annos.keys()
79+
p['dataset'] = dataset
80+
np.random.shuffle(videos)
81+
for v, video in enumerate(videos):
82+
p['v'] = v
83+
p['video'] = video
84+
video_path_name = annos[video]['name']
85+
init_rect = np.array(annos[video]['init_rect']).astype(np.float)
86+
image_files = [join(base_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']]
87+
target_pos = np.array([init_rect[0] + init_rect[2] / 2 -1 , init_rect[1] + init_rect[3] / 2 -1]) # 0-index
88+
target_sz = np.array([init_rect[2], init_rect[3]])
89+
ims = []
90+
for image_file in image_files:
91+
im = cv2.imread(image_file)
92+
if im.shape[2] == 1:
93+
cv2.cvtColor(im, im, cv2.COLOR_GRAY2RGB)
94+
ims.append(im)
95+
p['ims'] = ims
96+
p['init_rect'] = init_rect
97+
98+
np.random.shuffle(params['scale_step'])
99+
np.random.shuffle(params['scale_penalty'])
100+
np.random.shuffle(params['interp_factor'])
101+
for scale_step in params['scale_step']:
102+
for scale_penalty in params['scale_penalty']:
103+
for interp_factor in params['interp_factor']:
104+
p['config'].scale_step = float(scale_step)
105+
p['config'].scale_penalty = float(scale_penalty)
106+
p['config'].interp_factor = float(interp_factor)
107+
tune_otb(p)

0 commit comments

Comments
 (0)