Skip to content

Conversation

@liulehui
Copy link
Contributor

@liulehui liulehui commented Oct 30, 2025

Description

  1. this PR added multihost GPU support for Ray Train JaxTrainer
  2. Following Jax GPU distributed doc: if ScalingConfig.use_gpu == True, we add "cuda" as JAX_PLATFORMS.
  3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize jax distributed with https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize

Related issues

Additional information

Optional: Add implementation details, API changes, usage examples, screenshots, etc.

  1. Tested with script here: https://gist.github.com/liulehui/b0b25065d48b730f2898b712aa92e06e

@liulehui liulehui added the go add ONLY when ready to merge, run all tests label Oct 30, 2025
@liulehui
Copy link
Contributor Author

jax gpu image build on anyscale platform: https://gist.github.com/liulehui/bda2419e1b3245d40d8027053a8dd26c

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@liulehui liulehui marked this pull request as ready for review November 22, 2025 01:58
@liulehui liulehui requested review from a team, matthewdeng and richardliaw as code owners November 22, 2025 01:58
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Nov 22, 2025
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Comment on lines 22 to +23
use_tpu: bool = False
use_gpu: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this JaxConfig has a few redundant params with the ScalingConfig, and both of these are passed through from scaling config.

Plus, this is a public API that users can modify, so you could end up with ScalingConfig != JaxConfig which is a bit confusing.

Let's discuss and address this in a followup PR. Ok to merge for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

@justinvyu justinvyu enabled auto-merge (squash) November 24, 2025 18:56
@justinvyu justinvyu merged commit b88bcc1 into ray-project:master Nov 24, 2025
7 checks passed
justinvyu pushed a commit that referenced this pull request Nov 26, 2025
1. Jax dependency is introduced in
#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
KaisennHu pushed a commit to KaisennHu/ray that referenced this pull request Nov 26, 2025
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
aslonnie pushed a commit that referenced this pull request Nov 26, 2025
Jax dependency is introduced in
#58322
The current test environment is for CUDA 12.1, which limit jax version
below 0.4.14.
jax <= 0.4.14 does not support py 3.12.
skip jax test if it runs against py3.12+.

Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
ykdojo pushed a commit to ykdojo/ray that referenced this pull request Nov 27, 2025
1. this PR added multihost GPU support for Ray Train JaxTrainer
2. Following Jax [GPU distributed
doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if
`ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS.
3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize
jax distributed with
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: YK <1811651+ykdojo@users.noreply.github.com>
SheldonTsen pushed a commit to SheldonTsen/ray that referenced this pull request Dec 1, 2025
1. this PR added multihost GPU support for Ray Train JaxTrainer
2. Following Jax [GPU distributed
doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if
`ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS.
3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize
jax distributed with
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
SheldonTsen pushed a commit to SheldonTsen/ray that referenced this pull request Dec 1, 2025
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
SheldonTsen pushed a commit to SheldonTsen/ray that referenced this pull request Dec 1, 2025
Jax dependency is introduced in
ray-project#58322
The current test environment is for CUDA 12.1, which limit jax version
below 0.4.14.
jax <= 0.4.14 does not support py 3.12.
skip jax test if it runs against py3.12+.

Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
matthewdeng pushed a commit that referenced this pull request Jan 13, 2026
## Description
1. Jax dependency is introduced in
#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
rushikeshadhav pushed a commit to rushikeshadhav/ray that referenced this pull request Jan 14, 2026
## Description
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
jeffery4011 pushed a commit to jeffery4011/ray that referenced this pull request Jan 20, 2026
## Description
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: jeffery4011 <jefferyshen1015@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants