Skip to content

Expose local/global ExchangeTopologies timeouts for PJRT CPU client. #29384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jaxlib/_jax/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ def get_tfrt_cpu_client(
num_nodes: int = ...,
collectives: CpuCollectives | None = ...,
num_devices: int | None = ...,
get_local_topology_timeout_minutes: int | None = ...,
get_global_topology_timeout_minutes: int | None = ...,
) -> Client: ...
def get_mock_gpu_client(
asynchronous: bool = ...,
Expand Down
17 changes: 15 additions & 2 deletions jaxlib/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ NB_MODULE(_jax, m) {
std::shared_ptr<DistributedRuntimeClient> distributed_client,
int node_id, int num_nodes,
std::shared_ptr<xla::cpu::CpuCollectives> collectives,
std::optional<int> num_devices) -> nb_class_ptr<PyClient> {
std::optional<int> num_devices,
std::optional<int> get_local_topology_timeout_minutes,
std::optional<int> get_global_topology_timeout_minutes)
-> nb_class_ptr<PyClient> {
std::unique_ptr<ifrt::PjRtClient> ifrt_client;
{
nb::gil_scoped_release gil_release;
Expand All @@ -357,6 +360,14 @@ NB_MODULE(_jax, m) {
ifrt_options.process_id = node_id;
ifrt_options.num_processes = num_nodes;
}
if (get_local_topology_timeout_minutes.has_value()) {
ifrt_options.get_local_topology_timeout =
absl::Minutes(*get_local_topology_timeout_minutes);
}
if (get_global_topology_timeout_minutes.has_value()) {
ifrt_options.get_global_topology_timeout =
absl::Minutes(*get_global_topology_timeout_minutes);
}
ifrt_client =
ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options)));
}
Expand All @@ -366,7 +377,9 @@ NB_MODULE(_jax, m) {
nb::arg("node_id") = 0, nb::arg("num_nodes") = 1,
nb::arg("collectives").none() =
std::shared_ptr<xla::cpu::CpuCollectives>(),
nb::arg("num_devices").none() = std::nullopt);
nb::arg("num_devices").none() = std::nullopt,
nb::arg("get_local_topology_timeout_minutes").none() = std::nullopt,
nb::arg("get_global_topology_timeout_minutes").none() = std::nullopt);
m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool {
absl::StatusOr<const PJRT_Api*> pjrt_api = pjrt::PjrtApi(platform_name);
return pjrt_api.ok();
Expand Down
6 changes: 5 additions & 1 deletion jaxlib/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version.
_version = 351
_version = 352

# An internal increasing version number for protecting jaxlib code against
# ifrt changes.
Expand All @@ -68,6 +68,8 @@ def make_cpu_client(
num_nodes=1,
collectives=None,
num_devices=None,
get_local_topology_timeout_minutes=None,
get_global_topology_timeout_minutes=None,
) -> Client:
register_custom_call_handler('cpu', _xla.register_custom_call_target)
register_custom_type_id_handler('cpu', _xla.register_custom_type_id)
Expand All @@ -78,6 +80,8 @@ def make_cpu_client(
num_nodes=num_nodes,
collectives=collectives,
num_devices=num_devices,
get_local_topology_timeout_minutes=get_local_topology_timeout_minutes,
get_global_topology_timeout_minutes=get_global_topology_timeout_minutes,
)


Expand Down
2 changes: 2 additions & 0 deletions jaxlib/xla_client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def make_cpu_client(
num_nodes: int = ...,
collectives: _xla.CpuCollectives | None = ...,
num_devices: int | None = ...,
get_local_topology_timeout_minutes: int | None = ...,
get_global_topology_timeout_minutes: int | None = ...,
) -> Client: ...
def make_gpu_client(
distributed_client: DistributedRuntimeClient | None = ...,
Expand Down
Loading