Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `ModelCheckpoint` no longer deletes files under the save-top-k mechanism when resuming from a folder that is not the same as the current checkpoint folder ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
- The `ModelCheckpoint` no longer deletes the file that was passed to `Trainer.fit(ckpt_path=...)` ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
- Calling `trainer.fit()` twice now raises an error with strategies that spawn subprocesses through `multiprocessing` (ddp_spawn, xla) ([#18776](https://github.com/Lightning-AI/lightning/pull/18776))
- The `ModelCheckpoint` now saves a symbolic link if `save_last=True` and `save_top_k != 0` ([#18748](https://github.com/Lightning-AI/lightning/pull/18748))

### Deprecated

Expand Down
37 changes: 22 additions & 15 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ class ModelCheckpoint(Checkpoint):
the number of finished epoch and optimizer steps respectively.
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
verbose: verbosity mode. Default: ``False``.
save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint
file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``.
save_last: When ``True``, saves a `last.ckpt` whenever a checkpoint file gets saved. On a local filesystem,
this will be a symbolic link, and otherwise a copy of the checkpoint file. This allows accessing the latest
checkpoint in a deterministic manner. Default: ``None``.
save_top_k: if ``save_top_k == k``,
the best k models according to the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(
self.best_model_score: Optional[Tensor] = None
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""

self.kth_value: Tensor
self.dirpath: Optional[_PATH]
Expand Down Expand Up @@ -371,12 +373,21 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)

self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = filepath

# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))

@staticmethod
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
if trainer.is_global_zero:
if os.path.lexists(linkpath):
os.remove(linkpath)
os.symlink(filepath, linkpath)
trainer.strategy.barrier()

def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from lightning.pytorch.trainer.states import TrainerFn

Expand Down Expand Up @@ -427,19 +438,12 @@ def __validate_init_configuration(self) -> None:
"should be mutually exclusive."
)

if self.monitor is None:
if self.monitor is None and self.save_top_k not in (-1, 0, 1):
# -1: save all epochs, 0: nothing is saved, 1: save last epoch
if self.save_top_k not in (-1, 0, 1):
raise MisconfigurationException(
f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
" configuration. No quantity for top_k to track."
)

if self.save_top_k == -1 and self.save_last:
rank_zero_info(
"ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"
" will duplicate the last checkpoint saved."
)
raise MisconfigurationException(
f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
" configuration. No quantity for top_k to track."
)

def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
self._fs = get_filesystem(dirpath if dirpath else "")
Expand Down Expand Up @@ -662,7 +666,10 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[

# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
self._save_checkpoint(trainer, filepath)
if self._fs.protocol == "file" and self._last_checkpoint_saved and self.save_top_k != 0:
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
else:
self._save_checkpoint(trainer, filepath)
if previous and self._should_remove_checkpoint(trainer, previous, filepath):
self._remove_checkpoint(trainer, previous)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self):
self.last_coeff = 10.0

def training_step(self, batch, batch_idx):
loss = self.step(torch.ones(32))
loss = self.step(torch.ones(32, device=self.device))
loss = loss / (loss + 0.0000001)
loss += self.last_coeff
self.log("my_loss", loss)
Expand All @@ -80,8 +80,7 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)

if save_last:
# last epochs are saved every step (so double the save calls)
expected = expected * 2
expected = expected
assert save_mock.call_count == expected


Expand Down
19 changes: 11 additions & 8 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import time
from argparse import Namespace
from datetime import timedelta
from logging import INFO
from pathlib import Path
from typing import Union
from unittest import mock
Expand Down Expand Up @@ -510,7 +509,8 @@ def test_model_checkpoint_save_last(tmpdir):
assert set(os.listdir(tmpdir)) == set(
[f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename]
)

assert os.path.islink(tmpdir / last_filename)
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"


Expand Down Expand Up @@ -589,10 +589,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
max_epochs=epochs,
logger=False,
)

with caplog.at_level(INFO):
trainer.fit(model)
assert "will duplicate the last checkpoint saved" in caplog.text
trainer.fit(model)

# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
Expand All @@ -606,6 +603,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])]
expected.append("last.ckpt")
assert set(os.listdir(tmpdir)) == set(expected)
assert os.path.islink(tmpdir / "last.ckpt")


@pytest.mark.parametrize("every_n_epochs", list(range(4)))
Expand Down Expand Up @@ -709,6 +707,8 @@ def test_model_checkpoint_topk_zero(tmpdir):
# check that only the last ckpt was created
assert os.listdir(tmpdir) == ["last.ckpt"]
assert checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
# 'last.ckpt' is not a symlink because there are no top-k checkpoints to link
assert not os.path.islink(checkpoint_callback.last_model_path)


def test_model_checkpoint_topk_all(tmpdir):
Expand Down Expand Up @@ -814,6 +814,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)
assert os.path.islink(path_last)

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
Expand Down Expand Up @@ -1343,7 +1344,7 @@ def test_save_last_saves_correct_last_model_path(tmpdir):
trainer = Trainer(callbacks=mc)
trainer.strategy.connect(BoringModel())

mc._save_last_checkpoint(trainer, {"foo": 1})
mc._save_last_checkpoint(trainer, {"foo": torch.tensor(1)})
expected = "foo=1-last.ckpt"
assert os.listdir(tmpdir) == [expected]
full_path = str(tmpdir / expected)
Expand All @@ -1366,6 +1367,8 @@ def test_save_last_versioning(tmpdir):
)
trainer.fit(model)
assert {"last.ckpt", "last-v1.ckpt"} == set(os.listdir(tmpdir))
# 'last.ckpt' is not a symlink since `save_top_k=0` didn't save any other checkpoints to link to
assert all(not os.path.islink(tmpdir / path) for path in set(os.listdir(tmpdir)))


def test_none_monitor_saves_correct_best_model_path(tmpdir):
Expand All @@ -1385,7 +1388,7 @@ def test_last_global_step_saved():
# this should not save anything
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
trainer = Mock()
monitor_candidates = {"foo": 123}
monitor_candidates = {"foo": torch.tensor(123)}
model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates)
model_checkpoint._save_last_checkpoint(trainer, monitor_candidates)
assert model_checkpoint._last_global_step_saved == 0
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,11 @@ def get_trainer_args():
"best_k_models",
"kth_best_model_path",
"kth_value",
"last_model_path",
):
assert getattr(before, attribute) == getattr(after, attribute)
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
# `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created,
# hence reloading that checkpoint will restore `after.last_model_path = ""`
assert after.last_model_path == ""


@RunIf(sklearn=True)
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_checkpoint_plugin_called(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="cpu",
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
callbacks=ck,
max_epochs=2,
Expand All @@ -60,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir):
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.save_checkpoint.call_count == 2
assert checkpoint_plugin.remove_checkpoint.call_count == 1

trainer.test(model, ckpt_path=ck.last_model_path)
Expand All @@ -72,6 +73,7 @@ def test_checkpoint_plugin_called(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="cpu",
strategy=SingleDeviceStrategy("cpu"),
plugins=[checkpoint_plugin],
callbacks=ck,
Expand All @@ -86,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir):
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.save_checkpoint.call_count == 2
assert checkpoint_plugin.remove_checkpoint.call_count == 1

trainer.test(model, ckpt_path=ck.last_model_path)
Expand Down