Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))


- Added the option `ModelCheckpoint(save_last='link')` to create a symbolic link for the 'last.ckpt' file ([#19191](https://github.com/Lightning-AI/lightning/pull/19191))


### Changed

- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
Expand All @@ -47,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The columns in the `metrics.csv` file produced by `CSVLogger` are now sorted alphabetically ([#19159](https://github.com/Lightning-AI/lightning/pull/19159))


- Reverted back to creating a checkpoint copy when `ModelCheckpoint(save_last=True)` instead of creating a symbolic link ([#19191](https://github.com/Lightning-AI/lightning/pull/19191))


### Deprecated

- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))
Expand Down
16 changes: 10 additions & 6 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Optional, Set
from typing import Any, Dict, Literal, Optional, Set
from weakref import proxy

import torch
Expand Down Expand Up @@ -83,9 +83,9 @@ class ModelCheckpoint(Checkpoint):
the number of finished epoch and optimizer steps respectively.
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
verbose: verbosity mode. Default: ``False``.
save_last: When ``True``, saves a `last.ckpt` whenever a checkpoint file gets saved. On a local filesystem,
this will be a symbolic link, and otherwise a copy of the checkpoint file. This allows accessing the latest
checkpoint in a deterministic manner. Default: ``None``.
save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
in a deterministic manner. Default: ``None``.
save_top_k: if ``save_top_k == k``,
the best k models according to the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(
filename: Optional[str] = None,
monitor: Optional[str] = None,
verbose: bool = False,
save_last: Optional[bool] = None,
save_last: Optional[Literal[True, False, "link"]] = None,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = "min",
Expand Down Expand Up @@ -272,6 +272,10 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
self._fs = get_filesystem(self.dirpath or "")
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)
if self.save_last == "link" and not _is_local_file_protocol(self.dirpath):
raise ValueError(
f"`ModelCheckpoint(save_last='link')` is only supported for local file paths, got `dirpath={dirpath}`."
)

@override
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -684,7 +688,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[

# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
if _is_local_file_protocol(filepath) and self._last_checkpoint_saved and self.save_top_k != 0:
if self.save_last == "link" and self._last_checkpoint_saved and self.save_top_k != 0:
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
else:
self._save_checkpoint(trainer, filepath)
Expand Down
20 changes: 15 additions & 5 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,14 @@ def test_model_checkpoint_file_extension(tmpdir):
assert set(expected) == set(os.listdir(tmpdir))


def test_model_checkpoint_save_last(tmpdir, monkeypatch):
@pytest.mark.parametrize("save_last", [True, "link"])
def test_model_checkpoint_save_last(save_last, tmpdir, monkeypatch):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = LogInTwoMethods()
epochs = 3
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_NAME_LAST", "last-{epoch}")
model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=True)
model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=save_last)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
Expand All @@ -509,10 +510,19 @@ def test_model_checkpoint_save_last(tmpdir, monkeypatch):
assert set(os.listdir(tmpdir)) == set(
[f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename]
)
assert os.path.islink(tmpdir / last_filename)
if save_last == "link":
assert os.path.islink(tmpdir / last_filename)
else:
assert os.path.isfile(tmpdir / last_filename)
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved


def test_model_checkpoint_save_last_as_link_not_local(tmp_path):
callback = ModelCheckpoint(dirpath="memory://not-a-filesystem-path", save_last="link")
with pytest.raises(ValueError, match="save_last='link'.* is only supported for local file paths"):
callback.setup(trainer=Trainer(), pl_module=BoringModel(), stage="fit")


def test_model_checkpoint_link_checkpoint(tmp_path):
"""Test that linking a checkpoint works and overwrites an existing link if present."""
trainer = Mock()
Expand Down Expand Up @@ -676,7 +686,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])]
expected.append("last.ckpt")
assert set(os.listdir(tmpdir)) == set(expected)
assert os.path.islink(tmpdir / "last.ckpt")
assert os.path.isfile(tmpdir / "last.ckpt")


@pytest.mark.parametrize("every_n_epochs", list(range(4)))
Expand Down Expand Up @@ -887,7 +897,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)
assert os.path.islink(path_last)
assert os.path.isfile(path_last)

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
Expand Down