Closed
Description
Describe the bug
The gradient checkpointing is enabled via self.training, but then the log_validations also unnecessarily encounter this codepath.
I found that I have to disable this when running even under torch.no_grad()
, looked and saw that the official examples do not do this either.
This gives a substantial performance boost more similar to a normal inference script running outside of a training loop.
Reproduction
Add print statements to the checkpointing function.
Logs
No response