-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[train][docs] update Jax doc to include GPU and multislice TPU support #60593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request successfully updates the documentation and JaxTrainer implementation to include GPU and multislice TPU support. The changes are consistent across the documentation files and the Python code, providing clearer explanations and examples for users. The removal of outdated JAX environment variables and the correction of dataset shard access in the examples are positive improvements.
| For GPU training, `ScalingConfig` is similar to other frameworks. Key fields include: | ||
|
|
||
| * :class:`num_workers <ray.train.ScalingConfig>`: The number of distributed training worker processes. | ||
| * :class:`use_gpu <ray.train.ScalingConfig>`: Whether each worker should use a GPU (or CPU). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The phrase "(or CPU)" is redundant here. If use_gpu is True, it means GPU. If False, it implies CPU. Removing it will make the description more concise.
| * :class:`use_gpu <ray.train.ScalingConfig>`: Whether each worker should use a GPU (or CPU). | |
| * :class:`use_gpu <ray.train.ScalingConfig>`: Whether each worker should use a GPU. |
| Together, these configurations provide a declarative API for defining your entire distributed JAX | ||
| training environment, allowing Ray Train to handle the complex task of launching and coordinating | ||
| workers across a TPU slice. | ||
| For GPU training, `ScalingConfig` is similar to other frameworks. Key fields include: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would edit this to not assume that the user knows how to set up other frameworks.
|
|
||
| * `use_tpu`: This is a new field added in Ray 2.49.0 to the V2 `ScalingConfig`. This boolean flag explicitly tells Ray Train to initialize the JAX backend for TPU execution. | ||
| * `topology`: This is a new field added in Ray 2.49.0 to the V2 `ScalingConfig`. Topology is a string defining the physical arrangement of the TPU chips (e.g., "4x4"). This is required for multi-host training and ensures Ray places workers correctly across the slice. For a list of supported TPU topologies by generation, | ||
| * :class:`use_tpu <ray.train.ScalingConfig>`: It's a new field added in Ray 2.49.0 to the V2 `ScalingConfig`. This boolean flag tells Ray Train to initialize the JAX backend for TPU execution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the previous wording ("This is...")
| Together, these configurations provide a declarative API for defining your entire distributed JAX | ||
| training environment, allowing Ray Train to handle the complex task of launching and coordinating | ||
| workers across a TPU slice. | ||
| For GPU training, `ScalingConfig` is similar to other frameworks. Key fields include: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A big thing that's missing here is what the relationship is between workers and resources, i.e. should one worker map to one GPU, node, or something else?
| # If you want to use GPUs, specify the GPU scaling config like below. | ||
| # gpu_scaling_config = ScalingConfig( | ||
| # use_gpu=True, | ||
| # num_workers=4, | ||
| # resources_per_worker={"GPU": 1}, | ||
| # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Right now gpu_scaling_config and tpu_scaling_config have different names, so simply uncomment this wouldn't enable GPU training, you'd also have to update scaling_config=....
Either you can:
- Name these back to just
scaling_config, or: - Uncomment this and add logic/a comment on switching between these two to enable GPU training.
| # gpu_scaling_config = ScalingConfig( | ||
| # num_workers=4, | ||
| # use_gpu=True, | ||
| # resources_per_worker={"GPU": 1}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to show resources_per_worker for this since it matches the default?
Description
We added GPU (#58322) and multislice TPU (#58629) support for JaxTrainer, this PR is to update the corresponding docs.
Additional information
make develop && make local