diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 724fa1673c3b3..6bda696859c8b 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -7,9 +7,8 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" RUN echo "Base image is $BASE_IMAGE" -# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" -# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - +ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \ + ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ARG FA_GFX_ARCHS="gfx90a;gfx942" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" @@ -68,7 +67,7 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ && export GPU_ARCHS=${FA_GFX_ARCHS} \ - && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ + && if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ && python3 setup.py install \ && cd ..; \ @@ -76,7 +75,7 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. # Manually removed it so that later steps of numpy upgrade can continue -RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ +RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi # build triton @@ -107,11 +106,11 @@ ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ - && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ + && if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \ + patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \ && python3 setup.py install \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \ + && export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \ + && cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \ && cd .. diff --git a/cmake/utils.cmake b/cmake/utils.cmake index f3c1286dd8498..071e16336dfa2 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -155,8 +155,11 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) # Find the intersection of the supported + detected architectures to # set the module architecture flags. # + + set(VLLM_ROCM_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") + set(${GPU_ARCHES}) - foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES}) + foreach (_ARCH ${VLLM_ROCM_SUPPORTED_ARCHS}) if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) list(APPEND ${GPU_ARCHES} ${_ARCH}) endif()