Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5d3a943
test jax trainer
liulehui Aug 5, 2025
030e31a
CUDA support
liulehui Oct 16, 2025
9681874
add unit test
liulehui Oct 29, 2025
dd877f1
set JAX_PLATFORMS automatically
liulehui Oct 16, 2025
db73a15
lint
liulehui Oct 29, 2025
cf8a942
try to mock jax distributed
liulehui Oct 29, 2025
dcc757d
mock jax distributed
liulehui Oct 29, 2025
b651a8a
reset to cpu for vanilla tests
liulehui Oct 29, 2025
041d9aa
modify
liulehui Oct 29, 2025
8e43fc1
remove
liulehui Oct 29, 2025
5f7eaf3
experimental gpu jax
liulehui Oct 30, 2025
f2d4f48
fix
liulehui Oct 30, 2025
97e4d18
some fixs
liulehui Nov 3, 2025
78ba180
shutdown gpu jax distributed
liulehui Nov 21, 2025
2b92238
gpu cuda env var
liulehui Nov 21, 2025
f728a2b
unit tests
liulehui Nov 22, 2025
34c337d
try to install jax[cuda] for gpu unit test
liulehui Nov 22, 2025
6baa078
try to fix cuda jaxlib
liulehui Nov 22, 2025
265a860
try jax 0.4.23
liulehui Nov 22, 2025
81ae1c0
pin to jax 0.4.23
liulehui Nov 22, 2025
9c19af0
pin jax to 0.4.23
liulehui Nov 22, 2025
6480128
try pin to 0.4.20
liulehui Nov 22, 2025
f7443f2
one more time trying 0.4.20
liulehui Nov 22, 2025
f9f4b66
fix requirements_compiled.txt
liulehui Nov 22, 2025
fa31ac1
pin to 0.3.27
liulehui Nov 22, 2025
c9dd655
pin to 0.4.13
liulehui Nov 23, 2025
fe1d722
remove from train-test-requirements.txt
liulehui Nov 23, 2025
e2fbcb7
remove from train-test-requirements.txt
liulehui Nov 23, 2025
190f2c6
remove duplicate in gpu
liulehui Nov 23, 2025
e9420f5
limit python version
liulehui Nov 23, 2025
012e6ba
align with compiled
liulehui Nov 23, 2025
ca15115
fix logging
liulehui Nov 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/ray/train/v2/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,22 @@ py_test(
],
)

py_test(
name = "test_jax_gpu",
size = "medium",
srcs = ["tests/test_jax_gpu.py"],
env = {"RAY_TRAIN_V2_ENABLED": "1"},
tags = [
"exclusive",
"team:ml",
"train_v2_gpu",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_lightgbm_trainer",
size = "small",
Expand Down
39 changes: 34 additions & 5 deletions python/ray/train/v2/jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
@dataclass
class JaxConfig(BackendConfig):
use_tpu: bool = False
use_gpu: bool = False
Comment on lines 22 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this JaxConfig has a few redundant params with the ScalingConfig, and both of these are passed through from scaling config.

Plus, this is a public API that users can modify, so you could end up with ScalingConfig != JaxConfig which is a bit confusing.

Let's discuss and address this in a followup PR. Ok to merge for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!


@property
def backend_cls(self):
return _JaxBackend


def _setup_jax_distributed_environment(
master_addr_with_port: str, num_workers: int, index: int, use_tpu: bool
master_addr_with_port: str,
num_workers: int,
index: int,
use_tpu: bool,
use_gpu: bool,
resources_per_worker: dict,
):
"""Set up distributed Jax training information.

Expand All @@ -40,6 +46,9 @@ def _setup_jax_distributed_environment(
index: Index of this worker.
use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not
already set, it will be set to "tpu".
use_gpu: Whether to configure for GPU. If True and JAX_PLATFORMS is not
already set, it will be set to "cuda".
resources_per_worker: The resources per worker.
"""
# Get JAX_PLATFORMS from environment if already set
jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower()
Expand All @@ -48,12 +57,31 @@ def _setup_jax_distributed_environment(
os.environ["JAX_PLATFORMS"] = "tpu"
jax_platforms = "tpu"

# TODO(lehui): Add env vars for JAX on GPU.
if not jax_platforms and use_gpu:
os.environ["JAX_PLATFORMS"] = "cuda"
jax_platforms = "cuda"

if "cuda" in jax_platforms.split(","):
num_gpus_per_worker = resources_per_worker.get("GPU", 0)
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(i) for i in range(num_gpus_per_worker)
)

import jax

if "tpu" in jax_platforms.split(","):
jax.distributed.initialize(master_addr_with_port, num_workers, index)
logger.info("Initialized JAX distributed on TPU.")

if "cuda" in jax_platforms.split(","):
if num_gpus_per_worker > 0:
local_device_ids = list(range(num_gpus_per_worker))
else:
local_device_ids = 0
jax.distributed.initialize(
master_addr_with_port, num_workers, index, local_device_ids
)
logger.info("Initialized JAX distributed on CUDA.")


def _shutdown_jax_distributed():
Expand All @@ -72,14 +100,13 @@ def _shutdown_jax_distributed():

class _JaxBackend(Backend):
def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
if not backend_config.use_tpu:
if not backend_config.use_tpu and not backend_config.use_gpu:
return

master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
master_addr_with_port = f"{master_addr}:{master_port}"

# Set up JAX distributed environment on all workers
# This sets JAX_PLATFORMS env var and initializes JAX distributed
setup_futures = []
for i in range(len(worker_group)):
setup_futures.append(
Expand All @@ -90,13 +117,15 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
num_workers=len(worker_group),
index=i,
use_tpu=backend_config.use_tpu,
use_gpu=backend_config.use_gpu,
resources_per_worker=worker_group.get_resources_per_worker(),
)
)
ray.get(setup_futures)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig):
"""Cleanup JAX distributed resources when shutting down worker group."""
if not backend_config.use_tpu:
if not backend_config.use_tpu and not backend_config.use_gpu:
return

# Shutdown JAX distributed on all workers
Expand Down
1 change: 1 addition & 0 deletions python/ray/train/v2/jax/jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
if not jax_config:
jax_config = JaxConfig(
use_tpu=scaling_config.use_tpu,
use_gpu=scaling_config.use_gpu,
)
super(JaxTrainer, self).__init__(
train_loop_per_worker=train_loop_per_worker,
Expand Down
65 changes: 65 additions & 0 deletions python/ray/train/v2/tests/test_jax_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import sys

import pytest

from ray.train import RunConfig, ScalingConfig
from ray.train.v2._internal.constants import (
HEALTH_CHECK_INTERVAL_S_ENV_VAR,
is_v2_enabled,
)
from ray.train.v2.jax import JaxTrainer

assert is_v2_enabled()


@pytest.fixture(autouse=True)
def reduce_health_check_interval(monkeypatch):
monkeypatch.setenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, "0.2")
yield


@pytest.mark.skipif(sys.platform == "darwin", reason="JAX GPU not supported on macOS")
def test_jax_distributed_gpu_training(ray_start_4_cpus_2_gpus, tmp_path):
"""Test multi-GPU JAX distributed training.

This test verifies that JAX distributed initialization works correctly
across multiple GPU workers and that they can coordinate.
"""

def train_func():
import jax

from ray import train

# Get JAX distributed info
devices = jax.devices()
world_rank = train.get_context().get_world_rank()
world_size = train.get_context().get_world_size()

# Verify distributed setup
assert world_size == 2, f"Expected world size 2, got {world_size}"
assert world_rank in [0, 1], f"Invalid rank {world_rank}"
assert len(devices) == 2, f"Expected 2 devices, got {len(devices)}"

train.report(
{
"world_rank": world_rank,
"world_size": world_size,
"num_devices": len(devices),
}
)

trainer = JaxTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
run_config=RunConfig(storage_path=str(tmp_path)),
)

result = trainer.fit()
assert result.error is None


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", "-x", __file__]))
4 changes: 4 additions & 0 deletions python/requirements/ml/dl-cpu-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ torch-spline-conv==1.2.2
torch-geometric==2.5.3

cupy-cuda12x==13.1.0; sys_platform != 'darwin'

# Keep JAX version consistent with dl-gpu-requirements.txt
jax==0.4.13; python_version < '3.12' and sys_platform != 'darwin'
jaxlib==0.4.13; python_version < '3.12' and sys_platform != 'darwin'
4 changes: 4 additions & 0 deletions python/requirements/ml/dl-gpu-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ torch-spline-conv==1.2.2+pt23cu121

cupy-cuda12x==13.1.0; sys_platform != 'darwin'
nixl==0.4.0; sys_platform != 'darwin'

--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Downgrading to JAX 0.4.13 to be compatible with CUDA 12.1
jaxlib==0.4.13+cuda12.cudnn89; python_version < '3.12' and sys_platform != 'darwin'
2 changes: 0 additions & 2 deletions python/requirements/ml/train-test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
evaluate==0.4.3
mosaicml; python_version < "3.12"
sentencepiece==0.1.96
jax==0.4.25
jaxlib==0.4.25
s3torchconnector==1.4.3
8 changes: 4 additions & 4 deletions python/requirements_compiled.txt
Original file line number Diff line number Diff line change
Expand Up @@ -862,10 +862,10 @@ isoduration==20.11.0
# via jsonschema
itsdangerous==2.1.2
# via flask
jax==0.4.25
# via -r python/requirements/ml/train-test-requirements.txt
jaxlib==0.4.25
# via -r python/requirements/ml/train-test-requirements.txt
jax==0.4.13 ; python_version < "3.12" and sys_platform != "darwin"
# via -r python/requirements/ml/dl-cpu-requirements.txt
jaxlib==0.4.13 ; python_version < "3.12" and sys_platform != "darwin"
# via -r python/requirements/ml/dl-cpu-requirements.txt
jedi==0.19.1
# via ipython
jinja2==3.1.6
Expand Down