Skip to content

feat: slurm detector + multigpu single process #891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 13, 2025
Merged

Conversation

avik-pal
Copy link
Collaborator

Last bit needed to run on ALPs. By default we run a single GPU per process, but if single_gpu_per_process is set to false, we will use all GPUs accessible locally in that process

@avik-pal avik-pal requested review from giordano and wsmoses March 12, 2025 18:57
@avik-pal
Copy link
Collaborator Author

Example Scripts:

#!/bin/bash -l
#
#SBATCH --job-name=smoke-test-sharding-reactant
#SBATCH --time=00:20:00
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=72
#SBATCH --account=<account>
#SBATCH --constraint=gpu
#SBATCH --partition=debug
#SBATCH --output slurm_logs/output-%j.out
#SBATCH --error slurm_logs/error-%j.out
#SBATCH --exclusive

export JULIA_DEBUG="Reactant,Reactant_jll"

srun --preserve-env bash ./matmul.sh
#!/bin/bash -l

unset no_proxy http_proxy https_proxy NO_PROXY HTTP_PROXY HTTPS_PROXY

cd <path/to/dir>

<julia> \
        --project=. \
        --threads=auto \
        test_matmul_sharded.jl
using Reactant

Reactant.Distributed.initialize(; single_gpu_per_process=false)

@assert length(Reactant.devices()) >= 2

N = min((length(Reactant.devices()) ÷ 2) * 2, 8)

mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))

x = reshape(collect(Float32, 1:64), 8, 8)
y = reshape(collect(Float32, 1:64), 8, 8)

x_ra = Reactant.to_rarray(x; sharding)
y_ra = Reactant.to_rarray(y; sharding)

res = @jit x_ra * y_ra

display(res)

@wsmoses
Copy link
Member

wsmoses commented Mar 12, 2025

Can you add the script explicitly as a doc. At minimum to facilitate easier intra team communication rn

@giordano
Copy link
Member

giordano commented Mar 12, 2025

On Leonardo, with 4 GPUs:

julia> Reactant.Distributed.initialize(; single_gpu_per_process=false)
ERROR: BoundsError: attempt to access 2-element Vector{SubString{String}} at index [3]
Stacktrace:
 [1] throw_boundserror(A::Vector{SubString{String}}, I::Tuple{Int64})
   @ Base ./essentials.jl:14
 [2] getindex
   @ ./essentials.jl:916 [inlined]
 [3] get_coordinator_address(::Reactant.Distributed.OpenMPIPMIXEnvDetector, ::Int64)
   @ Reactant.Distributed ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:149
 [4] auto_detect_unset_distributed_params(; detector_list::Vector{…}, coordinator_address::Nothing, num_processes::Nothing, process_id::Nothing, local_gpu_device_ids::Nothing, initialization_timeout_in_seconds::
Int64, single_gpu_per_process::Bool)
   @ Reactant.Distributed ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:91
 [5] auto_detect_unset_distributed_params
   @ ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:59 [inlined]
 [6] initialize(; coordinator_address::Nothing, num_processes::Nothing, process_id::Nothing, single_gpu_per_process::Bool, local_gpu_device_ids::Nothing, initialization_timeout_in_seconds::Int64, kwargs::@Kwargs
{})
   @ Reactant.Distributed ~/.julia/packages/Reactant/xIT8Z/src/Distributed.jl:18
 [7] top-level scope
   @ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> @assert length(Reactant.devices()) >= 2
2025-03-12 20:15:54.953843: I external/xla/xla/service/service.cc:152] XLA service 0x22ddc90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-12 20:15:54.953865: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-12 20:15:54.953869: I external/xla/xla/service/service.cc:160]   StreamExecutor device (1): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-12 20:15:54.953873: I external/xla/xla/service/service.cc:160]   StreamExecutor device (2): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-12 20:15:54.953876: I external/xla/xla/service/service.cc:160]   StreamExecutor device (3): NVIDIA A100-SXM-64GB, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741806954.956636 3388504 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741806954.956688 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 0 for BFCAllocator.
I0000 00:00:1741806954.957151 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 1 for BFCAllocator.
I0000 00:00:1741806954.957163 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 2 for BFCAllocator.
I0000 00:00:1741806954.957175 3388504 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 3 for BFCAllocator.
I0000 00:00:1741806954.957186 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741806954.957197 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1741806954.957208 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 2 for CollectiveBFCAllocator.
I0000 00:00:1741806954.957218 3388504 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 3 for CollectiveBFCAllocator.
I0000 00:00:1741806955.144081 3388504 cuda_dnn.cc:529] Loaded cuDNN version 90400

But at least I got res correctly:

julia> res = @jit x_ra * y_ra
2025-03-12 20:16:24.968321: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
8×8 ConcretePJRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
 1380.0  3236.0  5092.0  6948.0   8804.0  10660.0  12516.0  14372.0
 1416.0  3336.0  5256.0  7176.0   9096.0  11016.0  12936.0  14856.0
 1452.0  3436.0  5420.0  7404.0   9388.0  11372.0  13356.0  15340.0
 1488.0  3536.0  5584.0  7632.0   9680.0  11728.0  13776.0  15824.0
 1524.0  3636.0  5748.0  7860.0   9972.0  12084.0  14196.0  16308.0
 1560.0  3736.0  5912.0  8088.0  10264.0  12440.0  14616.0  16792.0
 1596.0  3836.0  6076.0  8316.0  10556.0  12796.0  15036.0  17276.0
 1632.0  3936.0  6240.0  8544.0  10848.0  13152.0  15456.0  17760.0

@avik-pal
Copy link
Collaborator Author

what do the env vars:

_PMIX_SERVER_URI = (
    "PMIX_SERVER_URI2",
    "PMIX_SERVER_URI3",
    "PMIX_SERVER_URI4",
    "PMIX_SERVER_URI41",
    "PMIX_SERVER_URI21",
)

look like? It is possible I screwed up the parsing for those

@giordano
Copy link
Member

julia> get.((ENV,), _PMIX_SERVER_URI, nothing)
("pmix-server.945570;tcp4://127.0.0.1:57215", "pmix-server.945570;tcp4://127.0.0.1:57215", nothing, nothing, "pmix-server.945570;tcp4://127.0.0.1:57215")

x_ra = Reactant.to_rarray(x; sharding)
y_ra = Reactant.to_rarray(y; sharding)

res = @jit x_ra * y_ra
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test like @assert isapprox(res, x * y) or something like that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will fail for now, since we don't all-gather data across the processes, so only the data from addressable devices are accessible:

Example output from running on 2 nodes

┌ Warning: Not all devices are addressable. Currently we only fill in the data for addressable devices. Remaining slices of data in `data` are left untouched.
└ @ Reactant.XLA.IFRT ~/reactant/Reactant.jl/src/xla/IFRT/Array.jl:156
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
 1524.0  3636.0  5748.0  7860.0   9972.0  12084.0  14196.0  16308.0
 1560.0  3736.0  5912.0  8088.0  10264.0  12440.0  14616.0  16792.0
 1596.0  3836.0  6076.0  8316.0  10556.0  12796.0  15036.0  17276.0
 1632.0  3936.0  6240.0  8544.0  10848.0  13152.0  15456.0  17760.0
┌ Warning: Not all devices are addressable. Currently we only fill in the data for addressable devices. Remaining slices of data in `data` are left untouched.
└ @ Reactant.XLA.IFRT ~/reactant/Reactant.jl/src/xla/IFRT/Array.jl:156
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
 1380.0  3236.0  5092.0  6948.0  8804.0  10660.0  12516.0  14372.0
 1416.0  3336.0  5256.0  7176.0  9096.0  11016.0  12936.0  14856.0
 1452.0  3436.0  5420.0  7404.0  9388.0  11372.0  13356.0  15340.0
 1488.0  3536.0  5584.0  7632.0  9680.0  11728.0  13776.0  15824.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0

I am trying to figure out a nicer way to allgather the data

@avik-pal avik-pal marked this pull request as draft March 12, 2025 20:18
@avik-pal
Copy link
Collaborator Author

marking it as draft for now. I want to finish testing on leonardo before merging

@avik-pal
Copy link
Collaborator Author

till I get access to leonardo, @giordano can you share the following env vars

OMPI_VERSION=5.0.0rc10
OMPI_TOOL_NAME=mpirun
PRTE_LAUNCHED=1
PMIX_NAMESPACE=prterun-%{hostname}-%{num_job_id}@1
PMIX_RANK=0
PMIX_SERVER_URI41=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI4=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI3=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI2=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI21=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_TMPDIR=/tmp/prte.%{hostname}.%{uid}/dvm.%{num_job_id}
PMIX_SYSTEM_TMPDIR=/tmp
OMPI_COMM_WORLD_SIZE=1
OMPI_WORLD_SIZE=1
OMPI_MCA_num_procs=1
OMPI_COMM_WORLD_RANK=0
OMPI_COMM_WORLD_LOCAL_RANK=0
OMPI_COMM_WORLD_NODE_RANK=0
PMIX_HOSTNAME=%{hostname}

xref jax-ml/jax#14576

there might be a mismatch between the version of PMIX I implemented the parser for and the one on leonardo

@avik-pal
Copy link
Collaborator Author

With the latest round of fixes, older PMIX versions should be supported and if they are not we will print all the relevant env vars to implement support for that version

@giordano
Copy link
Member

$ env | grep -E '^(OMPI|PRTE|PMIX)_'
PMIX_SYSTEM_TMPDIR=/tmp
PMIX_RANK=0
PMIX_SERVER_TMPDIR=/var/spool/slurmd/pmix.13836604.0/
PMIX_PTL_MODULE=tcp,usock
PMIX_HOSTNAME=lrdn2985.leonardo.local
PMIX_BFROP_BUFFER_TYPE=PMIX_BFROP_BUFFER_NON_DESC
PMIX_DSTORE_ESH_BASE_PATH=/var/spool/slurmd/pmix.13836604.0//pmix_dstor_ds12_1564944
PMIX_SERVER_URI3=pmix-server.1564944;tcp4://127.0.0.1:57439
PMIX_SERVER_URI2=pmix-server.1564944;tcp4://127.0.0.1:57439
PMIX_VERSION=3.1.5
PMIX_SERVER_URI21=pmix-server.1564944;tcp4://127.0.0.1:57439
PMIX_DSTORE_21_BASE_PATH=/var/spool/slurmd/pmix.13836604.0//pmix_dstor_ds21_1564944
PMIX_SECURITY_MODE=native
PMIX_NAMESPACE=slurm.pmix.13836604.0
PMIX_GDS_MODULE=ds21,ds12,hash

@avik-pal
Copy link
Collaborator Author

Great. v2, v3, and v5 are currently supported. I couldn't find examples for v4 so I am throwing an error in that case.

@giordano
Copy link
Member

giordano commented Mar 12, 2025

Is it expected Reactant.Distributed.initialize(; single_gpu_per_process=false) takes a lot of time? It has been running for me already for a 3 minutes, I thought it'd be almost immediate. It's hanging at

julia> Reactant.Distributed.initialize(; single_gpu_per_process=false)
2025-03-12 23:40:37.117640: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:38] gRPC insecure server credentials are used.
2025-03-12 23:40:37.117742: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:637] Initializing CoordinationService
2025-03-12 23:40:37.121176: I external/xla/xla/pjrt/distributed/service.cc:75] Coordination service is enabled.
E0312 23:40:37.122932534  775807 socket_utils_common_posix.cc:327] setsockopt(TCP_USER_TIMEOUT) Protocol not available
2025-03-12 23:40:37.123030: I external/xla/xla/pjrt/distributed/service.cc:105] Jax service listening on [::]:62965
2025-03-12 23:40:37.123134: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.

@avik-pal
Copy link
Collaborator Author

avik-pal commented Mar 12, 2025 via email

@giordano
Copy link
Member

is there any warning being displayed for the proxy variables?

No, and I had checked earlier they aren't set at all.

@avik-pal
Copy link
Collaborator Author

oof okay... let me wait till the access is confirmed and then deal with that

@avik-pal
Copy link
Collaborator Author

avik-pal commented Mar 13, 2025

Is it expected Reactant.Distributed.initialize(; single_gpu_per_process=false) takes a lot of time? It has been running for me already for a 3 minutes, I thought it'd be almost immediate. It's hanging at

julia> Reactant.Distributed.initialize(; single_gpu_per_process=false)
2025-03-12 23:40:37.117640: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:38] gRPC insecure server credentials are used.
2025-03-12 23:40:37.117742: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:637] Initializing CoordinationService
2025-03-12 23:40:37.121176: I external/xla/xla/pjrt/distributed/service.cc:75] Coordination service is enabled.
E0312 23:40:37.122932534  775807 socket_utils_common_posix.cc:327] setsockopt(TCP_USER_TIMEOUT) Protocol not available
2025-03-12 23:40:37.123030: I external/xla/xla/pjrt/distributed/service.cc:105] Jax service listening on [::]:62965
2025-03-12 23:40:37.123134: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.

Oh wait, running in REPL won't work. It is probably indefinitely waiting for the other processes to connect. Can you try running using slurm/mpirun? I will add a check for interactive session and warn accordingly

@avik-pal
Copy link
Collaborator Author

Add a check that will warn if users are calling Distributed.initialize() in an interactive mode

@giordano
Copy link
Member

Ok, thanks for the clarification. Some progress:

ERROR: LoadError: AssertionError: `PJRT.MakeCPUClient` does not support num_nodes > 1
Stacktrace:
ERROR: LoadError: AssertionError: `PJRT.MakeCPUClient` does not support num_nodes > 1
Stacktrace:
 [1] MakeCPUClient(; node_id::Int64, num_nodes::Int64, asynchronous::Bool, distributed_runtime_client::Reactant.XLA.DistributedRuntimeClient)
   @ Reactant.XLA.PJRT ~/.julia/packages/Reactant/k1KX2/src/xla/PJRT/Client.jl:127
 [1] MakeCPUClient(; node_id::Int64, num_nodes::Int64, asynchronous::Bool, distributed_runtime_client::Reactant.XLA.DistributedRuntimeClient)
   @ Reactant.XLA.PJRT ~/.julia/packages/Reactant/k1KX2/src/xla/PJRT/Client.jl:127
 [2] CPUClient(; checkcount::Bool, kwargs::@Kwargs{node_id::Int64, num_nodes::Int64, distributed_runtime_client::Reactant.XLA.DistributedRuntimeClient, asynchronous::Bool})
   @ Reactant.XLA.PJRT ~/.julia/packages/Reactant/k1KX2/src/xla/PJRT/Client.jl:111
 [3] initialize_default_clients!(state::Reactant.XLA.PJRTBackendState)
   @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:182
 [4] getproperty
   @ ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:56 [inlined]
 [5] default_backend
   @ ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:77 [inlined]
 [6] devices()
   @ Reactant ~/.julia/packages/Reactant/k1KX2/src/Devices.jl:9
 [7] top-level scope
   @ ~/tmp/test_matmul_sharded.jl:7
in expression starting at /leonardo/home/userexternal/mgiordan/tmp/test_matmul_sharded.jl:7
 [2] CPUClient(; checkcount::Bool, kwargs::@Kwargs{node_id::Int64, num_nodes::Int64, distributed_runtime_client::Reactant.XLA.DistributedRuntimeClient, asynchronous::Bool})
   @ Reactant.XLA.PJRT ~/.julia/packages/Reactant/k1KX2/src/xla/PJRT/Client.jl:111
 [3] initialize_default_clients!(state::Reactant.XLA.PJRTBackendState)
   @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:182
 [4] getproperty
   @ ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:56 [inlined]
 [5] default_backend
   @ ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:77 [inlined]
 [6] devices()
   @ Reactant ~/.julia/packages/Reactant/k1KX2/src/Devices.jl:9
 [7] top-level scope
   @ ~/tmp/test_matmul_sharded.jl:7
in expression starting at /leonardo/home/userexternal/mgiordan/tmp/test_matmul_sharded.jl:7
I0000 00:00:1741832775.794173 3178587 coordination_service_agent.cc:619] Coordination agent has initiated Shutdown().
I0000 00:00:1741832775.798749  284890 coordination_service_agent.cc:619] Coordination agent has initiated Shutdown().
2025-03-13 03:26:15.799395: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:2113] Shutdown barrier in coordination service has passed.
I0000 00:00:1741832775.799423  284890 coordination_service_agent.cc:640] Coordination agent has successfully shut down.
I0000 00:00:1741832775.799501 3178587 coordination_service_agent.cc:640] Coordination agent has successfully shut down.
I0000 00:00:1741832775.800352  284992 coordination_service_agent.cc:449] Cancelling error polling because the service or the agent is shutting down.
I0000 00:00:1741832775.801138 3178682 coordination_service_agent.cc:449] Cancelling error polling because the service or the agent is shutting down.
2025-03-13 03:26:15.801635: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:1069] /job:jax_worker/replica:0/task:0 has disconnected from coordination service.
2025-03-13 03:26:15.801657: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:1069] /job:jax_worker/replica:0/task:1 has disconnected from coordination service.
srun: error: lrdn0283: task 1: Exited with exit code 1
srun: error: lrdn0272: task 0: Exited with exit code 1

@avik-pal
Copy link
Collaborator Author

Add a Preferences file with

[Reactant]
xla_runtime = "IFRT"

@giordano
Copy link
Member

Ok, that worked, thanks. stderr:

┌ Debug: REACTANT_XLA_RUNTIME: 
│   REACTANT_XLA_RUNTIME = "IFRT"
└ @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:138
┌ Debug: REACTANT_XLA_RUNTIME: 
│   REACTANT_XLA_RUNTIME = "IFRT"
└ @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/XLA.jl:138
[ Info: initialising....
[ Info: initialising....
┌ Debug: Detected cluster environment
│   detector = Reactant.Distributed.SlurmEnvDetector()
└ @ Reactant.Distributed ~/.julia/packages/Reactant/k1KX2/src/Distributed.jl:89
┌ Debug: Detected Reactant distributed params
│   coordinator_address = "lrdn0634:63058"
│   num_processes = 2
│   process_id = 1
│   local_gpu_device_ids = nothing
└ @ Reactant.Distributed ~/.julia/packages/Reactant/k1KX2/src/Distributed.jl:27
┌ Debug: Detected cluster environment
│   detector = Reactant.Distributed.SlurmEnvDetector()
└ @ Reactant.Distributed ~/.julia/packages/Reactant/k1KX2/src/Distributed.jl:89
┌ Debug: Detected Reactant distributed params
│   coordinator_address = "lrdn0634:63058"
│   num_processes = 2
│   process_id = 0
│   local_gpu_device_ids = nothing
└ @ Reactant.Distributed ~/.julia/packages/Reactant/k1KX2/src/Distributed.jl:27
2025-03-13 03:33:57.388328: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.
┌ Debug: [PID 1] Connecting to Reactant distributed service on lrdn0634:63058
└ @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/Distributed.jl:195
┌ Debug: [PID 0] Starting Reactant distributed service on [::]:63058
└ @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/Distributed.jl:163
2025-03-13 03:33:57.505302: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:38] gRPC insecure server credentials are used.
2025-03-13 03:33:57.506365: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:637] Initializing CoordinationService
2025-03-13 03:33:57.510050: I external/xla/xla/pjrt/distributed/service.cc:75] Coordination service is enabled.
E0313 03:33:57.512232932 1652112 socket_utils_common_posix.cc:327] setsockopt(TCP_USER_TIMEOUT) Protocol not available
2025-03-13 03:33:57.512333: I external/xla/xla/pjrt/distributed/service.cc:105] Jax service listening on [::]:63058
2025-03-13 03:33:57.513377: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.
┌ Debug: [PID 0] Connecting to Reactant distributed service on lrdn0634:63058
└ @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/Distributed.jl:195
2025-03-13 03:33:58.401411: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:378] /job:jax_worker/replica:0/task:1 has connected to coordination service. Incarnation: 2398013051362712241
2025-03-13 03:33:58.401463: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:378] /job:jax_worker/replica:0/task:0 has connected to coordination service. Incarnation: 6528365760427963783
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741833238.401672 1409208 coordination_service_agent.cc:369] Coordination agent has successfully connected.
I0000 00:00:1741833238.402252 1409208 coordination_service_agent.cc:442] Polling for error from coordination service. This is a long-running RPC that will return only if an error is encountered or cancelled (e.g. due to shutdown).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741833238.401813 1652112 coordination_service_agent.cc:369] Coordination agent has successfully connected.
2025-03-13 03:33:58.402740: I external/xla/xla/pjrt/distributed/client.cc:121] Connected to distributed JAX controller
I0000 00:00:1741833238.402542 1652112 coordination_service_agent.cc:442] Polling for error from coordination service. This is a long-running RPC that will return only if an error is encountered or cancelled (e.g. due to shutdown).
┌ Debug: [PID 1] Connected to Reactant distributed service on lrdn0634:63058
└ @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/Distributed.jl:198
2025-03-13 03:33:58.402866: I external/xla/xla/pjrt/distributed/client.cc:121] Connected to distributed JAX controller
┌ Debug: [PID 0] Connected to Reactant distributed service on lrdn0634:63058
└ @ Reactant.XLA ~/.julia/packages/Reactant/k1KX2/src/xla/Distributed.jl:198
┌ Debug: New Global State
│   Reactant.XLA.global_state = Reactant.XLA.State(1, 2, nothing, nothing, Reactant.XLA.DistributedRuntimeClient(Ptr{Nothing} @0x0000000001c2ed30), "lrdn0634:63058", "[::]:63058")
└ @ Reactant.Distributed ~/.julia/packages/Reactant/k1KX2/src/Distributed.jl:33
[ Info: initialised
┌ Debug: New Global State
│   Reactant.XLA.global_state = Reactant.XLA.State(0, 2, nothing, Reactant.XLA.DistributedRuntimeService(Ptr{Nothing} @0x0000000003334710), Reactant.XLA.DistributedRuntimeClient(Ptr{Nothing} @0x0000000003116e30), "lrdn0634:63058", "[::]:63058")
└ @ Reactant.Distributed ~/.julia/packages/Reactant/k1KX2/src/Distributed.jl:33
[ Info: initialised
I0000 00:00:1741833238.512944 1652112 pjrt_client.cc:524] PjRt-IFRT device count: total=2, addressable=1
I0000 00:00:1741833238.512962 1652112 pjrt_client.cc:528] Addressable PjRt-IFRT device: CpuDevice(id=0)
I0000 00:00:1741833238.513059 1409208 pjrt_client.cc:524] PjRt-IFRT device count: total=2, addressable=1
I0000 00:00:1741833238.513785 1409208 pjrt_client.cc:528] Addressable PjRt-IFRT device: CpuDevice(id=131072)
2025-03-13 03:33:58.889873: I external/xla/xla/service/service.cc:152] XLA service 0x33292c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-13 03:33:58.890141: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-13 03:33:58.890144: I external/xla/xla/service/service.cc:160]   StreamExecutor device (1): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-13 03:33:58.890147: I external/xla/xla/service/service.cc:160]   StreamExecutor device (2): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-13 03:33:58.890148: I external/xla/xla/service/service.cc:160]   StreamExecutor device (3): NVIDIA A100-SXM-64GB, Compute Capability 8.0
I0000 00:00:1741833238.893159 1652112 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741833238.893459 1652112 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 0 for BFCAllocator.
I0000 00:00:1741833238.893733 1652112 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 1 for BFCAllocator.
I0000 00:00:1741833238.893745 1652112 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 2 for BFCAllocator.
I0000 00:00:1741833238.893754 1652112 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 3 for BFCAllocator.
I0000 00:00:1741833238.893764 1652112 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741833238.893773 1652112 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1741833238.893781 1652112 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 2 for CollectiveBFCAllocator.
I0000 00:00:1741833238.893790 1652112 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 3 for CollectiveBFCAllocator.
2025-03-13 03:33:58.896516: I external/xla/xla/service/service.cc:152] XLA service 0x3d85340 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-13 03:33:58.896777: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-13 03:33:58.896780: I external/xla/xla/service/service.cc:160]   StreamExecutor device (1): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-13 03:33:58.896782: I external/xla/xla/service/service.cc:160]   StreamExecutor device (2): NVIDIA A100-SXM-64GB, Compute Capability 8.0
2025-03-13 03:33:58.896784: I external/xla/xla/service/service.cc:160]   StreamExecutor device (3): NVIDIA A100-SXM-64GB, Compute Capability 8.0
I0000 00:00:1741833238.899774 1409208 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741833238.900059 1409208 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 0 for BFCAllocator.
I0000 00:00:1741833238.900532 1409208 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 1 for BFCAllocator.
I0000 00:00:1741833238.900544 1409208 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 2 for BFCAllocator.
I0000 00:00:1741833238.900555 1409208 gpu_helpers.cc:136] XLA backend allocating 51074678784 bytes on device 3 for BFCAllocator.
I0000 00:00:1741833238.900564 1409208 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741833238.900574 1409208 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1741833238.900583 1409208 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 2 for CollectiveBFCAllocator.
I0000 00:00:1741833238.900591 1409208 gpu_helpers.cc:177] XLA backend will use up to 17024892928 bytes on device 3 for CollectiveBFCAllocator.
I0000 00:00:1741833239.085642 1652112 cuda_dnn.cc:529] Loaded cuDNN version 90400
I0000 00:00:1741833239.088441 1409208 cuda_dnn.cc:529] Loaded cuDNN version 90400
I0000 00:00:1741833239.092438 1652112 pjrt_client.cc:524] PjRt-IFRT device count: total=8, addressable=4
I0000 00:00:1741833239.092452 1652112 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=0)
I0000 00:00:1741833239.092455 1652112 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=1)
I0000 00:00:1741833239.092456 1652112 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=2)
I0000 00:00:1741833239.092457 1652112 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=3)
I0000 00:00:1741833239.092570 1409208 pjrt_client.cc:524] PjRt-IFRT device count: total=8, addressable=4
I0000 00:00:1741833239.092580 1409208 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=4)
I0000 00:00:1741833239.092582 1409208 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=5)
I0000 00:00:1741833239.092584 1409208 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=6)
I0000 00:00:1741833239.092585 1409208 pjrt_client.cc:528] Addressable PjRt-IFRT device: CudaDevice(id=7)
[ Info: JIT compiling and running....
[ Info: JIT compiling and running....
2025-03-13 03:34:21.981289: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
2025-03-13 03:34:22.075672: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
[ Info: done
[ Info: done

stdout:

length(Reactant.devices()) = 8
length(Reactant.devices()) = 8
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
8×8 ConcreteIFRTArray{Float32,2} with sharding Reactant.Sharding.HloSharding{2, 2}:
 1380.0  3236.0  5092.0  6948.0  8804.0  10660.0  12516.0  14372.0
 1416.0  3336.0  5256.0  7176.0  9096.0  11016.0  12936.0  14856.0
 1452.0  3436.0  5420.0  7404.0  9388.0  11372.0  13356.0  15340.0
 1488.0  3536.0  5584.0  7632.0  9680.0  11728.0  13776.0  15824.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0     0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
    0.0     0.0     0.0     0.0      0.0      0.0      0.0      0.0
 1524.0  3636.0  5748.0  7860.0   9972.0  12084.0  14196.0  16308.0
 1560.0  3736.0  5912.0  8088.0  10264.0  12440.0  14616.0  16792.0
 1596.0  3836.0  6076.0  8316.0  10556.0  12796.0  15036.0  17276.0
 1632.0  3936.0  6240.0  8544.0  10848.0  13152.0  15456.0  17760.0

I presume one rank has res set to all zeros, only the main one has the full matrix?

@avik-pal
Copy link
Collaborator Author

I presume one rank has res set to all zeros, only the main one has the full matrix?

No, first one has top 4 rows and the other one has bottom 4 rows.

But this is good, it confirms that our setup works on Leonardo as well 🎉. Lets wait for the tests and then we should be good to go here.

@avik-pal avik-pal marked this pull request as ready for review March 13, 2025 02:41
@avik-pal avik-pal merged commit dcbf3f9 into main Mar 13, 2025
55 of 56 checks passed
@avik-pal avik-pal deleted the ap/slurm_runner branch March 13, 2025 03:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants