|
| 1 | +import argparse |
| 2 | +import cv2 |
| 3 | +import numpy as np |
| 4 | +from os import makedirs |
| 5 | +from os.path import isfile, isdir, join |
| 6 | +from util import cxy_wh_2_rect1 |
| 7 | +import torch |
| 8 | +import json |
| 9 | +from DCFNet import * |
| 10 | + |
| 11 | +parser = argparse.ArgumentParser(description='Tune parameters for DCFNet tracker on OTB2015') |
| 12 | +parser.add_argument('-v', '--visualization', dest='visualization', action='store_true', |
| 13 | + help='whether visualize result') |
| 14 | + |
| 15 | +args = parser.parse_args() |
| 16 | + |
| 17 | + |
| 18 | +def tune_otb(param): |
| 19 | + regions = [] # result and states[1 init / 2 lost / 0 skip] |
| 20 | + # save result |
| 21 | + benchmark_result_path = join('result', param['dataset']) |
| 22 | + tracker_path = join(benchmark_result_path, (param['network_name'] + |
| 23 | + '_scale_step_{:.3f}'.format(param['config'].scale_step) + |
| 24 | + '_scale_penalty_{:.3f}'.format(param['config'].scale_penalty) + |
| 25 | + '_interp_factor_{:.3f}'.format(param['config'].interp_factor))) |
| 26 | + result_path = join(tracker_path, '{:s}.txt'.format(param['video'])) |
| 27 | + if isfile(result_path): |
| 28 | + return |
| 29 | + if not isdir(tracker_path): makedirs(tracker_path) |
| 30 | + with open(result_path, 'w') as f: # Occupation |
| 31 | + for x in regions: |
| 32 | + f.write('') |
| 33 | + |
| 34 | + ims = param['ims'] |
| 35 | + toc = 0 |
| 36 | + for f, im in enumerate(ims): |
| 37 | + tic = cv2.getTickCount() |
| 38 | + if f == 0: # init |
| 39 | + init_rect = p['init_rect'] |
| 40 | + tracker = DCFNetTraker(ims[f], init_rect, config=param['config']) |
| 41 | + regions.append(init_rect) |
| 42 | + else: # tracking |
| 43 | + rect = tracker.track(ims[f]) |
| 44 | + regions.append(rect) |
| 45 | + toc += cv2.getTickCount() - tic |
| 46 | + |
| 47 | + if args.visualization: # visualization (skip lost frame) |
| 48 | + if f == 0: cv2.destroyAllWindows() |
| 49 | + location = [int(l) for l in location] # int |
| 50 | + cv2.rectangle(im, (location[0], location[1]), (location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3) |
| 51 | + cv2.putText(im, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) |
| 52 | + |
| 53 | + cv2.imshow(video, im) |
| 54 | + cv2.waitKey(1) |
| 55 | + toc /= cv2.getTickFrequency() |
| 56 | + print('{:2d} Video: {:12s} Time: {:2.1f}s Speed: {:3.1f}fps'.format(v, video, toc, f / toc)) |
| 57 | + regions = np.array(regions) |
| 58 | + regions[:,:2] += 1 # 1-index |
| 59 | + with open(result_path, 'w') as f: |
| 60 | + for x in regions: |
| 61 | + f.write(','.join(['{:.2f}'.format(i) for i in x]) + '\n') |
| 62 | + |
| 63 | + |
| 64 | +params = {'dataset':['OTB2015'], 'network':['param.pth'], |
| 65 | + 'scale_step':np.arange(1.01, 1.05, 0.005, np.float32), |
| 66 | + 'scale_penalty':np.arange(0.98, 1.0, 0.025, np.float32), |
| 67 | + 'interp_factor':np.arange(0.001, 0.015, 0.001, np.float32)} |
| 68 | + |
| 69 | +p = dict() |
| 70 | +p['config'] = TrackerConfig() |
| 71 | +for network in params['network']: |
| 72 | + p['network_name'] = network |
| 73 | + np.random.shuffle(params['dataset']) |
| 74 | + for dataset in params['dataset']: |
| 75 | + base_path = join('dataset', dataset) |
| 76 | + json_path = join('dataset', dataset+'.json') |
| 77 | + annos = json.load(open(json_path, 'r')) |
| 78 | + videos = annos.keys() |
| 79 | + p['dataset'] = dataset |
| 80 | + np.random.shuffle(videos) |
| 81 | + for v, video in enumerate(videos): |
| 82 | + p['v'] = v |
| 83 | + p['video'] = video |
| 84 | + video_path_name = annos[video]['name'] |
| 85 | + init_rect = np.array(annos[video]['init_rect']).astype(np.float) |
| 86 | + image_files = [join(base_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']] |
| 87 | + target_pos = np.array([init_rect[0] + init_rect[2] / 2 -1 , init_rect[1] + init_rect[3] / 2 -1]) # 0-index |
| 88 | + target_sz = np.array([init_rect[2], init_rect[3]]) |
| 89 | + ims = [] |
| 90 | + for image_file in image_files: |
| 91 | + im = cv2.imread(image_file) |
| 92 | + if im.shape[2] == 1: |
| 93 | + cv2.cvtColor(im, im, cv2.COLOR_GRAY2RGB) |
| 94 | + ims.append(im) |
| 95 | + p['ims'] = ims |
| 96 | + p['init_rect'] = init_rect |
| 97 | + |
| 98 | + np.random.shuffle(params['scale_step']) |
| 99 | + np.random.shuffle(params['scale_penalty']) |
| 100 | + np.random.shuffle(params['interp_factor']) |
| 101 | + for scale_step in params['scale_step']: |
| 102 | + for scale_penalty in params['scale_penalty']: |
| 103 | + for interp_factor in params['interp_factor']: |
| 104 | + p['config'].scale_step = float(scale_step) |
| 105 | + p['config'].scale_penalty = float(scale_penalty) |
| 106 | + p['config'].interp_factor = float(interp_factor) |
| 107 | + tune_otb(p) |
0 commit comments