-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[Train] Add TPU multi-slice support to JaxTrainer #58629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
matthewdeng
merged 49 commits into
ray-project:master
from
ryanaoleary:jax-tpu-multi-slice
Jan 7, 2026
Merged
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
e629113
[Train] Add TPU multi-slice support to JaxTrainer
ryanaoleary e66a1e9
Update python/ray/util/tpu.py
ryanaoleary 4b1fdf0
Update python/ray/train/v2/_internal/execution/worker_group/worker_gr…
ryanaoleary 671c2a0
update test code
ryanaoleary 9db7f14
Add import
ryanaoleary 6566c03
Add cleanup to abort
ryanaoleary d4a8f20
Remove nested configs and set env vars
ryanaoleary 8f44feb
Fix merge
ryanaoleary ee70ef8
Fix bugbot comments
ryanaoleary adcd473
Format and add default for num_workers when None
ryanaoleary b579400
Default resources per worker to 1
ryanaoleary 95baf58
Check for accelerator type before calling slice placement group
ryanaoleary 5a8cc49
Specify SlicePlacementGroup is for TPUs
ryanaoleary 3f3a203
Add TODO for PG cleaner
ryanaoleary cfafb00
Add back use_gpu arg
ryanaoleary 0f02ba5
Make num_workers required for TPUs and add some tests
ryanaoleary e462b86
Change to Optional[dict]
ryanaoleary b3e4342
Fix import in docstring
ryanaoleary d95fe7e
remove head_pgs var
ryanaoleary 316bbb1
Move placement group logic to unified helper function
ryanaoleary 36d8888
Bound slice ID calculation
ryanaoleary 25746d4
fix merge
ryanaoleary bb2ead6
Handle edge case pointed out by bugbot
ryanaoleary bcbdecc
Add defensive check for topology and accelerator type on_start
ryanaoleary 471945a
Avoid resource leaks
ryanaoleary 91aba10
Check for negative num_slices
ryanaoleary a636b48
Move SlicePlacementGroup to WorkerGroupState
ryanaoleary d61d924
Delete TPUReservationCallback
ryanaoleary be0be9b
Remove num_slices arg and calculate it from num_workers
ryanaoleary 3a1951f
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary 0b38fb4
Remove num_slices from JaxTrainer
ryanaoleary 69a3a08
Remove double placement group cleanup
ryanaoleary 6d5f1e2
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary 988eed5
Add new TPU util and move num_slices to WorkerGroupContext
ryanaoleary c986fe6
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary b3aa2ad
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary 1756d33
Check before accessing v2 worker_group fields
ryanaoleary db4203f
Fix tests, remove config.py change, and add _validate_tpu_config
ryanaoleary d0e80d3
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary dfcc004
Remove unnecessary test from test_config
ryanaoleary d9ee6bd
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary 50f628f
Add TPU util that we added to utility.rst
ryanaoleary f120306
Make health check less aggressive to reduce CI flakiness
ryanaoleary 3348db3
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary 815ce63
Fix fixture causing CI error
ryanaoleary f3f7b2f
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary 46b5d1f
Add missing fixture
ryanaoleary fabef4e
Trying to fix test startup error due to fixture (only happens in CI)
ryanaoleary b2bf780
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 0 additions & 45 deletions
45
python/ray/train/v2/_internal/callbacks/tpu_reservation_callback.py
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.