Skip to content

Commit 4e4eb96

Browse files
committed
# Update python scrypt
For multi scale training
1 parent dc6b10e commit 4e4eb96

File tree

3 files changed

+75
-30
lines changed

3 files changed

+75
-30
lines changed

cfgs/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@ def mkdir(path, max_depth=3):
1414

1515
# input and output size
1616
############################
17+
multi_scale_inp_size = [np.array([320, 320], dtype=np.int),
18+
np.array([352, 352], dtype=np.int),
19+
np.array([384, 384], dtype=np.int),
20+
np.array([416, 416], dtype=np.int),
21+
np.array([448, 448], dtype=np.int),
22+
np.array([480, 480], dtype=np.int),
23+
np.array([512, 512], dtype=np.int),
24+
np.array([544, 544], dtype=np.int),
25+
np.array([576, 576], dtype=np.int),
26+
# np.array([608, 608], dtype=np.int),
27+
] # w, h
28+
multi_scale_out_size = [multi_scale_inp_size[0] / 32,
29+
multi_scale_inp_size[1] / 32,
30+
multi_scale_inp_size[2] / 32,
31+
multi_scale_inp_size[3] / 32,
32+
multi_scale_inp_size[4] / 32,
33+
multi_scale_inp_size[5] / 32,
34+
multi_scale_inp_size[6] / 32,
35+
multi_scale_inp_size[7] / 32,
36+
multi_scale_inp_size[8] / 32,
37+
# multi_scale_inp_size[9] / 32,
38+
] # w, h
1739
inp_size = np.array([416, 416], dtype=np.int) # w, h
1840
out_size = inp_size / 32
1941

darknet.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import utils.network as net_utils
77
import cfgs.config as cfg
88
from layers.reorg.reorg_layer import ReorgLayer
9-
from utils.cython_bbox import bbox_ious, bbox_intersections, bbox_overlaps, anchor_intersections
9+
from utils.cython_bbox import bbox_ious, anchor_intersections
1010
from utils.cython_yolo import yolo_to_bbox
11+
from functools import partial
1112

1213
from multiprocessing import Pool
14+
import multiprocessing
1315

1416

1517
def _make_layers(in_channels, net_cfg):
@@ -25,17 +27,21 @@ def _make_layers(in_channels, net_cfg):
2527
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
2628
else:
2729
out_channels, ksize = item
28-
layers.append(net_utils.Conv2d_BatchNorm(in_channels, out_channels, ksize, same_padding=True))
29-
# layers.append(net_utils.Conv2d(in_channels, out_channels, ksize, same_padding=True))
30+
layers.append(net_utils.Conv2d_BatchNorm(in_channels,
31+
out_channels,
32+
ksize,
33+
same_padding=True))
34+
# layers.append(net_utils.Conv2d(in_channels, out_channels,
35+
# ksize, same_padding=True))
3036
in_channels = out_channels
3137

3238
return nn.Sequential(*layers), in_channels
3339

3440

35-
def _process_batch(data):
36-
W, H = cfg.out_size
37-
inp_size = cfg.inp_size
38-
out_size = cfg.out_size
41+
def _process_batch(data, size_index):
42+
W, H = cfg.multi_scale_out_size[size_index]
43+
inp_size = cfg.multi_scale_inp_size[size_index]
44+
out_size = cfg.multi_scale_out_size[size_index]
3945

4046
bbox_pred_np, gt_boxes, gt_classes, dontcares, iou_pred_np = data
4147

@@ -105,7 +111,7 @@ def _process_batch(data):
105111
ious_reshaped = np.reshape(ious, [hw, num_anchors, len(cell_inds)])
106112
for i, cell_ind in enumerate(cell_inds):
107113
if cell_ind >= hw or cell_ind < 0:
108-
print(cell_ind)
114+
print('cell over {} hw {}'.format(cell_ind, hw))
109115
continue
110116
a = anchor_inds[i]
111117

@@ -154,7 +160,8 @@ def __init__(self):
154160
self.conv3, c3 = _make_layers(c2, net_cfgs[6])
155161

156162
stride = 2
157-
self.reorg = ReorgLayer(stride=2) # stride*stride times the channels of conv1s
163+
# stride*stride times the channels of conv1s
164+
self.reorg = ReorgLayer(stride=2)
158165
# cat [conv1s, conv3]
159166
self.conv4, c4 = _make_layers((c1*(stride*stride) + c3), net_cfgs[7])
160167

@@ -172,7 +179,7 @@ def __init__(self):
172179
def loss(self):
173180
return self.bbox_loss + self.iou_loss + self.cls_loss
174181

175-
def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
182+
def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None, size_index=0):
176183
conv1s = self.conv1s(im_data)
177184
conv2 = self.conv2(conv1s)
178185
conv3 = self.conv3(conv2)
@@ -201,7 +208,7 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
201208
bbox_pred_np = bbox_pred.data.cpu().numpy()
202209
iou_pred_np = iou_pred.data.cpu().numpy()
203210
_boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask = self._build_target(
204-
bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np)
211+
bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np, size_index)
205212

206213
_boxes = net_utils.np_to_variable(_boxes)
207214
_ious = net_utils.np_to_variable(_ious)
@@ -223,14 +230,16 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
223230

224231
return bbox_pred, iou_pred, prob_pred
225232

226-
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np):
233+
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np, size_index):
227234
"""
228235
:param bbox_pred: shape: (bsize, h x w, num_anchors, 4) : (sig(tx), sig(ty), exp(tw), exp(th))
229236
"""
230237

231238
bsize = bbox_pred_np.shape[0]
232239

233-
targets = self.pool.map(_process_batch, ((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b], iou_pred_np[b]) for b in range(bsize)))
240+
targets = self.pool.map(partial(_process_batch, size_index=size_index),
241+
((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b], iou_pred_np[b])
242+
for b in range(bsize)))
234243

235244
_boxes = np.stack(tuple((row[0] for row in targets)))
236245
_ious = np.stack(tuple((row[1] for row in targets)))
@@ -250,7 +259,7 @@ def load_from_npz(self, fname, num_conv=None):
250259
keys = list(own_dict.keys())
251260

252261
for i, start in enumerate(range(0, len(keys), 5)):
253-
if num_conv is not None and i>= num_conv:
262+
if num_conv is not None and i >= num_conv:
254263
break
255264
end = min(start+5, len(keys))
256265
for key in keys[start:end]:
@@ -263,8 +272,8 @@ def load_from_npz(self, fname, num_conv=None):
263272
param = param.permute(3, 2, 0, 1)
264273
own_dict[key].copy_(param)
265274

275+
266276
if __name__ == '__main__':
267277
net = Darknet19()
268278
# net.load_from_npz('models/yolo-voc.weights.npz')
269279
net.load_from_npz('models/darknet19.weights.npz', num_conv=18)
270-

train.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import os
2-
import cv2
32
import torch
4-
import numpy as np
53
import datetime
6-
from torch.multiprocessing import Pool
74

85
from darknet import Darknet19
96

@@ -12,6 +9,7 @@
129
import utils.network as net_utils
1310
from utils.timer import Timer
1411
import cfgs.config as cfg
12+
from random import randint
1513

1614
try:
1715
from pycrayon import CrayonClient
@@ -21,12 +19,15 @@
2119

2220
# data loader
2321
imdb = VOCDataset(cfg.imdb_train, cfg.DATA_DIR, cfg.train_batch_size,
24-
yolo_utils.preprocess_train, processes=2, shuffle=True, dst_size=cfg.inp_size)
22+
yolo_utils.preprocess_train, processes=2, shuffle=True,
23+
dst_size=cfg.multi_scale_inp_size)
24+
# dst_size=cfg.inp_size)
2525
print('load data succ...')
2626

2727
net = Darknet19()
2828
# net_utils.load_net(cfg.trained_model, net)
29-
# pretrained_model = os.path.join(cfg.train_output_dir, 'darknet19_voc07trainval_exp1_63.h5')
29+
# pretrained_model = os.path.join(cfg.train_output_dir,
30+
# 'darknet19_voc07trainval_exp1_63.h5')
3031
# pretrained_model = cfg.trained_model
3132
# net_utils.load_net(pretrained_model, net)
3233
net.load_from_npz(cfg.pretrained_model, num_conv=18)
@@ -37,7 +38,8 @@
3738
# optimizer
3839
start_epoch = 0
3940
lr = cfg.init_learning_rate
40-
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
41+
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=cfg.momentum,
42+
weight_decay=cfg.weight_decay)
4143

4244
# tensorboad
4345
use_tensorboard = cfg.use_tensorboard and CrayonClient is not None
@@ -63,19 +65,23 @@
6365
cnt = 0
6466
t = Timer()
6567
step_cnt = 0
66-
for step in range(start_epoch * imdb.batch_per_epoch, cfg.max_epoch * imdb.batch_per_epoch):
68+
size_index = 0
69+
for step in range(start_epoch * imdb.batch_per_epoch,
70+
cfg.max_epoch * imdb.batch_per_epoch):
6771
t.tic()
6872
# batch
69-
batch = imdb.next_batch()
73+
batch = imdb.next_batch(size_index)
7074
im = batch['images']
7175
gt_boxes = batch['gt_boxes']
7276
gt_classes = batch['gt_classes']
7377
dontcare = batch['dontcare']
7478
orgin_im = batch['origin_im']
7579

7680
# forward
77-
im_data = net_utils.np_to_variable(im, is_cuda=True, volatile=False).permute(0, 3, 1, 2)
78-
net(im_data, gt_boxes, gt_classes, dontcare)
81+
im_data = net_utils.np_to_variable(im,
82+
is_cuda=True,
83+
volatile=False).permute(0, 3, 1, 2)
84+
net(im_data, gt_boxes, gt_classes, dontcare, size_index)
7985

8086
# backward
8187
loss = net.loss
@@ -94,9 +100,12 @@
94100
bbox_loss /= cnt
95101
iou_loss /= cnt
96102
cls_loss /= cnt
97-
print(('epoch %d[%d/%d], loss: %.3f, bbox_loss: %.3f, iou_loss: %.3f, cls_loss: %.3f (%.2f s/batch, rest:%s)' % (
98-
imdb.epoch, step_cnt, batch_per_epoch, train_loss, bbox_loss, iou_loss, cls_loss, duration,
99-
str(datetime.timedelta(seconds=int((batch_per_epoch - step_cnt) * duration))))))
103+
print(('epoch %d[%d/%d], loss: %.3f, bbox_loss: %.3f, iou_loss: %.3f, '
104+
'cls_loss: %.3f (%.2f s/batch, rest:%s)' % (
105+
imdb.epoch, step_cnt, batch_per_epoch, train_loss, bbox_loss,
106+
iou_loss, cls_loss, duration,
107+
str(datetime.timedelta(seconds=int((batch_per_epoch - step_cnt)
108+
* duration))))))
100109

101110
if use_tensorboard and step % cfg.log_interval == 0:
102111
exp.add_scalar_value('loss_train', train_loss, step=step)
@@ -109,13 +118,18 @@
109118
bbox_loss, iou_loss, cls_loss = 0., 0., 0.
110119
cnt = 0
111120
t.clear()
121+
size_index = randint(0, len(cfg.multi_scale_inp_size) - 1)
122+
print("image_size {}".format(cfg.multi_scale_inp_size[size_index]))
112123

113124
if step > 0 and (step % imdb.batch_per_epoch == 0):
114125
if imdb.epoch in cfg.lr_decay_epochs:
115126
lr *= cfg.lr_decay
116-
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
127+
optimizer = torch.optim.SGD(net.parameters(), lr=lr,
128+
momentum=cfg.momentum,
129+
weight_decay=cfg.weight_decay)
117130

118-
save_name = os.path.join(cfg.train_output_dir, '{}_{}.h5'.format(cfg.exp_name, imdb.epoch))
131+
save_name = os.path.join(cfg.train_output_dir,
132+
'{}_{}.h5'.format(cfg.exp_name, imdb.epoch))
119133
net_utils.save_net(save_name, net)
120134
print(('save model: {}'.format(save_name)))
121135
step_cnt = 0

0 commit comments

Comments
 (0)