Skip to content

Commit dd7b8e1

Browse files
committed
add refinedet 320
1 parent f9c0f1e commit dd7b8e1

File tree

10 files changed

+893
-15
lines changed

10 files changed

+893
-15
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
*.so
2+
__pycache__
3+
build

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
* SSD [SSD: Single Shot Multibox Detector](https://arxiv.org/abs/1512.02325)
44
* FSSD [FSSD: Feature Fusion Single Shot Multibox Detector](https://arxiv.org/abs/1712.00960)
55
* RFB-SSD[Receptive Field Block Net for Accurate and Fast Object Detection](https://arxiv.org/abs/1711.07767)
6+
* RefindeDet[Single-Shot Refinement Neural Network for Object Detection](https://arxiv.org/pdf/1711.06897.pdf)
67

78
### VOC2007 Test
89
| System | *mAP* | **FPS** (Titan X Maxwell) |

data/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,21 @@
9999

100100
'clip' : True,
101101
}
102+
103+
VOC_320 = {
104+
'feature_maps' : [40, 20, 10, 5],
105+
106+
'min_dim' : 320,
107+
108+
'steps' : [8, 16, 32, 64],
109+
110+
'min_sizes' : [32, 64, 128, 256],
111+
112+
'max_sizes' : [],
113+
114+
'aspect_ratios' : [[2], [2], [2], [2]],
115+
116+
'variance' : [0.1, 0.2],
117+
118+
'clip' : True,
119+
}

layers/functions/detection.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.backends.cudnn as cudnn
44
from torch.autograd import Function
55
from torch.autograd import Variable
6-
from utils.box_utils import decode, nms
6+
from utils.box_utils import decode, nms,center_size
77

88

99
class Detect(Function):
@@ -12,15 +12,16 @@ class Detect(Function):
1212
scores and threshold to a top_k number of output predictions for both
1313
confidence score and locations.
1414
"""
15-
def __init__(self, num_classes, bkg_label, cfg):
15+
def __init__(self, num_classes, bkg_label, cfg,object_score = 0):
1616
self.num_classes = num_classes
1717
self.background_label = bkg_label
18+
self.object_score = object_score
1819
#self.thresh = thresh
1920

2021
# Parameters used in nms.
2122
self.variance = cfg['variance']
2223

23-
def forward(self, predictions, prior):
24+
def forward(self, predictions, prior,arm_data = None):
2425
"""
2526
Args:
2627
loc_data: (tensor) Loc preds from loc layers
@@ -32,28 +33,39 @@ def forward(self, predictions, prior):
3233
"""
3334

3435
loc, conf = predictions
35-
3636
loc_data = loc.data
3737
conf_data = conf.data
3838
prior_data = prior.data
3939
num = loc_data.size(0) # batch size
40+
if arm_data:
41+
arm_loc,arm_conf = arm_data
42+
arm_loc_data = arm_loc.data
43+
arm_conf_data = arm_conf.data
44+
arm_object_conf = arm_conf_data[:,1:]
45+
no_object_index = arm_object_conf<=self.object_score
46+
conf_data[no_object_index.expand_as(conf_data)] = 0
47+
4048
self.num_priors = prior_data.size(0)
41-
self.boxes = torch.zeros(1, self.num_priors, 4)
42-
self.scores = torch.zeros(1, self.num_priors, self.num_classes)
49+
self.boxes = torch.zeros(num, self.num_priors, 4)
50+
self.scores = torch.zeros(num, self.num_priors, self.num_classes)
4351

4452
if num == 1:
4553
# size batch x num_classes x num_priors
4654
conf_preds = conf_data.unsqueeze(0)
4755

4856
else:
49-
conf_preds = conf_data.view(num, num_priors,
57+
conf_preds = conf_data.view(num, self.num_priors,
5058
self.num_classes)
51-
self.boxes.expand_(num, self.num_priors, 4)
52-
self.scores.expand_(num, self.num_priors, self.num_classes)
53-
59+
self.boxes.expand(num, self.num_priors, 4)
60+
self.scores.expand(num, self.num_priors, self.num_classes)
5461
# Decode predictions into bboxes.
5562
for i in range(num):
56-
decoded_boxes = decode(loc_data[i], prior_data, self.variance)
63+
if arm_data:
64+
default = decode(arm_loc_data[i],prior_data,self.variance)
65+
default = center_size(default)
66+
else:
67+
default = prior_data
68+
decoded_boxes = decode(loc_data[i], default, self.variance)
5769
# For each class, perform nms
5870
conf_scores = conf_preds[i].clone()
5971
'''

layers/functions/prior_box.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ def forward(self):
4545

4646
# aspect_ratio: 1
4747
# rel size: sqrt(s_k * s_(k+1))
48-
s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))
49-
mean += [cx, cy, s_k_prime, s_k_prime]
48+
if self.max_sizes:
49+
s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))
50+
mean += [cx, cy, s_k_prime, s_k_prime]
5051

5152
# rest of aspect ratios
5253
for ar in self.aspect_ratios[k]:

layers/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .multibox_loss import MultiBoxLoss
2+
from .refine_multibox_loss import RefineMultiBoxLoss
23
from .l2norm import L2Norm
34

45
__all__ = ['MultiBoxLoss','L2Norm']
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# coding=utf-8
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from torch.autograd import Variable
6+
from utils.box_utils import match,refine_match, log_sum_exp,decode
7+
GPU = False
8+
if torch.cuda.is_available():
9+
GPU = True
10+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
11+
12+
13+
class RefineMultiBoxLoss(nn.Module):
14+
"""SSD Weighted Loss Function
15+
Compute Targets:
16+
1) Produce Confidence Target Indices by matching ground truth boxes
17+
with (default) 'priorboxes' that have jaccard index > threshold parameter
18+
(default threshold: 0.5).
19+
2) Produce localization target by 'encoding' variance into offsets of ground
20+
truth boxes and their matched 'priorboxes'.
21+
3) Hard negative mining to filter the excessive number of negative examples
22+
that comes with using a large number of default bounding boxes.
23+
(default negative:positive ratio 3:1)
24+
Objective Loss:
25+
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
26+
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
27+
weighted by α which is set to 1 by cross val.
28+
Args:
29+
c: class confidences,
30+
l: predicted boxes,
31+
g: ground truth boxes
32+
N: number of matched default boxes
33+
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
34+
"""
35+
36+
37+
def __init__(self, num_classes,overlap_thresh,prior_for_matching,bkg_label,neg_mining,neg_pos,neg_overlap,encode_target,object_score = 0):
38+
super(RefineMultiBoxLoss, self).__init__()
39+
self.num_classes = num_classes
40+
self.threshold = overlap_thresh
41+
self.background_label = bkg_label
42+
self.encode_target = encode_target
43+
self.use_prior_for_matching = prior_for_matching
44+
self.do_neg_mining = neg_mining
45+
self.negpos_ratio = neg_pos
46+
self.neg_overlap = neg_overlap
47+
self.object_score = object_score
48+
self.variance = [0.1,0.2]
49+
50+
def forward(self, odm_data,priors, targets,arm_data = None,filter_object = False):
51+
"""Multibox Loss
52+
Args:
53+
predictions (tuple): A tuple containing loc preds, conf preds,
54+
and prior boxes from SSD net.
55+
conf shape: torch.size(batch_size,num_priors,num_classes)
56+
loc shape: torch.size(batch_size,num_priors,4)
57+
priors shape: torch.size(num_priors,4)
58+
59+
ground_truth (tensor): Ground truth boxes and labels for a batch,
60+
shape: [batch_size,num_objs,5] (last idx is the label).
61+
arm_data (tuple): arm branch containg arm_loc and arm_conf
62+
filter_object: whether filter out the prediction according to the arm conf score
63+
"""
64+
65+
loc_data,conf_data = odm_data
66+
if arm_data:
67+
arm_loc,arm_conf = arm_data
68+
priors = priors.data
69+
num = loc_data.size(0)
70+
num_priors = (priors.size(0))
71+
72+
# match priors (default boxes) and ground truth boxes
73+
loc_t = torch.Tensor(num, num_priors, 4)
74+
conf_t = torch.LongTensor(num, num_priors)
75+
for idx in range(num):
76+
truths = targets[idx][:,:-1].data
77+
labels = targets[idx][:,-1].data
78+
#for object detection
79+
if self.num_classes == 2:
80+
labels = labels > 0
81+
if arm_data:
82+
refine_match(self.threshold,truths,priors,self.variance,labels,loc_t,conf_t,idx,arm_loc[idx].data)
83+
else:
84+
match(self.threshold,truths,priors,self.variance,labels,loc_t,conf_t,idx)
85+
if GPU:
86+
loc_t = loc_t.cuda()
87+
conf_t = conf_t.cuda()
88+
# wrap targets
89+
loc_t = Variable(loc_t, requires_grad=False)
90+
conf_t = Variable(conf_t,requires_grad=False)
91+
if arm_data and filter_object:
92+
arm_conf_data = arm_conf.data[:,:,1]
93+
pos = conf_t > 0
94+
object_score_index = arm_conf_data <= self.object_score
95+
pos[object_score_index] = 0
96+
97+
else:
98+
pos = conf_t > 0
99+
100+
# Localization Loss (Smooth L1)
101+
# Shape: [batch,num_priors,4]
102+
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
103+
loc_p = loc_data[pos_idx].view(-1,4)
104+
loc_t = loc_t[pos_idx].view(-1,4)
105+
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
106+
107+
# Compute max conf across batch for hard negative mining
108+
batch_conf = conf_data.view(-1,self.num_classes)
109+
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1,1))
110+
111+
# Hard Negative Mining
112+
loss_c[pos] = 0 # filter out pos boxes for now
113+
loss_c = loss_c.view(num, -1)
114+
_,loss_idx = loss_c.sort(1, descending=True)
115+
_,idx_rank = loss_idx.sort(1)
116+
num_pos = pos.long().sum(1,keepdim=True)
117+
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
118+
neg = idx_rank < num_neg.expand_as(idx_rank)
119+
120+
# Confidence Loss Including Positive and Negative Examples
121+
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
122+
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
123+
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
124+
targets_weighted = conf_t[(pos+neg).gt(0)]
125+
loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
126+
127+
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
128+
N = num_pos.data.sum()
129+
loss_l/=N
130+
loss_c/=N
131+
return loss_l,loss_c

0 commit comments

Comments
 (0)