-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
When saving a checkpoint at every_n_train_steps=3
, it performs the checkpoint saving inside on_train_batch_end function in ModelCheckpoint
class. During that checkpoint saving, the state dict of fit loop is snapshotted and saved, along with the batch progress of it. But the batch_progress is only incremented after on_train_batch_end is called/checkpoint is saved, thus the saved checkpoint having incorrect batch_progress which looks like this:
# in checkpoint file checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']
{total: {ready: 3, completed: 2, started: 3, processed: 3}}
And the expected value should be: {total: {ready: 3, completed: 3, started: 3, processed: 3}}
, which is what the checkpoint saved after validation contains.
This causes an issue that when we resume from batch_end checkpoint, the starting batch_idx is 2 while the global step is 3 in training_step function in model module (they should match), and following saved checkpoint all having incorrect step value in file name. This doesn't seem like expected behavior, am I missing something?
I'm currently using a hack in the on_train_batch_end override function like this to overcome this issue:
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
# hack: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py#L233-L237
# At the time this function is called, the `completed` value in batch progress is not incremented yet.
# If a checkpoint is saved, the saved checkpoint will have an incorrect completed value in batch progress.
# When we resume from this checkpoint, it will cause batch_idx becoming one step behind global step value in training_step func in modelModule
trainer.fit_loop.epoch_loop.batch_progress.increment_completed()
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
# revert back changes to completed value in batch progress
trainer.fit_loop.epoch_loop.batch_progress.total.completed -= 1
trainer.fit_loop.epoch_loop.batch_progress.current.completed -= 1
What version are you seeing the problem on?
v1.9, master
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response