-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[Train] Add TPU multi-slice support to JaxTrainer #58629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Add TPU multi-slice support to JaxTrainer #58629
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces multi-slice TPU support for JaxTrainer by refactoring the accelerator configuration and leveraging ray.util.tpu.SlicePlacementGroup. The changes include a new AcceleratorConfig API, which provides a cleaner way to specify GPU and TPU resources, and deprecates older fields like use_gpu and use_tpu. The logic for reserving TPU slices and determining worker configurations is now encapsulated within SlicePlacementGroup, which correctly handles multi-slice reservations and auto-detects num_workers and resources_per_worker. The test suite has been significantly improved to cover various single-host, multi-host, and multi-slice scenarios. Overall, this is a well-structured and comprehensive update that greatly enhances TPU support in Ray Train. I have a couple of suggestions to address a potential runtime error and a syntax issue.
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
liulehui
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you!!
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
9ba6928 to
6ad813f
Compare
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
liulehui
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🫶
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
|
This pull request has been automatically marked as stale because it has not had You can always ask for help on our discussion forum or Ray's public slack channel. If you'd like to keep this open, just leave any comment, and the stale label will be removed. |
1e28858 to
5e03ad6
Compare
|
cc: @matthewdeng @liulehui I've resolved the outstanding comments and this PR is ready for another review. I'm fixing the comments on #59136 now so that it can be merged first, and then this PR will just contain the Ray Train changes. |
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
matthewdeng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Few minor remaining comments.
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
dayshah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have a nit, but non-blocking can address in a follow-up if it makes sense @ryanaoleary
| return max(1, math.ceil(num_workers / workers_per_slice)) | ||
| except Exception: | ||
| # Fallback to 1 if calculation fails. | ||
| return 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 1 on failed calculation or invalid inputs? I feel like it makes more sense to raise, but maybe missing something here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dayshah I had it default to 1 because in the Ray Train code we call validate_tpu_config and get_tpu_worker_resources for the TPU inputs, which should avoid this ever raising since we validate that the num_workers and topology / accelerator type are compatible with valid inputs.
My thought was that if for some reason this call did fail in the controller (which is the only place it's called), that it'd be better to return a default slice of 1 rather than crash the controller. I think since I don't expect this case to happen regardless due to the validation, that I can update this util to Raise like you suggest since yeah that does make more sense.
## Description This PR adds support in the `JaxTrainer` to schedule across multiple TPU slices using the `ray.util.tpu` public utilities. To support this, this PR adds new `AcceleratorConfig`s to the V2 scaling config, which consolidate the accelerator related fields for TPU and GPU. When `TPUAcceleratorConfig` is specified, the JaxTrainer utilizes a `SlicePlacementGroup` to atomically reserve `num_slices` TPU slices of the desired topology, auto-detecting the required values for `num_workers` and `resources_per_worker` when unspecified. TODO: I'll add some manual testing and usage examples in the comments. ## Related issues ray-project#55162 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: ryanaoleary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: jasonwrwang <jasonwrwang@tencent.com>
Description
This PR adds support in the
JaxTrainerto schedule across multiple TPU slices using theray.util.tpupublic utilities.To support this, this PR adds new
AcceleratorConfigs to the V2 scaling config, which consolidate the accelerator related fields for TPU and GPU. WhenTPUAcceleratorConfigis specified, the JaxTrainer utilizes aSlicePlacementGroupto atomically reservenum_slicesTPU slices of the desired topology, auto-detecting the required values fornum_workersandresources_per_workerwhen unspecified.TODO: I'll add some manual testing and usage examples in the comments.
Related issues
#55162
Additional information