From 3bc6e9fe54a322b980962f112be47249207fb0ce Mon Sep 17 00:00:00 2001 From: Arijit Ghosh Date: Wed, 15 Jan 2025 00:48:57 +0100 Subject: [PATCH] updated model_checkpoint.py to add the facility of retaining periodic checkpoints --- src/lightning/pytorch/callbacks/model_checkpoint.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 85bfb65c0ea6e..e760fd2581482 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -136,6 +136,10 @@ class ModelCheckpoint(Checkpoint): If this is ``False``, then the check runs at the end of the validation. enable_version_counter: Whether to append a version to the existing file name. If this is ``False``, then the checkpoint files will be overwritten. + retain_periodic_ckpt: Whether to retain the periodic checkpoints when multiple checkpoints are + saved. If this is ``False``, then only the latest checkpoint will be saved. If this is ``True``, + don't change the default value of ``save_top_k``. + Default: ``False``. Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -228,6 +232,7 @@ def __init__( every_n_epochs: Optional[int] = None, save_on_train_epoch_end: Optional[bool] = None, enable_version_counter: bool = True, + retain_periodic_ckpt: bool = False, ): super().__init__() self.monitor = monitor @@ -247,6 +252,7 @@ def __init__( self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" + self.retain_periodic_ckpt = retain_periodic_ckpt self.kth_value: Tensor self.dirpath: Optional[_PATH] @@ -714,7 +720,12 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate previous, self.best_model_path = self.best_model_path, filepath self._save_checkpoint(trainer, filepath) - if self.save_top_k == 1 and previous and self._should_remove_checkpoint(trainer, previous, filepath): + if ( + self.save_top_k == 1 + and not self.retain_periodic_ckpt + and previous + and self._should_remove_checkpoint(trainer, previous, filepath) + ): self._remove_checkpoint(trainer, previous) def _update_best_and_save(