Skip to content

Commit 504e640

Browse files
committed
bug fix and support gpu,100fps
1 parent e996b40 commit 504e640

File tree

2 files changed

+26
-32
lines changed

2 files changed

+26
-32
lines changed

track/DCFNet.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,14 @@ class TrackerConfig(object):
3333
net_average_image = np.array([123, 117, 104]).reshape(-1, 1, 1).astype(np.float32)
3434
output_sigma = crop_sz / (1 + padding) * output_sigma_factor
3535
y = gaussian_shaped_labels(output_sigma, net_input_size)
36-
# cv2.imshow('gaussian', y)
37-
# cv2.waitKey(0)
3836
yf_ = np.fft.fft2(y)
3937
yf = torch.Tensor(1, 1, crop_sz, crop_sz, 2)
4038
yf_real = torch.Tensor(np.real(yf_))
4139
yf_imag = torch.Tensor(np.imag(yf_))
4240
yf[0, 0, :, :, 0] = yf_real
4341
yf[0, 0, :, :, 1] = yf_imag
4442
yf = yf.cuda()
45-
# y = torch.irfft(yf, 2, onesided=False)
46-
# cv2.imshow('gaussian', y[0,0].data.cpu().numpy())
47-
# cv2.waitKey(0)
4843
cos_window = torch.Tensor(np.outer(np.hanning(crop_sz), np.hanning(crop_sz))).cuda()
49-
# cv2.imshow('cos window', cos_window.data.cpu().numpy())
50-
# cv2.waitKey(0)
5144

5245

5346
def DCFNet_init(im, target_pos, target_sz, use_gpu=True):
@@ -59,6 +52,7 @@ def DCFNet_track(state, im):
5952

6053

6154
if __name__ == '__main__':
55+
# base dataset path and setting
6256
raw_data_path = '/media/sensetime/memo/OTB2015'
6357
if not isdir(raw_data_path):
6458
raw_data_path = '/data1/qwang/OTB100'
@@ -67,32 +61,34 @@ def DCFNet_track(state, im):
6761
json_path = join('dataset', dataset + '.json')
6862
annos = json.load(open(json_path, 'r'))
6963
videos = sorted(annos.keys())
64+
7065
use_gpu = True
7166
visualization = False
72-
for video_id, video in enumerate(videos[30:]): # run without resetting
67+
68+
# default parameter and load feature extractor network
69+
config = TrackerConfig()
70+
net = DCFNet(config)
71+
net.load_param(config.feature_path)
72+
net.eval()
73+
net.cuda()
74+
75+
# loop videos
76+
for video_id, video in enumerate(videos): # run without resetting
7377
video_path_name = annos[video]['name']
7478
init_rect = np.array(annos[video]['init_rect']).astype(np.float)
7579
image_files = [join(raw_data_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']]
7680
n_images = len(image_files)
7781

78-
target_pos, target_sz = rect1_2_cxy_wh(init_rect)
82+
target_pos, target_sz = rect1_2_cxy_wh(init_rect) # OTB label is 1-indexed
7983

8084
im = cv2.imread(image_files[0]) # HxWxC
8185
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
8286

83-
# cv2.imshow('image', im)
84-
# cv2.waitKey(0)
85-
86-
# init tracker
87-
config = TrackerConfig()
88-
net = DCFNet(config)
89-
net.load_param(config.feature_path)
90-
net.eval()
91-
net.cuda()
87+
# confine results
9288
min_sz = np.maximum(config.min_scale_factor * target_sz, 4)
93-
[im_h, im_w, _] = im.shape
9489
max_sz = np.minimum(im.shape[:2], config.max_scale_factor * target_sz)
9590

91+
# crop template
9692
window_sz = target_sz * (1 + config.padding)
9793
bbox = cxy_wh_2_bbox(target_pos, window_sz)
9894
patch = resample(im, bbox, config.net_input_size, [0, 0, 0])
@@ -104,22 +100,20 @@ def DCFNet_track(state, im):
104100
res = [cxy_wh_2_rect1(target_pos, target_sz)] # save in .txt
105101
tic = time.time()
106102
patch_crop = np.zeros((config.num_scale, patch.shape[0], patch.shape[1], patch.shape[2]), np.float32)
107-
for f in range(1, n_images):
108-
im = cv2.imread(image_files[0])
103+
for f in range(1, n_images): # track
104+
im = cv2.imread(image_files[f])
109105
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
110-
# track
111-
for i in range(config.num_scale):
106+
107+
for i in range(config.num_scale): # crop multi-scale search region
112108
window_sz = target_sz * (config.scale_factor[i] * (1 + config.padding))
113109
bbox = cxy_wh_2_bbox(target_pos, window_sz)
114-
patch_crop[i,:] = resample(im, bbox, config.net_input_size, [0, 0, 0])
110+
patch_crop[i, :] = resample(im, bbox, config.net_input_size, [0, 0, 0])
115111
# cv2.imwrite('crop2.jpg', np.transpose(patch_crop[0,::-1,:,:], (1, 2, 0)))
116112
# cv2.imshow('crop.jpg', np.transpose(patch_crop[i], (1, 2, 0)).astype(np.float32) / 255)
117113
# cv2.waitKey(0)
118114

119115
search = patch_crop - config.net_average_image
120116
response = net(torch.Tensor(search).cuda()).cpu()
121-
response_cpu = response.data.cpu().numpy()
122-
cv2.imwrite('response_map.jpg', response_cpu[0,0,:])
123117
peak, idx = torch.max(response.view(config.num_scale, -1), 1)
124118
peak = peak.data.numpy() * config.scale_factor
125119
best_scale = np.argmax(peak)
@@ -131,22 +125,22 @@ def DCFNet_track(state, im):
131125
c_max = c_max - config.net_input_size[1]
132126
window_sz = target_sz * (config.scale_factor[best_scale] * (1 + config.padding))
133127

134-
target_pos = target_pos + np.array([r_max, c_max]) * window_sz / config.net_input_size
128+
target_pos = target_pos + np.array([c_max, r_max]) * window_sz / config.net_input_size
135129
target_sz = np.minimum(np.maximum(window_sz / (1 + config.padding), min_sz), max_sz)
136130

137131
# model update
138132
window_sz = target_sz * (1 + config.padding)
139133
bbox = cxy_wh_2_bbox(target_pos, window_sz)
140134
patch = resample(im, bbox, config.net_input_size, [0, 0, 0])
141135
target = patch - config.net_average_image
142-
net.update(torch.Tensor(np.expand_dims(target, axis=0)), lr=config.interp_factor)
136+
net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda(), lr=config.interp_factor)
143137

144138
res.append(cxy_wh_2_rect1(target_pos, target_sz)) # 1-index
145139

146140
if visualization:
147141
im_show = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
148-
cv2.rectangle(im_show, (int(target_pos[1] - target_sz[1] / 2), int(target_pos[0] - target_sz[0] / 2)),
149-
(int(target_pos[1] + target_sz[1] / 2), int(target_pos[0] + target_sz[0] / 2)),
142+
cv2.rectangle(im_show, (int(target_pos[0] - target_sz[0] / 2), int(target_pos[1] - target_sz[1] / 2)),
143+
(int(target_pos[0] + target_sz[0] / 2), int(target_pos[1] + target_sz[1] / 2)),
150144
(0, 255, 0), 3)
151145
cv2.putText(im_show, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2, cv2.LINE_AA)
152146
cv2.imshow(video, im_show)

track/net.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def update(self, z, lr=1.):
6666
self.model_alphaf = alphaf
6767
self.model_zf = zf
6868
else:
69-
self.model_alphaf = (1 - lr) * self.model_alphaf + lr * alphaf
70-
self.model_zf = (1 - lr) * self.model_zf + lr * zf
69+
self.model_alphaf = (1 - lr) * self.model_alphaf.data + lr * alphaf.data
70+
self.model_zf = (1 - lr) * self.model_zf.data + lr * zf.data
7171

7272
def load_param(self, path='param.pth'):
7373
self.feature.load_state_dict(torch.load(path))

0 commit comments

Comments
 (0)