-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Step when validation happens drifts for val_check_interval
when gradient accumulation turned on
#17207
Comments
val_check_interval
when gradient accumulation turned on val_check_interval
when gradient accumulation turned on
I think it is actually the moment when validation happens drift. The checkpoint saving is just a side effect. |
Validation check tracks training batches instead of training steps. According to the documentation,
However, training batches does not always equal to training steps (global steps). Training step is I think it would make sense to validate after N training steps instead of training batches. Other module such as Logger and Model Checkpoint use global steps to track training steps too. I propose we can change from to elif self.trainer.val_check_batch != float("inf"):
# if `check_val_every_n_epoch is` None`, run a validation loop every n training steps
# else condition it based on the batch_idx of the current epoch
next_iteration = self.global_step if self.trainer.check_val_every_n_epoch is None else self.batch_idx + 1
is_val_check_batch = next_iteration % self.trainer.val_check_batch == 0 |
Is there a plan to add step-based validation checks in Lightning? Until Lightning adds official support for this, any recommendations on how we can override the default Lightning behavior and use number of steps instead of batches to trigger the validation? |
Right now, for myself, I have to discard the last batch to make steps multiples of accum grad. |
Bug description
First of all, my task relies on step count instead of epochs. So I am doing validation checks by steps and saving checkpoints after that. However, as I turned gradient accumulation on, and the batch count is not divisible, I encountered weird drifts for the actual step when the validation is performed, and thus the checkpointing.
In the example below, I override the
_save_checkpoint
function to monitor the actual file name and it turns out to be drifting. My general setting isval_check_interval=accumulation*5
to make it validate every 5 effective optimizer steps,accumulation=3
and#batches=67
so there is one batch leftover.How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
Other than this phenomenon, I have two more questions
val_check_interval
tied to the number of batches rather thanglobal_step
?cc @carmocca @justusschock
The text was updated successfully, but these errors were encountered: