@@ -7,9 +7,8 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7
7
8
8
RUN echo "Base image is $BASE_IMAGE"
9
9
10
- # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
11
- # BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
12
-
10
+ ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
11
+ ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
13
12
14
13
ARG FA_GFX_ARCHS="gfx90a;gfx942"
15
14
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
@@ -68,15 +67,15 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
68
67
&& git checkout ${FA_BRANCH} \
69
68
&& git submodule update --init \
70
69
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
71
- && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 " ]; then \
70
+ && if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE " ]; then \
72
71
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
73
72
&& python3 setup.py install \
74
73
&& cd ..; \
75
74
fi
76
75
77
76
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
78
77
# Manually removed it so that later steps of numpy upgrade can continue
79
- RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 " ]; then \
78
+ RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE " ]; then \
80
79
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
81
80
82
81
# build triton
@@ -107,11 +106,11 @@ ENV CCACHE_DIR=/root/.cache/ccache
107
106
RUN --mount=type=cache,target=/root/.cache/ccache \
108
107
--mount=type=cache,target=/root/.cache/pip \
109
108
pip install -U -r requirements-rocm.txt \
110
- && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
109
+ && if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
110
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
111
111
&& python3 setup.py install \
112
- && cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
113
- && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
114
- && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
112
+ && export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
113
+ && cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
115
114
&& cd ..
116
115
117
116
0 commit comments