-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathvot_SiamRPN_upd.py
54 lines (46 loc) · 1.62 KB
/
vot_SiamRPN_upd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import vot
from vot import Rectangle
import sys
import cv2 # imread
import torch
import numpy as np
from os.path import realpath, dirname, join
from net_upd import SiamRPNBIG
from updatenet import UpdateResNet
from run_SiamRPN_upd import SiamRPN_init, SiamRPN_track_upd
from utils import get_axis_aligned_bbox, cxy_wh_2_rect
# load net
net_file = join(realpath(dirname(__file__)), 'SiamRPNBIG.model')
net = SiamRPNBIG()
net.load_state_dict(torch.load(net_file))
net.eval().cuda()
updatenet = UpdateResNet()
update_model=torch.load('../models/vot2016.pth.tar')['state_dict']
#update_model_fix = dict()
#for i in update_model.keys():
# update_model_fix['.'.join(i.split('.')[1:])] = update_model[i]
#updatenet.load_state_dict(update_model_fix)
updatenet.load_state_dict(update_model)
updatenet.eval().cuda()
# warm up
#for i in range(10):
# net.temple(torch.autograd.Variable(torch.FloatTensor(1, 3, 127, 127)).cuda())
# net(torch.autograd.Variable(torch.FloatTensor(1, 3, 255, 255)).cuda())
# start to track
handle = vot.VOT("polygon")
Polygon = handle.region()
cx, cy, w, h = get_axis_aligned_bbox(Polygon)
image_file = handle.frame()
if not image_file:
sys.exit(0)
target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
im = cv2.imread(image_file) # HxWxC
state = SiamRPN_init(im, target_pos, target_sz, net) # init tracker
while True:
image_file = handle.frame()
if not image_file:
break
im = cv2.imread(image_file) # HxWxC
state = SiamRPN_track_upd(state, im,updatenet)
res = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
handle.report(Rectangle(res[0], res[1], res[2], res[3]))