Skip to content

Commit dcdeec7

Browse files
Add info message, if model weights are loaded.
1 parent 4be3352 commit dcdeec7

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def create_parser():
4242
parser.add_argument("-c", "--config-file", type=str, required=False,
4343
help="Path to a yaml configuration file.")
4444
parser.add_argument("--log", required=False, type=str, help="Path to the log file destination.")
45-
parser.add_argument("--save_path", required=False, type=str, default="",
45+
parser.add_argument("--save_path", required=False, type=str, default="",
4646
help="Path to the model destination. If empty, model won't be saved.")
47-
parser.add_argument("--load_path", required=False, type=str, default="",
47+
parser.add_argument("--load_path", required=False, type=str, default="",
4848
help="Path to the saved model. If empty, model won't be loaded.")
4949
return parser
5050

@@ -181,9 +181,10 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
181181
seq_length = 12
182182

183183
if args.load_path is not None and Path(args.load_path).is_file():
184-
model = torch.load(args.load_path)
185-
else:
186-
model = build_model(11, seq_length, args.batch_size).to(device)
184+
print("Loading model weights from: " + args.load_path)
185+
model = torch.load(args.load_path)
186+
else:
187+
model = build_model(11, seq_length, args.batch_size).to(device)
187188

188189
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
189190
floss = loss_func()
@@ -256,7 +257,7 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
256257
write_to_csv(history_item, args.log, write_header=epoch == 0, append=epoch != 0)
257258

258259
if args.save_path is not None:
259-
torch.save(model, args.save_path)
260+
torch.save(model, args.save_path)
260261

261262
# Test here
262263
test_results = test(model, dataloaders['test'], verbose)

0 commit comments

Comments
 (0)