### Describe the bug The gradient checkpointing is enabled via self.training, but then the log_validations also unnecessarily encounter this codepath. I found that I have to disable this when running even under `torch.no_grad()`, looked and saw that the official examples do not do this either. This gives a substantial performance boost more similar to a normal inference script running outside of a training loop. ### Reproduction Add print statements to the checkpointing function. ### Logs _No response_ ### System Info - ### Who can help? @linoytsaban @yiyixuxu