Skip to content

Commit edda7d4

Browse files
committed
fix bugs
1 parent 84d05e3 commit edda7d4

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@ YOLO9000: Better, Faster, Stronger by Joseph Redmon and Ali Farhadi.
1010

1111
I used a Cython extension for postprocessing and
1212
`multiprocessing.Pool` for image preprocessing.
13-
Testing an image in VOC2007 costs about 13~20ms.
13+
Testing an image in VOC2007 costs about 13~20ms.
14+
15+
**NOTE:**
16+
This is still an experimental project.
17+
VOC07 test mAP is about 0.71 (trained on VOC07+12 trainval,
18+
reported by [@cory8249](https://github.com/longcw/yolo2-pytorch/issues/23)).
19+
See https://github.com/longcw/yolo2-pytorch/issues/1 and https://github.com/longcw/yolo2-pytorch/issues/23
20+
for more details about training.
21+
22+
BTW, I recommend to write your own dataloader using [torch.utils.data.Dataset](http://pytorch.org/docs/data.html)
23+
since `multiprocessing.Pool.imap` won't stop even there is no enough memory space.
24+
25+
1426

1527
### Installation and demo
1628
1. Clone this repository

darknet.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _process_batch(data):
3737
inp_size = cfg.inp_size
3838
out_size = cfg.out_size
3939

40-
bbox_pred_np, gt_boxes, gt_classes, dontcares = data
40+
bbox_pred_np, gt_boxes, gt_classes, dontcares, iou_pred_np = data
4141

4242
# net output
4343
hw, num_anchors, _ = bbox_pred_np.shape
@@ -61,21 +61,22 @@ def _process_batch(data):
6161
np.ascontiguousarray(bbox_pred_np, dtype=np.float),
6262
anchors,
6363
H, W)
64-
bbox_np = bbox_np[0]
65-
bbox_np[:, :, 0::2] *= float(inp_size[0])
66-
bbox_np[:, :, 1::2] *= float(inp_size[1])
64+
bbox_np = bbox_np[0] # bbox_np = (hw, num_anchors, (x1, y1, x2, y2)) range: 0 ~ 1
65+
bbox_np[:, :, 0::2] *= float(inp_size[0]) # rescale x
66+
bbox_np[:, :, 1::2] *= float(inp_size[1]) # rescale y
6767

6868
# gt_boxes_b = np.asarray(gt_boxes[b], dtype=np.float)
6969
gt_boxes_b = np.asarray(gt_boxes, dtype=np.float)
7070

71-
# for each cell
71+
# for each cell, compare predicted_bbox and gt_bbox
7272
bbox_np_b = np.reshape(bbox_np, [-1, 4])
7373
ious = bbox_ious(
7474
np.ascontiguousarray(bbox_np_b, dtype=np.float),
7575
np.ascontiguousarray(gt_boxes_b, dtype=np.float)
7676
)
7777
best_ious = np.max(ious, axis=1).reshape(_iou_mask.shape)
78-
_iou_mask[best_ious <= cfg.iou_thresh] = cfg.noobject_scale
78+
iou_penalty = 0 - iou_pred_np[best_ious < cfg.iou_thresh]
79+
_iou_mask[best_ious <= cfg.iou_thresh] = cfg.noobject_scale * iou_penalty
7980

8081
# locate the cell of each gt_boxe
8182
cell_w = float(inp_size[0]) / W
@@ -108,7 +109,8 @@ def _process_batch(data):
108109
continue
109110
a = anchor_inds[i]
110111

111-
_iou_mask[cell_ind, a, :] = cfg.object_scale
112+
iou_pred_cell_anchor = iou_pred_np[cell_ind, a, :] # 0 ~ 1, should be close to 1
113+
_iou_mask[cell_ind, a, :] = cfg.object_scale * (1 - iou_pred_cell_anchor)
112114
# _ious[cell_ind, a, :] = anchor_ious[a, i]
113115
_ious[cell_ind, a, :] = ious_reshaped[cell_ind, a, i]
114116

@@ -119,8 +121,8 @@ def _process_batch(data):
119121
_class_mask[cell_ind, a, :] = cfg.class_scale
120122
_classes[cell_ind, a, gt_classes[i]] = 1.
121123

122-
_boxes[:, :, 2:4] = np.maximum(_boxes[:, :, 2:4], 0.001)
123-
_boxes[:, :, 2:4] = np.log(_boxes[:, :, 2:4])
124+
# _boxes[:, :, 2:4] = np.maximum(_boxes[:, :, 2:4], 0.001)
125+
# _boxes[:, :, 2:4] = np.log(_boxes[:, :, 2:4])
124126

125127
return _boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask
126128

@@ -172,14 +174,10 @@ def loss(self):
172174

173175
def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
174176
conv1s = self.conv1s(im_data)
175-
176177
conv2 = self.conv2(conv1s)
177-
178178
conv3 = self.conv3(conv2)
179-
180179
conv1s_reorg = self.reorg(conv1s)
181180
cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)
182-
183181
conv4 = self.conv4(cat_1_3)
184182
conv5 = self.conv5(conv4) # batch_size, out_channels, h, w
185183

@@ -191,11 +189,8 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
191189

192190
# tx, ty, tw, th, to -> sig(tx), sig(ty), exp(tw), exp(th), sig(to)
193191
xy_pred = F.sigmoid(conv5_reshaped[:, :, :, 0:2])
194-
195-
wh_pred = conv5_reshaped[:, :, :, 2:4]
196-
wh_pred_exp = torch.exp(wh_pred)
197-
bbox_pred = torch.cat([xy_pred, wh_pred_exp], 3)
198-
192+
wh_pred = torch.exp(conv5_reshaped[:, :, :, 2:4])
193+
bbox_pred = torch.cat([xy_pred, wh_pred], 3)
199194
iou_pred = F.sigmoid(conv5_reshaped[:, :, :, 4:5])
200195

201196
score_pred = conv5_reshaped[:, :, :, 5:].contiguous()
@@ -204,8 +199,9 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
204199
# for training
205200
if self.training:
206201
bbox_pred_np = bbox_pred.data.cpu().numpy()
202+
iou_pred_np = iou_pred.data.cpu().numpy()
207203
_boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask = self._build_target(
208-
bbox_pred_np, gt_boxes, gt_classes, dontcare)
204+
bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np)
209205

210206
_boxes = net_utils.np_to_variable(_boxes)
211207
_ious = net_utils.np_to_variable(_ious)
@@ -218,30 +214,23 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
218214

219215
# _boxes[:, :, :, 2:4] = torch.log(_boxes[:, :, :, 2:4])
220216
box_mask = box_mask.expand_as(_boxes)
221-
# self.bbox_loss = torch.sum(torch.pow(_boxes - bbox_pred, 2) * box_mask) / num_boxes
222-
bbox_pred_log = torch.cat([xy_pred, wh_pred], 3)
223-
self.bbox_loss = nn.MSELoss(size_average=False)(bbox_pred_log * box_mask, _boxes * box_mask) / num_boxes
224217

225-
# self.iou_loss = torch.sum(torch.pow(_ious - iou_pred, 2) * iou_mask) / num_boxes
218+
self.bbox_loss = nn.MSELoss(size_average=False)(bbox_pred * box_mask, _boxes * box_mask) / num_boxes
226219
self.iou_loss = nn.MSELoss(size_average=False)(iou_pred * iou_mask, _ious * iou_mask) / num_boxes
227220

228221
class_mask = class_mask.expand_as(prob_pred)
229-
# self.cls_loss = torch.sum(torch.pow(_classes - prob_pred, 2) * class_mask) / num_boxes
230222
self.cls_loss = nn.MSELoss(size_average=False)(prob_pred * class_mask, _classes * class_mask) / num_boxes
231223

232-
# wh_pred = torch.exp(conv5_reshaped[:, :, :, 2:4])
233-
# bbox_pred = torch.cat([xy_pred, wh_pred], 3)
234-
235224
return bbox_pred, iou_pred, prob_pred
236225

237-
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare):
226+
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np):
238227
"""
239228
:param bbox_pred: shape: (bsize, h x w, num_anchors, 4) : (sig(tx), sig(ty), exp(tw), exp(th))
240229
"""
241230

242231
bsize = bbox_pred_np.shape[0]
243232

244-
targets = self.pool.map(_process_batch, ((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b]) for b in range(bsize)))
233+
targets = self.pool.map(_process_batch, ((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b], iou_pred_np[b]) for b in range(bsize)))
245234

246235
_boxes = np.stack(tuple((row[0] for row in targets)))
247236
_ious = np.stack(tuple((row[1] for row in targets)))
@@ -250,7 +239,6 @@ def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare):
250239
_iou_mask = np.stack(tuple((row[4] for row in targets)))
251240
_class_mask = np.stack(tuple((row[5] for row in targets)))
252241

253-
254242
return _boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask
255243

256244
def load_from_npz(self, fname, num_conv=None):

0 commit comments

Comments
 (0)