Calling trainer.fit
twice with spawn strategies won't work as expected
#18775
Labels
bug
Something isn't working
priority: 1
Medium priority task
strategy: ddp
DistributedDataParallel
strategy: xla
ver: 2.0.x
Milestone
Bug description
Since data in the spawned region is not shared with the main process, the spawn launcher saves a checkpoint of the weights before finishing that is then loaded on the main process:
https://github.com/Lightning-AI/lightning/blob/984f49f7195ddc67e961c7c498ee6e19fc0cecb5/src/lightning/pytorch/strategies/launchers/multiprocessing.py#L190-L195 https://github.com/Lightning-AI/lightning/blob/984f49f7195ddc67e961c7c498ee6e19fc0cecb5/src/lightning/pytorch/strategies/launchers/multiprocessing.py#L162-L168
This means that the optimizer states are not loaded, as well as any other state in the trainer.
This isn't a problem with calling
test/validate/predict
afterfit
.Solution
Since this is a silent correctness issue. We should raise an error in the short term.
The launcher can check if
fit
was called and is getting called again, and then raise aNotImplementedError
.In the longer term, we can save a full checkpoint that contains all the relevant data and then lift this restriction.
cc @tchaton @justusschock @awaelchli @carmocca @JackCaoG @Liyang90 @gkroiz
The text was updated successfully, but these errors were encountered: