-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[train][jax] Enable Jax trainer on GPU #58322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
5d3a943
test jax trainer
liulehui 030e31a
CUDA support
liulehui 9681874
add unit test
liulehui dd877f1
set JAX_PLATFORMS automatically
liulehui db73a15
lint
liulehui cf8a942
try to mock jax distributed
liulehui dcc757d
mock jax distributed
liulehui b651a8a
reset to cpu for vanilla tests
liulehui 041d9aa
modify
liulehui 8e43fc1
remove
liulehui 5f7eaf3
experimental gpu jax
liulehui f2d4f48
fix
liulehui 97e4d18
some fixs
liulehui 78ba180
shutdown gpu jax distributed
liulehui 2b92238
gpu cuda env var
liulehui f728a2b
unit tests
liulehui 34c337d
try to install jax[cuda] for gpu unit test
liulehui 6baa078
try to fix cuda jaxlib
liulehui 265a860
try jax 0.4.23
liulehui 81ae1c0
pin to jax 0.4.23
liulehui 9c19af0
pin jax to 0.4.23
liulehui 6480128
try pin to 0.4.20
liulehui f7443f2
one more time trying 0.4.20
liulehui f9f4b66
fix requirements_compiled.txt
liulehui fa31ac1
pin to 0.3.27
liulehui c9dd655
pin to 0.4.13
liulehui fe1d722
remove from train-test-requirements.txt
liulehui e2fbcb7
remove from train-test-requirements.txt
liulehui 190f2c6
remove duplicate in gpu
liulehui e9420f5
limit python version
liulehui 012e6ba
align with compiled
liulehui ca15115
fix logging
liulehui File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import sys | ||
|
|
||
| import pytest | ||
|
|
||
| from ray.train import RunConfig, ScalingConfig | ||
| from ray.train.v2._internal.constants import ( | ||
| HEALTH_CHECK_INTERVAL_S_ENV_VAR, | ||
| is_v2_enabled, | ||
| ) | ||
| from ray.train.v2.jax import JaxTrainer | ||
|
|
||
| assert is_v2_enabled() | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def reduce_health_check_interval(monkeypatch): | ||
| monkeypatch.setenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, "0.2") | ||
| yield | ||
|
|
||
|
|
||
| @pytest.mark.skipif(sys.platform == "darwin", reason="JAX GPU not supported on macOS") | ||
| def test_jax_distributed_gpu_training(ray_start_4_cpus_2_gpus, tmp_path): | ||
| """Test multi-GPU JAX distributed training. | ||
|
|
||
| This test verifies that JAX distributed initialization works correctly | ||
| across multiple GPU workers and that they can coordinate. | ||
| """ | ||
|
|
||
| def train_func(): | ||
| import jax | ||
|
|
||
| from ray import train | ||
|
|
||
| # Get JAX distributed info | ||
| devices = jax.devices() | ||
| world_rank = train.get_context().get_world_rank() | ||
| world_size = train.get_context().get_world_size() | ||
|
|
||
| # Verify distributed setup | ||
| assert world_size == 2, f"Expected world size 2, got {world_size}" | ||
| assert world_rank in [0, 1], f"Invalid rank {world_rank}" | ||
| assert len(devices) == 2, f"Expected 2 devices, got {len(devices)}" | ||
|
|
||
| train.report( | ||
| { | ||
| "world_rank": world_rank, | ||
| "world_size": world_size, | ||
| "num_devices": len(devices), | ||
| } | ||
| ) | ||
|
|
||
| trainer = JaxTrainer( | ||
| train_func, | ||
| scaling_config=ScalingConfig(num_workers=2, use_gpu=True), | ||
| run_config=RunConfig(storage_path=str(tmp_path)), | ||
| ) | ||
|
|
||
| result = trainer.fit() | ||
| assert result.error is None | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import sys | ||
|
|
||
| sys.exit(pytest.main(["-v", "-x", __file__])) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,4 @@ | ||
| evaluate==0.4.3 | ||
| mosaicml; python_version < "3.12" | ||
| sentencepiece==0.1.96 | ||
| jax==0.4.25 | ||
| jaxlib==0.4.25 | ||
| s3torchconnector==1.4.3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this JaxConfig has a few redundant params with the ScalingConfig, and both of these are passed through from scaling config.
Plus, this is a public API that users can modify, so you could end up with ScalingConfig != JaxConfig which is a bit confusing.
Let's discuss and address this in a followup PR. Ok to merge for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg!