Skip to content

Commit 216ad59

Browse files
committed
# Update dark net
For add global average pool
1 parent f49dcea commit 216ad59

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

darknet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,21 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None, size_i
189189
cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)
190190
conv4 = self.conv4(cat_1_3)
191191
conv5 = self.conv5(conv4) # batch_size, out_channels, h, w
192-
conv5 = self.global_average_pool(conv5)
192+
global_average_pool = self.global_average_pool(conv5)
193193

194194
# for detection
195195
# bsize, c, h, w -> bsize, h, w, c -> bsize, h x w, num_anchors, 5+num_classes
196-
bsize, _, h, w = conv5.size()
196+
bsize, _, h, w = global_average_pool.size()
197197
# assert bsize == 1, 'detection only support one image per batch'
198-
conv5_reshaped = conv5.permute(0, 2, 3, 1).contiguous().view(bsize, -1, cfg.num_anchors, cfg.num_classes + 5)
198+
global_average_pool_reshaped = global_average_pool.permute(0, 2, 3, 1).contiguous().view(bsize, -1, cfg.num_anchors, cfg.num_classes + 5)
199199

200200
# tx, ty, tw, th, to -> sig(tx), sig(ty), exp(tw), exp(th), sig(to)
201-
xy_pred = F.sigmoid(conv5_reshaped[:, :, :, 0:2])
202-
wh_pred = torch.exp(conv5_reshaped[:, :, :, 2:4])
201+
xy_pred = F.sigmoid(global_average_pool_reshaped[:, :, :, 0:2])
202+
wh_pred = torch.exp(global_average_pool_reshaped[:, :, :, 2:4])
203203
bbox_pred = torch.cat([xy_pred, wh_pred], 3)
204-
iou_pred = F.sigmoid(conv5_reshaped[:, :, :, 4:5])
204+
iou_pred = F.sigmoid(global_average_pool_reshaped[:, :, :, 4:5])
205205

206-
score_pred = conv5_reshaped[:, :, :, 5:].contiguous()
206+
score_pred = global_average_pool_reshaped[:, :, :, 5:].contiguous()
207207
prob_pred = F.softmax(score_pred.view(-1, score_pred.size()[-1])).view_as(score_pred)
208208

209209
# for training

0 commit comments

Comments
 (0)