2626 help = 'dir dataset to download or/and load images' )
2727parser .add_argument ('--data_split' , default = 'train' , type = str ,
2828 help = 'Options: (default) train | val | test' )
29- parser .add_argument ('--arch' , '-a' , default = 'resnet152 ' ,
29+ parser .add_argument ('--arch' , '-a' , default = 'fbresnet152 ' ,
3030 choices = convnets .model_names ,
3131 help = 'model architecture: ' +
3232 ' | ' .join (convnets .model_names ) +
3333 ' (default: fbresnet152)' )
34- parser .add_argument ('--workers' , default = 4 , type = int ,
34+ parser .add_argument ('--workers' , default = 4 , type = int ,
3535 help = 'number of data loading workers (default: 4)' )
36- parser .add_argument ('--batch_size' , '-b' , default = 80 , type = int ,
36+ parser .add_argument ('--batch_size' , '-b' , default = 80 , type = int ,
3737 help = 'mini-batch size (default: 80)' )
3838parser .add_argument ('--mode' , default = 'both' , type = str ,
3939 help = 'Options: att | noatt | (default) both' )
@@ -56,7 +56,7 @@ def main():
5656 if args .dataset == 'coco' :
5757 if 'coco' not in args .dir_data :
5858 raise ValueError ('"coco" string not in dir_data' )
59- dataset = datasets .COCOImages (args .data_split , dict (dir = args .dir_data ),
59+ dataset = datasets .COCOImages (args .data_split , dict (dir = args .dir_data ),
6060 transform = transforms .Compose ([
6161 transforms .Scale (args .size ),
6262 transforms .CenterCrop (args .size ),
@@ -68,7 +68,7 @@ def main():
6868 raise ValueError ('train split is required for vgenome' )
6969 if 'vgenome' not in args .dir_data :
7070 raise ValueError ('"vgenome" string not in dir_data' )
71- dataset = datasets .VisualGenomeImages (args .data_split , dict (dir = args .dir_data ),
71+ dataset = datasets .VisualGenomeImages (args .data_split , dict (dir = args .dir_data ),
7272 transform = transforms .Compose ([
7373 transforms .Scale (args .size ),
7474 transforms .CenterCrop (args .size ),
@@ -122,7 +122,7 @@ def extract(data_loader, model, path_file, mode):
122122
123123 nb_regions = output_att .size (2 ) * output_att .size (3 )
124124 output_noatt = output_att .sum (3 ).sum (2 ).div (nb_regions ).view (- 1 , 2048 )
125-
125+
126126 batch_size = output_att .size (0 )
127127 if mode == 'both' or mode == 'att' :
128128 hdf5_att [idx :idx + batch_size ] = output_att .data .cpu ().numpy ()
@@ -141,7 +141,7 @@ def extract(data_loader, model, path_file, mode):
141141 i , len (data_loader ),
142142 batch_time = batch_time ,
143143 data_time = data_time ,))
144-
144+
145145 hdf5_file .close ()
146146
147147 # Saving image names in the same order than extraction
@@ -154,4 +154,4 @@ def extract(data_loader, model, path_file, mode):
154154
155155
156156if __name__ == '__main__' :
157- main ()
157+ main ()
0 commit comments