Skip to content

gradient checkpointing runs during validations in the training examples #10107

Closed
@bghira

Description

Describe the bug

The gradient checkpointing is enabled via self.training, but then the log_validations also unnecessarily encounter this codepath.

I found that I have to disable this when running even under torch.no_grad(), looked and saw that the official examples do not do this either.

This gives a substantial performance boost more similar to a normal inference script running outside of a training loop.

Reproduction

Add print statements to the checkpointing function.

Logs

No response

System Info

Who can help?

@linoytsaban @yiyixuxu

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions