Skip to content

Commit 65d2cf9

Browse files
JartXtjtanaa
andauthored
[BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA (#27190)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent d63cd9f commit 65d2cf9

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

vllm/attention/layer.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
SlidingWindowSpec,
4848
)
4949

50+
if current_platform.is_rocm():
51+
from vllm.platforms.rocm import on_gfx9
52+
else:
53+
on_gfx9 = lambda *args, **kwargs: False
54+
55+
5056
FP8_DTYPE = current_platform.fp8_dtype()
5157
logger = init_logger(__name__)
5258
USE_XFORMERS_OPS = None
@@ -96,18 +102,29 @@ def maybe_get_vit_flash_attn_backend(
96102
attn_backend: _Backend,
97103
use_upstream_fa: bool,
98104
attn_backend_override: _Backend | None = None,
99-
) -> tuple[_Backend, Callable]:
100-
if (
101-
attn_backend != _Backend.FLASH_ATTN
102-
and attn_backend != _Backend.ROCM_AITER_FA
103-
and check_upstream_fa_availability(torch.get_default_dtype())
104-
and attn_backend_override is None
105-
):
106-
attn_backend = _Backend.FLASH_ATTN
107-
use_upstream_fa = True
105+
) -> tuple[_Backend, Callable | None]:
106+
if current_platform.is_rocm():
107+
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
108+
attn_backend = _Backend.ROCM_AITER_FA
109+
110+
elif (
111+
check_upstream_fa_availability(torch.get_default_dtype())
112+
and on_gfx9()
113+
and attn_backend_override is None
114+
):
115+
attn_backend = _Backend.FLASH_ATTN
116+
use_upstream_fa = True
117+
else:
118+
return _Backend.TORCH_SDPA, None
108119

109-
if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
110-
use_upstream_fa = True
120+
elif current_platform.is_cuda():
121+
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
122+
torch.get_default_dtype()
123+
):
124+
attn_backend = _Backend.FLASH_ATTN
125+
use_upstream_fa = True
126+
else:
127+
return _Backend.TORCH_SDPA, None
111128

112129
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
113130
if attn_backend == _Backend.ROCM_AITER_FA:
@@ -570,6 +587,7 @@ def forward(
570587
value = torch.repeat_interleave(value, num_repeat, dim=2)
571588

572589
if self.is_flash_attn_backend:
590+
assert self._flash_attn_varlen_func is not None
573591
cu_seqlens_q = torch.arange(
574592
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
575593
)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,12 @@ def forward(
429429
).contiguous()
430430
elif self.attn_backend == _Backend.TORCH_SDPA:
431431
# Execute attention entry by entry for speed & less VRAM.
432+
from vllm.platforms import current_platform
433+
434+
if current_platform.is_rocm():
435+
q = q.contiguous()
436+
k = k.contiguous()
437+
v = v.contiguous()
432438
outputs = []
433439
for i in range(1, len(cu_seqlens)):
434440
start_idx = cu_seqlens[i - 1]

vllm/model_executor/models/qwen2_vl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,12 @@ def forward(
462462
).contiguous()
463463
elif self.attn_backend == _Backend.TORCH_SDPA:
464464
# Execute attention entry by entry for speed & less VRAM.
465+
from vllm.platforms import current_platform
466+
467+
if current_platform.is_rocm():
468+
q = q.contiguous()
469+
k = k.contiguous()
470+
v = v.contiguous()
465471
outputs = []
466472
for i in range(1, len(cu_seqlens)):
467473
start_idx = cu_seqlens[i - 1]

vllm/platforms/rocm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,16 @@ class RocmPlatform(Platform):
205205

206206
@classmethod
207207
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
208+
from importlib.util import find_spec
209+
208210
from vllm.attention.backends.registry import _Backend
209211

210212
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
211213
return _Backend.ROCM_AITER_FA
212-
if on_gfx9():
214+
215+
if on_gfx9() and find_spec("flash_attn") is not None:
213216
return _Backend.FLASH_ATTN
217+
214218
return _Backend.TORCH_SDPA
215219

216220
@classmethod

0 commit comments

Comments
 (0)