|
2 | 2 | import json
|
3 | 3 | import logging
|
4 | 4 | import os
|
5 |
| -import time |
6 | 5 | from datetime import datetime
|
7 | 6 | from pathlib import Path
|
8 | 7 |
|
@@ -234,22 +233,22 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
|
234 | 233 | logging.info(f"Re-training with best config: \n{best_config}")
|
235 | 234 | trainer = TorchTrainer(config=best_config, **data)
|
236 | 235 | trainer.train()
|
| 236 | + best_model_path = trainer.checkpoint_callback.last_model_path |
237 | 237 | else:
|
238 | 238 | # If not merging training and validation data, load the best result from tune experiments.
|
239 | 239 | logging.info(f"Loading best model with best config: \n{best_config}")
|
240 | 240 | trainer = TorchTrainer(config=best_config, **data)
|
241 | 241 | best_checkpoint = os.path.join(best_log_dir, "best_model.ckpt")
|
242 | 242 | last_checkpoint = os.path.join(best_log_dir, "last.ckpt")
|
243 | 243 | trainer._setup_model(checkpoint_path=best_checkpoint)
|
244 |
| - os.popen(f"cp {best_checkpoint} {os.path.join(checkpoint_dir, 'best_model.ckpt')}") |
| 244 | + best_model_path = os.path.join(checkpoint_dir, 'best_model.ckpt') |
| 245 | + os.popen(f"cp {best_checkpoint} {best_model_path}") |
245 | 246 | os.popen(f"cp {last_checkpoint} {os.path.join(checkpoint_dir, 'last.ckpt')}")
|
246 | 247 |
|
247 | 248 | if "test" in data["datasets"]:
|
248 | 249 | test_results = trainer.test()
|
249 | 250 | logging.info(f"Test results after re-training: {test_results}")
|
250 |
| - logging.info( |
251 |
| - f"Best model saved to {trainer.checkpoint_callback.best_model_path or trainer.checkpoint_callback.last_model_path}." |
252 |
| - ) |
| 251 | + logging.info(f"Best model saved to {best_model_path}.") |
253 | 252 |
|
254 | 253 |
|
255 | 254 | def main():
|
@@ -341,7 +340,7 @@ def main():
|
341 | 340 | # Save best model after parameter search.
|
342 | 341 | best_config = analysis.get_best_config(f"val_{config.val_metric}", config.mode, scope="all")
|
343 | 342 | best_log_dir = analysis.get_best_logdir(f"val_{config.val_metric}", config.mode, scope="all")
|
344 |
| - retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not args.no_merge_train_val) |
| 343 | + retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not config.no_merge_train_val) |
345 | 344 |
|
346 | 345 |
|
347 | 346 | if __name__ == "__main__":
|
|
0 commit comments