Skip to content

Commit 3b1d40f

Browse files
mawong-amdjimpang
authored andcommitted
[Bugfix][CI/Build][Hardware][AMD] Fix AMD tests, add HF cache, update CK FA, add partially supported model notes (vllm-project#6543)
1 parent e976e03 commit 3b1d40f

File tree

12 files changed

+116
-39
lines changed

12 files changed

+116
-39
lines changed

.buildkite/run-amd-test.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,18 @@ trap remove_docker_container EXIT
6666

6767
echo "--- Running container"
6868

69+
HF_CACHE="$(realpath ~)/huggingface"
70+
mkdir -p ${HF_CACHE}
71+
HF_MOUNT="/root/.cache/huggingface"
72+
6973
docker run \
7074
--device /dev/kfd --device /dev/dri \
7175
--network host \
76+
--shm-size=16gb \
7277
--rm \
7378
-e HF_TOKEN \
79+
-v ${HF_CACHE}:${HF_MOUNT} \
80+
-e HF_HOME=${HF_MOUNT} \
7481
--name ${container_name} \
7582
${image_name} \
7683
/bin/bash -c "${@}"

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ steps:
4444
mirror_hardwares: [amd]
4545
fast_check: true
4646
commands:
47-
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
47+
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
48+
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
4849
- pytest -v -s basic_correctness/test_basic_correctness.py
4950
- pytest -v -s basic_correctness/test_cpu_offload.py
5051
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
3333
# versions are derived from Dockerfile.rocm
3434
#
3535
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
36-
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
36+
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
3737

3838
#
3939
# Try to find python package with an executable that exactly matches
@@ -101,7 +101,7 @@ elseif(HIP_FOUND)
101101
# ROCm 5.X and 6.X
102102
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
103103
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
104-
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
104+
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
105105
"expected for ROCm build, saw ${Torch_VERSION} instead.")
106106
endif()
107107
else()

Dockerfile.rocm

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
44
# Default ROCm ARCHes to build vLLM for.
55
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
66

7-
# Whether to build CK-based flash-attention
8-
# If 0, will not build flash attention
9-
# This is useful for gfx target where flash-attention is not supported
10-
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
11-
# Triton FA is used by default on ROCm now so this is unnecessary.
7+
# Whether to install CK-based flash-attention
8+
# If 0, will not install flash-attention
129
ARG BUILD_FA="1"
10+
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
11+
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
12+
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
13+
# architectures specified in `FA_GFX_ARCHS`
14+
ARG TRY_FA_WHEEL="1"
15+
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
1316
ARG FA_GFX_ARCHS="gfx90a;gfx942"
14-
ARG FA_BRANCH="ae7928c"
17+
ARG FA_BRANCH="23a2b1c2"
1518

1619
# Whether to build triton on rocm
1720
ARG BUILD_TRITON="1"
18-
ARG TRITON_BRANCH="0ef1848"
21+
ARG TRITON_BRANCH="e0fc12c"
1922

2023
### Base image build stage
2124
FROM $BASE_IMAGE AS base
@@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \
4346
ARG APP_MOUNT=/vllm-workspace
4447
WORKDIR ${APP_MOUNT}
4548

46-
RUN pip install --upgrade pip
49+
RUN python3 -m pip install --upgrade pip
4750
# Remove sccache so it doesn't interfere with ccache
4851
# TODO: implement sccache support across components
49-
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
52+
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
5053
# Install torch == 2.5.0 on ROCm
5154
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
5255
*"rocm-6.1"*) \
53-
pip uninstall -y torch torchaudio torchvision \
54-
&& pip install --no-cache-dir --pre \
56+
python3 -m pip uninstall -y torch torchaudio torchvision \
57+
&& python3 -m pip install --no-cache-dir --pre \
5558
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
5659
torchvision==0.20.0.dev20240710 \
5760
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
@@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
7073
FROM base AS build_amdsmi
7174
# Build amdsmi wheel always
7275
RUN cd /opt/rocm/share/amd_smi \
73-
&& pip wheel . --wheel-dir=/install
76+
&& python3 -m pip wheel . --wheel-dir=/install
7477

7578

7679
### Flash-Attention wheel build stage
7780
FROM base AS build_fa
7881
ARG BUILD_FA
82+
ARG TRY_FA_WHEEL
83+
ARG FA_WHEEL_URL
7984
ARG FA_GFX_ARCHS
8085
ARG FA_BRANCH
8186
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
8287
RUN --mount=type=cache,target=${CCACHE_DIR} \
8388
if [ "$BUILD_FA" = "1" ]; then \
84-
mkdir -p libs \
85-
&& cd libs \
86-
&& git clone https://github.com/ROCm/flash-attention.git \
87-
&& cd flash-attention \
88-
&& git checkout "${FA_BRANCH}" \
89-
&& git submodule update --init \
90-
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
89+
if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
90+
# If a suitable wheel exists, we download it instead of building FA
91+
mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
92+
else \
93+
mkdir -p libs \
94+
&& cd libs \
95+
&& git clone https://github.com/ROCm/flash-attention.git \
96+
&& cd flash-attention \
97+
&& git checkout "${FA_BRANCH}" \
98+
&& git submodule update --init \
99+
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
100+
fi; \
91101
# Create an empty directory otherwise as later build stages expect one
92102
else mkdir -p /install; \
93103
fi
@@ -126,7 +136,7 @@ RUN case "$(which python3)" in \
126136

127137
# Package upgrades for useful functionality or to avoid dependency issues
128138
RUN --mount=type=cache,target=/root/.cache/pip \
129-
pip install --upgrade numba scipy huggingface-hub[cli]
139+
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
130140

131141
# Make sure punica kernels are built (for LoRA)
132142
ENV VLLM_INSTALL_PUNICA_KERNELS=1
@@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false
137147

138148
RUN --mount=type=cache,target=${CCACHE_DIR} \
139149
--mount=type=cache,target=/root/.cache/pip \
140-
pip install -U -r requirements-rocm.txt \
150+
python3 -m pip install -Ur requirements-rocm.txt \
141151
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
142152
*"rocm-6.1"*) \
143153
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
@@ -153,27 +163,27 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
153163
mkdir -p libs \
154164
&& cp /install/*.whl libs \
155165
# Preemptively uninstall to avoid same-version no-installs
156-
&& pip uninstall -y amdsmi;
166+
&& python3 -m pip uninstall -y amdsmi;
157167

158168
# Copy triton wheel(s) into final image if they were built
159169
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
160170
mkdir -p libs \
161171
&& if ls /install/*.whl; then \
162172
cp /install/*.whl libs \
163173
# Preemptively uninstall to avoid same-version no-installs
164-
&& pip uninstall -y triton; fi
174+
&& python3 -m pip uninstall -y triton; fi
165175

166176
# Copy flash-attn wheel(s) into final image if they were built
167177
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
168178
mkdir -p libs \
169179
&& if ls /install/*.whl; then \
170180
cp /install/*.whl libs \
171181
# Preemptively uninstall to avoid same-version no-installs
172-
&& pip uninstall -y flash-attn; fi
182+
&& python3 -m pip uninstall -y flash-attn; fi
173183

174184
# Install wheels that were built to the final image
175185
RUN --mount=type=cache,target=/root/.cache/pip \
176186
if ls libs/*.whl; then \
177-
pip install libs/*.whl; fi
187+
python3 -m pip install libs/*.whl; fi
178188

179189
CMD ["/bin/bash"]

docs/source/getting_started/amd-installation.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor
9090

9191
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_
9292

93-
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm>`_
93+
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_
9494

95-
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
95+
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
96+
Alternatively, wheels intended for vLLM use can be accessed under the releases.
9697

9798
.. note::
98-
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
9999
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
100100

101101
3. Build vLLM.
@@ -110,5 +110,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
110110
.. tip::
111111

112112
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
113+
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
113114
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
114115
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.

requirements-rocm.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@
22
-r requirements-common.txt
33

44
# Dependencies for AMD GPUs
5+
awscli
6+
boto3
7+
botocore
58
ray >= 2.10.0
9+
peft
610
pytest-asyncio
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from vllm.utils import is_hip
2+
13
from ..utils import compare_two_settings
24

35

46
def test_cpu_offload():
57
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
68
["--cpu-offload-gb", "4"])
7-
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
8-
[], ["--cpu-offload-gb", "1"])
9+
if not is_hip():
10+
# compressed-tensors quantization is currently not supported in ROCm.
11+
compare_two_settings(
12+
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
13+
["--cpu-offload-gb", "1"])

tests/models/test_paligemma.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import os
12
from typing import List, Optional, Tuple, Type
23

34
import pytest
45
from transformers import AutoTokenizer
56

67
from vllm.multimodal.utils import rescale_image_size
78
from vllm.sequence import SampleLogprobs
9+
from vllm.utils import is_hip
810

911
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
1012
from .utils import check_logprobs_close
@@ -22,6 +24,12 @@
2224

2325
models = ["google/paligemma-3b-mix-224"]
2426

27+
# ROCm Triton FA can run into compilation issues with these models due to,
28+
# excessive use of shared memory. Use other backends in the meantime.
29+
# FIXME (mattwong, gshtrasb, hongxiayan)
30+
if is_hip():
31+
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
32+
2533

2634
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
2735
Optional[SampleLogprobs]],
@@ -130,7 +138,15 @@ def run_test(
130138
[0.25, 0.5, 1.0],
131139
],
132140
)
133-
@pytest.mark.parametrize("dtype", ["float", "half"])
141+
@pytest.mark.parametrize("dtype", [
142+
pytest.param(
143+
"float",
144+
marks=pytest.mark.skipif(
145+
is_hip(),
146+
reason=
147+
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
148+
), "half"
149+
])
134150
@pytest.mark.parametrize("max_tokens", [128])
135151
@pytest.mark.parametrize("num_logprobs", [5])
136152
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,

tests/models/test_phi3v.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import re
23
from typing import List, Optional, Tuple, Type
34

@@ -6,7 +7,7 @@
67

78
from vllm.multimodal.utils import rescale_image_size
89
from vllm.sequence import SampleLogprobs
9-
from vllm.utils import is_cpu
10+
from vllm.utils import is_cpu, is_hip
1011

1112
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
1213
from .utils import check_logprobs_close
@@ -47,6 +48,12 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
4748
if is_cpu():
4849
target_dtype = "bfloat16"
4950

51+
# ROCm Triton FA can run into shared memory issues with these models,
52+
# use other backends in the meantime
53+
# FIXME (mattwong, gshtrasb, hongxiayan)
54+
if is_hip():
55+
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
56+
5057

5158
def run_test(
5259
hf_runner: Type[HfRunner],

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,12 @@ def __init__(
275275
triton_attention)
276276
self.attn_func = triton_attention
277277
logger.debug("Using Triton FA in ROCmBackend")
278+
if self.sliding_window != (-1, -1):
279+
logger.warning("ROCm Triton FA does not currently support "
280+
"sliding window attention. If using half "
281+
"precision, please try using the ROCm CK "
282+
"FA backend instead by setting the env var "
283+
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
278284
else:
279285
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
280286
# either
@@ -434,6 +440,8 @@ def forward(
434440
max_seqlen_k=prefill_meta.max_prefill_seq_len,
435441
softmax_scale=self.scale,
436442
causal=True,
443+
window_size=self.sliding_window,
444+
alibi_slopes=self.alibi_slopes,
437445
)
438446

439447
# common code for prefill

0 commit comments

Comments
 (0)