Skip to content

Conversation

@ryanaoleary
Copy link
Contributor

Description

This PR adds support in the JaxTrainer to schedule across multiple TPU slices using the ray.util.tpu public utilities.

To support this, this PR adds new AcceleratorConfigs to the V2 scaling config, which consolidate the accelerator related fields for TPU and GPU. When TPUAcceleratorConfig is specified, the JaxTrainer utilizes a SlicePlacementGroup to atomically reserve num_slices TPU slices of the desired topology, auto-detecting the required values for num_workers and resources_per_worker when unspecified.

TODO: I'll add some manual testing and usage examples in the comments.

Related issues

#55162

Additional information

Optional: Add implementation details, API changes, usage examples, screenshots, etc.

@ryanaoleary ryanaoleary requested review from a team as code owners November 14, 2025 10:14
@ryanaoleary
Copy link
Contributor Author

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces multi-slice TPU support for JaxTrainer by refactoring the accelerator configuration and leveraging ray.util.tpu.SlicePlacementGroup. The changes include a new AcceleratorConfig API, which provides a cleaner way to specify GPU and TPU resources, and deprecates older fields like use_gpu and use_tpu. The logic for reserving TPU slices and determining worker configurations is now encapsulated within SlicePlacementGroup, which correctly handles multi-slice reservations and auto-detects num_workers and resources_per_worker. The test suite has been significantly improved to cover various single-host, multi-host, and multi-slice scenarios. Overall, this is a well-structured and comprehensive update that greatly enhances TPU support in Ray Train. I have a couple of suggestions to address a potential runtime error and a syntax issue.

@ray-gardener ray-gardener bot added train Ray Train Related Issue community-contribution Contributed by the community labels Nov 14, 2025
@liulehui liulehui requested a review from xyuzh November 14, 2025 21:45
Copy link
Contributor

@liulehui liulehui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!!

Copy link
Contributor

@liulehui liulehui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫶

@github-actions
Copy link

This pull request has been automatically marked as stale because it has not had
any activity for 14 days. It will be closed in another 14 days if no further activity occurs.
Thank you for your contributions.

You can always ask for help on our discussion forum or Ray's public slack channel.

If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@github-actions github-actions bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Dec 18, 2025
@ryanaoleary
Copy link
Contributor Author

cc: @matthewdeng @liulehui I've resolved the outstanding comments and this PR is ready for another review. I'm fixing the comments on #59136 now so that it can be merged first, and then this PR will just contain the Ray Train changes.

@github-actions github-actions bot added unstale A PR that has been marked unstale. It will not get marked stale again if this label is on it. and removed stale The issue is stale. It will be closed within 7 days unless there are further conversation labels Dec 19, 2025
@ryanaoleary
Copy link
Contributor Author

@liulehui Thank you for the reviews!! I think the CI failure for test_jax_distributed_shutdown_timeout should be fixed with 1756d33.

Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Few minor remaining comments.

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
@matthewdeng matthewdeng enabled auto-merge (squash) January 6, 2026 02:26
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Jan 6, 2026
@github-actions github-actions bot disabled auto-merge January 6, 2026 03:47
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary requested a review from a team as a code owner January 6, 2026 05:14
@matthewdeng matthewdeng enabled auto-merge (squash) January 6, 2026 23:54
Copy link
Contributor

@dayshah dayshah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have a nit, but non-blocking can address in a follow-up if it makes sense @ryanaoleary

return max(1, math.ceil(num_workers / workers_per_slice))
except Exception:
# Fallback to 1 if calculation fails.
return 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 1 on failed calculation or invalid inputs? I feel like it makes more sense to raise, but maybe missing something here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dayshah I had it default to 1 because in the Ray Train code we call validate_tpu_config and get_tpu_worker_resources for the TPU inputs, which should avoid this ever raising since we validate that the num_workers and topology / accelerator type are compatible with valid inputs.

My thought was that if for some reason this call did fail in the controller (which is the only place it's called), that it'd be better to return a default slice of 1 rather than crash the controller. I think since I don't expect this case to happen regardless due to the validation, that I can update this util to Raise like you suggest since yeah that does make more sense.

@matthewdeng matthewdeng merged commit 236d074 into ray-project:master Jan 7, 2026
7 checks passed
AYou0207 pushed a commit to AYou0207/ray that referenced this pull request Jan 13, 2026
## Description
This PR adds support in the `JaxTrainer` to schedule across multiple TPU
slices using the `ray.util.tpu` public utilities.

To support this, this PR adds new `AcceleratorConfig`s to the V2 scaling
config, which consolidate the accelerator related fields for TPU and
GPU. When `TPUAcceleratorConfig` is specified, the JaxTrainer utilizes a
`SlicePlacementGroup` to atomically reserve `num_slices` TPU slices of
the desired topology, auto-detecting the required values for
`num_workers` and `resources_per_worker` when unspecified.

TODO: I'll add some manual testing and usage examples in the comments.

## Related issues
ray-project#55162

## Additional information
> Optional: Add implementation details, API changes, usage examples,
screenshots, etc.

---------

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: jasonwrwang <jasonwrwang@tencent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community go add ONLY when ready to merge, run all tests train Ray Train Related Issue unstale A PR that has been marked unstale. It will not get marked stale again if this label is on it.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants