Potential off by 1 error when resuming training of mid-epoch checkpoint #19367
Labels
bug
Something isn't working
help wanted
Open to be worked on
loops
Related to the Loop API
ver: 2.1.x
Bug description
During the fit loop, here's a simple log of the
global_step
andbatch_idx
values duringon_train_batch_start
andon_train_batch_end
.Notice that global_step and batch_idx are equal during batch_start and global step is 1 greater than batch index for batch_end. Now, if I save a mid-epoch checkpoint after 5 training steps and resume training, I see the following
Now the two values are off by 1 during batch start and off by 2 during batch end. This seems to be an issue because it changes when validation and checkpointing is run. In both runs, I have
Trainer(val_check_interval=5, ...)
andModelCheckpoint(every_n_train_steps=5, ...)
. In the original run, validation happens after 5 and 10 training steps, as expected. In the resumed run, validation only happens once after 6 training steps.My initial guess is that this is happening because
self.batch_progress.increment_completed()
insrc/lightning/pytorch/loops/training_epoch_loop.py
is called afterso the checkpoint thinks we've only completed
global_steps-1
training steps.What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
cc @carmocca @justusschock
The text was updated successfully, but these errors were encountered: