Skip to content

Commit 6f924c3

Browse files
committed
assert storage is LightningModule
1 parent 88a3ee4 commit 6f924c3

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/lightning/pytorch/core/saving.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ def _load_from_checkpoint(
8585
if issubclass(cls, pl.LightningDataModule):
8686
return _load_state(cls, checkpoint, **kwargs)
8787
if issubclass(cls, pl.LightningModule):
88+
storage = _load_state(cls, checkpoint, strict=strict, **kwargs)
8889
assert len(checkpoint["state_dict"]) > 0
90+
assert isinstance(storage, pl.LightningModule)
8991
map_location = list(checkpoint["state_dict"].values())[0].device
90-
return _load_state(cls, checkpoint, strict=strict, **kwargs).to(map_location)
92+
return storage.to(map_location)
9193

9294
raise NotImplementedError(f"Unsupported {cls}")
9395

0 commit comments

Comments
 (0)