Skip to content

Commit 2f108e1

Browse files
authored
Merge pull request #309 from ntumlgroup/retrain_args
Set `no_merge_train_val` in `*_tune.yml`
2 parents 128579e + e7c9578 commit 2f108e1

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

search_params.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import json
33
import logging
44
import os
5-
import time
65
from datetime import datetime
76
from pathlib import Path
87

@@ -234,22 +233,22 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
234233
logging.info(f"Re-training with best config: \n{best_config}")
235234
trainer = TorchTrainer(config=best_config, **data)
236235
trainer.train()
236+
best_model_path = trainer.checkpoint_callback.last_model_path
237237
else:
238238
# If not merging training and validation data, load the best result from tune experiments.
239239
logging.info(f"Loading best model with best config: \n{best_config}")
240240
trainer = TorchTrainer(config=best_config, **data)
241241
best_checkpoint = os.path.join(best_log_dir, "best_model.ckpt")
242242
last_checkpoint = os.path.join(best_log_dir, "last.ckpt")
243243
trainer._setup_model(checkpoint_path=best_checkpoint)
244-
os.popen(f"cp {best_checkpoint} {os.path.join(checkpoint_dir, 'best_model.ckpt')}")
244+
best_model_path = os.path.join(checkpoint_dir, 'best_model.ckpt')
245+
os.popen(f"cp {best_checkpoint} {best_model_path}")
245246
os.popen(f"cp {last_checkpoint} {os.path.join(checkpoint_dir, 'last.ckpt')}")
246247

247248
if "test" in data["datasets"]:
248249
test_results = trainer.test()
249250
logging.info(f"Test results after re-training: {test_results}")
250-
logging.info(
251-
f"Best model saved to {trainer.checkpoint_callback.best_model_path or trainer.checkpoint_callback.last_model_path}."
252-
)
251+
logging.info(f"Best model saved to {best_model_path}.")
253252

254253

255254
def main():
@@ -341,7 +340,7 @@ def main():
341340
# Save best model after parameter search.
342341
best_config = analysis.get_best_config(f"val_{config.val_metric}", config.mode, scope="all")
343342
best_log_dir = analysis.get_best_logdir(f"val_{config.val_metric}", config.mode, scope="all")
344-
retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not args.no_merge_train_val)
343+
retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not config.no_merge_train_val)
345344

346345

347346
if __name__ == "__main__":

0 commit comments

Comments
 (0)