1
- # default base image
2
- ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
3
-
4
- FROM $BASE_IMAGE
5
-
6
- ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7
-
8
- RUN echo "Base image is $BASE_IMAGE"
9
-
10
- ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
11
- ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
12
-
1
+ # Default ROCm 6.1 base image
2
+ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
3
+
4
+ # Tested and supported base rocm/pytorch images
5
+ ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
6
+ ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
7
+ ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
8
+
9
+ # Default ROCm ARCHes to build vLLM for.
10
+ ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
11
+
12
+ # Whether to build CK-based flash-attention
13
+ # If 0, will not build flash attention
14
+ # This is useful for gfx target where flash-attention is not supported
15
+ # (i.e. those that do not appear in `FA_GFX_ARCHS`)
16
+ # Triton FA is used by default on ROCm now so this is unnecessary.
17
+ ARG BUILD_FA="1"
13
18
ARG FA_GFX_ARCHS="gfx90a;gfx942"
14
- RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
15
-
16
19
ARG FA_BRANCH="ae7928c"
17
- RUN echo "FA_BRANCH is $FA_BRANCH"
18
20
19
- # whether to build flash-attention
20
- # if 0, will not build flash attention
21
- # this is useful for gfx target where flash-attention is not supported
22
- # In that case, we need to use the python reference attention implementation in vllm
23
- ARG BUILD_FA="1"
24
-
25
- # whether to build triton on rocm
21
+ # Whether to build triton on rocm
26
22
ARG BUILD_TRITON="1"
23
+ ARG TRITON_BRANCH="0ef1848"
27
24
28
- # Install some basic utilities
29
- RUN apt-get update && apt-get install python3 python3-pip -y
25
+ ### Base image build stage
26
+ FROM $BASE_IMAGE AS base
27
+
28
+ # Import arg(s) defined before this build stage
29
+ ARG PYTORCH_ROCM_ARCH
30
30
31
31
# Install some basic utilities
32
+ RUN apt-get update && apt-get install python3 python3-pip -y
32
33
RUN apt-get update && apt-get install -y \
33
34
curl \
34
35
ca-certificates \
@@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \
39
40
build-essential \
40
41
wget \
41
42
unzip \
42
- nvidia-cuda-toolkit \
43
43
tmux \
44
44
ccache \
45
45
&& rm -rf /var/lib/apt/lists/*
46
46
47
- ### Mount Point ###
48
- # When launching the container, mount the code directory to /app
47
+ # When launching the container, mount the code directory to /vllm-workspace
49
48
ARG APP_MOUNT=/vllm-workspace
50
- VOLUME [ ${APP_MOUNT} ]
51
49
WORKDIR ${APP_MOUNT}
52
50
53
- RUN python3 -m pip install --upgrade pip
54
- RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
51
+ RUN pip install --upgrade pip
52
+ # Remove sccache so it doesn't interfere with ccache
53
+ # TODO: implement sccache support across components
54
+ RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
55
+ # Install torch == 2.4.0 on ROCm
56
+ RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
57
+ *"rocm-5.7"*) \
58
+ pip uninstall -y torch \
59
+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
60
+ --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
61
+ *"rocm-6.0"*) \
62
+ pip uninstall -y torch \
63
+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
64
+ --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
65
+ *"rocm-6.1"*) \
66
+ pip uninstall -y torch \
67
+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
68
+ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
69
+ *) ;; esac
55
70
56
71
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
57
72
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
58
73
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
59
74
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
60
75
61
- # Install ROCm flash-attention
62
- RUN if [ "$BUILD_FA" = "1" ]; then \
63
- mkdir libs \
76
+ ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
77
+ ENV CCACHE_DIR=/root/.cache/ccache
78
+
79
+
80
+ ### AMD-SMI build stage
81
+ FROM base AS build_amdsmi
82
+ # Build amdsmi wheel always
83
+ RUN cd /opt/rocm/share/amd_smi \
84
+ && pip wheel . --wheel-dir=/install
85
+
86
+
87
+ ### Flash-Attention wheel build stage
88
+ FROM base AS build_fa
89
+ ARG BUILD_FA
90
+ ARG FA_GFX_ARCHS
91
+ ARG FA_BRANCH
92
+ # Build ROCm flash-attention wheel if `BUILD_FA = 1`
93
+ RUN --mount=type=cache,target=${CCACHE_DIR} \
94
+ if [ "$BUILD_FA" = "1" ]; then \
95
+ mkdir -p libs \
64
96
&& cd libs \
65
97
&& git clone https://github.com/ROCm/flash-attention.git \
66
98
&& cd flash-attention \
67
- && git checkout ${FA_BRANCH} \
99
+ && git checkout " ${FA_BRANCH}" \
68
100
&& git submodule update --init \
69
- && export GPU_ARCHS=${FA_GFX_ARCHS} \
70
- && if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \
71
- patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
72
- && python3 setup.py install \
73
- && cd ..; \
101
+ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
102
+ *"rocm-5.7"*) \
103
+ export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
104
+ && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
105
+ *) ;; esac \
106
+ && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
107
+ # Create an empty directory otherwise as later build stages expect one
108
+ else mkdir -p /install; \
74
109
fi
75
110
76
- # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
77
- # Manually removed it so that later steps of numpy upgrade can continue
78
- RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
79
- rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
80
111
81
- # build triton
82
- RUN if [ "$BUILD_TRITON" = "1" ]; then \
112
+ ### Triton wheel build stage
113
+ FROM base AS build_triton
114
+ ARG BUILD_TRITON
115
+ ARG TRITON_BRANCH
116
+ # Build triton wheel if `BUILD_TRITON = 1`
117
+ RUN --mount=type=cache,target=${CCACHE_DIR} \
118
+ if [ "$BUILD_TRITON" = "1" ]; then \
83
119
mkdir -p libs \
84
120
&& cd libs \
85
- && pip uninstall -y triton \
86
- && git clone https://github.com/ROCm/triton.git \
87
- && cd triton/python \
88
- && pip3 install . \
89
- && cd ../..; \
121
+ && git clone https://github.com/OpenAI/triton.git \
122
+ && cd triton \
123
+ && git checkout "${TRITON_BRANCH}" \
124
+ && cd python \
125
+ && python3 setup.py bdist_wheel --dist-dir=/install; \
126
+ # Create an empty directory otherwise as later build stages expect one
127
+ else mkdir -p /install; \
90
128
fi
91
129
92
- WORKDIR /vllm-workspace
130
+
131
+ ### Final vLLM build stage
132
+ FROM base AS final
133
+ # Import the vLLM development directory from the build context
93
134
COPY . .
94
135
95
- #RUN python3 -m pip install pynvml # to be removed eventually
96
- RUN python3 -m pip install --upgrade pip numba
136
+ # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
137
+ # Manually remove it so that later steps of numpy upgrade can continue
138
+ RUN case "$(which python3)" in \
139
+ *"/opt/conda/envs/py_3.9"*) \
140
+ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
141
+ *) ;; esac
142
+
143
+ # Package upgrades for useful functionality or to avoid dependency issues
144
+ RUN --mount=type=cache,target=/root/.cache/pip \
145
+ pip install --upgrade numba scipy huggingface-hub[cli]
97
146
98
- # make sure punica kernels are built (for LoRA)
147
+ # Make sure punica kernels are built (for LoRA)
99
148
ENV VLLM_INSTALL_PUNICA_KERNELS=1
100
149
# Workaround for ray >= 2.10.0
101
150
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
151
+ # Silences the HF Tokenizers warning
152
+ ENV TOKENIZERS_PARALLELISM=false
102
153
103
- ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
104
-
105
- ENV CCACHE_DIR=/root/.cache/ccache
106
- RUN --mount=type=cache,target=/root/.cache/ccache \
154
+ RUN --mount=type=cache,target=${CCACHE_DIR} \
107
155
--mount=type=cache,target=/root/.cache/pip \
108
156
pip install -U -r requirements-rocm.txt \
109
- && if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
110
- patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
111
- && python3 setup.py install \
112
- && export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
113
- && cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
114
- && cd ..
157
+ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
158
+ *"rocm-6.0"*) \
159
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
160
+ *"rocm-6.1"*) \
161
+ # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
162
+ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
163
+ && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
164
+ # Prevent interference if torch bundles its own HIP runtime
165
+ && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
166
+ *) ;; esac \
167
+ && python3 setup.py clean --all \
168
+ && python3 setup.py develop
169
+
170
+ # Copy amdsmi wheel into final image
171
+ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
172
+ mkdir -p libs \
173
+ && cp /install/*.whl libs \
174
+ # Preemptively uninstall to avoid same-version no-installs
175
+ && pip uninstall -y amdsmi;
115
176
177
+ # Copy triton wheel(s) into final image if they were built
178
+ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
179
+ mkdir -p libs \
180
+ && if ls /install/*.whl; then \
181
+ cp /install/*.whl libs \
182
+ # Preemptively uninstall to avoid same-version no-installs
183
+ && pip uninstall -y triton; fi
184
+
185
+ # Copy flash-attn wheel(s) into final image if they were built
186
+ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
187
+ mkdir -p libs \
188
+ && if ls /install/*.whl; then \
189
+ cp /install/*.whl libs \
190
+ # Preemptively uninstall to avoid same-version no-installs
191
+ && pip uninstall -y flash-attn; fi
192
+
193
+ # Install wheels that were built to the final image
194
+ RUN --mount=type=cache,target=/root/.cache/pip \
195
+ if ls libs/*.whl; then \
196
+ pip install libs/*.whl; fi
116
197
117
198
CMD ["/bin/bash"]
0 commit comments