We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 88a3ee4 commit 6f924c3Copy full SHA for 6f924c3
src/lightning/pytorch/core/saving.py
@@ -85,9 +85,11 @@ def _load_from_checkpoint(
85
if issubclass(cls, pl.LightningDataModule):
86
return _load_state(cls, checkpoint, **kwargs)
87
if issubclass(cls, pl.LightningModule):
88
+ storage = _load_state(cls, checkpoint, strict=strict, **kwargs)
89
assert len(checkpoint["state_dict"]) > 0
90
+ assert isinstance(storage, pl.LightningModule)
91
map_location = list(checkpoint["state_dict"].values())[0].device
- return _load_state(cls, checkpoint, strict=strict, **kwargs).to(map_location)
92
+ return storage.to(map_location)
93
94
raise NotImplementedError(f"Unsupported {cls}")
95
0 commit comments