Skip to content

[JAX][DOC] Add optimizer state offloading doc #28988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

zhenying-liu
Copy link
Contributor

  1. Add the optimizer state offloading with a code example.
  2. Add the memory usage comparison for activation/parameter/optimizer state offloading with their baseline implementations. The memory stats were collected on a GPU.

@zhenying-liu
Copy link
Contributor Author


By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement.
By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement. However, offloading may degrade performance due to memory transfers between host and device, so it's important to consider this trade-off when designing your optimization strategy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't write this statement. With good overlap, you might not see any degradation right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did see a lot of degradation on GPUs, especially parameter and optimizer state offloading. Activation offloading is optimized recently, but its performance is still worse than no offloading. So we want to tell the user about this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of completely removing this sentence, can we let the user be aware of the performance concern? @yashk2810 @jreiffers
"Note that offloading performance may vary significantly across device types."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mention this below in the optimizer offloading section which is fine. No need to mention it again at the top. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Removed here at the top and mentioned in the "Limitations of Parameter Offloading" session.


### Basic Implementation

In this section, you will implement a simple model with the Adam optimizer. This implementation will help you understand the baseline behavior before exploring optimizer state offloading. It is particularly useful for understanding memory patterns in large-scale neural network training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"In this section, let's implement a simple model with the Adam optimizer". In general, prefer not using "you"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Not use "you" in this document.

@zhenying-liu zhenying-liu force-pushed the opt_offload branch 6 times, most recently from 712601b to d7dc38e Compare May 28, 2025 00:00
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 28, 2025
This implementation demonstrates how to:
1. Set up sharding specifications for `device` and `pinned_host`
2. Move optimizer states between host and device memory via {func}`jax.device_put`
3. Use `in_sharding` and `out_shardings` to ensure proper memory placement
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_sharding shouldn't be used. Can you remove that please? The arrays should be on the correct memory kind and sharding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed all the in_shardings. Verified in the colab that the all the code still running expectedly.
So in_shardings is indeed redundant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants