Skip to content
40 changes: 29 additions & 11 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
SlidingWindowSpec,
)

if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False


FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
Expand Down Expand Up @@ -96,18 +102,29 @@ def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend,
use_upstream_fa: bool,
attn_backend_override: _Backend | None = None,
) -> tuple[_Backend, Callable]:
if (
attn_backend != _Backend.FLASH_ATTN
and attn_backend != _Backend.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
) -> tuple[_Backend, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = _Backend.ROCM_AITER_FA

elif (
check_upstream_fa_availability(torch.get_default_dtype())
and on_gfx9()
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
else:
return _Backend.TORCH_SDPA, None

if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
elif current_platform.is_cuda():
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
else:
return _Backend.TORCH_SDPA, None

if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if attn_backend == _Backend.ROCM_AITER_FA:
Expand Down Expand Up @@ -570,6 +587,7 @@ def forward(
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.is_flash_attn_backend:
assert self._flash_attn_varlen_func is not None
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ def forward(
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform

if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,12 @@ def forward(
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform

if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
Expand Down
6 changes: 5 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,16 @@ class RocmPlatform(Platform):

@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from importlib.util import find_spec

from vllm.attention.backends.registry import _Backend

if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
return _Backend.ROCM_AITER_FA
if on_gfx9():

if on_gfx9() and find_spec("flash_attn") is not None:
return _Backend.FLASH_ATTN

return _Backend.TORCH_SDPA

@classmethod
Expand Down