Skip to content

Commit cb40d2d

Browse files
mawong-amdprashantgupta24
authored andcommitted
[Hardware][AMD][CI/Build][Doc] Upgrade to ROCm 6.1, Dockerfile improvements, test fixes (vllm-project#5422)
1 parent 03eacd4 commit cb40d2d

File tree

15 files changed

+259
-122
lines changed

15 files changed

+259
-122
lines changed

CMakeLists.txt

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
3232
# versions are derived from Dockerfile.rocm
3333
#
3434
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
35-
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
36-
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
35+
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
3736

3837
#
3938
# Try to find python package with an executable that exactly matches
@@ -98,18 +97,11 @@ elseif(HIP_FOUND)
9897
# .hip extension automatically, HIP must be enabled explicitly.
9998
enable_language(HIP)
10099

101-
# ROCm 5.x
102-
if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
103-
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
104-
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
105-
"expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
106-
endif()
107-
108-
# ROCm 6.x
109-
if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
110-
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
111-
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
112-
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
100+
# ROCm 5.X and 6.X
101+
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
102+
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
103+
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
104+
"expected for ROCm build, saw ${Torch_VERSION} instead.")
113105
endif()
114106
else()
115107
message(FATAL_ERROR "Can't find CUDA or HIP installation.")

Dockerfile.rocm

Lines changed: 145 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
1-
# default base image
2-
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
3-
4-
FROM $BASE_IMAGE
5-
6-
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7-
8-
RUN echo "Base image is $BASE_IMAGE"
9-
10-
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
11-
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
12-
1+
# Default ROCm 6.1 base image
2+
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
3+
4+
# Tested and supported base rocm/pytorch images
5+
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
6+
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
7+
ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
8+
9+
# Default ROCm ARCHes to build vLLM for.
10+
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
11+
12+
# Whether to build CK-based flash-attention
13+
# If 0, will not build flash attention
14+
# This is useful for gfx target where flash-attention is not supported
15+
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
16+
# Triton FA is used by default on ROCm now so this is unnecessary.
17+
ARG BUILD_FA="1"
1318
ARG FA_GFX_ARCHS="gfx90a;gfx942"
14-
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
15-
1619
ARG FA_BRANCH="ae7928c"
17-
RUN echo "FA_BRANCH is $FA_BRANCH"
1820

19-
# whether to build flash-attention
20-
# if 0, will not build flash attention
21-
# this is useful for gfx target where flash-attention is not supported
22-
# In that case, we need to use the python reference attention implementation in vllm
23-
ARG BUILD_FA="1"
24-
25-
# whether to build triton on rocm
21+
# Whether to build triton on rocm
2622
ARG BUILD_TRITON="1"
23+
ARG TRITON_BRANCH="0ef1848"
2724

28-
# Install some basic utilities
29-
RUN apt-get update && apt-get install python3 python3-pip -y
25+
### Base image build stage
26+
FROM $BASE_IMAGE AS base
27+
28+
# Import arg(s) defined before this build stage
29+
ARG PYTORCH_ROCM_ARCH
3030

3131
# Install some basic utilities
32+
RUN apt-get update && apt-get install python3 python3-pip -y
3233
RUN apt-get update && apt-get install -y \
3334
curl \
3435
ca-certificates \
@@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \
3940
build-essential \
4041
wget \
4142
unzip \
42-
nvidia-cuda-toolkit \
4343
tmux \
4444
ccache \
4545
&& rm -rf /var/lib/apt/lists/*
4646

47-
### Mount Point ###
48-
# When launching the container, mount the code directory to /app
47+
# When launching the container, mount the code directory to /vllm-workspace
4948
ARG APP_MOUNT=/vllm-workspace
50-
VOLUME [ ${APP_MOUNT} ]
5149
WORKDIR ${APP_MOUNT}
5250

53-
RUN python3 -m pip install --upgrade pip
54-
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
51+
RUN pip install --upgrade pip
52+
# Remove sccache so it doesn't interfere with ccache
53+
# TODO: implement sccache support across components
54+
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
55+
# Install torch == 2.4.0 on ROCm
56+
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
57+
*"rocm-5.7"*) \
58+
pip uninstall -y torch \
59+
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
60+
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
61+
*"rocm-6.0"*) \
62+
pip uninstall -y torch \
63+
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
64+
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
65+
*"rocm-6.1"*) \
66+
pip uninstall -y torch \
67+
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
68+
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
69+
*) ;; esac
5570

5671
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
5772
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
5873
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
5974
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
6075

61-
# Install ROCm flash-attention
62-
RUN if [ "$BUILD_FA" = "1" ]; then \
63-
mkdir libs \
76+
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
77+
ENV CCACHE_DIR=/root/.cache/ccache
78+
79+
80+
### AMD-SMI build stage
81+
FROM base AS build_amdsmi
82+
# Build amdsmi wheel always
83+
RUN cd /opt/rocm/share/amd_smi \
84+
&& pip wheel . --wheel-dir=/install
85+
86+
87+
### Flash-Attention wheel build stage
88+
FROM base AS build_fa
89+
ARG BUILD_FA
90+
ARG FA_GFX_ARCHS
91+
ARG FA_BRANCH
92+
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
93+
RUN --mount=type=cache,target=${CCACHE_DIR} \
94+
if [ "$BUILD_FA" = "1" ]; then \
95+
mkdir -p libs \
6496
&& cd libs \
6597
&& git clone https://github.com/ROCm/flash-attention.git \
6698
&& cd flash-attention \
67-
&& git checkout ${FA_BRANCH} \
99+
&& git checkout "${FA_BRANCH}" \
68100
&& git submodule update --init \
69-
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
70-
&& if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \
71-
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
72-
&& python3 setup.py install \
73-
&& cd ..; \
101+
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
102+
*"rocm-5.7"*) \
103+
export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
104+
&& patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
105+
*) ;; esac \
106+
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
107+
# Create an empty directory otherwise as later build stages expect one
108+
else mkdir -p /install; \
74109
fi
75110

76-
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
77-
# Manually removed it so that later steps of numpy upgrade can continue
78-
RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
79-
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
80111

81-
# build triton
82-
RUN if [ "$BUILD_TRITON" = "1" ]; then \
112+
### Triton wheel build stage
113+
FROM base AS build_triton
114+
ARG BUILD_TRITON
115+
ARG TRITON_BRANCH
116+
# Build triton wheel if `BUILD_TRITON = 1`
117+
RUN --mount=type=cache,target=${CCACHE_DIR} \
118+
if [ "$BUILD_TRITON" = "1" ]; then \
83119
mkdir -p libs \
84120
&& cd libs \
85-
&& pip uninstall -y triton \
86-
&& git clone https://github.com/ROCm/triton.git \
87-
&& cd triton/python \
88-
&& pip3 install . \
89-
&& cd ../..; \
121+
&& git clone https://github.com/OpenAI/triton.git \
122+
&& cd triton \
123+
&& git checkout "${TRITON_BRANCH}" \
124+
&& cd python \
125+
&& python3 setup.py bdist_wheel --dist-dir=/install; \
126+
# Create an empty directory otherwise as later build stages expect one
127+
else mkdir -p /install; \
90128
fi
91129

92-
WORKDIR /vllm-workspace
130+
131+
### Final vLLM build stage
132+
FROM base AS final
133+
# Import the vLLM development directory from the build context
93134
COPY . .
94135

95-
#RUN python3 -m pip install pynvml # to be removed eventually
96-
RUN python3 -m pip install --upgrade pip numba
136+
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
137+
# Manually remove it so that later steps of numpy upgrade can continue
138+
RUN case "$(which python3)" in \
139+
*"/opt/conda/envs/py_3.9"*) \
140+
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
141+
*) ;; esac
142+
143+
# Package upgrades for useful functionality or to avoid dependency issues
144+
RUN --mount=type=cache,target=/root/.cache/pip \
145+
pip install --upgrade numba scipy huggingface-hub[cli]
97146

98-
# make sure punica kernels are built (for LoRA)
147+
# Make sure punica kernels are built (for LoRA)
99148
ENV VLLM_INSTALL_PUNICA_KERNELS=1
100149
# Workaround for ray >= 2.10.0
101150
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
151+
# Silences the HF Tokenizers warning
152+
ENV TOKENIZERS_PARALLELISM=false
102153

103-
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
104-
105-
ENV CCACHE_DIR=/root/.cache/ccache
106-
RUN --mount=type=cache,target=/root/.cache/ccache \
154+
RUN --mount=type=cache,target=${CCACHE_DIR} \
107155
--mount=type=cache,target=/root/.cache/pip \
108156
pip install -U -r requirements-rocm.txt \
109-
&& if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
110-
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
111-
&& python3 setup.py install \
112-
&& export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
113-
&& cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
114-
&& cd ..
157+
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
158+
*"rocm-6.0"*) \
159+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
160+
*"rocm-6.1"*) \
161+
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
162+
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
163+
&& cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
164+
# Prevent interference if torch bundles its own HIP runtime
165+
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
166+
*) ;; esac \
167+
&& python3 setup.py clean --all \
168+
&& python3 setup.py develop
169+
170+
# Copy amdsmi wheel into final image
171+
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
172+
mkdir -p libs \
173+
&& cp /install/*.whl libs \
174+
# Preemptively uninstall to avoid same-version no-installs
175+
&& pip uninstall -y amdsmi;
115176

177+
# Copy triton wheel(s) into final image if they were built
178+
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
179+
mkdir -p libs \
180+
&& if ls /install/*.whl; then \
181+
cp /install/*.whl libs \
182+
# Preemptively uninstall to avoid same-version no-installs
183+
&& pip uninstall -y triton; fi
184+
185+
# Copy flash-attn wheel(s) into final image if they were built
186+
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
187+
mkdir -p libs \
188+
&& if ls /install/*.whl; then \
189+
cp /install/*.whl libs \
190+
# Preemptively uninstall to avoid same-version no-installs
191+
&& pip uninstall -y flash-attn; fi
192+
193+
# Install wheels that were built to the final image
194+
RUN --mount=type=cache,target=/root/.cache/pip \
195+
if ls libs/*.whl; then \
196+
pip install libs/*.whl; fi
116197

117198
CMD ["/bin/bash"]

cmake/utils.cmake

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,27 +147,31 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
147147
if (${GPU_LANG} STREQUAL "HIP")
148148
#
149149
# `GPU_ARCHES` controls the `--offload-arch` flags.
150-
# `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
151-
# via the `PYTORCH_ROCM_ARCH` env variable.
152150
#
153-
151+
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
152+
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
153+
# "rocm_agent_enumerator" in "enable_language(HIP)"
154+
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
155+
#
156+
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
157+
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
158+
else()
159+
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
160+
endif()
154161
#
155162
# Find the intersection of the supported + detected architectures to
156163
# set the module architecture flags.
157164
#
158-
159-
set(VLLM_ROCM_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
160-
161165
set(${GPU_ARCHES})
162-
foreach (_ARCH ${VLLM_ROCM_SUPPORTED_ARCHS})
166+
foreach (_ARCH ${HIP_ARCHITECTURES})
163167
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
164168
list(APPEND ${GPU_ARCHES} ${_ARCH})
165169
endif()
166170
endforeach()
167171

168172
if(NOT ${GPU_ARCHES})
169173
message(FATAL_ERROR
170-
"None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
174+
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
171175
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
172176
endif()
173177

docs/source/getting_started/amd-installation.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Option 2: Build from source
8888
- `Pytorch <https://pytorch.org/>`_
8989
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
9090

91-
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
91+
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
9292

9393
Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started <https://pytorch.org/get-started/locally/>`_
9494

@@ -126,12 +126,12 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
126126
127127
$ cd vllm
128128
$ pip install -U -r requirements-rocm.txt
129-
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
129+
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
130130
131131
132132
.. tip::
133133

134134
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
135135
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
136-
- To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention.
136+
- To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
137137
- The ROCm version of pytorch, ideally, should match the ROCm driver version.

tests/async_engine/test_openapi_server_ray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
# and debugging.
55
import ray
66

7-
from ..utils import VLLM_PATH, RemoteOpenAIServer
7+
from ..utils import RemoteOpenAIServer
88

99
# any model with a chat template should work here
1010
MODEL_NAME = "facebook/opt-125m"
1111

1212

1313
@pytest.fixture(scope="module")
1414
def ray_ctx():
15-
ray.init(runtime_env={"working_dir": VLLM_PATH})
15+
ray.init()
1616
yield
1717
ray.shutdown()
1818

0 commit comments

Comments
 (0)