Skip to content

feat: slurm detector + multigpu single process #891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 13, 2025
7 changes: 5 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ pages = [
"Getting Started" => "introduction/index.md",
"Configuration" => "introduction/configuration.md",
],
"Tutorials" =>
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
"Tutorials" => [
"Overview" => "tutorials/index.md",
"Profiling" => "tutorials/profiling.md",
"Distributed" => "tutorials/multihost.md",
],
"API Reference" => [
"Reactant API" => "api/api.md",
"Ops" => "api/ops.md",
Expand Down
2 changes: 2 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export default defineConfig({
items: [
{text: "Overview", link: "/tutorials/"},
{text: "Profiling", link: "/tutorials/profiling"},
{text: "Distributed", link: "/tutorials/multihost"},
],
},
{
Expand Down Expand Up @@ -122,6 +123,7 @@ export default defineConfig({
items: [
{ text: "Overview", link: "/tutorials/" },
{ text: "Profiling", link: "/tutorials/profiling" },
{ text: "Distributed", link: "/tutorials/multihost" },
],
},
"/api/": {
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/sharding.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
CollapsedDocStrings = true
```

# Sharding API
# [Sharding API](@id sharding-api)

`Reactant.Sharding` module provides a high-level API to construct MLIR operations with
support for sharding.
Expand Down
1 change: 1 addition & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tutorials

- [Profiling](@ref profiling).
- [Multi-Host Environments](@ref distributed).

We are currently working on adding more tutorials to Reactant!! Please check back soon!
82 changes: 82 additions & 0 deletions docs/src/tutorials/multihost.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# [Multi-Host Environments](@ref distributed)

!!! tip "Use XLA IFRT Runtime"

While PJRT does support some minimal distributed capabilities on CUDA GPUs, distributed
support in Reactant is primarily provided via IFRT. Before loading Reactant, set the
"xla_runtime" preference to be "IFRT". This can be done with:

```julia
using Preferences, UUIDs

Preferences.set_preference!(
UUID("3c362404-f566-11ee-1572-e11a4b42c853"),
"xla_runtime" => "IFRT"
)
```

At the top of your code, just after loading Reactant and before running any Reactant related
operations, run `Reactant.Distributed.initialize()`.

!!! tip "Enable debug logging for debugging"

Reactant emits a lot of useful debugging information when setting up the Distributed
Runtime. This can be printing by setting the env var `JULIA_DEBUG` to contain
`Reactant`.

After this simply setup your code with [`Reactant.Sharding`](@ref sharding-api) and the code
will run on multiple devices across multiple nodes.

## Example Slurm Script for Multi-Host Matrix Multiplication

::: code-group

```bash [main.sbatch]
#!/bin/bash -l
#
#SBATCH --job-name=matmul-sharding-reactant
#SBATCH --time=00:20:00
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=72
#SBATCH --account=<account>
#SBATCH --constraint=gpu

export JULIA_DEBUG="Reactant,Reactant_jll"

srun --preserve-env bash ./matmul.sh
```

```bash [matmul.sh]
#!/bin/bash -l

# Important else XLA might hang indefinitely
unset no_proxy http_proxy https_proxy NO_PROXY HTTP_PROXY HTTPS_PROXY

julia --project=. --threads=auto matmul_sharded.jl
```

```julia [matmul_sharded.jl]
using Reactant

Reactant.Distributed.initialize(; single_gpu_per_process=false)

@assert length(Reactant.devices()) >= 2

N = min((length(Reactant.devices()) ÷ 2) * 2, 8)

mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))

x = reshape(collect(Float32, 1:64), 8, 8)
y = reshape(collect(Float32, 1:64), 8, 8)

x_ra = Reactant.to_rarray(x; sharding)
y_ra = Reactant.to_rarray(y; sharding)

res = @jit x_ra * y_ra
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test like @assert isapprox(res, x * y) or something like that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will fail for now, since we don't all-gather data across the processes, so only the data from addressable devices are accessible:

Example output from running on 2 nodes

┌ Warning: Not all devices are addressable. Currently we only fill in the data for addressable devices. Remaining slices of data in `data` are left untouched.
└ @ Reactant.XLA.IFRT ~/reactant/Reactant.jl/src/xla/IFRT/Array.jl:156
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
 1524.0  3636.0  5748.0  7860.0   9972.0  12084.0  14196.0  16308.0
 1560.0  3736.0  5912.0  8088.0  10264.0  12440.0  14616.0  16792.0
 1596.0  3836.0  6076.0  8316.0  10556.0  12796.0  15036.0  17276.0
 1632.0  3936.0  6240.0  8544.0  10848.0  13152.0  15456.0  17760.0
┌ Warning: Not all devices are addressable. Currently we only fill in the data for addressable devices. Remaining slices of data in `data` are left untouched.
└ @ Reactant.XLA.IFRT ~/reactant/Reactant.jl/src/xla/IFRT/Array.jl:156
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
 1380.0  3236.0  5092.0  6948.0  8804.0  10660.0  12516.0  14372.0
 1416.0  3336.0  5256.0  7176.0  9096.0  11016.0  12936.0  14856.0
 1452.0  3436.0  5420.0  7404.0  9388.0  11372.0  13356.0  15340.0
 1488.0  3536.0  5584.0  7632.0  9680.0  11728.0  13776.0  15824.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0

I am trying to figure out a nicer way to allgather the data


display(res)
```

:::
110 changes: 103 additions & 7 deletions src/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@ function initialize(;
coordinator_address::Union{Nothing,String}=nothing,
num_processes::Union{Nothing,Integer}=nothing,
process_id::Union{Nothing,Integer}=nothing,
single_gpu_per_process::Bool=true,
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
initialization_timeout_in_seconds::Integer=300,
kwargs...,
)
if isinteractive()
@warn "Reactant.Distributed.initialize() should not be called in interactive mode. \
Use Reactant.Distributed.initialize() in a script instead."
end

@assert !initialized[] "`Distributed.initialize` has already been called"

(coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params(;
Expand All @@ -20,6 +26,7 @@ function initialize(;
process_id,
local_gpu_device_ids,
initialization_timeout_in_seconds,
single_gpu_per_process,
)

@debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids
Expand All @@ -43,6 +50,8 @@ struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end

struct MPIEnvDetector <: AbstractClusterEnvDetector end

struct SlurmEnvDetector <: AbstractClusterEnvDetector end

# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py

is_env_present(::AbstractClusterEnvDetector) = false
Expand All @@ -53,12 +62,19 @@ function get_process_id end
function get_local_process_id end

function auto_detect_unset_distributed_params(;
detector_list=[OpenMPIORTEEnvDetector(), OpenMPIPMIXEnvDetector(), MPIEnvDetector()],
detector_list=[
SlurmEnvDetector(),
OpenMPIORTEEnvDetector(),
MPIEnvDetector(),
# Keep this at the end since parsing for this is a bit flaky
OpenMPIPMIXEnvDetector(),
],
coordinator_address::Union{Nothing,String}=nothing,
num_processes::Union{Nothing,Integer}=nothing,
process_id::Union{Nothing,Integer}=nothing,
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
initialization_timeout_in_seconds::Integer=300,
single_gpu_per_process::Bool=true,
)
if all(
Base.Fix2(!==, nothing),
Expand Down Expand Up @@ -91,7 +107,7 @@ function auto_detect_unset_distributed_params(;
process_id = get_process_id(detector)
end

if local_gpu_device_ids === nothing
if local_gpu_device_ids === nothing && single_gpu_per_process
local_gpu_device_ids = [get_local_process_id(detector)]
end

Expand All @@ -108,16 +124,18 @@ const _PMIX_SERVER_URI = (
"PMIX_SERVER_URI41",
"PMIX_SERVER_URI21",
)
const _PMIX_NAMESPACE = "PMIX_NAMESPACE"
const _PRTERUN = "PRTE_LAUNCHED"
const _PMIX_VERSION = "PMIX_VERSION"
const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE"
const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK"
const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK"

is_env_present(::OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI)
is_env_present(::OpenMPIPMIXEnvDetector) = any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
is_env_present(::OpenMPIPMIXEnvDetector) = haskey(ENV, _PMIX_NAMESPACE)

function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
orte_uri = ENV[_ORTE_URI]

job_id = parse(Int, split(orte_uri, '.'; limit=2)[1])
port = job_id % 2^12 + (65535 - 2^12 + 1)

Expand All @@ -132,11 +150,48 @@ function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
return "$(launcher_ip):$(port)"
end

function _throw_pmix_env_error(msg)
msg = msg * " Open an issue on Reactant with the relevant PMIX Enviroment Variables \
(you might want to obfuscate identifiable variables from this log \
before opening an issue)\n\n"
for (var, val) in [var => val for (var, val) in ENV if startswith(var, "PMIX")]
msg *= " * $var => $val.\n"
end
return error(msg)
end

function get_coordinator_address(::OpenMPIPMIXEnvDetector, ::Integer)
varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
pmix_uri = ENV[_PMIX_SERVER_URI[varname]]
pmix_version = parse(VersionNumber, ENV[_PMIX_VERSION])
pmix_uri = ENV[_PMIX_SERVER_URI[findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)]]
@debug "PMIX VERSION: $(pmix_version)"
if v"5" ≤ pmix_version < v"6"
return get_coordinator_address_pmixv5(pmix_uri)
elseif v"2" ≤ pmix_version < v"4"
return get_coordinator_address_pmixv2_or_3(pmix_uri)
else
_throw_pmix_env_error("Unsupported PMIX version: $(pmix_version).")
end
end

function get_coordinator_address_pmixv2_or_3(pmix_uri)
pre_semicolon = first(split(pmix_uri, ";"))
if startswith(pre_semicolon, "pmix-server.")
job_id = parse(Int, first(split(last(split(pre_semicolon, '.'; limit=2)))))
elseif contains(pre_semicolon, ".")
job_id = parse(Int, first(split(pre_semicolon, '.')))
else
_throw_pmix_env_error("Could not parse coordinator address from Open MPI \
environment.")
end
return get_coordinator_address_from_pmix_uri(pmix_uri, job_id)
end

job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1])
function get_coordinator_address_pmixv5(pmix_uri)
job_id = parse(Int, first(split(last(split(pmix_uri, '-'; limit=3)), "@"; limit=2)))
return get_coordinator_address_from_pmix_uri(pmix_uri, job_id)
end

function get_coordinator_address_from_pmix_uri(pmix_uri, job_id)
port = job_id % 2^12 + (65535 - 2^12 + 1)

launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri)
Expand All @@ -159,4 +214,45 @@ function get_local_process_id(::AbstractOMPIClusterEnvDetector)
return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID])
end

# SlurmEnvDetector
# Based on https://github.com/jax-ml/jax/blob/d89835acbacec938971400d6fa54ea6dd5efe76c/jax/_src/clusters/slurm_cluster.py#L3
const _SLURM_JOB_ID = "SLURM_JOB_ID"
const _SLURM_NODELIST = "SLURM_STEP_NODELIST"
const _SLURM_PROCESS_COUNT = "SLURM_NTASKS"
const _SLURM_PROCESS_ID = "SLURM_PROCID"
const _SLURM_LOCAL_PROCESS_ID = "SLURM_LOCALID"
const _SLURM_NUM_NODES = "SLURM_STEP_NUM_NODES"

is_env_present(::SlurmEnvDetector) = haskey(ENV, _SLURM_JOB_ID)

function get_coordinator_address(::SlurmEnvDetector, ::Integer)
port = parse(Int, ENV[_SLURM_JOB_ID]) % 2^12 + (65535 - 2^12 + 1)

# Parse the first hostname of the job
# If we are looking for 'node001',
# node_list potential formats are 'node001', 'node001,host2',
# 'node[001-0015],host2', and 'node[001,007-015],host2'.
node_list = ENV[_SLURM_NODELIST]
ind = findfirst(Base.Fix2(in, (',', '[')), node_list)
ind = isnothing(ind) ? length(node_list) + 1 : ind

if ind == length(node_list) + 1 || node_list[ind] == ','
# 'node001' or 'node001,host2'
return "$(node_list[1:ind-1]):$(port)"
else
# 'node[001-0015],host2' or 'node[001,007-015],host2'
prefix = node_list[1:(ind - 1)]
suffix = node_list[(ind + 1):end]
ind2 = findfirst(Base.Fix2(in, (',', '-')), suffix)
ind2 = isnothing(ind2) ? length(suffix) : ind2
return "$(prefix)$(suffix[1:ind2-1]):$(port)"
end
end

get_process_count(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_PROCESS_COUNT])

get_process_id(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_PROCESS_ID])

get_local_process_id(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_LOCAL_PROCESS_ID])

end
6 changes: 4 additions & 2 deletions src/xla/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function update!(
coordinator_address::String,
num_processes::Int,
process_id::Int,
local_gpu_device_ids::Vector{Int},
local_gpu_device_ids::Union{Nothing,Vector{Int}},
coordinator_bind_address::Union{Nothing,String}=nothing,
cluster_register_timeout_in_minutes::Integer=60,
rpc_timeout_in_seconds::Integer=120,
Expand All @@ -141,7 +141,9 @@ function update!(
@assert 0 ≤ process_id < num_processes

state.coordinator_address = coordinator_address
state.local_gpu_device_ids = local_gpu_device_ids
if local_gpu_device_ids !== nothing
state.local_gpu_device_ids = local_gpu_device_ids
end
state.process_id = process_id
state.num_processes = num_processes

Expand Down
5 changes: 3 additions & 2 deletions src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ function XLA.buffer_on_cpu(::Array)
end

function XLA.to_host(buffer::Array, data, reactant_sharding)
if length(XLA.devices(XLA.sharding(buffer))) == 1
reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding)

if reactant_sharding isa Reactant.Sharding.NoSharding
GC.@preserve buffer data begin
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
Expand All @@ -147,7 +149,6 @@ function XLA.to_host(buffer::Array, data, reactant_sharding)
return data
end

reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding)
@assert reactant_sharding isa Reactant.Sharding.HloSharding
client = XLA.client(buffer)
all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids)
Expand Down
Loading
Loading