Skip to content

Commit 92c6b0e

Browse files
Support overrideable jax coordinator port and modify default port for Cloud TPU jax.distributed.initialize()
PiperOrigin-RevId: 826594509
1 parent 9c9626c commit 92c6b0e

File tree

6 files changed

+33
-20
lines changed

6 files changed

+33
-20
lines changed

jax/_src/clusters/cloud_tpu_cluster.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# We use an arbitrarily chosen port for the coordinator since we cannot
2828
# rely on communication to choose one in real time.
29-
coordinator_port = '8476'
29+
coordinator_port = '8482'
3030

3131
metadata_response_code_success = 200
3232

@@ -95,7 +95,7 @@ def is_env_present(cls) -> bool:
9595
return False
9696

9797
@classmethod
98-
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
98+
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
9999
# For both GCE via QueuedResources and GKE via JobSet, the
100100
# Megascale coordinator address is set as the host with process id = 0,
101101
# so can be used as the jax distributed system coordinator.
@@ -108,7 +108,8 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str:
108108
coordinator_address = coordinator_address.split(':')[0]
109109
logger.debug("TPU Cluster using coordinator address: %s", coordinator_address)
110110
cls.wait_for_coordinator(coordinator_address, timeout_secs)
111-
return f'{coordinator_address}:{coordinator_port}'
111+
port = override_coordinator_port or coordinator_port
112+
return f'{coordinator_address}:{port}'
112113

113114
@classmethod
114115
def wait_for_coordinator(cls, coordinator_address, timeout_secs):

jax/_src/clusters/cluster.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from collections.abc import Sequence
18+
import os
1819
import logging
1920
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
2021

@@ -69,7 +70,8 @@ def auto_detect_unset_distributed_params(cls,
6970
if env:
7071
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
7172
if coordinator_address is None:
72-
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout)
73+
coordinator_port = os.environ.get("JAX_COORDINATOR_PORT")
74+
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout, override_coordinator_port=coordinator_port)
7375
if num_processes is None:
7476
num_processes = env.get_process_count()
7577
if process_id is None:
@@ -95,7 +97,7 @@ def is_env_present(cls) -> bool:
9597
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")
9698

9799
@classmethod
98-
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
100+
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
99101
"""Returns address and port used by JAX to bootstrap.
100102
101103
Process id 0 will open a tcp socket at "hostname:port" where

jax/_src/clusters/k8s_cluster.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _controller(cls):
173173
)
174174

175175
@classmethod
176-
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
176+
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
177177
controller = cls._controller()
178178
job = cls._job()
179179
pod = cls._pod()
@@ -254,9 +254,10 @@ def wait_for_host(hostname):
254254

255255
wait_for_host(coordinator_hostname)
256256

257+
port = override_coordinator_port or cls._coordinator_port
257258
return '{hostname}:{port}'.format(
258259
hostname=coordinator_hostname,
259-
port=cls._coordinator_port
260+
port=port
260261
)
261262

262263
else:

jax/_src/clusters/mpi4py_cluster.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def is_env_present(cls) -> bool:
3333
return find_spec("mpi4py") is not None
3434

3535
@classmethod
36-
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
36+
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
3737

3838
# Using mpi4py, figure out rank 0 and it's hostname.
3939
# Then broadcast the hostname and port.
@@ -49,8 +49,11 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str:
4949
# Order all the hostnames, and find unique ones
5050
hostname = socket.gethostname()
5151

52-
# Apparently, we want to pick a port in an ephemeral range...
53-
port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1)
52+
if override_coordinator_port:
53+
port_id = override_coordinator_port
54+
else:
55+
# Apparently, we want to pick a port in an ephemeral range...
56+
port_id = str(hash(hostname) % 2**12 + (65535 - 2**12 + 1))
5457

5558
hostname = f'{hostname}:{port_id}'
5659

jax/_src/clusters/ompi_cluster.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,20 @@ def is_env_present(cls) -> bool:
3333
return _ORTE_URI in os.environ
3434

3535
@classmethod
36-
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
36+
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
3737
# Examples of orte_uri:
3838
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
3939
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370
4040
orte_uri = os.environ[_ORTE_URI]
41-
job_id_str = orte_uri.split('.', maxsplit=1)[0]
42-
# The jobid is always a multiple of 2^12, let's divide it by 2^12
43-
# to reduce likelihood of port conflict between jobs
44-
job_id = int(job_id_str) // 2**12
45-
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
46-
port = job_id % 2**12 + (65535 - 2**12 + 1)
41+
if override_coordinator_port:
42+
port = override_coordinator_port
43+
else:
44+
job_id_str = orte_uri.split('.', maxsplit=1)[0]
45+
# The jobid is always a multiple of 2^12, let's divide it by 2^12
46+
# to reduce likelihood of port conflict between jobs
47+
job_id = int(job_id_str) // 2**12
48+
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
49+
port = str(job_id % 2**12 + (65535 - 2**12 + 1))
4750
launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
4851
if launcher_ip_match is None:
4952
raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.')

jax/_src/clusters/slurm_cluster.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ def is_env_present(cls) -> bool:
3434
(_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID))
3535

3636
@classmethod
37-
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
38-
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
39-
port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1)
37+
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
38+
if override_coordinator_port:
39+
port = override_coordinator_port
40+
else:
41+
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
42+
port = str(int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1))
4043

4144
# Parse the first hostname of the job
4245
# If we are looking for 'node001',

0 commit comments

Comments
 (0)