Skip to content

Commit 5fbb892

Browse files
authored
Merge pull request longcw#42 from SnowMasaya/multi_scale
Multi scale Training
2 parents 0f6def2 + 7268683 commit 5fbb892

File tree

25 files changed

+650
-395
lines changed

25 files changed

+650
-395
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,6 @@ ENV/
100100
models
101101
models/*
102102
data/*
103+
104+
VOCdevkit/*
105+
src/*

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,18 @@ and set the path in `yolo2-pytorch/cfgs/exps/darknet19_exp1.py`.
8282
7. (optional) Training with TensorBoard.
8383

8484
To use the TensorBoard, install Crayon (https://github.com/torrvision/crayon)
85+
How to use the crayon
86+
```
87+
docker pull alband/crayon
88+
docker run -d -p 8888:8888 -p 8889:8889 --name crayon alband/crayon
89+
```
90+
8591
and set `use_tensorboard = True` in `yolo2-pytorch/cfgs/config.py`.
8692

8793

8894
6. Run the training program: `python train.py`.
8995

9096

91-
9297
### Evaluation
9398

9499
Set the path of the `trained_model` in `yolo2-pytorch/cfgs/config.py`.

cfgs/config.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
2-
from config_voc import *
3-
from exps.darknet19_exp1 import *
2+
from .config_voc import * # noqa
3+
from .exps.darknet19_exp1 import * # noqa
44

55

66
def mkdir(path, max_depth=3):
@@ -14,6 +14,28 @@ def mkdir(path, max_depth=3):
1414

1515
# input and output size
1616
############################
17+
multi_scale_inp_size = [np.array([320, 320], dtype=np.int),
18+
np.array([352, 352], dtype=np.int),
19+
np.array([384, 384], dtype=np.int),
20+
np.array([416, 416], dtype=np.int),
21+
np.array([448, 448], dtype=np.int),
22+
np.array([480, 480], dtype=np.int),
23+
np.array([512, 512], dtype=np.int),
24+
np.array([544, 544], dtype=np.int),
25+
np.array([576, 576], dtype=np.int),
26+
# np.array([608, 608], dtype=np.int),
27+
] # w, h
28+
multi_scale_out_size = [multi_scale_inp_size[0] / 32,
29+
multi_scale_inp_size[1] / 32,
30+
multi_scale_inp_size[2] / 32,
31+
multi_scale_inp_size[3] / 32,
32+
multi_scale_inp_size[4] / 32,
33+
multi_scale_inp_size[5] / 32,
34+
multi_scale_inp_size[6] / 32,
35+
multi_scale_inp_size[7] / 32,
36+
multi_scale_inp_size[8] / 32,
37+
# multi_scale_inp_size[9] / 32,
38+
] # w, h
1739
inp_size = np.array([416, 416], dtype=np.int) # w, h
1840
out_size = inp_size / 32
1941

@@ -28,6 +50,7 @@ def _to_color(indx, base):
2850
g = 2 - (indx % base2) % base
2951
return b * 127, r * 127, g * 127
3052

53+
3154
base = int(np.ceil(pow(num_classes, 1. / 3)))
3255
colors = [_to_color(x, base) for x in range(num_classes)]
3356

cfgs/config_voc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
'sheep', 'sofa', 'train', 'tvmonitor')
1313
num_classes = len(label_names)
1414

15-
anchors = np.asarray([(1.08, 1.19), (3.42, 4.41), (6.63, 11.38), (9.42, 5.11), (16.62, 10.52)], dtype=np.float)
15+
anchors = np.asarray([(1.08, 1.19), (3.42, 4.41),
16+
(6.63, 11.38), (9.42, 5.11), (16.62, 10.52)],
17+
dtype=np.float)
1618
num_anchors = len(anchors)
17-

darknet.py

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import utils.network as net_utils
77
import cfgs.config as cfg
88
from layers.reorg.reorg_layer import ReorgLayer
9-
from utils.cython_bbox import bbox_ious, bbox_intersections, bbox_overlaps, anchor_intersections
9+
from utils.cython_bbox import bbox_ious, anchor_intersections
1010
from utils.cython_yolo import yolo_to_bbox
11+
from functools import partial
1112

1213
from multiprocessing import Pool
1314

@@ -25,17 +26,21 @@ def _make_layers(in_channels, net_cfg):
2526
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
2627
else:
2728
out_channels, ksize = item
28-
layers.append(net_utils.Conv2d_BatchNorm(in_channels, out_channels, ksize, same_padding=True))
29-
# layers.append(net_utils.Conv2d(in_channels, out_channels, ksize, same_padding=True))
29+
layers.append(net_utils.Conv2d_BatchNorm(in_channels,
30+
out_channels,
31+
ksize,
32+
same_padding=True))
33+
# layers.append(net_utils.Conv2d(in_channels, out_channels,
34+
# ksize, same_padding=True))
3035
in_channels = out_channels
3136

3237
return nn.Sequential(*layers), in_channels
3338

3439

35-
def _process_batch(data):
36-
W, H = cfg.out_size
37-
inp_size = cfg.inp_size
38-
out_size = cfg.out_size
40+
def _process_batch(data, size_index):
41+
W, H = cfg.multi_scale_out_size[size_index]
42+
inp_size = cfg.multi_scale_inp_size[size_index]
43+
out_size = cfg.multi_scale_out_size[size_index]
3944

4045
bbox_pred_np, gt_boxes, gt_classes, dontcares, iou_pred_np = data
4146

@@ -61,7 +66,8 @@ def _process_batch(data):
6166
np.ascontiguousarray(bbox_pred_np, dtype=np.float),
6267
anchors,
6368
H, W)
64-
bbox_np = bbox_np[0] # bbox_np = (hw, num_anchors, (x1, y1, x2, y2)) range: 0 ~ 1
69+
# bbox_np = (hw, num_anchors, (x1, y1, x2, y2)) range: 0 ~ 1
70+
bbox_np = bbox_np[0]
6571
bbox_np[:, :, 0::2] *= float(inp_size[0]) # rescale x
6672
bbox_np[:, :, 1::2] *= float(inp_size[1]) # rescale y
6773

@@ -89,8 +95,10 @@ def _process_batch(data):
8995
target_boxes = np.empty(gt_boxes_b.shape, dtype=np.float)
9096
target_boxes[:, 0] = cx - np.floor(cx) # cx
9197
target_boxes[:, 1] = cy - np.floor(cy) # cy
92-
target_boxes[:, 2] = (gt_boxes_b[:, 2] - gt_boxes_b[:, 0]) / inp_size[0] * out_size[0] # tw
93-
target_boxes[:, 3] = (gt_boxes_b[:, 3] - gt_boxes_b[:, 1]) / inp_size[1] * out_size[1] # th
98+
target_boxes[:, 2] = \
99+
(gt_boxes_b[:, 2] - gt_boxes_b[:, 0]) / inp_size[0] * out_size[0] # tw
100+
target_boxes[:, 3] = \
101+
(gt_boxes_b[:, 3] - gt_boxes_b[:, 1]) / inp_size[1] * out_size[1] # th
94102

95103
# for each gt boxes, match the best anchor
96104
gt_boxes_resize = np.copy(gt_boxes_b)
@@ -105,12 +113,14 @@ def _process_batch(data):
105113
ious_reshaped = np.reshape(ious, [hw, num_anchors, len(cell_inds)])
106114
for i, cell_ind in enumerate(cell_inds):
107115
if cell_ind >= hw or cell_ind < 0:
108-
print cell_ind
116+
print('cell inds size {}'.format(len(cell_inds)))
117+
print('cell over {} hw {}'.format(cell_ind, hw))
109118
continue
110119
a = anchor_inds[i]
111120

112-
iou_pred_cell_anchor = iou_pred_np[cell_ind, a, :] # 0 ~ 1, should be close to 1
113-
_iou_mask[cell_ind, a, :] = cfg.object_scale * (1 - iou_pred_cell_anchor)
121+
# 0 ~ 1, should be close to 1
122+
iou_pred_cell_anchor = iou_pred_np[cell_ind, a, :]
123+
_iou_mask[cell_ind, a, :] = cfg.object_scale * (1 - iou_pred_cell_anchor) # noqa
114124
# _ious[cell_ind, a, :] = anchor_ious[a, i]
115125
_ious[cell_ind, a, :] = ious_reshaped[cell_ind, a, i]
116126

@@ -154,13 +164,15 @@ def __init__(self):
154164
self.conv3, c3 = _make_layers(c2, net_cfgs[6])
155165

156166
stride = 2
157-
self.reorg = ReorgLayer(stride=2) # stride*stride times the channels of conv1s
167+
# stride*stride times the channels of conv1s
168+
self.reorg = ReorgLayer(stride=2)
158169
# cat [conv1s, conv3]
159170
self.conv4, c4 = _make_layers((c1*(stride*stride) + c3), net_cfgs[7])
160171

161172
# linear
162173
out_channels = cfg.num_anchors * (cfg.num_classes + 5)
163174
self.conv5 = net_utils.Conv2d(c4, out_channels, 1, 1, relu=False)
175+
self.global_average_pool = nn.AvgPool2d((1, 1))
164176

165177
# train
166178
self.bbox_loss = None
@@ -172,65 +184,83 @@ def __init__(self):
172184
def loss(self):
173185
return self.bbox_loss + self.iou_loss + self.cls_loss
174186

175-
def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
187+
def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None,
188+
size_index=0):
176189
conv1s = self.conv1s(im_data)
177190
conv2 = self.conv2(conv1s)
178191
conv3 = self.conv3(conv2)
179192
conv1s_reorg = self.reorg(conv1s)
180193
cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)
181194
conv4 = self.conv4(cat_1_3)
182195
conv5 = self.conv5(conv4) # batch_size, out_channels, h, w
196+
global_average_pool = self.global_average_pool(conv5)
183197

184198
# for detection
185-
# bsize, c, h, w -> bsize, h, w, c -> bsize, h x w, num_anchors, 5+num_classes
186-
bsize, _, h, w = conv5.size()
199+
# bsize, c, h, w -> bsize, h, w, c ->
200+
# bsize, h x w, num_anchors, 5+num_classes
201+
bsize, _, h, w = global_average_pool.size()
187202
# assert bsize == 1, 'detection only support one image per batch'
188-
conv5_reshaped = conv5.permute(0, 2, 3, 1).contiguous().view(bsize, -1, cfg.num_anchors, cfg.num_classes + 5)
203+
global_average_pool_reshaped = \
204+
global_average_pool.permute(0, 2, 3, 1).contiguous().view(bsize,
205+
-1, cfg.num_anchors, cfg.num_classes + 5) # noqa
189206

190207
# tx, ty, tw, th, to -> sig(tx), sig(ty), exp(tw), exp(th), sig(to)
191-
xy_pred = F.sigmoid(conv5_reshaped[:, :, :, 0:2])
192-
wh_pred = torch.exp(conv5_reshaped[:, :, :, 2:4])
208+
xy_pred = F.sigmoid(global_average_pool_reshaped[:, :, :, 0:2])
209+
wh_pred = torch.exp(global_average_pool_reshaped[:, :, :, 2:4])
193210
bbox_pred = torch.cat([xy_pred, wh_pred], 3)
194-
iou_pred = F.sigmoid(conv5_reshaped[:, :, :, 4:5])
211+
iou_pred = F.sigmoid(global_average_pool_reshaped[:, :, :, 4:5])
195212

196-
score_pred = conv5_reshaped[:, :, :, 5:].contiguous()
197-
prob_pred = F.softmax(score_pred.view(-1, score_pred.size()[-1])).view_as(score_pred)
213+
score_pred = global_average_pool_reshaped[:, :, :, 5:].contiguous()
214+
prob_pred = F.softmax(score_pred.view(-1, score_pred.size()[-1])).view_as(score_pred) # noqa
198215

199216
# for training
200217
if self.training:
201218
bbox_pred_np = bbox_pred.data.cpu().numpy()
202219
iou_pred_np = iou_pred.data.cpu().numpy()
203-
_boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask = self._build_target(
204-
bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np)
220+
_boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask = \
221+
self._build_target(bbox_pred_np,
222+
gt_boxes,
223+
gt_classes,
224+
dontcare,
225+
iou_pred_np,
226+
size_index)
205227

206228
_boxes = net_utils.np_to_variable(_boxes)
207229
_ious = net_utils.np_to_variable(_ious)
208230
_classes = net_utils.np_to_variable(_classes)
209-
box_mask = net_utils.np_to_variable(_box_mask, dtype=torch.FloatTensor)
210-
iou_mask = net_utils.np_to_variable(_iou_mask, dtype=torch.FloatTensor)
211-
class_mask = net_utils.np_to_variable(_class_mask, dtype=torch.FloatTensor)
231+
box_mask = net_utils.np_to_variable(_box_mask,
232+
dtype=torch.FloatTensor)
233+
iou_mask = net_utils.np_to_variable(_iou_mask,
234+
dtype=torch.FloatTensor)
235+
class_mask = net_utils.np_to_variable(_class_mask,
236+
dtype=torch.FloatTensor)
212237

213238
num_boxes = sum((len(boxes) for boxes in gt_boxes))
214239

215240
# _boxes[:, :, :, 2:4] = torch.log(_boxes[:, :, :, 2:4])
216241
box_mask = box_mask.expand_as(_boxes)
217242

218-
self.bbox_loss = nn.MSELoss(size_average=False)(bbox_pred * box_mask, _boxes * box_mask) / num_boxes
219-
self.iou_loss = nn.MSELoss(size_average=False)(iou_pred * iou_mask, _ious * iou_mask) / num_boxes
243+
self.bbox_loss = nn.MSELoss(size_average=False)(bbox_pred * box_mask, _boxes * box_mask) / num_boxes # noqa
244+
self.iou_loss = nn.MSELoss(size_average=False)(iou_pred * iou_mask, _ious * iou_mask) / num_boxes # noqa
220245

221246
class_mask = class_mask.expand_as(prob_pred)
222-
self.cls_loss = nn.MSELoss(size_average=False)(prob_pred * class_mask, _classes * class_mask) / num_boxes
247+
self.cls_loss = nn.MSELoss(size_average=False)(prob_pred * class_mask, _classes * class_mask) / num_boxes # noqa
223248

224249
return bbox_pred, iou_pred, prob_pred
225250

226-
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np):
251+
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare,
252+
iou_pred_np, size_index):
227253
"""
228-
:param bbox_pred: shape: (bsize, h x w, num_anchors, 4) : (sig(tx), sig(ty), exp(tw), exp(th))
254+
:param bbox_pred: shape: (bsize, h x w, num_anchors, 4) :
255+
(sig(tx), sig(ty), exp(tw), exp(th))
229256
"""
230257

231258
bsize = bbox_pred_np.shape[0]
232259

233-
targets = self.pool.map(_process_batch, ((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b], iou_pred_np[b]) for b in range(bsize)))
260+
targets = self.pool.map(partial(_process_batch, size_index=size_index),
261+
((bbox_pred_np[b], gt_boxes[b],
262+
gt_classes[b], dontcare[b], iou_pred_np[b])
263+
for b in range(bsize)))
234264

235265
_boxes = np.stack(tuple((row[0] for row in targets)))
236266
_ious = np.stack(tuple((row[1] for row in targets)))
@@ -244,27 +274,28 @@ def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_n
244274
def load_from_npz(self, fname, num_conv=None):
245275
dest_src = {'conv.weight': 'kernel', 'conv.bias': 'biases',
246276
'bn.weight': 'gamma', 'bn.bias': 'biases',
247-
'bn.running_mean': 'moving_mean', 'bn.running_var': 'moving_variance'}
277+
'bn.running_mean': 'moving_mean',
278+
'bn.running_var': 'moving_variance'}
248279
params = np.load(fname)
249280
own_dict = self.state_dict()
250-
keys = own_dict.keys()
281+
keys = list(own_dict.keys())
251282

252283
for i, start in enumerate(range(0, len(keys), 5)):
253-
if num_conv is not None and i>= num_conv:
284+
if num_conv is not None and i >= num_conv:
254285
break
255286
end = min(start+5, len(keys))
256287
for key in keys[start:end]:
257288
list_key = key.split('.')
258289
ptype = dest_src['{}.{}'.format(list_key[-2], list_key[-1])]
259290
src_key = '{}-convolutional/{}:0'.format(i, ptype)
260-
print(src_key, own_dict[key].size(), params[src_key].shape)
291+
print((src_key, own_dict[key].size(), params[src_key].shape))
261292
param = torch.from_numpy(params[src_key])
262293
if ptype == 'kernel':
263294
param = param.permute(3, 2, 0, 1)
264295
own_dict[key].copy_(param)
265296

297+
266298
if __name__ == '__main__':
267299
net = Darknet19()
268300
# net.load_from_npz('models/yolo-voc.weights.npz')
269301
net.load_from_npz('models/darknet19.weights.npz', num_conv=18)
270-

datasets/imdb.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2-
import PIL
32
import numpy as np
43
from multiprocessing import Pool
4+
from functools import partial
5+
import cfgs.config as cfg
6+
import cv2
57

68

79
def mkdir(path, max_depth=3):
@@ -13,8 +15,15 @@ def mkdir(path, max_depth=3):
1315
os.mkdir(path)
1416

1517

18+
def image_resize(im, size_index):
19+
w, h = cfg.multi_scale_inp_size[size_index]
20+
im = cv2.resize(im, (w, h))
21+
return im
22+
23+
1624
class ImageDataset(object):
17-
def __init__(self, name, datadir, batch_size, im_processor, processes=3, shuffle=True, dst_size=None):
25+
def __init__(self, name, datadir, batch_size, im_processor,
26+
processes=3, shuffle=True, dst_size=None):
1827
self._name = name
1928
self._data_dir = datadir
2029
self._batch_size = batch_size
@@ -38,29 +47,33 @@ def __init__(self, name, datadir, batch_size, im_processor, processes=3, shuffle
3847
self.gen = None
3948
self._im_processor = im_processor
4049

41-
def next_batch(self):
42-
batch = {'images': [], 'gt_boxes': [], 'gt_classes': [], 'dontcare': [], 'origin_im': []}
50+
def next_batch(self, size_index):
51+
batch = {'images': [], 'gt_boxes': [], 'gt_classes': [],
52+
'dontcare': [], 'origin_im': []}
4353
i = 0
4454
while i < self.batch_size:
4555
try:
46-
images, gt_boxes, classes, dontcare, origin_im = self.gen.next()
56+
images, gt_boxes, classes, dontcare, origin_im = next(self.gen)
57+
images = image_resize(images, size_index)
4758
batch['images'].append(images)
4859
batch['gt_boxes'].append(gt_boxes)
4960
batch['gt_classes'].append(classes)
5061
batch['dontcare'].append(dontcare)
5162
batch['origin_im'].append(origin_im)
5263
i += 1
53-
except (StopIteration, AttributeError):
64+
except (StopIteration, AttributeError, TypeError):
5465
indexes = np.arange(len(self.image_names), dtype=np.int)
5566
if self._shuffle:
5667
np.random.shuffle(indexes)
57-
self.gen = self.pool.imap(self._im_processor,
58-
([self.image_names[i], self.get_annotation(i), self.dst_size] for i in indexes),
68+
self.gen = self.pool.imap(partial(self._im_processor,
69+
size_index=size_index),
70+
([self.image_names[i],
71+
self.get_annotation(i),
72+
self.dst_size] for i in indexes),
5973
chunksize=self.batch_size)
6074
self._epoch += 1
61-
print('epoch {} start...'.format(self._epoch))
75+
print(('epoch {} start...'.format(self._epoch)))
6276
batch['images'] = np.asarray(batch['images'])
63-
6477
return batch
6578

6679
def close(self):
@@ -132,5 +145,3 @@ def batch_size(self):
132145
@property
133146
def batch_per_epoch(self):
134147
return self.num_images // self.batch_size
135-
136-

0 commit comments

Comments
 (0)