Description
Bug description
Since data in the spawned region is not shared with the main process, the spawn launcher saves a checkpoint of the weights before finishing that is then loaded on the main process:
https://github.com/Lightning-AI/lightning/blob/984f49f7195ddc67e961c7c498ee6e19fc0cecb5/src/lightning/pytorch/strategies/launchers/multiprocessing.py#L190-L195 https://github.com/Lightning-AI/lightning/blob/984f49f7195ddc67e961c7c498ee6e19fc0cecb5/src/lightning/pytorch/strategies/launchers/multiprocessing.py#L162-L168
This means that the optimizer states are not loaded, as well as any other state in the trainer.
This isn't a problem with calling test/validate/predict
after fit
.
Solution
Since this is a silent correctness issue. We should raise an error in the short term.
The launcher can check if fit
was called and is getting called again, and then raise a NotImplementedError
.
In the longer term, we can save a full checkpoint that contains all the relevant data and then lift this restriction.
cc @tchaton @justusschock @awaelchli @carmocca @JackCaoG @Liyang90 @gkroiz