@@ -42,9 +42,9 @@ def create_parser():
42
42
parser .add_argument ("-c" , "--config-file" , type = str , required = False ,
43
43
help = "Path to a yaml configuration file." )
44
44
parser .add_argument ("--log" , required = False , type = str , help = "Path to the log file destination." )
45
- parser .add_argument ("--save_path" , required = False , type = str , default = "" ,
45
+ parser .add_argument ("--save_path" , required = False , type = str , default = "" ,
46
46
help = "Path to the model destination. If empty, model won't be saved." )
47
- parser .add_argument ("--load_path" , required = False , type = str , default = "" ,
47
+ parser .add_argument ("--load_path" , required = False , type = str , default = "" ,
48
48
help = "Path to the saved model. If empty, model won't be loaded." )
49
49
return parser
50
50
@@ -181,9 +181,10 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
181
181
seq_length = 12
182
182
183
183
if args .load_path is not None and Path (args .load_path ).is_file ():
184
- model = torch .load (args .load_path )
185
- else :
186
- model = build_model (11 , seq_length , args .batch_size ).to (device )
184
+ print ("Loading model weights from: " + args .load_path )
185
+ model = torch .load (args .load_path )
186
+ else :
187
+ model = build_model (11 , seq_length , args .batch_size ).to (device )
187
188
188
189
optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
189
190
floss = loss_func ()
@@ -256,7 +257,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
256
257
write_to_csv (history_item , args .log , write_header = epoch == 0 , append = epoch != 0 )
257
258
258
259
if args .save_path is not None :
259
- torch .save (model , args .save_path )
260
+ torch .save (model , args .save_path )
260
261
261
262
# Test here
262
263
test_results = test (model , dataloaders ['test' ], verbose )
0 commit comments