-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA #27190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA #27190
Conversation
Signed-off-by: JartX <sagformas@epdcenter.es>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug where ViT FlashAttention was incorrectly enabled on ROCm RDNA3 devices, which lack support for it. The fix introduces platform-specific logic to select the appropriate attention backend for ROCm devices, defaulting to Torch SDPA on non-GFX9 hardware. While the change is correct, my review identifies an opportunity to improve maintainability by removing duplicated logic, making the codebase cleaner and less prone to future inconsistencies.
Signed-off-by: JartX <sagformas@epdcenter.es>
|
Thanks. Let me help verify :D |
vllm/attention/layer.py
Outdated
| if current_platform.is_rocm(): | ||
| if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): | ||
| attn_backend = _Backend.ROCM_AITER_FA | ||
| elif on_gfx9(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition was crafted for ROCm platform (on_gfx9()).
if (
attn_backend != _Backend.FLASH_ATTN
and attn_backend != _Backend.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):On on_gfx9(), we will always attempt to use flash_attn, but if it does not support, then we fallback to _Backend.TORCH_SDPA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjtanaa The first step was simply to check if it was rocm and if it was on_gfx1x and if so, return _Backend.TORCH_SDPA, None directly. Would you like to leave it like this with a comment indicating why? I've only seen it affect RDNA3 with the upstream flow. I think this way we could mark the required changes as resolved. If not, please tell me how the code should be :)
vllm/attention/layer.py
Outdated
| else: | ||
| return _Backend.TORCH_SDPA, None | ||
| else: | ||
| if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, this condition and its content is also applicable to ROCm on gfx9.
|
Suggestions
Lines 204 to 211 in 80e9452
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
return _Backend.ROCM_AITER_FA
from importlib.util import find_spec
+ if on_gfx9() and find_spec("flash_attn") is not None:
return _Backend.FLASH_ATTN
return _Backend.TORCH_SDPACC @DarkLight1337 (as I am not familiar the latest abstraction of decoupling ViT Attention and LLM Attention) |
Signed-off-by: JartX <sagformas@epdcenter.es>
|
@tjtanaa dont need edit the rocm.py it works fine, the problem is in attention/layer.py Your recommendation fails, because: FA support is detected for on_gfx11, and attempts to load FA also fails. check_upstream_FA_availability() is True. This must be controlled implicitly. That's why in ROCM, it is filtered by the graphics model, there is only support in on_gfx9 I have updated your suggestion to work with rdna3 and fallback to SDPA :) |
|
@JartX But can you update the |
@tjtanaa mmm okays, go to add the filter, to fine tune hehe😃 |
Signed-off-by: JartX <sagformas@epdcenter.es>
|
@tjtanaa added :) |
|
/gemini summary |
Summary of ChangesThis pull request resolves a critical bug affecting ROCm RDNA3 devices where FlashAttention was erroneously enabled, causing model loading failures. It refactors the ViT attention backend selection logic to ensure proper detection and fallback mechanisms for ROCm platforms, preventing unsupported FlashAttention usage and defaulting to Highlights
Changelog
Activity
|
…solve conflict on layer.py attn_backend_override Signed-off-by: JartX <sagformas@epdcenter.es>
4d21e1e to
46ed9c6
Compare
DarkLight1337
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
|
Can you fix DCO? |
Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: JartX <sagformas@epdcenter.es>
64f4906 to
0dabbf1
Compare
do it, many thanks @DarkLight1337 :D |
Head branch was pushed to by a user without write access
46284b7 to
c42f2d1
Compare
|
Sorry @DarkLight1337 the pre-commit was missing now all its okays, can you reenable the auto-merge, please? many thanks (L) |
|
@DarkLight1337 @tjtanaa What can we do with tests that fail? |
|
The test passes after retrying, so it should be good to merge |
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190) Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
|
Unfortunately I think this PR breaks AMD CI (and also qwen2.5 vl): cc @DarkLight1337 , @JartX , @tjtanaa could you help take a look at proper fix? I think we should update the hacky logic to set use_upstream_fa in qwen2.5-vl, attempting fix in #27790 |
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190) Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190) Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190) Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190) Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
The refactor introduced in the following PR:
#26104
improved the flash-attn selection, but broke the loading of models like Qwen/Qwen3-VL-30B-A3B-Instruct in RocM RDNA3, as it doesn't support flash-attn for VL. In the PR, I use the same backend selection for ROCM, which is used in: https://github.com/vllm-project/vllm/blob/main/vllm/platforms/rocm.py
in the method:
get_vit_attn_backend
I'm quoting @tjtanaa to help me check that I haven't broken anything upstream for models other than RDNA3 :)
Thanks a lot!
UPDATE 25-10-25
After verifying that there is an inference bug in the qwen3vl, qwen2.5vl and qwen2vl models on rocm with the TORCH_SDPA attention backend, we proceed to expand the PR with the code provided by @tjtanaa in the PR: #27106