Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[scale] Implement activation checkpointing for transformers. #3864

Merged
merged 5 commits into from
Jul 28, 2021

Conversation

stephenroller
Copy link
Contributor

@stephenroller stephenroller commented Jul 26, 2021

Patch description
Leverage Fairscale to add in activation checkpointing. Activation checkpointing works by throwing away the activations during the forward pass, and recomputing them just-in-time during the backward pass, thereby trading away compute for memory savings.

It is particularly powerful when combined with FSDP. Utilized with --checkpoint-activations true

Testing steps
New CI test to just ensure that the path doesn't crash. Due to the simplicity of the call, I rely on fairscale's tests for correctness.

Manual testing: parlai train -m bart -t taskmaster2 -bs 4 -dynb full -tstep 100 --checkpoint-activations true. Took some very non-scientific measurements but were exactly as expected:

Checkpoint Memory (nvidia-smi) Updates/sec
No (baseline) 10823mb 1.01 u/s
Yes 7853mb .77 u/s

Copy link
Contributor

@meganung meganung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! a clean wrapper add!

@stephenroller stephenroller merged commit fa49002 into master Jul 28, 2021
@stephenroller stephenroller deleted the checkpoint branch July 28, 2021 19:35
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants