Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] [Hardware][AMD] Remove xformer patches and ray issue fix #3558

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"

ARG FA_BRANCH="3d2b6f5"
ARG FA_BRANCH="ae7928c"
RUN echo "FA_BRANCH is $FA_BRANCH"

# whether to build flash-attention
Expand Down Expand Up @@ -98,18 +98,16 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \
COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install xformers==0.0.23 --no-deps
RUN python3 -m pip install xformers --no-deps

RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& if [ "$BUILD_FA" = "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cd ..

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3

CMD ["/bin/bash"]
33 changes: 0 additions & 33 deletions patch_xformers.rocm.sh

This file was deleted.

13 changes: 0 additions & 13 deletions rocm_patch/commonpy_xformers-0.0.23.rocm.patch

This file was deleted.

152 changes: 0 additions & 152 deletions rocm_patch/flashpy_xformers-0.0.23.rocm.patch

This file was deleted.

7 changes: 5 additions & 2 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, List, Any, Dict

import torch
from xformers.ops.fmha.attn_bias import AttentionBias


@dataclass
Expand Down Expand Up @@ -82,7 +81,11 @@ def __post_init__(self):
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
try:
from xformers.ops.fmha.attn_bias import AttentionBias
self.attn_bias: Optional[List[AttentionBias]] = None
except ImportError:
self.attn_bias = None

# Cuda graph is only used for decoding now.
if self.use_cuda_graph:
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip

logger = init_logger(__name__)

Expand Down Expand Up @@ -67,9 +66,6 @@ def _use_flash_attn() -> bool:
logger.info("flash_attn is not found. Using xformers backend.")
return False

if is_hip():
# AMD GPUs.
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
Expand Down
41 changes: 28 additions & 13 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.utils import is_hip


class FlashAttentionBackend:
Expand Down Expand Up @@ -99,19 +100,33 @@ def forward(
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=input_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len,
max_seqlen_k=input_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
if is_hip():
# window_size and alibi_slopes not supported
hongxiayang marked this conversation as resolved.
Show resolved Hide resolved
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=input_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len,
max_seqlen_k=input_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=input_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len,
max_seqlen_k=input_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(
f"Supported head sizes are: {suppored_head_sizes}.")

self.use_ref_attention = _check_use_ref_attention()
if self.use_ref_attention:
print("ref attention used.")

def forward(
self,
Expand Down Expand Up @@ -119,7 +121,6 @@ def forward(
value.shape[-1])

if self.use_ref_attention:
print("ref attention used.")
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(input_metadata.prompt_lens):
Expand Down