-
Notifications
You must be signed in to change notification settings - Fork 22
feat: slurm detector + multigpu single process #891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Example Scripts:#!/bin/bash -l
#
#SBATCH --job-name=smoke-test-sharding-reactant
#SBATCH --time=00:20:00
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=72
#SBATCH --account=<account>
#SBATCH --constraint=gpu
#SBATCH --partition=debug
#SBATCH --output slurm_logs/output-%j.out
#SBATCH --error slurm_logs/error-%j.out
#SBATCH --exclusive
export JULIA_DEBUG="Reactant,Reactant_jll"
srun --preserve-env bash ./matmul.sh #!/bin/bash -l
unset no_proxy http_proxy https_proxy NO_PROXY HTTP_PROXY HTTPS_PROXY
cd <path/to/dir>
<julia> \
--project=. \
--threads=auto \
test_matmul_sharded.jl using Reactant
Reactant.Distributed.initialize(; single_gpu_per_process=false)
@assert length(Reactant.devices()) >= 2
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)
mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))
x = reshape(collect(Float32, 1:64), 8, 8)
y = reshape(collect(Float32, 1:64), 8, 8)
x_ra = Reactant.to_rarray(x; sharding)
y_ra = Reactant.to_rarray(y; sharding)
res = @jit x_ra * y_ra
display(res) |
Can you add the script explicitly as a doc. At minimum to facilitate easier intra team communication rn |
On Leonardo, with 4 GPUs: julia> Reactant.Distributed.initialize(; single_gpu_per_process=false)
ERROR: BoundsError: attempt to access 2-element Vector{SubString{String}} at index [3]
Stacktrace:
[1] throw_boundserror(A::Vector{SubString{String}}, I::Tuple{Int64})
@ Base ./essentials.jl:14
[2] getindex
@ ./essentials.jl:916 [inlined]
[3] get_coordinator_address(::Reactant.Distributed.OpenMPIPMIXEnvDetector, ::Int64)
@ Reactant.Distributed ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:149
[4] auto_detect_unset_distributed_params(; detector_list::Vector{…}, coordinator_address::Nothing, num_processes::Nothing, process_id::Nothing, local_gpu_device_ids::Nothing, initialization_timeout_in_seconds::
Int64, single_gpu_per_process::Bool)
@ Reactant.Distributed ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:91
[5] auto_detect_unset_distributed_params
@ ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:59 [inlined]
[6] initialize(; coordinator_address::Nothing, num_processes::Nothing, process_id::Nothing, single_gpu_per_process::Bool, local_gpu_device_ids::Nothing, initialization_timeout_in_seconds::Int64, kwargs::@Kwargs
{})
@ Reactant.Distributed ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:18
[7] top-level scope
@ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> @assert length(Reactant.devices()) >= 2
2025-03-12 20:15:54.953843: I external/xla/xla/service/service.cc:152] XLA service 0x22ddc90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-12 20:15:54.953865: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-12 20:15:54.953869: I external/xla/xla/service/service.cc:160] StreamExecutor device (1): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-12 20:15:54.953873: I external/xla/xla/service/service.cc:160] StreamExecutor device (2): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-12 20:15:54.953876: I external/xla/xla/service/service.cc:160] StreamExecutor device (3): NVIDIA A100-SXM-64GB, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741806954.956636 3388504 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741806954.956688 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 0 for BFCAllocator.
I0000 00:00:1741806954.957151 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 1 for BFCAllocator.
I0000 00:00:1741806954.957163 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 2 for BFCAllocator.
I0000 00:00:1741806954.957175 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 3 for BFCAllocator.
I0000 00:00:1741806954.957186 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741806954.957197 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1741806954.957208 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 2 for CollectiveBFCAllocator.
I0000 00:00:1741806954.957218 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 3 for CollectiveBFCAllocator.
I0000 00:00:1741806955.144081 3388504 cuda_dnn.cc:529] Loaded cuDNN version 90400 But at least I got julia> res = @jit x_ra * y_ra
2025-03-12 20:16:24.968321: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
8×8 ConcretePJRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
1380.0 3236.0 5092.0 6948.0 8804.0 10660.0 12516.0 14372.0
1416.0 3336.0 5256.0 7176.0 9096.0 11016.0 12936.0 14856.0
1452.0 3436.0 5420.0 7404.0 9388.0 11372.0 13356.0 15340.0
1488.0 3536.0 5584.0 7632.0 9680.0 11728.0 13776.0 15824.0
1524.0 3636.0 5748.0 7860.0 9972.0 12084.0 14196.0 16308.0
1560.0 3736.0 5912.0 8088.0 10264.0 12440.0 14616.0 16792.0
1596.0 3836.0 6076.0 8316.0 10556.0 12796.0 15036.0 17276.0
1632.0 3936.0 6240.0 8544.0 10848.0 13152.0 15456.0 17760.0 |
what do the env vars:
look like? It is possible I screwed up the parsing for those |
julia> get.((ENV,), _PMIX_SERVER_URI, nothing)
("pmix-server.945570;tcp4://127.0.0.1:57215", "pmix-server.945570;tcp4://127.0.0.1:57215", nothing, nothing, "pmix-server.945570;tcp4://127.0.0.1:57215") |
x_ra = Reactant.to_rarray(x; sharding) | ||
y_ra = Reactant.to_rarray(y; sharding) | ||
|
||
res = @jit x_ra * y_ra |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a test like @assert isapprox(res, x * y)
or something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will fail for now, since we don't all-gather data across the processes, so only the data from addressable devices are accessible:
Example output from running on 2 nodes
┌ Warning: Not all devices are addressable. Currently we only fill in the data for addressable devices. Remaining slices of data in `data` are left untouched.
└ @ Reactant.XLA.IFRT ~/reactant/Reactant.jl/src/xla/IFRT/Array.jl:156
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1524.0 3636.0 5748.0 7860.0 9972.0 12084.0 14196.0 16308.0
1560.0 3736.0 5912.0 8088.0 10264.0 12440.0 14616.0 16792.0
1596.0 3836.0 6076.0 8316.0 10556.0 12796.0 15036.0 17276.0
1632.0 3936.0 6240.0 8544.0 10848.0 13152.0 15456.0 17760.0
┌ Warning: Not all devices are addressable. Currently we only fill in the data for addressable devices. Remaining slices of data in `data` are left untouched.
└ @ Reactant.XLA.IFRT ~/reactant/Reactant.jl/src/xla/IFRT/Array.jl:156
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
1380.0 3236.0 5092.0 6948.0 8804.0 10660.0 12516.0 14372.0
1416.0 3336.0 5256.0 7176.0 9096.0 11016.0 12936.0 14856.0
1452.0 3436.0 5420.0 7404.0 9388.0 11372.0 13356.0 15340.0
1488.0 3536.0 5584.0 7632.0 9680.0 11728.0 13776.0 15824.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
I am trying to figure out a nicer way to allgather the data
marking it as draft for now. I want to finish testing on leonardo before merging |
till I get access to leonardo, @giordano can you share the following env vars
xref jax-ml/jax#14576 there might be a mismatch between the version of PMIX I implemented the parser for and the one on leonardo |
With the latest round of fixes, older PMIX versions should be supported and if they are not we will print all the relevant env vars to implement support for that version |
82ba809
to
de78b7f
Compare
$ env | grep -E '^(OMPI|PRTE|PMIX)_'
PMIX_SYSTEM_TMPDIR=/tmp
PMIX_RANK=0
PMIX_SERVER_TMPDIR=/var/spool/slurmd/pmix.13836604.0/
PMIX_PTL_MODULE=tcp,usock
PMIX_HOSTNAME=lrdn2985.leonardo.local
PMIX_BFROP_BUFFER_TYPE=PMIX_BFROP_BUFFER_NON_DESC
PMIX_DSTORE_ESH_BASE_PATH=/var/spool/slurmd/pmix.13836604.0//pmix_dstor_ds12_1564944
PMIX_SERVER_URI3=pmix-server.1564944;tcp4://127.0.0.1:57439
PMIX_SERVER_URI2=pmix-server.1564944;tcp4://127.0.0.1:57439
PMIX_VERSION=3.1.5
PMIX_SERVER_URI21=pmix-server.1564944;tcp4://127.0.0.1:57439
PMIX_DSTORE_21_BASE_PATH=/var/spool/slurmd/pmix.13836604.0//pmix_dstor_ds21_1564944
PMIX_SECURITY_MODE=native
PMIX_NAMESPACE=slurm.pmix.13836604.0
PMIX_GDS_MODULE=ds21,ds12,hash |
Great. v2, v3, and v5 are currently supported. I couldn't find examples for v4 so I am throwing an error in that case. |
de78b7f
to
6151de7
Compare
Is it expected
|
Then it's stalling. Is there any warning being displayed for the proxy
variables?
…On Wed, 12 Mar, 2025, 18:43 Mosè Giordano, ***@***.***> wrote:
Is it expected Reactant.Distributed.initialize(;
single_gpu_per_process=false) takes a lot of time? It has been running
for me already for a 3 minutes, I thought it'd be almost immediate
—
Reply to this email directly, view it on GitHub
<#891 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHJF57XL3TNZQDCQIANG3I32UCZ7XAVCNFSM6AAAAABY4LTFSCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDOMJZGI4DKOBZGI>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
[image: giordano]*giordano* left a comment (EnzymeAD/Reactant.jl#891)
<#891 (comment)>
Is it expected Reactant.Distributed.initialize(;
single_gpu_per_process=false) takes a lot of time? It has been running
for me already for a 3 minutes, I thought it'd be almost immediate
—
Reply to this email directly, view it on GitHub
<#891 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHJF57XL3TNZQDCQIANG3I32UCZ7XAVCNFSM6AAAAABY4LTFSCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDOMJZGI4DKOBZGI>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
No, and I had checked earlier they aren't set at all. |
oof okay... let me wait till the access is confirmed and then deal with that |
6151de7
to
dc6e3e6
Compare
Oh wait, running in REPL won't work. It is probably indefinitely waiting for the other processes to connect. Can you try running using slurm/mpirun? I will add a check for interactive session and warn accordingly |
Add a check that will warn if users are calling |
650906c
to
eefb39f
Compare
Ok, thanks for the clarification. Some progress:
|
Add a Preferences file with
|
Ok, that worked, thanks. stderr:
stdout:
I presume one rank has |
No, first one has top 4 rows and the other one has bottom 4 rows. But this is good, it confirms that our setup works on Leonardo as well 🎉. Lets wait for the tests and then we should be good to go here. |
Last bit needed to run on ALPs. By default we run a single GPU per process, but if
single_gpu_per_process
is set to false, we will use all GPUs accessible locally in that process