-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[train][jax] Enable Jax trainer on GPU #58322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
jax gpu image build on anyscale platform: https://gist.github.com/liulehui/bda2419e1b3245d40d8027053a8dd26c |
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
| use_tpu: bool = False | ||
| use_gpu: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this JaxConfig has a few redundant params with the ScalingConfig, and both of these are passed through from scaling config.
Plus, this is a public API that users can modify, so you could end up with ScalingConfig != JaxConfig which is a bit confusing.
Let's discuss and address this in a followup PR. Ok to merge for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg!
1. Jax dependency is introduced in #58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com>
1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com>
Jax dependency is introduced in #58322 The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. jax <= 0.4.14 does not support py 3.12. skip jax test if it runs against py3.12+. Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
1. this PR added multihost GPU support for Ray Train JaxTrainer 2. Following Jax [GPU distributed doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if `ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS. 3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize jax distributed with https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: YK <1811651+ykdojo@users.noreply.github.com>
1. this PR added multihost GPU support for Ray Train JaxTrainer 2. Following Jax [GPU distributed doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if `ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS. 3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize jax distributed with https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com>
Jax dependency is introduced in ray-project#58322 The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. jax <= 0.4.14 does not support py 3.12. skip jax test if it runs against py3.12+. Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
## Description 1. Jax dependency is introduced in #58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
## Description 1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
## Description 1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: jeffery4011 <jefferyshen1015@gmail.com>
Description
ScalingConfig.use_gpu == True, we add "cuda" as JAX_PLATFORMS.Related issues
Additional information