Skip to content

Commit

Permalink
updated model_checkpoint.py to add the facility of retaining periodic…
Browse files Browse the repository at this point in the history
… checkpoints
  • Loading branch information
arijit-hub committed Jan 14, 2025
1 parent a944e77 commit 3bc6e9f
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class ModelCheckpoint(Checkpoint):
If this is ``False``, then the check runs at the end of the validation.
enable_version_counter: Whether to append a version to the existing file name.
If this is ``False``, then the checkpoint files will be overwritten.
retain_periodic_ckpt: Whether to retain the periodic checkpoints when multiple checkpoints are
saved. If this is ``False``, then only the latest checkpoint will be saved. If this is ``True``,
don't change the default value of ``save_top_k``.
Default: ``False``.
Note:
For extra customization, ModelCheckpoint includes the following attributes:
Expand Down Expand Up @@ -228,6 +232,7 @@ def __init__(
every_n_epochs: Optional[int] = None,
save_on_train_epoch_end: Optional[bool] = None,
enable_version_counter: bool = True,
retain_periodic_ckpt: bool = False,
):
super().__init__()
self.monitor = monitor
Expand All @@ -247,6 +252,7 @@ def __init__(
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
self.retain_periodic_ckpt = retain_periodic_ckpt

self.kth_value: Tensor
self.dirpath: Optional[_PATH]
Expand Down Expand Up @@ -714,7 +720,12 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
previous, self.best_model_path = self.best_model_path, filepath
self._save_checkpoint(trainer, filepath)

if self.save_top_k == 1 and previous and self._should_remove_checkpoint(trainer, previous, filepath):
if (
self.save_top_k == 1
and not self.retain_periodic_ckpt
and previous
and self._should_remove_checkpoint(trainer, previous, filepath)
):
self._remove_checkpoint(trainer, previous)

def _update_best_and_save(
Expand Down

0 comments on commit 3bc6e9f

Please sign in to comment.