Skip to content

Commit 7fa25e1

Browse files
committed
# Update python scrypt
For multi scale training
1 parent 4e4eb96 commit 7fa25e1

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

darknet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _process_batch(data, size_index):
111111
ious_reshaped = np.reshape(ious, [hw, num_anchors, len(cell_inds)])
112112
for i, cell_ind in enumerate(cell_inds):
113113
if cell_ind >= hw or cell_ind < 0:
114+
print('cell inds size {}'.format(len(cell_inds)))
114115
print('cell over {} hw {}'.format(cell_ind, hw))
115116
continue
116117
a = anchor_inds[i]
@@ -168,6 +169,7 @@ def __init__(self):
168169
# linear
169170
out_channels = cfg.num_anchors * (cfg.num_classes + 5)
170171
self.conv5 = net_utils.Conv2d(c4, out_channels, 1, 1, relu=False)
172+
self.global_average_pool = nn.AvgPool2d((1,1))
171173

172174
# train
173175
self.bbox_loss = None
@@ -187,6 +189,7 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None, size_i
187189
cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)
188190
conv4 = self.conv4(cat_1_3)
189191
conv5 = self.conv5(conv4) # batch_size, out_channels, h, w
192+
conv5 = self.global_average_pool(conv5)
190193

191194
# for detection
192195
# bsize, c, h, w -> bsize, h, w, c -> bsize, h x w, num_anchors, 5+num_classes

datasets/imdb.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import PIL
33
import numpy as np
44
from multiprocessing import Pool
5+
from functools import partial
6+
import cfgs.config as cfg
7+
import cv2
58

69

710
def mkdir(path, max_depth=3):
@@ -13,6 +16,12 @@ def mkdir(path, max_depth=3):
1316
os.mkdir(path)
1417

1518

19+
def image_resize(im, size_index):
20+
w, h = cfg.multi_scale_inp_size[size_index]
21+
im = cv2.resize(im, (w, h))
22+
return im
23+
24+
1625
class ImageDataset(object):
1726
def __init__(self, name, datadir, batch_size, im_processor, processes=3, shuffle=True, dst_size=None):
1827
self._name = name
@@ -38,12 +47,13 @@ def __init__(self, name, datadir, batch_size, im_processor, processes=3, shuffle
3847
self.gen = None
3948
self._im_processor = im_processor
4049

41-
def next_batch(self):
50+
def next_batch(self, size_index):
4251
batch = {'images': [], 'gt_boxes': [], 'gt_classes': [], 'dontcare': [], 'origin_im': []}
4352
i = 0
4453
while i < self.batch_size:
4554
try:
4655
images, gt_boxes, classes, dontcare, origin_im = next(self.gen)
56+
images = image_resize(images, size_index)
4757
batch['images'].append(images)
4858
batch['gt_boxes'].append(gt_boxes)
4959
batch['gt_classes'].append(classes)
@@ -54,7 +64,7 @@ def next_batch(self):
5464
indexes = np.arange(len(self.image_names), dtype=np.int)
5565
if self._shuffle:
5666
np.random.shuffle(indexes)
57-
self.gen = self.pool.imap(self._im_processor,
67+
self.gen = self.pool.imap(partial(self._im_processor, size_index=size_index),
5868
([self.image_names[i], self.get_annotation(i), self.dst_size] for i in indexes),
5969
chunksize=self.batch_size)
6070
self._epoch += 1

0 commit comments

Comments
 (0)