Skip to content

Add docs for multi-process run in Kubernetes #28317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/k8s.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Distributed run using K8s Jobset
name: Multi-process run using K8s

on:
push:
Expand Down
85 changes: 85 additions & 0 deletions docs/multi_process.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,83 @@ what it prints:

Woohoo, look at all those TPU cores!

### Kubernetes Example

Running multi-controller JAX on a Kubernetes cluster is almost identical in spirit to the GPU and TPU examples above: every pod runs the same Python program, JAX discovers its peers, and the cluster behaves like one giant machine.

1. **Container image** - start from a JAX-enabled image, e.g. one of the public JAX AI images on Google Artifact Registry ([TPU][google-artifact-tpu] / [GPU][google-artifact-gpu]) or NVIDIA ([NGC][nvidia-ngc] / [JAX-Toolbox][nvidia-jax-toolbox]).

2. **Workload type** - use either a [JobSet][k8s-jobset] or an [indexed Job][k8s-indexed-job]. Each replica corresponds to one JAX process.

3. **Service Account** - JAX needs permission to list the pods that belong to the job so that processes discover their peers. A minimal RBAC setup is provided in [examples/k8s/svc-acct.yaml][rbac-svc-acct].

Below is a [minimal JobSet][minimal-jobset] that launches two replicas. Replace the placeholders -
image, GPU count, and any private registry secrets - with values that match your environment.

```yaml
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: jaxjob
spec:
replicatedJobs:
- name: workers
template:
spec:
parallelism: 2
completions: 2
backoffLimit: 0
template:
spec:
serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml
restartPolicy: Never
imagePullSecrets:
# https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/
- name: null
containers:
- name: main
image: null # e.g. ghcr.io/nvidia/jax:jax
imagePullPolicy: Always
resources:
limits:
cpu: 1
# https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/
nvidia.com/gpu: null
command:
- python
args:
- -c
- |
import jax
jax.distributed.initialize()
print(jax.devices())
print(jax.local_devices())
assert jax.process_count() > 1
assert len(jax.devices()) > len(jax.local_devices())
```

Apply the manifest and watch the pods complete:

```bash
$ kubectl apply -f example.yaml
$ kubectl get pods -l jobset.sigs.k8s.io/jobset-name=jaxjob
NAME READY STATUS RESTARTS AGE
jaxjob-workers-0-0-xpx8l 0/1 Completed 0 8m32s
jaxjob-workers-0-1-ddkq8 0/1 Completed 0 8m32s
```

When the job finishes, inspect the logs to confirm that every process saw all accelerators:

```bash
$ kubectl logs -l jobset.sigs.k8s.io/jobset-name=jaxjob
[CudaDevice(id=0), CudaDevice(id=1)]
[CudaDevice(id=0)]
[CudaDevice(id=0), CudaDevice(id=1)]
[CudaDevice(id=1)]
```

Every pod should have the same set of global devices and a different set of local devices. At this point, you can replace the inline script with your real JAX program.

Once the processes are set up, we can start building global {class}`jax.Array`s
and running computations. The remaining Python code examples in this tutorial
are meant to be run on all processes simultaneously, after running
Expand Down Expand Up @@ -580,3 +657,11 @@ assert (np.all(
[distributed_arrays]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
[gpu_machines]: https://cloud.google.com/compute/docs/gpus
[unified_sharding]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
[google-artifact-tpu]: https://console.cloud.google.com/artifacts/docker/cloud-tpu-images/us/jax-ai-image/tpu
[google-artifact-gpu]: https://console.cloud.google.com/artifacts/docker/deeplearning-images/us-central1/jax-ai-image/gpu
[nvidia-ngc]: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax
[nvidia-jax-toolbox]: https://github.com/NVIDIA/JAX-Toolbox
[k8s-jobset]: https://github.com/kubernetes-sigs/jobset
[k8s-indexed-job]: https://kubernetes.io/docs/concepts/workloads/controllers/job/#parallel-jobs
[rbac-svc-acct]: https://github.com/jax-ml/jax/blob/main/examples/k8s/svc-acct.yaml
[minimal-jobset]: https://github.com/jax-ml/jax/blob/main/examples/k8s/example.yaml
4 changes: 3 additions & 1 deletion jax/_src/clusters/k8s_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def _handle_api_exception(cls):
"this job does not have the permission for pod introspection. Please "
"either grant the default SA permission to read pod info, or create a "
"dedicated service account with the permission and associated with "
"the job. For more details, see <PLACERHOLDER_LINK>.",
"the job. For an example on setting up the service account, see the "
"example/k8s directory in the JAX repo. For more details, please refer to "
"https://docs.jax.dev/en/latest/multi_process.html#kubernetes-example",
width=80
))
raise RuntimeError('\n'.join(err_msg)) from e
Expand Down
Loading