Skip to content

Commit 8c7ac4b

Browse files
committed
fix for test
1 parent 1ee80e9 commit 8c7ac4b

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

datasets/imdb.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ def next_batch(self, size_index):
6464

6565
# multi-scale
6666
w, h = cfg.multi_scale_inp_size[size_index]
67-
gt_boxes = gt_boxes.astype(float)
68-
gt_boxes[:, 0::2] *= float(w) / images.shape[1]
69-
gt_boxes[:, 1::2] *= float(h) / images.shape[0]
67+
gt_boxes = np.asarray(gt_boxes, dtype=np.float)
68+
if len(gt_boxes) > 0:
69+
gt_boxes[:, 0::2] *= float(w) / images.shape[1]
70+
gt_boxes[:, 1::2] *= float(h) / images.shape[0]
7071
images = cv2.resize(images, (w, h))
7172

7273
batch['images'].append(images)

test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
imdb_name = cfg.imdb_test
2525
# trained_model = cfg.trained_model
2626
trained_model = os.path.join(cfg.train_output_dir,
27-
'darknet19_voc07trainval_exp3_118.h5')
27+
'darknet19_voc07trainval_exp3_73.h5')
2828
output_dir = cfg.test_output_dir
2929

3030
max_per_image = 300
@@ -126,7 +126,7 @@ def test_net(net, imdb, max_per_image=300, thresh=0.5, vis=False):
126126
# data loader
127127
imdb = VOCDataset(imdb_name, cfg.DATA_DIR, cfg.batch_size,
128128
yolo_utils.preprocess_test,
129-
processes=2, shuffle=False, dst_size=cfg.multi_scale_inp_size)
129+
processes=1, shuffle=False, dst_size=cfg.multi_scale_inp_size)
130130

131131
net = Darknet19()
132132
net_utils.load_net(trained_model, net)

utils/yolo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ def preprocess_train(data, size_index):
8888
def preprocess_test(data, size_index):
8989

9090
im, _, inp_size = data
91-
inp_size = inp_size[size_index]
91+
9292
if isinstance(im, str):
9393
im = cv2.imread(im)
9494
ori_im = np.copy(im)
9595

96-
if inp_size is not None:
96+
if inp_size is not None and size_index is not None:
97+
inp_size = inp_size[size_index]
9798
w, h = inp_size
9899
im = cv2.resize(im, (w, h))
99100
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

0 commit comments

Comments
 (0)