Skip to content

Commit e1ce887

Browse files
ryan597pre-commit-ci[bot]carmoccaBorda
authored
Fix load_from_checkpoint to return model on correct device (Lightning-AI#17308)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent 84eb82a commit e1ce887

File tree

4 files changed

+69
-6
lines changed

4 files changed

+69
-6
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949

5050
### Fixed
5151

52-
-
52+
- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))
53+
5354

5455

5556
## [2.0.1.post0] - 2023-04-11

src/lightning/pytorch/core/saving.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from copy import deepcopy
2222
from enum import Enum
2323
from pathlib import Path
24-
from typing import Any, Callable, cast, Dict, IO, Optional, Type, Union
24+
from typing import Any, Callable, Dict, IO, Optional, Type, Union
2525
from warnings import warn
2626

2727
import yaml
@@ -56,8 +56,6 @@ def _load_from_checkpoint(
5656
strict: Optional[bool] = None,
5757
**kwargs: Any,
5858
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
59-
if map_location is None:
60-
map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage)
6159
with pl_legacy_patch():
6260
checkpoint = pl_load(checkpoint_path, map_location=map_location)
6361

@@ -87,7 +85,14 @@ def _load_from_checkpoint(
8785
if issubclass(cls, pl.LightningDataModule):
8886
return _load_state(cls, checkpoint, **kwargs)
8987
if issubclass(cls, pl.LightningModule):
90-
return _load_state(cls, checkpoint, strict=strict, **kwargs)
88+
storage = _load_state(cls, checkpoint, strict=strict, **kwargs)
89+
state_dict = checkpoint["state_dict"]
90+
if not state_dict:
91+
raise ValueError(f"The state dict in {checkpoint_path!r} contains no parameters.")
92+
map_location = list(state_dict.values())[0].device
93+
assert isinstance(storage, pl.LightningModule)
94+
return storage.to(map_location)
95+
9196
raise NotImplementedError(f"Unsupported {cls}")
9297

9398

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
import torch
3+
4+
import lightning.pytorch as pl
5+
from lightning.pytorch.callbacks import ModelCheckpoint
6+
from lightning.pytorch.demos.boring_classes import BoringModel
7+
from tests_pytorch.helpers.runif import RunIf
8+
9+
10+
def create_boring_checkpoint(tmp_path, model, accelerator="cuda"):
11+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="checkpoint")
12+
trainer = pl.Trainer(
13+
devices=1,
14+
accelerator=accelerator,
15+
max_epochs=1,
16+
enable_model_summary=False,
17+
enable_progress_bar=False,
18+
callbacks=[checkpoint_callback],
19+
)
20+
trainer.fit(model)
21+
22+
23+
@pytest.mark.parametrize(
24+
"map_location", (None, "cpu", torch.device("cpu"), lambda storage, loc: storage, {"cpu": "cpu"})
25+
)
26+
def test_load_from_checkpoint_map_location_cpu(tmp_path, map_location):
27+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
28+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
29+
assert model.device.type == "cpu"
30+
31+
32+
@RunIf(min_cuda_gpus=1)
33+
@pytest.mark.parametrize(
34+
"map_location", (None, "cuda", torch.device("cuda"), lambda storage, loc: storage.cuda(), {"cpu": "cuda"})
35+
)
36+
def test_load_from_checkpoint_map_location_gpu(tmp_path, map_location):
37+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cuda")
38+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
39+
assert model.device.type == "cuda"
40+
41+
42+
@RunIf(min_cuda_gpus=1)
43+
@pytest.mark.parametrize("map_location", ("cpu", torch.device("cpu"), lambda storage, loc: storage, {"cuda": "cpu"}))
44+
def test_load_from_checkpoint_map_location_gpu_to_cpu(tmp_path, map_location):
45+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
46+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
47+
assert model.device.type == "cpu"
48+
49+
50+
@RunIf(min_cuda_gpus=1)
51+
@pytest.mark.parametrize(
52+
"map_location", ("cuda", torch.device("cuda"), lambda storage, loc: storage.cuda(), {"cpu": "cuda"})
53+
)
54+
def test_load_from_checkpoint_map_location_cpu_to_gpu(tmp_path, map_location):
55+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
56+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
57+
assert model.device.type == "cuda"

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
150150

151151
# Assert model parameters are identical after loading
152152
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
153-
assert torch.equal(ddp_param.float().cpu(), shard_param)
153+
assert torch.equal(ddp_param, shard_param)
154154

155155

156156
@RunIf(min_torch="1.12")

0 commit comments

Comments
 (0)