Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
e629113
[Train] Add TPU multi-slice support to JaxTrainer
ryanaoleary Nov 14, 2025
e66a1e9
Update python/ray/util/tpu.py
ryanaoleary Nov 14, 2025
4b1fdf0
Update python/ray/train/v2/_internal/execution/worker_group/worker_gr…
ryanaoleary Nov 14, 2025
671c2a0
update test code
ryanaoleary Nov 14, 2025
9db7f14
Add import
ryanaoleary Nov 14, 2025
6566c03
Add cleanup to abort
ryanaoleary Nov 14, 2025
d4a8f20
Remove nested configs and set env vars
ryanaoleary Dec 3, 2025
8f44feb
Fix merge
ryanaoleary Dec 3, 2025
ee70ef8
Fix bugbot comments
ryanaoleary Dec 3, 2025
adcd473
Format and add default for num_workers when None
ryanaoleary Dec 3, 2025
b579400
Default resources per worker to 1
ryanaoleary Dec 3, 2025
95baf58
Check for accelerator type before calling slice placement group
ryanaoleary Dec 3, 2025
5a8cc49
Specify SlicePlacementGroup is for TPUs
ryanaoleary Dec 18, 2025
3f3a203
Add TODO for PG cleaner
ryanaoleary Dec 18, 2025
cfafb00
Add back use_gpu arg
ryanaoleary Dec 18, 2025
0f02ba5
Make num_workers required for TPUs and add some tests
ryanaoleary Dec 18, 2025
e462b86
Change to Optional[dict]
ryanaoleary Dec 18, 2025
b3e4342
Fix import in docstring
ryanaoleary Dec 18, 2025
d95fe7e
remove head_pgs var
ryanaoleary Dec 18, 2025
316bbb1
Move placement group logic to unified helper function
ryanaoleary Dec 19, 2025
36d8888
Bound slice ID calculation
ryanaoleary Dec 19, 2025
25746d4
fix merge
ryanaoleary Dec 19, 2025
bb2ead6
Handle edge case pointed out by bugbot
ryanaoleary Dec 19, 2025
bcbdecc
Add defensive check for topology and accelerator type on_start
ryanaoleary Dec 19, 2025
471945a
Avoid resource leaks
ryanaoleary Dec 20, 2025
91aba10
Check for negative num_slices
ryanaoleary Dec 20, 2025
a636b48
Move SlicePlacementGroup to WorkerGroupState
ryanaoleary Dec 31, 2025
d61d924
Delete TPUReservationCallback
ryanaoleary Dec 31, 2025
be0be9b
Remove num_slices arg and calculate it from num_workers
ryanaoleary Dec 31, 2025
3a1951f
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Dec 31, 2025
0b38fb4
Remove num_slices from JaxTrainer
ryanaoleary Dec 31, 2025
69a3a08
Remove double placement group cleanup
ryanaoleary Dec 31, 2025
6d5f1e2
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Dec 31, 2025
988eed5
Add new TPU util and move num_slices to WorkerGroupContext
ryanaoleary Jan 3, 2026
c986fe6
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Jan 3, 2026
b3aa2ad
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Jan 5, 2026
1756d33
Check before accessing v2 worker_group fields
ryanaoleary Jan 5, 2026
db4203f
Fix tests, remove config.py change, and add _validate_tpu_config
ryanaoleary Jan 6, 2026
d0e80d3
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Jan 6, 2026
dfcc004
Remove unnecessary test from test_config
ryanaoleary Jan 6, 2026
d9ee6bd
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Jan 6, 2026
50f628f
Add TPU util that we added to utility.rst
ryanaoleary Jan 6, 2026
f120306
Make health check less aggressive to reduce CI flakiness
ryanaoleary Jan 6, 2026
3348db3
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Jan 6, 2026
815ce63
Fix fixture causing CI error
ryanaoleary Jan 6, 2026
f3f7b2f
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Jan 6, 2026
46b5d1f
Add missing fixture
ryanaoleary Jan 6, 2026
fabef4e
Trying to fix test startup error due to fixture (only happens in CI)
ryanaoleary Jan 6, 2026
b2bf780
Merge branch 'master' into jax-tpu-multi-slice
ryanaoleary Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/ray-core/api/utility.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Utility
ray.util.tpu.get_current_pod_name
ray.util.tpu.get_num_tpu_chips_on_node
ray.util.tpu.get_tpu_coordinator_env_vars
ray.util.tpu.get_tpu_num_slices_for_workers
ray.util.tpu.get_tpu_version_from_type
ray.util.tpu.get_tpu_worker_resources

Expand Down
163 changes: 97 additions & 66 deletions python/ray/tests/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import ray
from ray._private.accelerators import TPUAcceleratorManager, tpu
from ray.tests.conftest import _ray_start_cluster
from ray.util.tpu import SlicePlacementGroup


Expand Down Expand Up @@ -170,77 +169,77 @@ def ray_start_cpu():


@pytest.fixture
def ray_tpu_cluster():
def ray_tpu_cluster(ray_start_cluster):
"""
Simulates a Ray cluster with two multi-host TPU v4-16 slices.
"""
pod_type = "v4-16"
topology = "2x2x2"

with _ray_start_cluster() as cluster:
slice_0_env_common = {
"TPU_NAME": "test-slice-0",
"TPU_ACCELERATOR_TYPE": pod_type,
"TPU_TOPOLOGY": topology,
}
slice_0_head_labels = {
"ray.io/tpu-slice-name": "test-slice-0",
"ray.io/tpu-worker-id": "0",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
slice_0_worker_labels = {
"ray.io/tpu-slice-name": "test-slice-0",
"ray.io/tpu-worker-id": "1",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
cluster.add_node(
num_cpus=2,
resources={"TPU": 4, f"TPU-{pod_type}-head": 1},
env_vars={**slice_0_env_common, "TPU_WORKER_ID": "0"},
labels=slice_0_head_labels,
)
cluster.add_node(
num_cpus=2,
resources={"TPU": 4},
env_vars={**slice_0_env_common, "TPU_WORKER_ID": "1"},
labels=slice_0_worker_labels,
)

slice_1_env_common = {
"TPU_NAME": "test-slice-1",
"TPU_ACCELERATOR_TYPE": pod_type,
"TPU_TOPOLOGY": topology,
}
slice_1_head_labels = {
"ray.io/tpu-slice-name": "test-slice-1",
"ray.io/tpu-worker-id": "0",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
slice_1_worker_labels = {
"ray.io/tpu-slice-name": "test-slice-1",
"ray.io/tpu-worker-id": "1",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
cluster.add_node(
num_cpus=2,
resources={"TPU": 4, f"TPU-{pod_type}-head": 1},
env_vars={**slice_1_env_common, "TPU_WORKER_ID": "0"},
labels=slice_1_head_labels,
)
cluster.add_node(
num_cpus=2,
resources={"TPU": 4},
env_vars={**slice_1_env_common, "TPU_WORKER_ID": "1"},
labels=slice_1_worker_labels,
)

ray.init(address=cluster.address)
yield cluster
ray.shutdown()
cluster = ray_start_cluster
slice_0_env_common = {
"TPU_NAME": "test-slice-0",
"TPU_ACCELERATOR_TYPE": pod_type,
"TPU_TOPOLOGY": topology,
}
slice_0_head_labels = {
"ray.io/tpu-slice-name": "test-slice-0",
"ray.io/tpu-worker-id": "0",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
slice_0_worker_labels = {
"ray.io/tpu-slice-name": "test-slice-0",
"ray.io/tpu-worker-id": "1",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
cluster.add_node(
num_cpus=2,
resources={"TPU": 4, f"TPU-{pod_type}-head": 1},
env_vars={**slice_0_env_common, "TPU_WORKER_ID": "0"},
labels=slice_0_head_labels,
)
cluster.add_node(
num_cpus=2,
resources={"TPU": 4},
env_vars={**slice_0_env_common, "TPU_WORKER_ID": "1"},
labels=slice_0_worker_labels,
)

slice_1_env_common = {
"TPU_NAME": "test-slice-1",
"TPU_ACCELERATOR_TYPE": pod_type,
"TPU_TOPOLOGY": topology,
}
slice_1_head_labels = {
"ray.io/tpu-slice-name": "test-slice-1",
"ray.io/tpu-worker-id": "0",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
slice_1_worker_labels = {
"ray.io/tpu-slice-name": "test-slice-1",
"ray.io/tpu-worker-id": "1",
"ray.io/tpu-pod-type": pod_type,
"ray.io/tpu-topology": topology,
}
cluster.add_node(
num_cpus=2,
resources={"TPU": 4, f"TPU-{pod_type}-head": 1},
env_vars={**slice_1_env_common, "TPU_WORKER_ID": "0"},
labels=slice_1_head_labels,
)
cluster.add_node(
num_cpus=2,
resources={"TPU": 4},
env_vars={**slice_1_env_common, "TPU_WORKER_ID": "1"},
labels=slice_1_worker_labels,
)

ray.init(address=cluster.address)
yield cluster
ray.shutdown()


def test_fetch_tpu_slice_name_from_pg(ray_tpu_cluster):
Expand Down Expand Up @@ -367,5 +366,37 @@ def test_get_tpu_version_invalid(invalid_type):
ray.util.tpu.get_tpu_version_from_type(invalid_type)


@pytest.mark.parametrize(
"topology, accelerator_type, num_workers, resources_per_worker, expected_slices",
[
# "2x2x1" has 4 chips, for 4 workers with TPU: 1 each we expect num_slices=1.
("2x2x1", "TPU-V4", 4, {"TPU": 1}, 1),
# "2x2x1" has 4 chips, for 8 workers with TPU: 1 each we expect num_slices=2.
("2x2x1", "v4", 8, {"TPU": 1}, 2),
# "2x2x2" has 8 chips and 2 hosts, defaulting to 1 TPU worker per host
# and requesting 4 workers, we expect num_slices=2.
("2x2x2", "TPU-V4", 4, None, 2),
# "2x2x4" has 16 chips and 4 hosts, defaulting to 1 TPU worker per host
# and requesting 4 workers, we expect num_slices=1.
("2x2x4", "TPU-V4", 4, None, 1),
# 0 workers requested -> fallback to 1 slice.
("2x2x1", "v4", 0, None, 1),
# Invalid topology -> fallback to 1 slice.
("", "v4", 4, {"TPU": 1}, 1),
("2x2x1", "", 4, {"TPU": 1}, 1),
],
)
def test_get_tpu_num_slices_for_workers(
topology, accelerator_type, num_workers, resources_per_worker, expected_slices
):
num_slices = ray.util.tpu.get_tpu_num_slices_for_workers(
topology=topology,
accelerator_type=accelerator_type,
num_workers=num_workers,
resources_per_worker=resources_per_worker,
)
assert num_slices == expected_slices


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
2 changes: 0 additions & 2 deletions python/ray/train/v2/_internal/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
from .backend_setup import BackendSetupCallback
from .datasets import DatasetsCallback
from .state_manager import StateManagerCallback
from .tpu_reservation_callback import TPUReservationCallback
from .working_dir_setup import WorkingDirectorySetupCallback

__all__ = [
"AcceleratorSetupCallback",
"BackendSetupCallback",
"DatasetsCallback",
"StateManagerCallback",
"TPUReservationCallback",
"WorkingDirectorySetupCallback",
]

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
if TYPE_CHECKING:
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint

from ray.util.tpu import get_tpu_num_slices_for_workers

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -327,13 +328,24 @@ def _start_worker_group(
except Exception as e:
return ControllerError(e)

# Calculate num_slices for the worker group if using TPU.
num_slices = 1
if scaling_config.use_tpu:
num_slices = get_tpu_num_slices_for_workers(
topology=scaling_config.topology,
accelerator_type=scaling_config.accelerator_type,
num_workers=num_workers,
resources_per_worker=resources_per_worker,
)

worker_group_context = WorkerGroupContext(
run_attempt_id=self._get_run_attempt_id(),
train_fn_ref=self._train_fn_ref,
num_workers=num_workers,
resources_per_worker=resources_per_worker,
placement_strategy=placement_strategy,
label_selector=label_selector,
num_slices=num_slices,
)
try:
self._worker_group = self.worker_group_cls.create(
Expand Down
30 changes: 25 additions & 5 deletions python/ray/train/v2/_internal/execution/worker_group/state.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from dataclasses import dataclass
from typing import List
from typing import List, Optional

import ray
from ray.actor import ActorHandle
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
from ray.train.v2._internal.execution.worker_group.worker import Worker
from ray.train.v2._internal.util import time_monotonic
from ray.util.placement_group import PlacementGroup, remove_placement_group
from ray.util.tpu import SlicePlacementGroup

logger = logging.getLogger(__name__)

Expand All @@ -28,16 +29,21 @@ class WorkerGroupState:
placement_group: PlacementGroup
workers: List[Worker]
sync_actor: ActorHandle
slice_placement_group: Optional[SlicePlacementGroup] = None

@property
def num_workers(self) -> int:
return len(self.workers)

def shutdown(self):
_shutdown_workers(self.workers)
_shutdown_placement_group(self.placement_group)
_shutdown_sync_actor(self.sync_actor)

if self.slice_placement_group:
self.slice_placement_group.shutdown()
else:
_shutdown_placement_group(self.placement_group)


class WorkerGroupStateBuilder:
"""Builder for WorkerGroupState.
Expand All @@ -58,13 +64,20 @@ def __init__(self):
self.placement_group = None
self.workers = None
self.sync_actor = None
self.slice_placement_group = None

def with_placement_group(
self, placement_group: PlacementGroup
) -> "WorkerGroupStateBuilder":
self.placement_group = placement_group
return self

def with_slice_placement_group(
self, slice_placement_group: SlicePlacementGroup
) -> "WorkerGroupStateBuilder":
self.slice_placement_group = slice_placement_group
return self

def with_workers(self, workers: List[Worker]) -> "WorkerGroupStateBuilder":
self.workers = workers
return self
Expand All @@ -91,19 +104,26 @@ def build(self) -> WorkerGroupState:
placement_group=self.placement_group,
workers=self.workers,
sync_actor=self.sync_actor,
slice_placement_group=self.slice_placement_group,
)

def shutdown(self):
if self.workers:
_shutdown_workers(self.workers)
self.workers = None
if self.placement_group:
_shutdown_placement_group(self.placement_group)
self.placement_group = None

if self.sync_actor:
_shutdown_sync_actor(self.sync_actor)
self.sync_actor = None

if self.slice_placement_group:
self.slice_placement_group.shutdown()
self.slice_placement_group = None
self.placement_group = None
elif self.placement_group:
_shutdown_placement_group(self.placement_group)
self.placement_group = None


def _shutdown_workers(workers: List[Worker], patience_s: float = 5):
# Run the worker shutdown logic on each of the workers. This should
Expand Down
Loading