Skip to content

Commit dafb3d2

Browse files
hongxiayangjimpang
authored andcommitted
[ROCm] support Radeon™ 7900 series (gfx1100) without using flash-attention (vllm-project#2768)
1 parent 22f88f2 commit dafb3d2

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

Dockerfile.rocm

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
1717
ARG FA_BRANCH="3d2b6f5"
1818
RUN echo "FA_BRANCH is $FA_BRANCH"
1919

20+
# whether to build flash-attention
21+
# if 0, will not build flash attention
22+
# this is useful for gfx target where flash-attention is not supported
23+
# In that case, we need to use the python reference attention implementation in vllm
24+
ARG BUILD_FA="1"
25+
2026
# Install some basic utilities
2127
RUN apt-get update && apt-get install python3 python3-pip -y
2228

@@ -50,7 +56,8 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
5056
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
5157

5258
# Install ROCm flash-attention
53-
RUN mkdir libs \
59+
RUN if [ "$BUILD_FA" == "1" ]; then \
60+
mkdir libs \
5461
&& cd libs \
5562
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
5663
&& cd flash-attention \
@@ -60,7 +67,8 @@ RUN mkdir libs \
6067
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
6168
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
6269
&& python3 setup.py install \
63-
&& cd ..
70+
&& cd ..; \
71+
fi
6472

6573
COPY ./ /app/vllm
6674

@@ -75,7 +83,8 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7583
RUN cd /app \
7684
&& cd vllm \
7785
&& pip install -U -r requirements-rocm.txt \
78-
&& bash patch_xformers.rocm.sh \
86+
&& if [ "$BUILD_FA" == "1" ]; then \
87+
bash patch_xformers.rocm.sh; fi \
7988
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
8089
&& python3 setup.py install \
8190
&& cd ..

docs/source/getting_started/amd-installation.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Requirements
1212

1313
* OS: Linux
1414
* Python: 3.8 -- 3.11
15-
* GPU: MI200s (gfx90a), MI300 (gfx942)
15+
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
1616
* Pytorch 2.0.1/2.1.1/2.2
1717
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
1818

@@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later
105105
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
106106
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
107107
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
108+
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
108109

109110
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
110111

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
# Supported NVIDIA GPU architectures.
2626
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
27-
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
27+
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"}
2828
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
2929

3030

vllm/model_executor/layers/attention.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Multi-head attention."""
22
from typing import List, Optional
33

4+
import importlib
45
import torch
56
import torch.nn as nn
67
from xformers import ops as xops
@@ -58,6 +59,40 @@ def __init__(
5859
raise ValueError(f"head_size ({self.head_size}) is not supported. "
5960
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
6061

62+
self.use_ref_attention = self.check_use_ref_attention()
63+
64+
def check_use_ref_attention(self) -> bool:
65+
if not is_hip():
66+
return False
67+
# For ROCm, check whether flash attention is installed or not.
68+
# if not, use_ref_attention needs to be True
69+
return importlib.util.find_spec("flash_attn") is None
70+
71+
def ref_masked_attention(
72+
self,
73+
query: torch.Tensor,
74+
key: torch.Tensor,
75+
value: torch.Tensor,
76+
) -> torch.Tensor:
77+
query = query.view(-1, self.num_heads, self.head_size)
78+
key = key.view(-1, self.num_kv_heads, self.head_size)
79+
value = value.view(-1, self.num_kv_heads, self.head_size)
80+
81+
seq_len, _, _ = query.shape
82+
attn_mask = torch.triu(torch.ones(seq_len,
83+
seq_len,
84+
dtype=query.dtype,
85+
device=query.device),
86+
diagonal=1)
87+
attn_mask = attn_mask * torch.finfo(query.dtype).min
88+
89+
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
90+
key).float()
91+
attn_weights = attn_weights + attn_mask.float()
92+
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
93+
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
94+
return out
95+
6196
def forward(
6297
self,
6398
query: torch.Tensor,
@@ -137,6 +172,16 @@ def forward(
137172
self.alibi_slopes, self.num_kv_heads, batch_size,
138173
seq_len, query.dtype)
139174

175+
if self.use_ref_attention:
176+
output = self.ref_masked_attention(
177+
query,
178+
key,
179+
value,
180+
)
181+
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
182+
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
183+
return output.reshape(batch_size, seq_len, hidden_size)
184+
140185
# TODO(woosuk): Too many view operations. Let's try to reduce
141186
# them in the future for code readability.
142187
if self.alibi_slopes is None:

0 commit comments

Comments
 (0)