Skip to content

Commit c41de5b

Browse files
ryan597pre-commit-ci[bot]carmoccaBorda
committed
Fix load_from_checkpoint to return model on correct device (#17308)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> (cherry picked from commit e1ce887)
1 parent d6b472d commit c41de5b

File tree

4 files changed

+79
-4
lines changed

4 files changed

+79
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
8+
## [UnReleased] - 2023-04-DD
9+
10+
### Changed
11+
12+
-
13+
14+
15+
### Fixed
16+
17+
- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))
18+
19+
720
## [1.9.5] - 2023-03-30
821

922
### Changed

src/pytorch_lightning/core/saving.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ def _load_from_checkpoint(
154154
strict: Optional[bool] = None,
155155
**kwargs: Any,
156156
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
157-
if map_location is None:
158-
map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage)
159157
with pl_legacy_patch():
160158
checkpoint = pl_load(checkpoint_path, map_location=map_location)
161159

@@ -185,7 +183,14 @@ def _load_from_checkpoint(
185183
if issubclass(cls, pl.LightningDataModule):
186184
return _load_state(cls, checkpoint, **kwargs)
187185
if issubclass(cls, pl.LightningModule):
188-
return _load_state(cls, checkpoint, strict=strict, **kwargs)
186+
storage = _load_state(cls, checkpoint, strict=strict, **kwargs)
187+
state_dict = checkpoint["state_dict"]
188+
if not state_dict:
189+
raise ValueError(f"The state dict in {checkpoint_path!r} contains no parameters.")
190+
map_location = list(state_dict.values())[0].device
191+
assert isinstance(storage, pl.LightningModule)
192+
return storage.to(map_location)
193+
189194
raise NotImplementedError(f"Unsupported {cls}")
190195

191196

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
import torch
3+
4+
import lightning.pytorch as pl
5+
from lightning.pytorch.callbacks import ModelCheckpoint
6+
from lightning.pytorch.demos.boring_classes import BoringModel
7+
from tests_pytorch.helpers.runif import RunIf
8+
9+
10+
def create_boring_checkpoint(tmp_path, model, accelerator="cuda"):
11+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="checkpoint")
12+
trainer = pl.Trainer(
13+
devices=1,
14+
accelerator=accelerator,
15+
max_epochs=1,
16+
enable_model_summary=False,
17+
enable_progress_bar=False,
18+
callbacks=[checkpoint_callback],
19+
)
20+
trainer.fit(model)
21+
22+
23+
@pytest.mark.parametrize(
24+
"map_location", (None, "cpu", torch.device("cpu"), lambda storage, loc: storage, {"cpu": "cpu"})
25+
)
26+
def test_load_from_checkpoint_map_location_cpu(tmp_path, map_location):
27+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
28+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
29+
assert model.device.type == "cpu"
30+
31+
32+
@RunIf(min_cuda_gpus=1)
33+
@pytest.mark.parametrize(
34+
"map_location", (None, "cuda", torch.device("cuda"), lambda storage, loc: storage.cuda(), {"cpu": "cuda"})
35+
)
36+
def test_load_from_checkpoint_map_location_gpu(tmp_path, map_location):
37+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cuda")
38+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
39+
assert model.device.type == "cuda"
40+
41+
42+
@RunIf(min_cuda_gpus=1)
43+
@pytest.mark.parametrize("map_location", ("cpu", torch.device("cpu"), lambda storage, loc: storage, {"cuda": "cpu"}))
44+
def test_load_from_checkpoint_map_location_gpu_to_cpu(tmp_path, map_location):
45+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
46+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
47+
assert model.device.type == "cpu"
48+
49+
50+
@RunIf(min_cuda_gpus=1)
51+
@pytest.mark.parametrize(
52+
"map_location", ("cuda", torch.device("cuda"), lambda storage, loc: storage.cuda(), {"cpu": "cuda"})
53+
)
54+
def test_load_from_checkpoint_map_location_cpu_to_gpu(tmp_path, map_location):
55+
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
56+
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
57+
assert model.device.type == "cuda"

tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
139139

140140
# Assert model parameters are identical after loading
141141
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
142-
assert torch.equal(ddp_param.float().cpu(), shard_param)
142+
assert torch.equal(ddp_param, shard_param)
143143

144144

145145
@RunIf(min_torch="1.12")

0 commit comments

Comments
 (0)