Skip to content

Incorrect batch progress saved in checkpoint at every_n_train_steps #18060

@shuaitang5

Description

@shuaitang5

Bug description

When saving a checkpoint at every_n_train_steps=3, it performs the checkpoint saving inside on_train_batch_end function in ModelCheckpoint class. During that checkpoint saving, the state dict of fit loop is snapshotted and saved, along with the batch progress of it. But the batch_progress is only incremented after on_train_batch_end is called/checkpoint is saved, thus the saved checkpoint having incorrect batch_progress which looks like this:

# in checkpoint file checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']
{total: {ready: 3, completed: 2, started: 3, processed: 3}}

And the expected value should be: {total: {ready: 3, completed: 3, started: 3, processed: 3}}, which is what the checkpoint saved after validation contains.

This causes an issue that when we resume from batch_end checkpoint, the starting batch_idx is 2 while the global step is 3 in training_step function in model module (they should match), and following saved checkpoint all having incorrect step value in file name. This doesn't seem like expected behavior, am I missing something?

I'm currently using a hack in the on_train_batch_end override function like this to overcome this issue:

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
        # hack: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py#L233-L237
        # At the time this function is called, the `completed` value in batch progress is not incremented yet.
        # If a checkpoint is saved, the saved checkpoint will have an incorrect completed value in batch progress.
        # When we resume from this checkpoint, it will cause batch_idx becoming one step behind global step value in training_step func in modelModule
        trainer.fit_loop.epoch_loop.batch_progress.increment_completed()
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
        
        # revert back changes to completed value in batch progress
        trainer.fit_loop.epoch_loop.batch_progress.total.completed -= 1
        trainer.fit_loop.epoch_loop.batch_progress.current.completed -= 1

What version are you seeing the problem on?

v1.9, master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca @justusschock

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onloopsRelated to the Loop APIrepro neededThe issue is missing a reproducible examplever: 1.9.xver: 2.1.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions