@@ -69,7 +69,7 @@ def get_parser():
6969 parser .add_argument ('--numworker' ,default = 12 ,type = int )
7070 parser .add_argument ('--freezeBN' , choices = dict (true = True , false = False ), default = True , action = LookupChoices )
7171 parser .add_argument ('--step' , default = 10 , type = int )
72- parser .add_argument ('--classes' , default = 7 , type = int )
72+ parser .add_argument ('--classes' , default = 20 , type = int )
7373 parser .add_argument ('--testInterval' , default = 10 , type = int )
7474 parser .add_argument ('--loadmodel' ,default = '' ,type = str )
7575 parser .add_argument ('--pretrainedModel' , default = '' , type = str )
@@ -176,7 +176,7 @@ def main(opts):
176176
177177 # Network definition
178178 if backbone == 'xception' :
179- net_ = deeplab_xception_transfer .deeplab_xception_transfer_projection_savemem (n_classes = 20 , os = 16 ,
179+ net_ = deeplab_xception_transfer .deeplab_xception_transfer_projection_savemem (n_classes = opts . classes , os = 16 ,
180180 hidden_layers = opts .hidden_layers , source_classes = 7 , )
181181 elif backbone == 'resnet' :
182182 # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
@@ -195,7 +195,7 @@ def main(opts):
195195 if not model_path == '' :
196196 x = torch .load (model_path )
197197 net_ .load_state_dict_new (x )
198- print ('load pretrainedModel.' )
198+ print ('load pretrainedModel:' , model_path )
199199 else :
200200 print ('no pretrainedModel.' )
201201 if not opts .loadmodel == '' :
@@ -320,7 +320,7 @@ def main(opts):
320320 # One testing epoch
321321 if useTest and epoch % nTestInterval == (nTestInterval - 1 ):
322322 val_cihp (net_ ,testloader = testloader , testloader_flip = testloader_flip , test_graph = test_graph ,
323- epoch = epoch ,writer = writer ,criterion = criterion )
323+ epoch = epoch ,writer = writer ,criterion = criterion , classes = opts . classes )
324324 torch .cuda .empty_cache ()
325325
326326
0 commit comments