@@ -25,7 +25,7 @@ def str2bool(v):
2525parser = argparse .ArgumentParser (
2626 description = 'Receptive Field Block Net Training' )
2727parser .add_argument ('-v' , '--version' , default = 'SSD_vgg' ,
28- help = 'RFB_vgg ,RFB_E_vgg RFB_mobile SSD version.' )
28+ help = 'RFB_vgg ,RFB_E_vgg RFB_mobile SSD_vgg version.' )
2929parser .add_argument ('-s' , '--size' , default = '300' ,
3030 help = '300 or 512 input size.' )
3131parser .add_argument ('-d' , '--dataset' , default = 'VOC' ,
@@ -45,8 +45,8 @@ def str2bool(v):
4545 default = 1e-3 , type = float , help = 'initial learning rate' )
4646parser .add_argument ('--momentum' , default = 0.9 , type = float , help = 'momentum' )
4747parser .add_argument (
48- '--resume_net' , default = False , help = 'resume net for retraining' )
49- parser .add_argument ('--resume_epoch' , default = 0 ,
48+ '--resume_net' , default = True , help = 'resume net for retraining' )
49+ parser .add_argument ('--resume_epoch' , default = 250 ,
5050 type = int , help = 'resume iter for retraining' )
5151parser .add_argument ('-max' ,'--max_epoch' , default = 300 ,
5252 type = int , help = 'max epoch for retraining' )
@@ -62,8 +62,8 @@ def str2bool(v):
6262parser .add_argument ('--save_frequency' ,default = 10 )
6363parser .add_argument ('--retest' , default = False , type = bool ,
6464 help = 'test cache results' )
65- parser .add_argument ('--test_frequency' ,default = 100 )
66- parser .add_argument ('--visdom' , default = True , type = str2bool , help = 'Use visdom to for loss visualization' )
65+ parser .add_argument ('--test_frequency' ,default = 10 )
66+ parser .add_argument ('--visdom' , default = False , type = str2bool , help = 'Use visdom to for loss visualization' )
6767parser .add_argument ('--send_images_to_visdom' , type = str2bool , default = False , help = 'Sample a random image from each 10th batch, send it to visdom after augmentations step' )
6868args = parser .parse_args ()
6969
@@ -110,7 +110,7 @@ def str2bool(v):
110110 import visdom
111111 viz = visdom .Visdom ()
112112
113- net = build_net ('train' , img_dim , num_classes )
113+ net = build_net (img_dim , num_classes )
114114print (net )
115115if not args .resume_net :
116116 base_weights = torch .load (args .basenet )
@@ -318,7 +318,7 @@ def train():
318318 win = lot ,
319319 update = 'append'
320320 )
321- if iteration == 0 :
321+ if iteration % epoch_size == 0 :
322322 viz .line (
323323 X = torch .zeros ((1 , 3 )).cpu (),
324324 Y = torch .Tensor ([loc_loss , conf_loss ,
@@ -371,7 +371,7 @@ def test_net(save_folder, net, detector, cuda, testset, transform, max_per_image
371371 x = x .cuda ()
372372
373373 _t ['im_detect' ].tic ()
374- out = net (x ) # forward pass
374+ out = net (x = x , test = True ) # forward pass
375375 boxes , scores = detector .forward (out ,priors )
376376 detect_time = _t ['im_detect' ].toc ()
377377 boxes = boxes [0 ]
@@ -424,7 +424,10 @@ def test_net(save_folder, net, detector, cuda, testset, transform, max_per_image
424424 pickle .dump (all_boxes , f , pickle .HIGHEST_PROTOCOL )
425425
426426 print ('Evaluating detections' )
427- testset .evaluate_detections (all_boxes , save_folder )
427+ if args .dataset == 'VOC' :
428+ aps ,map = testset .evaluate_detections (all_boxes , save_folder )
429+ return aps ,map
430+
428431
429432
430433if __name__ == '__main__' :
0 commit comments