Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224))


- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303))


## [2.1.3] - 2023-12-21

### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> Non
elif os.path.isdir(linkpath):
shutil.rmtree(linkpath)
try:
os.symlink(filepath, linkpath)
os.symlink(os.path.relpath(filepath, os.path.dirname(linkpath)), linkpath)
except OSError:
# on Windows, special permissions are required to create symbolic links as a regular user
# fall back to copying the file
Expand Down
21 changes: 21 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,20 +534,23 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(file)
assert not os.path.isabs(os.readlink(link))

# link exists (is a file)
new_file1 = tmp_path / "new_file1"
new_file1.touch()
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file1), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(new_file1)
assert not os.path.isabs(os.readlink(link))

# link exists (is a link)
new_file2 = tmp_path / "new_file2"
new_file2.touch()
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file2), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(new_file2)
assert not os.path.isabs(os.readlink(link))

# link exists (is a folder)
folder = tmp_path / "folder"
Expand All @@ -557,13 +560,15 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(folder), linkpath=str(folder_link))
assert os.path.islink(folder_link)
assert os.path.realpath(folder_link) == str(folder)
assert not os.path.isabs(os.readlink(folder_link))

# link exists (is a link to a folder)
new_folder = tmp_path / "new_folder"
new_folder.mkdir()
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_folder), linkpath=str(folder_link))
assert os.path.islink(folder_link)
assert os.path.realpath(folder_link) == str(new_folder)
assert not os.path.isabs(os.readlink(folder_link))

# simulate permission error on Windows (creation of symbolic links requires privileges)
file = tmp_path / "win_file"
Expand All @@ -575,6 +580,22 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
assert os.path.isfile(link) # fall back to copying instead of linking


def test_model_checkpoint_link_checkpoint_relative_path(tmp_path, monkeypatch):
"""Test that linking a checkpoint works with relative paths."""
trainer = Mock()
monkeypatch.chdir(tmp_path)

folder = Path("x/z/z")
folder.mkdir(parents=True)
file = folder / "file"
file.touch()
link = folder / "link"
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file.absolute()), linkpath=str(link.absolute()))
assert os.path.islink(link)
assert Path(os.readlink(link)) == file.relative_to(folder)
assert not os.path.isabs(os.readlink(link))


def test_invalid_top_k(tmpdir):
"""Make sure that a MisconfigurationException is raised for a negative save_top_k argument."""
with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):
Expand Down