@@ -8,23 +8,42 @@ install_cuda_aarch64() {
88 # CUDA_MAJOR_VERSION: cu128 --> 12
99 CUDA_MAJOR_VERSION=${CU_VERSION: 2: 2}
1010 dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
11- # nccl version must match libtorch_cuda.so was built with https://github.com/pytorch/pytorch/blob/main/.ci/docker/ci_commit_pins/nccl-cu12.txt
12- dnf -y install cuda-compiler-${CU_VER} .aarch64 \
11+
12+ # nccl version must match libtorch_cuda.so was built with
13+ if [[ ${CU_VERSION: 0: 4} == " cu12" ]]; then
14+ # cu12: https://github.com/pytorch/pytorch/blob/main/.ci/docker/ci_commit_pins/nccl-cu12.txt
15+ if [[ ${CU_VERSION} == " cu128" ]]; then
16+ nccl_version=" 2.26.2-1"
17+ elif [[ ${CU_VERSION} == " cu126" ]]; then
18+ nccl_version=" 2.24.3-1"
19+ else
20+ # removed cu129 support from pytorch upstream
21+ echo " Unsupported CUDA version: ${CU_VERSION} "
22+ exit 1
23+ fi
24+ elif [[ ${CU_VERSION: 0: 4} == " cu13" ]]; then
25+ # cu13: https://github.com/pytorch/pytorch/blob/main/.ci/docker/ci_commit_pins/nccl-cu13.txt
26+ nccl_version=" 2.27.7-1"
27+ fi
28+
29+ dnf --nogpgcheck -y install cuda-compiler-${CU_VER} .aarch64 \
1330 cuda-libraries-${CU_VER} .aarch64 \
1431 cuda-libraries-devel-${CU_VER} .aarch64 \
15- libnccl-2.27.3-1 +cuda${CU_DOT_VER} libnccl-devel-2.27.3-1 +cuda${CU_DOT_VER} libnccl-static-2.27.3-1 +cuda${CU_DOT_VER}
32+ libnccl-${nccl_version} +cuda${CU_DOT_VER} libnccl-devel-${nccl_version} +cuda${CU_DOT_VER} libnccl-static-${nccl_version} +cuda${CU_DOT_VER}
1633 dnf clean all
17-
18- nvshmem_version=3.3.9
34+ # nvshmem version is from https://github.com/pytorch/pytorch/blob/f9fa138a3910bd1de1e7acb95265fa040672a952/.ci/docker/common/install_cuda.sh#L67
35+ nvshmem_version=3.3.24
1936 nvshmem_path=" https://developer.download.nvidia.com/compute/redist/nvshmem/${nvshmem_version} /builds/cuda${CUDA_MAJOR_VERSION} /txz/agnostic/aarch64"
20- nvshmem_filename=" libnvshmem_cuda12-linux-sbsa-${nvshmem_version} .tar.gz"
21- curl -L ${nvshmem_path} /${nvshmem_filename} -o nvshmem.tar.gz
22- tar -xzf nvshmem.tar.gz
23- cp -a libnvshmem/lib/* /usr/local/cuda/lib64/
24- cp -a libnvshmem/include/* /usr/local/cuda/include/
25- rm -rf nvshmem.tar.gz nvshmem
37+ nvshmem_prefix=" libnvshmem-linux-sbsa-${nvshmem_version} _cuda${CUDA_MAJOR_VERSION} -archive"
38+ nvshmem_tarname=" ${nvshmem_prefix} .tar.xz"
39+ curl -L ${nvshmem_path} /${nvshmem_tarname} -o nvshmem.tar.xz
40+ tar -xJf nvshmem.tar.xz
41+ cp -a ${nvshmem_prefix} /lib/* /usr/local/cuda/lib64/
42+ cp -a ${nvshmem_prefix} /include/* /usr/local/cuda/include/
43+ rm -rf nvshmem.tar.xz ${nvshmem_prefix}
2644 echo " nvshmem ${nvshmem_version} for cuda ${CUDA_MAJOR_VERSION} installed successfully"
2745
46+ export PATH=/usr/local/cuda/bin:$PATH
2847 export LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/include:/usr/lib64:$LD_LIBRARY_PATH
2948 ls -lart /usr/local/
3049 nvcc --version
0 commit comments