Skip to content

Conversation

@liulehui
Copy link
Contributor

Description

We added GPU (#58322) and multislice TPU (#58629) support for JaxTrainer, this PR is to update the corresponding docs.

Additional information

  1. tested with make develop && make local

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@liulehui liulehui requested a review from a team as a code owner January 29, 2026 23:15
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully updates the documentation and JaxTrainer implementation to include GPU and multislice TPU support. The changes are consistent across the documentation files and the Python code, providing clearer explanations and examples for users. The removal of outdated JAX environment variables and the correction of dataset shard access in the examples are positive improvements.

For GPU training, `ScalingConfig` is similar to other frameworks. Key fields include:

* :class:`num_workers <ray.train.ScalingConfig>`: The number of distributed training worker processes.
* :class:`use_gpu <ray.train.ScalingConfig>`: Whether each worker should use a GPU (or CPU).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The phrase "(or CPU)" is redundant here. If use_gpu is True, it means GPU. If False, it implies CPU. Removing it will make the description more concise.

Suggested change
* :class:`use_gpu <ray.train.ScalingConfig>`: Whether each worker should use a GPU (or CPU).
* :class:`use_gpu <ray.train.ScalingConfig>`: Whether each worker should use a GPU.

Together, these configurations provide a declarative API for defining your entire distributed JAX
training environment, allowing Ray Train to handle the complex task of launching and coordinating
workers across a TPU slice.
For GPU training, `ScalingConfig` is similar to other frameworks. Key fields include:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would edit this to not assume that the user knows how to set up other frameworks.


* `use_tpu`: This is a new field added in Ray 2.49.0 to the V2 `ScalingConfig`. This boolean flag explicitly tells Ray Train to initialize the JAX backend for TPU execution.
* `topology`: This is a new field added in Ray 2.49.0 to the V2 `ScalingConfig`. Topology is a string defining the physical arrangement of the TPU chips (e.g., "4x4"). This is required for multi-host training and ensures Ray places workers correctly across the slice. For a list of supported TPU topologies by generation,
* :class:`use_tpu <ray.train.ScalingConfig>`: It's a new field added in Ray 2.49.0 to the V2 `ScalingConfig`. This boolean flag tells Ray Train to initialize the JAX backend for TPU execution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the previous wording ("This is...")

Together, these configurations provide a declarative API for defining your entire distributed JAX
training environment, allowing Ray Train to handle the complex task of launching and coordinating
workers across a TPU slice.
For GPU training, `ScalingConfig` is similar to other frameworks. Key fields include:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A big thing that's missing here is what the relationship is between workers and resources, i.e. should one worker map to one GPU, node, or something else?

Comment on lines +66 to +71
# If you want to use GPUs, specify the GPU scaling config like below.
# gpu_scaling_config = ScalingConfig(
# use_gpu=True,
# num_workers=4,
# resources_per_worker={"GPU": 1},
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Right now gpu_scaling_config and tpu_scaling_config have different names, so simply uncomment this wouldn't enable GPU training, you'd also have to update scaling_config=....

Either you can:

  1. Name these back to just scaling_config, or:
  2. Uncomment this and add logic/a comment on switching between these two to enable GPU training.

# gpu_scaling_config = ScalingConfig(
# num_workers=4,
# use_gpu=True,
# resources_per_worker={"GPU": 1},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to show resources_per_worker for this since it matches the default?

@ray-gardener ray-gardener bot added docs An issue or change related to documentation train Ray Train Related Issue labels Jan 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs An issue or change related to documentation train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants