Skip to content

Inconcistency in loading from checkpoint in LightningCLI #20801

Open
@Northo

Description

@Northo

Bug description

When using a checkpoint in LightningCLI, the model is first instantiated and then the checkpoint is loaded by supplying it to the Trainer's method's ckpt_path argument.

The problem is that hyperparameters in the checkpoint are not used when instantiating the model, and thus when allocating tensors, which can cause checkpoint loading to fail if tensor sizes do not match. Furthermore, if there is complicated instantiation logic in the model, this may lead to other silent bugs or failures.

This was first raised as a discussion in #20715

What version are you seeing the problem on?

v2.5

How to reproduce the bug

Here is a minimal example, where the predict method is used. We modify the out_dim in fit, so that the last layer has a different size, causing loading in predict to fail.

# cli.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

class DemoModelWithHyperparameters(DemoModel):
    def __init__(self, *args, **kwargs):
        self.save_hyperparameters()
        super().__init__(*args, **kwargs)

def cli_main():
    cli = LightningCLI(DemoModelWithHyperparameters, BoringDataModule)

if __name__ == "__main__":
    cli_main()

and then run

$ python src/lightning_cli_load_checkpoint/cli.py fit --trainer.max_epochs 1 --model.out_dim 2
$ python src/lightning_cli_load_checkpoint/cli.py predict --ckpt_path <path_to_checkpoint>

Error messages and logs

Restoring states from the checkpoint path at lightning_logs/version_23/checkpoints/epoch=0-step=64.ckpt
Traceback (most recent call last):
  File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 20, in <module>
    cli_main()
  File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 16, in cli_main
    cli = MyLightningCLI(DemoModelWithHyperparameters, datamodule_class=BoringDataModule)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 398, in __init__
    self._run_subcommand(self.subcommand)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 708, in _run_subcommand
    fn(**fn_kwargs)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 887, in predict
    return call._call_and_handle_interrupt(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 928, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 409, in _restore_modules_and_callbacks
    self.restore_model()
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 286, in restore_model
    self.trainer.strategy.load_model_state_dict(
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 372, in load_model_state_dict
    self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for DemoModelWithHyperparameters:
	size mismatch for l1.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([10, 32]).
	size mismatch for l1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([10]).

Expected behavior

I'd expect the loading to respect the checkpoint's arguments. In other words, while the current implementation roughly follows this logic:

model = Model(**cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)

I'd expect it to be closer to

model = Model.load_from_checkpoint(ckpt_path, **cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • lightning: 2.5.1
    • lightning-cli-load-checkpoint: 0.1.0
    • lightning-utilities: 0.14.3
    • pytorch-lightning: 2.5.1
    • torch: 2.6.0
    • torchmetrics: 1.7.1
  • Packages:
    • aiohappyeyeballs: 2.6.1
    • aiohttp: 3.11.16
    • aiosignal: 1.3.2
    • antlr4-python3-runtime: 4.9.3
    • attrs: 25.3.0
    • autocommand: 2.2.2
    • backports.tarfile: 1.2.0
    • contourpy: 1.3.1
    • cycler: 0.12.1
    • docstring-parser: 0.16
    • filelock: 3.18.0
    • fonttools: 4.57.0
    • frozenlist: 1.5.0
    • fsspec: 2025.3.2
    • hydra-core: 1.3.2
    • idna: 3.10
    • importlib-metadata: 8.0.0
    • importlib-resources: 6.5.2
    • inflect: 7.3.1
    • jaraco.collections: 5.1.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.1
    • jaraco.text: 3.12.1
    • jinja2: 3.1.6
    • jsonargparse: 4.38.0
    • kiwisolver: 1.4.8
    • lightning: 2.5.1
    • lightning-cli-load-checkpoint: 0.1.0
    • lightning-utilities: 0.14.3
    • markdown-it-py: 3.0.0
    • markupsafe: 3.0.2
    • matplotlib: 3.10.1
    • mdurl: 0.1.2
    • more-itertools: 10.3.0
    • mpmath: 1.3.0
    • multidict: 6.4.3
    • networkx: 3.4.2
    • numpy: 2.2.4
    • omegaconf: 2.3.0
    • packaging: 24.2
    • pillow: 11.2.1
    • platformdirs: 4.2.2
    • propcache: 0.3.1
    • protobuf: 6.30.2
    • pygments: 2.19.1
    • pyparsing: 3.2.3
    • python-dateutil: 2.9.0.post0
    • pytorch-lightning: 2.5.1
    • pyyaml: 6.0.2
    • rich: 13.9.4
    • setuptools: 78.1.0
    • six: 1.17.0
    • sympy: 1.13.1
    • tensorboardx: 2.6.2.2
    • tomli: 2.0.1
    • torch: 2.6.0
    • torchmetrics: 1.7.1
    • tqdm: 4.67.1
    • typeguard: 4.3.0
    • typeshed-client: 2.7.0
    • typing-extensions: 4.13.2
    • wheel: 0.45.1
    • yarl: 1.19.0
    • zipp: 3.19.2
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: arm
    • python: 3.12.7
    • release: 24.4.0
    • version: Darwin Kernel Version 24.4.0: Fri Apr 11 18:33:47 PDT 2025; root:xnu-11417.101.15~117/RELEASE_ARM64_T6000

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions