Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# We use an arbitrarily chosen port for the coordinator since we cannot
# rely on communication to choose one in real time.
coordinator_port = '8476'
coordinator_port = '8482'

metadata_response_code_success = 200

Expand Down Expand Up @@ -95,7 +95,7 @@ def is_env_present(cls) -> bool:
return False

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
# For both GCE via QueuedResources and GKE via JobSet, the
# Megascale coordinator address is set as the host with process id = 0,
# so can be used as the jax distributed system coordinator.
Expand All @@ -108,7 +108,8 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str:
coordinator_address = coordinator_address.split(':')[0]
logger.debug("TPU Cluster using coordinator address: %s", coordinator_address)
cls.wait_for_coordinator(coordinator_address, timeout_secs)
return f'{coordinator_address}:{coordinator_port}'
port = override_coordinator_port or coordinator_port
return f'{coordinator_address}:{port}'

@classmethod
def wait_for_coordinator(cls, coordinator_address, timeout_secs):
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from collections.abc import Sequence
import os
import logging
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

Expand Down Expand Up @@ -69,7 +70,8 @@ def auto_detect_unset_distributed_params(cls,
if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout)
coordinator_port = os.environ.get("JAX_COORDINATOR_PORT")
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout, override_coordinator_port=coordinator_port)
if num_processes is None:
num_processes = env.get_process_count()
if process_id is None:
Expand All @@ -95,7 +97,7 @@ def is_env_present(cls) -> bool:
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
"""Returns address and port used by JAX to bootstrap.

Process id 0 will open a tcp socket at "hostname:port" where
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/clusters/k8s_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _controller(cls):
)

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
controller = cls._controller()
job = cls._job()
pod = cls._pod()
Expand Down Expand Up @@ -254,9 +254,10 @@ def wait_for_host(hostname):

wait_for_host(coordinator_hostname)

port = override_coordinator_port or cls._coordinator_port
return '{hostname}:{port}'.format(
hostname=coordinator_hostname,
port=cls._coordinator_port
port=port
)

else:
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def is_env_present(cls) -> bool:
return find_spec("mpi4py") is not None

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:

# Using mpi4py, figure out rank 0 and it's hostname.
# Then broadcast the hostname and port.
Expand All @@ -49,8 +49,11 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str:
# Order all the hostnames, and find unique ones
hostname = socket.gethostname()

# Apparently, we want to pick a port in an ephemeral range...
port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1)
if override_coordinator_port:
port_id = override_coordinator_port
else:
# Apparently, we want to pick a port in an ephemeral range...
port_id = str(hash(hostname) % 2**12 + (65535 - 2**12 + 1))

hostname = f'{hostname}:{port_id}'

Expand Down
17 changes: 10 additions & 7 deletions jax/_src/clusters/ompi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@ def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
# Examples of orte_uri:
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370
orte_uri = os.environ[_ORTE_URI]
job_id_str = orte_uri.split('.', maxsplit=1)[0]
# The jobid is always a multiple of 2^12, let's divide it by 2^12
# to reduce likelihood of port conflict between jobs
job_id = int(job_id_str) // 2**12
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = job_id % 2**12 + (65535 - 2**12 + 1)
if override_coordinator_port:
port = override_coordinator_port
else:
job_id_str = orte_uri.split('.', maxsplit=1)[0]
# The jobid is always a multiple of 2^12, let's divide it by 2^12
# to reduce likelihood of port conflict between jobs
job_id = int(job_id_str) // 2**12
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = str(job_id % 2**12 + (65535 - 2**12 + 1))
launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
if launcher_ip_match is None:
raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.')
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/clusters/slurm_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ def is_env_present(cls) -> bool:
(_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID))

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1)
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
if override_coordinator_port:
port = override_coordinator_port
else:
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = str(int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1))

# Parse the first hostname of the job
# If we are looking for 'node001',
Expand Down
Loading