Skip to content

Calling trainer.fit twice with spawn strategies won't work as expected #18775

Open
@carmocca

Description

@carmocca

Bug description

Since data in the spawned region is not shared with the main process, the spawn launcher saves a checkpoint of the weights before finishing that is then loaded on the main process:

https://github.com/Lightning-AI/lightning/blob/984f49f7195ddc67e961c7c498ee6e19fc0cecb5/src/lightning/pytorch/strategies/launchers/multiprocessing.py#L190-L195 https://github.com/Lightning-AI/lightning/blob/984f49f7195ddc67e961c7c498ee6e19fc0cecb5/src/lightning/pytorch/strategies/launchers/multiprocessing.py#L162-L168

This means that the optimizer states are not loaded, as well as any other state in the trainer.

This isn't a problem with calling test/validate/predict after fit.

Solution

Since this is a silent correctness issue. We should raise an error in the short term.

The launcher can check if fit was called and is getting called again, and then raise a NotImplementedError.

In the longer term, we can save a full checkpoint that contains all the relevant data and then lift this restriction.

cc @tchaton @justusschock @awaelchli @carmocca @JackCaoG @Liyang90 @gkroiz

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions