-
Notifications
You must be signed in to change notification settings - Fork 3k
[JAX][DOC] Add optimizer state offloading doc #28988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
zhenying-liu
commented
May 23, 2025
- Add the optimizer state offloading with a code example.
- Add the memory usage comparison for activation/parameter/optimizer state offloading with their baseline implementations. The memory stats were collected on a GPU.
docs/notebooks/host-offloading.md
Outdated
|
||
By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement. | ||
By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement. However, offloading may degrade performance due to memory transfers between host and device, so it's important to consider this trade-off when designing your optimization strategy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't write this statement. With good overlap, you might not see any degradation right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We did see a lot of degradation on GPUs, especially parameter and optimizer state offloading. Activation offloading is optimized recently, but its performance is still worse than no offloading. So we want to tell the user about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of completely removing this sentence, can we let the user be aware of the performance concern? @yashk2810 @jreiffers
"Note that offloading performance may vary significantly across device types."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mention this below in the optimizer offloading section which is fine. No need to mention it again at the top. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Removed here at the top and mentioned in the "Limitations of Parameter Offloading" session.
docs/notebooks/host-offloading.md
Outdated
|
||
### Basic Implementation | ||
|
||
In this section, you will implement a simple model with the Adam optimizer. This implementation will help you understand the baseline behavior before exploring optimizer state offloading. It is particularly useful for understanding memory patterns in large-scale neural network training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"In this section, let's implement a simple model with the Adam optimizer". In general, prefer not using "you"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Not use "you" in this document.
712601b
to
d7dc38e
Compare
docs/notebooks/host-offloading.md
Outdated
This implementation demonstrates how to: | ||
1. Set up sharding specifications for `device` and `pinned_host` | ||
2. Move optimizer states between host and device memory via {func}`jax.device_put` | ||
3. Use `in_sharding` and `out_shardings` to ensure proper memory placement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_sharding shouldn't be used. Can you remove that please? The arrays should be on the correct memory kind and sharding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed all the in_shardings. Verified in the colab that the all the code still running expectedly.
So in_shardings is indeed redundant.