Skip to content

[JAX] Extend colocated_cpu_devices to accept Mesh besides devices #29387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions jax/experimental/colocated_python/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,43 @@
from __future__ import annotations

import collections
from typing import Any, Callable, Sequence, Type
from typing import Any, Callable, Sequence, Type, overload

import jax
from jax._src import api_util
from jax._src import util
from jax.experimental.colocated_python.func import make_callable
from jax.experimental.colocated_python.obj import wrap_class
import numpy as np


@overload
def colocated_cpu_devices(
devices: Sequence[jax.Device],
devices_or_mesh: Sequence[jax.Device],
) -> Sequence[jax.Device]:
"""Finds CPU devices colocated with the given devices."""
if not isinstance(devices, tuple):
devices = tuple(devices)
...


@overload
def colocated_cpu_devices(
devices_or_mesh: jax.sharding.Mesh,
) -> jax.sharding.Mesh:
...


def colocated_cpu_devices(devices_or_mesh):
"""Finds devices or a mesh that has CPU devices colocated with the given devices or mesh."""
if isinstance(devices_or_mesh, jax.sharding.Mesh):
return _colocated_cpu_mesh_cached(devices_or_mesh)

if not isinstance(devices_or_mesh, tuple):
devices_or_mesh = tuple(devices_or_mesh)
try:
return _colocated_cpu_devices_cached(devices)
return _colocated_cpu_devices_cached(devices_or_mesh)
except (ValueError, AttributeError):
return _colocated_cpu_devices_cached_fallback_to_cpu_backend(devices)
return _colocated_cpu_devices_cached_fallback_to_cpu_backend(
devices_or_mesh
)


@util.cache(max_size=1024, trace_context_in_key=False)
Expand Down Expand Up @@ -78,6 +96,20 @@ def _colocated_cpu_devices_cached_fallback_to_cpu_backend(
]


@util.cache(max_size=1024, trace_context_in_key=False)
def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh:
"""Returns a CPU mesh that is similar to the given mesh but has colocated CPU devices."""
# Finding colocated CPU devices reuses the cache of `colocated_cpu_devices`
# called with devices. `_colocated_cpu_mesh` itself is also cached to avoid
# creating a new `Mesh` object repeatedly.
flat_cpu_devices = colocated_cpu_devices(tuple(mesh.devices.flat))
return jax.sharding.Mesh(
np.array(flat_cpu_devices).reshape(mesh.axis_sizes),
mesh.axis_names,
axis_types=mesh.axis_types,
)


def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]:
"""Executes the given Python function on the same devices as the arguments."""
return make_callable(
Expand Down
14 changes: 14 additions & 0 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def setUp(self):
" requires NumPy 2.0.0 or later"
)

def testColocatedCpuDevices(self):
mesh = jax.sharding.Mesh(
np.array(jax.local_devices()[:1]).reshape((1, 1)), ("x", "y")
)
cpu_mesh1 = colocated_python.colocated_cpu_devices(mesh)

cpu_devices = colocated_python.colocated_cpu_devices(
jax.local_devices()[:1]
)
cpu_mesh2 = jax.sharding.Mesh(
np.array(cpu_devices).reshape((1, 1)), ("x", "y")
)
self.assertEqual(cpu_mesh1, cpu_mesh2)

def testMakeColocatedPythonProgram(self):
def add_one(x):
return x + 1
Expand Down
Loading