Skip to content

Commit 7268683

Browse files
committed
# Update
For multi-scale test work correct
1 parent 1b320fa commit 7268683

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

datasets/pascal_voc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def _do_python_eval(self, output_dir='output'):
235235
aps += [ap]
236236
print(('AP for {} = {:.4f}'.format(cls, ap)))
237237
if output_dir is not None:
238-
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
238+
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
239239
pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
240240
print(('Mean AP = {:.4f}'.format(np.mean(aps))))
241241
print('~~~~~~~~')

datasets/voc_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,11 @@ def voc_eval(detpath,
115115
i + 1, len(imagenames)))
116116
# save
117117
print('Saving cached annotations to {:s}'.format(cachefile))
118-
with open(cachefile, 'w') as f:
118+
with open(cachefile, 'wb') as f:
119119
pickle.dump(recs, f)
120120
else:
121121
# load
122-
with open(cachefile, 'r') as f:
122+
with open(cachefile, 'rb') as f:
123123
recs = pickle.load(f)
124124

125125
# extract gt objects for this class

test.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import cv2
33
import numpy as np
44
import pickle
5+
import argparse
56

67
from darknet import Darknet19
78
import utils.yolo as yolo_utils
@@ -18,6 +19,13 @@ def preprocess(fname):
1819
return image, im_data
1920

2021

22+
parser = argparse.ArgumentParser(description='PyTorch Yolo')
23+
parser.add_argument('--image_size_index', type=int, default=0,
24+
metavar='image_size_index',
25+
help='setting images size index 0:320, 1:352, 2:384, 3:416, 4:448, 5:480, 6:512, 7:544, 8:576')
26+
args = parser.parse_args()
27+
28+
2129
# hyper-parameters
2230
# ------------
2331
imdb_name = cfg.imdb_test
@@ -44,10 +52,11 @@ def test_net(net, imdb, max_per_image=300, thresh=0.5, vis=False):
4452
# timers
4553
_t = {'im_detect': Timer(), 'misc': Timer()}
4654
det_file = os.path.join(output_dir, 'detections.pkl')
55+
size_index = args.image_size_index
4756

4857
for i in range(num_images):
4958

50-
batch = imdb.next_batch()
59+
batch = imdb.next_batch(size_index=size_index)
5160
ori_im = batch['origin_im'][0]
5261
im_data = net_utils.np_to_variable(batch['images'], is_cuda=True,
5362
volatile=True).permute(0, 3, 1, 2)
@@ -65,7 +74,9 @@ def test_net(net, imdb, max_per_image=300, thresh=0.5, vis=False):
6574
prob_pred,
6675
ori_im.shape,
6776
cfg,
68-
thresh)
77+
thresh,
78+
size_index
79+
)
6980
detect_time = _t['im_detect'].toc()
7081

7182
_t['misc'].tic()
@@ -122,7 +133,7 @@ def test_net(net, imdb, max_per_image=300, thresh=0.5, vis=False):
122133
# data loader
123134
imdb = VOCDataset(imdb_name, cfg.DATA_DIR, cfg.batch_size,
124135
yolo_utils.preprocess_test,
125-
processes=2, shuffle=False, dst_size=cfg.inp_size)
136+
processes=2, shuffle=False, dst_size=cfg.multi_scale_inp_size)
126137

127138
net = Darknet19()
128139
net_utils.load_net(trained_model, net)

utils/yolo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ def preprocess_train(data, size_index):
7878
return im, boxes, gt_classes, [], ori_im
7979

8080

81-
def preprocess_test(data):
81+
def preprocess_test(data, size_index):
8282

8383
im, _, inp_size = data
84+
inp_size = inp_size[size_index]
8485
if isinstance(im, str):
8586
im = cv2.imread(im)
8687
ori_im = np.copy(im)
@@ -94,7 +95,8 @@ def preprocess_test(data):
9495
return im, [], [], [], ori_im
9596

9697

97-
def postprocess(bbox_pred, iou_pred, prob_pred, im_shape, cfg, thresh=0.05):
98+
def postprocess(bbox_pred, iou_pred, prob_pred, im_shape, cfg, thresh=0.05,
99+
size_index=0):
98100
"""
99101
bbox_pred: (bsize, HxW, num_anchors, 4)
100102
ndarray of float (sig(tx), sig(ty), exp(tw), exp(th))
@@ -105,7 +107,7 @@ def postprocess(bbox_pred, iou_pred, prob_pred, im_shape, cfg, thresh=0.05):
105107
# num_classes, num_anchors = cfg.num_classes, cfg.num_anchors
106108
num_classes = cfg.num_classes
107109
anchors = cfg.anchors
108-
W, H = cfg.out_size
110+
W, H = cfg.multi_scale_out_size[size_index]
109111
assert bbox_pred.shape[0] == 1, 'postprocess only support one image per batch' # noqa
110112

111113
bbox_pred = yolo_to_bbox(

0 commit comments

Comments
 (0)