@@ -24,17 +24,17 @@ def str2bool(v):
2424
2525parser = argparse .ArgumentParser (
2626 description = 'Receptive Field Block Net Training' )
27- parser .add_argument ('-v' , '--version' , default = 'RFB_vgg ' ,
27+ parser .add_argument ('-v' , '--version' , default = 'SSD_vgg ' ,
2828 help = 'RFB_vgg ,RFB_E_vgg RFB_mobile SSD_vgg version.' )
29- parser .add_argument ('-s' , '--size' , default = '300 ' ,
29+ parser .add_argument ('-s' , '--size' , default = '512 ' ,
3030 help = '300 or 512 input size.' )
31- parser .add_argument ('-d' , '--dataset' , default = 'VOC ' ,
31+ parser .add_argument ('-d' , '--dataset' , default = 'COCO ' ,
3232 help = 'VOC or COCO dataset' )
3333parser .add_argument (
34- '--basenet' , default = '/mnt/lvmhdd1/zuoxin/ssd_pytorch_models /vgg16_reducedfc.pth' , help = 'pretrained base model' )
34+ '--basenet' , default = 'weights /vgg16_reducedfc.pth' , help = 'pretrained base model' )
3535parser .add_argument ('--jaccard_threshold' , default = 0.5 ,
3636 type = float , help = 'Min Jaccard index for matching' )
37- parser .add_argument ('-b' , '--batch_size' , default = 32 ,
37+ parser .add_argument ('-b' , '--batch_size' , default = 8 ,
3838 type = int , help = 'Batch size for training' )
3939parser .add_argument ('--num_workers' , default = 4 ,
4040 type = int , help = 'Number of workers used in dataloading' )
@@ -45,8 +45,8 @@ def str2bool(v):
4545 default = 4e-3 , type = float , help = 'initial learning rate' )
4646parser .add_argument ('--momentum' , default = 0.9 , type = float , help = 'momentum' )
4747
48- parser .add_argument ('--resume_net' , default = True , help = 'resume net for retraining' )
49- parser .add_argument ('--resume_epoch' , default = 10 ,
48+ parser .add_argument ('--resume_net' , default = False , help = 'resume net for retraining' )
49+ parser .add_argument ('--resume_epoch' , default = 0 ,
5050 type = int , help = 'resume iter for retraining' )
5151
5252parser .add_argument ('-max' ,'--max_epoch' , default = 300 ,
@@ -59,7 +59,7 @@ def str2bool(v):
5959 type = float , help = 'Gamma update for SGD' )
6060parser .add_argument ('--log_iters' , default = True ,
6161 type = bool , help = 'Print the loss at each iteration' )
62- parser .add_argument ('--save_folder' , default = '/mnt/lvmhdd1/zuoxin/ssd_pytorch_models /' ,
62+ parser .add_argument ('--save_folder' , default = 'weights /' ,
6363 help = 'Location to save checkpoint models' )
6464parser .add_argument ('--date' ,default = '1213' )
6565parser .add_argument ('--save_frequency' ,default = 10 )
@@ -82,7 +82,7 @@ def str2bool(v):
8282 train_sets = [('2007' , 'trainval' ), ('2012' , 'trainval' )]
8383 cfg = (VOC_300 , VOC_512 )[args .size == '512' ]
8484else :
85- train_sets = [('2014 ' , 'train' ),( '2014' , 'valminusminival ' )]
85+ train_sets = [('2017 ' , 'train' )]
8686 cfg = (COCO_300 , COCO_512 )[args .size == '512' ]
8787
8888if args .version == 'RFB_vgg' :
@@ -194,7 +194,7 @@ def weights_init(m):
194194 img_dim , rgb_means , p ), AnnotationTransform ())
195195elif args .dataset == 'COCO' :
196196 testset = COCODetection (
197- COCOroot , [('2014 ' , 'minival ' )], None )
197+ COCOroot , [('2017 ' , 'val ' )], None )
198198 train_dataset = COCODetection (COCOroot , train_sets , preproc (
199199 img_dim , rgb_means , p ))
200200else :
0 commit comments