Skip to content

Commit 3c2bffd

Browse files
committed
fix the test bug
1 parent 8959cae commit 3c2bffd

File tree

8 files changed

+48
-65
lines changed

8 files changed

+48
-65
lines changed

data/config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
import os.path
33

44
# gets home dir cross platform
5-
home = os.path.expanduser("~")
6-
ddir = os.path.join(home,"data/VOCdevkit/")
75

86
# note: if you used our download scripts, this should be right
9-
VOCroot = ddir # path to VOCdevkit root dir
10-
COCOroot = os.path.join(home,"data/COCO/")
7+
VOCroot = '/mnt/lvmhdd1/zuoxin/dataset/VOCdevkit' # path to VOCdevkit root dir
8+
COCOroot = ''
119

1210

1311
#RFB CONFIGS
@@ -32,7 +30,7 @@
3230
VOC_512= {
3331
'feature_maps' : [64, 32, 16, 8, 4, 2, 1],
3432

35-
'min_dim' : 524,
33+
'min_dim' : 512,
3634

3735
'steps' : [8, 16, 32, 64, 128, 256, 512],
3836

data/voc0712.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ def evaluate_detections(self, all_boxes, output_dir=None):
250250
all_boxes[class][image] = [] or np.array of shape #dets x 5
251251
"""
252252
self._write_voc_results_file(all_boxes)
253-
self._do_python_eval(output_dir)
253+
aps,map = self._do_python_eval(output_dir)
254+
return aps,map
254255

255256
def _get_voc_results_file_template(self):
256257
filename = 'comp4_det_test' + '_{:s}.txt'
@@ -327,6 +328,7 @@ def _do_python_eval(self, output_dir='output'):
327328
print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
328329
print('-- Thanks, The Management')
329330
print('--------------------------------------------------------------')
331+
return aps,np.mean(aps)
330332

331333
def detection_collate(batch):
332334
"""Custom collate fn for dealing with batches of images that have a different

models/RFB_Net_E_vgg.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,8 @@ def forward(self,x):
185185

186186
class RFBNet(nn.Module):
187187

188-
def __init__(self, phase, size, base, extras, head, num_classes):
188+
def __init__(self, size, base, extras, head, num_classes):
189189
super(RFBNet, self).__init__()
190-
self.phase = phase
191190
self.num_classes = num_classes
192191
self.size = size
193192

@@ -209,10 +208,9 @@ def __init__(self, phase, size, base, extras, head, num_classes):
209208

210209
self.loc = nn.ModuleList(head[0])
211210
self.conf = nn.ModuleList(head[1])
212-
if self.phase == 'test':
213-
self.softmax = nn.Softmax()
211+
self.softmax = nn.Softmax()
214212

215-
def forward(self, x):
213+
def forward(self, x,test=False):
216214
"""Applies network layers and ops on input image(s) x.
217215
218216
Args:
@@ -268,7 +266,7 @@ def forward(self, x):
268266
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
269267
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
270268

271-
if self.phase == "test":
269+
if test:
272270
output = (
273271
loc.view(loc.size(0), -1, 4), # loc preds
274272
self.softmax(conf.view(-1, self.num_classes)), # conf preds
@@ -368,14 +366,11 @@ def multibox(size, vgg, extra_layers, cfg, num_classes):
368366
}
369367

370368

371-
def build_net(phase, size=300, num_classes=21):
372-
if phase != "test" and phase != "train":
373-
print("Error: Phase not recognized")
374-
return
369+
def build_net(size=300, num_classes=21):
375370
if size != 300 and size != 512:
376371
print("Error: Sorry only RFB300 and RFB512 are supported!")
377372
return
378373

379-
return RFBNet(phase, size, *multibox(size, vgg(vgg_base[str(size)], 3),
374+
return RFBNet(size, *multibox(size, vgg(vgg_base[str(size)], 3),
380375
add_extras(size, extras[str(size)], 1024),
381-
mbox[str(size)], num_classes), num_classes)
376+
mbox[str(size)], num_classes), num_classes=num_classes)

models/RFB_Net_mobile.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,9 @@ def __init__(self, phase, size, base, extras, head, num_classes):
154154

155155
self.loc = nn.ModuleList(head[0])
156156
self.conf = nn.ModuleList(head[1])
157-
if self.phase == 'test':
158-
self.softmax = nn.Softmax()
157+
self.softmax = nn.Softmax()
159158

160-
def forward(self, x):
159+
def forward(self, x,test=False):
161160
"""Applies network layers and ops on input image(s) x.
162161
163162
Args:
@@ -209,7 +208,7 @@ def forward(self, x):
209208
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
210209
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
211210

212-
if self.phase == "test":
211+
if test:
213212
output = (
214213
loc.view(loc.size(0), -1, 4), # loc preds
215214
self.softmax(conf.view(-1, self.num_classes)), # conf preds
@@ -335,14 +334,11 @@ def multibox(size, base, extra_layers, cfg, num_classes):
335334
}
336335

337336

338-
def build_net(phase, size=300, num_classes=21):
339-
if phase != "test" and phase != "train":
340-
print("Error: Phase not recognized")
341-
return
337+
def build_net(size=300, num_classes=21):
342338
if size != 300:
343339
print("Error: Sorry only RFB300_mobile is supported!")
344340
return
345341

346-
return RFBNet(phase, size, *multibox(size, MobileNet(),
342+
return RFBNet(size, *multibox(size, MobileNet(),
347343
add_extras(size, extras[str(size)], 1024),
348-
mbox[str(size)], num_classes), num_classes)
344+
mbox[str(size)], num_classes), num_classes=num_classes)

models/RFB_Net_vgg.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,8 @@ class RFBNet(nn.Module):
138138
head: "multibox head" consists of loc and conf conv layers
139139
"""
140140

141-
def __init__(self, phase, size, base, extras, head, num_classes):
141+
def __init__(self,size, base, extras, head, num_classes):
142142
super(RFBNet, self).__init__()
143-
self.phase = phase
144143
self.num_classes = num_classes
145144
self.size = size
146145

@@ -159,10 +158,9 @@ def __init__(self, phase, size, base, extras, head, num_classes):
159158

160159
self.loc = nn.ModuleList(head[0])
161160
self.conf = nn.ModuleList(head[1])
162-
if self.phase == 'test':
163-
self.softmax = nn.Softmax()
161+
self.softmax = nn.Softmax()
164162

165-
def forward(self, x):
163+
def forward(self, x, test = False):
166164
"""Applies network layers and ops on input image(s) x.
167165
168166
Args:
@@ -214,7 +212,7 @@ def forward(self, x):
214212
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
215213
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
216214

217-
if self.phase == "test":
215+
if test:
218216
output = (
219217
loc.view(loc.size(0), -1, 4), # loc preds
220218
self.softmax(conf.view(-1, self.num_classes)), # conf preds
@@ -311,14 +309,11 @@ def multibox(size, vgg, extra_layers, cfg, num_classes):
311309
}
312310

313311

314-
def build_net(phase, size=300, num_classes=21):
315-
if phase != "test" and phase != "train":
316-
print("Error: Phase not recognized")
317-
return
312+
def build_net(size=300, num_classes=21):
318313
if size != 300 and size != 512:
319314
print("Error: Sorry only RFBNet300 and RFBNet512 are supported!")
320315
return
321316

322-
return RFBNet(phase, size, *multibox(size, vgg(vgg_base[str(size)], 3),
317+
return RFBNet(size, *multibox(size, vgg(vgg_base[str(size)], 3),
323318
add_extras(size, extras[str(size)], 1024),
324-
mbox[str(size)], num_classes), num_classes)
319+
mbox[str(size)], num_classes), num_classes=num_classes)

models/SSD_vgg.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ class SSD(nn.Module):
2323
head: "multibox head" consists of loc and conf conv layers
2424
"""
2525

26-
def __init__(self, phase, base, extras, head, num_classes):
26+
def __init__(self,base, extras, head, num_classes):
2727
super(SSD, self).__init__()
28-
self.phase = phase
2928
self.num_classes = num_classes
3029
# TODO: implement __call__ in PriorBox
3130
self.size = 300
@@ -39,11 +38,9 @@ def __init__(self, phase, base, extras, head, num_classes):
3938
self.loc = nn.ModuleList(head[0])
4039
self.conf = nn.ModuleList(head[1])
4140

42-
if phase == 'test':
43-
self.softmax = nn.Softmax()
44-
self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
41+
self.softmax = nn.Softmax()
4542

46-
def forward(self, x):
43+
def forward(self, x,test=False):
4744
"""Applies network layers and ops on input image(s) x.
4845
4946
Args:
@@ -91,8 +88,8 @@ def forward(self, x):
9188

9289
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
9390
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
94-
if self.phase == "test":
95-
output = self.detect(
91+
if test:
92+
output =(
9693
loc.view(loc.size(0), -1, 4), # loc preds
9794
self.softmax(conf.view(-1, self.num_classes)), # conf preds
9895
)
@@ -159,14 +156,11 @@ def multibox(vgg, extra_layers, cfg, num_classes):
159156
}
160157

161158

162-
def build_net(phase, size=300, num_classes=21):
163-
if phase != "test" and phase != "train":
164-
print("Error: Phase not recognized")
165-
return
159+
def build_net(size=300, num_classes=21):
166160
if size != 300 and size != 512:
167161
print("Error: Sorry only SSD300 and SSD512 is supported currently!")
168162
return
169163

170-
return SSD(phase, *multibox(vgg(vgg_base[str(size)], 3),
164+
return SSD(*multibox(vgg(vgg_base[str(size)], 3),
171165
add_extras(extras[str(size)], 1024),
172-
mbox[str(size)], num_classes), num_classes)
166+
mbox[str(size)], num_classes), num_classes=num_classes)

test_RFB.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_net(save_folder, net, detector, cuda, testset, transform, max_per_image
8585
x = x.cuda()
8686

8787
_t['im_detect'].tic()
88-
out = net(x) # forward pass
88+
out = net(x,test = True) # forward pass
8989
boxes, scores = detector.forward(out,priors)
9090
detect_time = _t['im_detect'].toc()
9191
boxes = boxes[0]
@@ -145,7 +145,7 @@ def test_net(save_folder, net, detector, cuda, testset, transform, max_per_image
145145
# load net
146146
img_dim = (300,512)[args.size=='512']
147147
num_classes = (21, 81)[args.dataset == 'COCO']
148-
net = build_net('test', img_dim, num_classes) # initialize detector
148+
net = build_net(img_dim, num_classes) # initialize detector
149149
state_dict = torch.load(args.trained_model)
150150
# create new OrderedDict that does not contain `module.`
151151

train_test.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def str2bool(v):
2525
parser = argparse.ArgumentParser(
2626
description='Receptive Field Block Net Training')
2727
parser.add_argument('-v', '--version', default='SSD_vgg',
28-
help='RFB_vgg ,RFB_E_vgg RFB_mobile SSD version.')
28+
help='RFB_vgg ,RFB_E_vgg RFB_mobile SSD_vgg version.')
2929
parser.add_argument('-s', '--size', default='300',
3030
help='300 or 512 input size.')
3131
parser.add_argument('-d', '--dataset', default='VOC',
@@ -45,8 +45,8 @@ def str2bool(v):
4545
default=1e-3, type=float, help='initial learning rate')
4646
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
4747
parser.add_argument(
48-
'--resume_net', default=False, help='resume net for retraining')
49-
parser.add_argument('--resume_epoch', default=0,
48+
'--resume_net', default=True, help='resume net for retraining')
49+
parser.add_argument('--resume_epoch', default=250,
5050
type=int, help='resume iter for retraining')
5151
parser.add_argument('-max','--max_epoch', default=300,
5252
type=int, help='max epoch for retraining')
@@ -62,8 +62,8 @@ def str2bool(v):
6262
parser.add_argument('--save_frequency',default=10)
6363
parser.add_argument('--retest', default=False, type=bool,
6464
help='test cache results')
65-
parser.add_argument('--test_frequency',default=100)
66-
parser.add_argument('--visdom', default=True, type=str2bool, help='Use visdom to for loss visualization')
65+
parser.add_argument('--test_frequency',default=10)
66+
parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom to for loss visualization')
6767
parser.add_argument('--send_images_to_visdom', type=str2bool, default=False, help='Sample a random image from each 10th batch, send it to visdom after augmentations step')
6868
args = parser.parse_args()
6969

@@ -110,7 +110,7 @@ def str2bool(v):
110110
import visdom
111111
viz = visdom.Visdom()
112112

113-
net = build_net('train', img_dim, num_classes)
113+
net = build_net(img_dim, num_classes)
114114
print(net)
115115
if not args.resume_net:
116116
base_weights = torch.load(args.basenet)
@@ -318,7 +318,7 @@ def train():
318318
win=lot,
319319
update='append'
320320
)
321-
if iteration == 0:
321+
if iteration%epoch_size == 0:
322322
viz.line(
323323
X=torch.zeros((1, 3)).cpu(),
324324
Y=torch.Tensor([loc_loss, conf_loss,
@@ -371,7 +371,7 @@ def test_net(save_folder, net, detector, cuda, testset, transform, max_per_image
371371
x = x.cuda()
372372

373373
_t['im_detect'].tic()
374-
out = net(x) # forward pass
374+
out = net(x = x,test = True) # forward pass
375375
boxes, scores = detector.forward(out,priors)
376376
detect_time = _t['im_detect'].toc()
377377
boxes = boxes[0]
@@ -424,7 +424,10 @@ def test_net(save_folder, net, detector, cuda, testset, transform, max_per_image
424424
pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
425425

426426
print('Evaluating detections')
427-
testset.evaluate_detections(all_boxes, save_folder)
427+
if args.dataset == 'VOC':
428+
aps,map = testset.evaluate_detections(all_boxes, save_folder)
429+
return aps,map
430+
428431

429432

430433
if __name__ == '__main__':

0 commit comments

Comments
 (0)