From 5d228c185556c0a27f591c4b075ac78a4acb1208 Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Sun, 11 Feb 2024 02:14:37 -0500 Subject: [PATCH] =?UTF-8?q?[ROCm]=20support=20Radeon=E2=84=A2=207900=20ser?= =?UTF-8?q?ies=20(gfx1100)=20without=20using=20flash-attention=20(#2768)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile.rocm | 15 +++++-- .../getting_started/amd-installation.rst | 3 +- setup.py | 2 +- vllm/model_executor/layers/attention.py | 45 +++++++++++++++++++ 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index f49b321372ed0..e0ef4a0f4131a 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -17,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" ARG FA_BRANCH="3d2b6f5" RUN echo "FA_BRANCH is $FA_BRANCH" +# whether to build flash-attention +# if 0, will not build flash attention +# this is useful for gfx target where flash-attention is not supported +# In that case, we need to use the python reference attention implementation in vllm +ARG BUILD_FA="1" + # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -50,7 +56,8 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: # Install ROCm flash-attention -RUN mkdir libs \ +RUN if [ "$BUILD_FA" == "1" ]; then \ + mkdir libs \ && cd libs \ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ && cd flash-attention \ @@ -60,7 +67,8 @@ RUN mkdir libs \ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ && python3 setup.py install \ - && cd .. + && cd ..; \ + fi COPY ./ /app/vllm @@ -75,7 +83,8 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" RUN cd /app \ && cd vllm \ && pip install -U -r requirements-rocm.txt \ - && bash patch_xformers.rocm.sh \ + && if [ "$BUILD_FA" == "1" ]; then \ + bash patch_xformers.rocm.sh; fi \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ && cd .. diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 6851ba136351c..5d9fdf4056709 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -12,7 +12,7 @@ Requirements * OS: Linux * Python: 3.8 -- 3.11 -* GPU: MI200s (gfx90a), MI300 (gfx942) +* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) * Pytorch 2.0.1/2.1.1/2.2 * ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9) @@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` * `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` * `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo `_. The default is `3d2b6f5` +* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target. Their values can be passed in when running ``docker build`` with ``--build-arg`` options. diff --git a/setup.py b/setup.py index 60efed0720ff1..ea58a1a49e7e3 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ # Supported NVIDIA GPU architectures. NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} -ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"} +ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"} # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2ce9d60f08d80..0622a54db1bc0 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -1,6 +1,7 @@ """Multi-head attention.""" from typing import List, Optional +import importlib import torch import torch.nn as nn from xformers import ops as xops @@ -58,6 +59,40 @@ def __init__( raise ValueError(f"head_size ({self.head_size}) is not supported. " f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + self.use_ref_attention = self.check_use_ref_attention() + + def check_use_ref_attention(self) -> bool: + if not is_hip(): + return False + # For ROCm, check whether flash attention is installed or not. + # if not, use_ref_attention needs to be True + return importlib.util.find_spec("flash_attn") is None + + def ref_masked_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query, + key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + def forward( self, query: torch.Tensor, @@ -137,6 +172,16 @@ def forward( self.alibi_slopes, self.num_kv_heads, batch_size, seq_len, query.dtype) + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(batch_size, seq_len, hidden_size) + # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. if self.alibi_slopes is None: