Skip to content

Conversation

@JartX
Copy link
Contributor

@JartX JartX commented Oct 20, 2025

The refactor introduced in the following PR:
#26104
improved the flash-attn selection, but broke the loading of models like Qwen/Qwen3-VL-30B-A3B-Instruct in RocM RDNA3, as it doesn't support flash-attn for VL. In the PR, I use the same backend selection for ROCM, which is used in: https://github.com/vllm-project/vllm/blob/main/vllm/platforms/rocm.py

in the method:

get_vit_attn_backend

I'm quoting @tjtanaa to help me check that I haven't broken anything upstream for models other than RDNA3 :)

Thanks a lot!

UPDATE 25-10-25
After verifying that there is an inference bug in the qwen3vl, qwen2.5vl and qwen2vl models on rocm with the TORCH_SDPA attention backend, we proceed to expand the PR with the code provided by @tjtanaa in the PR: #27106

Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX requested a review from LucasWilkinson as a code owner October 20, 2025 09:19
@mergify mergify bot added the rocm Related to AMD ROCm label Oct 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where ViT FlashAttention was incorrectly enabled on ROCm RDNA3 devices, which lack support for it. The fix introduces platform-specific logic to select the appropriate attention backend for ROCm devices, defaulting to Torch SDPA on non-GFX9 hardware. While the change is correct, my review identifies an opportunity to improve maintainability by removing duplicated logic, making the codebase cleaner and less prone to future inconsistencies.

Signed-off-by: JartX <sagformas@epdcenter.es>
@tjtanaa
Copy link
Collaborator

tjtanaa commented Oct 20, 2025

Thanks. Let me help verify :D

@JartX
Copy link
Contributor Author

JartX commented Oct 20, 2025

Thanks. Let me help verify :D

@tjtanaa Could you please check this too?
#27187

Many thanks!

if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = _Backend.ROCM_AITER_FA
elif on_gfx9():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition was crafted for ROCm platform (on_gfx9()).

        if (
            attn_backend != _Backend.FLASH_ATTN
            and attn_backend != _Backend.ROCM_AITER_FA
            and check_upstream_fa_availability(torch.get_default_dtype())
        ):

On on_gfx9(), we will always attempt to use flash_attn, but if it does not support, then we fallback to _Backend.TORCH_SDPA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa The first step was simply to check if it was rocm and if it was on_gfx1x and if so, return _Backend.TORCH_SDPA, None directly. Would you like to leave it like this with a comment indicating why? I've only seen it affect RDNA3 with the upstream flow. I think this way we could mark the required changes as resolved. If not, please tell me how the code should be :)

else:
return _Backend.TORCH_SDPA, None
else:
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, this condition and its content is also applicable to ROCm on gfx9.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Oct 21, 2025

Suggestions

  1. Add torch sdpa fallback
    https://github.com/JartX/vllm/blob/c8735e798fc090f1ec00576be3041a19c8c05695/vllm/attention/layer.py#L103-L117
    if current_platform.is_rocm():
        if (
            attn_backend not in [_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA]
        ):
            if check_upstream_fa_availability(torch.get_default_dtype()):
                attn_backend = _Backend.FLASH_ATTN
                use_upstream_fa = True
            else:
                return _Backend.TORCH_SDPA, None

    elif current_platform.is_cuda():
        if (
            attn_backend != _Backend.FLASH_ATTN
            and check_upstream_fa_availability(torch.get_default_dtype())
        ):
            attn_backend = _Backend.FLASH_ATTN
            use_upstream_fa = True
    else:
        # for other platforms, use torch_sdpa
        return _Backend.TORCH_SDPA, None
  1. Update the default

vllm/vllm/platforms/rocm.py

Lines 204 to 211 in 80e9452

def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
return _Backend.ROCM_AITER_FA
if on_gfx9():
return _Backend.FLASH_ATTN
return _Backend.TORCH_SDPA

    @classmethod
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
        from vllm.attention.backends.registry import _Backend

        if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
            return _Backend.ROCM_AITER_FA
            
        from importlib.util import find_spec
+        if on_gfx9() and find_spec("flash_attn") is not None:
            return _Backend.FLASH_ATTN
            
        return _Backend.TORCH_SDPA

CC @DarkLight1337 (as I am not familiar the latest abstraction of decoupling ViT Attention and LLM Attention)

@JartX
Copy link
Contributor Author

JartX commented Oct 22, 2025

@tjtanaa dont need edit the rocm.py it works fine, the problem is in attention/layer.py

Your recommendation fails, because: FA support is detected for on_gfx11, and attempts to load FA also fails. check_upstream_FA_availability() is True. This must be controlled implicitly. That's why in ROCM, it is filtered by the graphics model, there is only support in on_gfx9

if current_platform.is_rocm():
    if (
        attn_backend not in [_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA]
    ):
        if check_upstream_fa_availability(torch.get_default_dtype()):
            attn_backend = _Backend.FLASH_ATTN
            use_upstream_fa = True
        else:
            return _Backend.TORCH_SDPA, None

elif current_platform.is_cuda():
    if (
        attn_backend != _Backend.FLASH_ATTN
        and check_upstream_fa_availability(torch.get_default_dtype())
    ):
        attn_backend = _Backend.FLASH_ATTN
        use_upstream_fa = True
else:
    # for other platforms, use torch_sdpa
    return _Backend.TORCH_SDPA, None

I have updated your suggestion to work with rdna3 and fallback to SDPA :)

@tjtanaa
Copy link
Collaborator

tjtanaa commented Oct 22, 2025

@JartX
Your suggestions LGTM.

But can you update the rocm.py as well?
If you look for get_vit_attn_backend, you will see that they call this function in all model.py, then only pass the attn_backend into maybe_get_vit_flash_attn_backend to override the attn_backend.
This is the first place where the attention backend for ViT is determined (all starts from platform.py), so it is better to also update the source of all attention selection.

CC @DarkLight1337

@DarkLight1337
Copy link
Member

cc @ywang96 @Isotr0py

@JartX
Copy link
Contributor Author

JartX commented Oct 22, 2025

@JartX Tus sugerencias LGTM.

rocm.py¿Pero también se puede actualizar [ ?]? Si buscas [ get_vit_attn_backend, verás que llaman a esta función en todos los [ model.py] y luego solo pasan [ attn_backend] maybe_get_vit_flash_attn_backendpara sobrescribir [] attn_backend. Este es el primer punto donde se determina el backend de atención para ViT (todo empieza desde [ platform.py]), por lo que es mejor actualizar también la fuente de toda la selección de atención.

CC@DarkLight1337

@tjtanaa mmm okays, go to add the filter, to fine tune hehe😃

Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX
Copy link
Contributor Author

JartX commented Oct 22, 2025

@tjtanaa added :)

@JartX
Copy link
Contributor Author

JartX commented Oct 22, 2025

/gemini summary

@gemini-code-assist
Copy link
Contributor

Summary of Changes

This pull request resolves a critical bug affecting ROCm RDNA3 devices where FlashAttention was erroneously enabled, causing model loading failures. It refactors the ViT attention backend selection logic to ensure proper detection and fallback mechanisms for ROCm platforms, preventing unsupported FlashAttention usage and defaulting to _Backend.TORCH_SDPA when necessary. This change improves the robustness and compatibility of the attention backend for various ROCm architectures.

Highlights

  • Bug Fix for ROCm RDNA3: Addresses an issue where FlashAttention was being incorrectly selected for ROCm RDNA3 devices, leading to model loading failures by ensuring FlashAttention is only enabled on supported GFX9 architectures.
  • Unified Backend Selection Logic: The logic for selecting the ViT attention backend on ROCm has been consolidated within vllm/attention/layer.py to align with the platform-specific checks in vllm/platforms/rocm.py, preventing unsupported FlashAttention usage.
  • Torch SDPA Fallback: Introduces a fallback mechanism to _Backend.TORCH_SDPA when FlashAttention or ROCm AITer FlashAttention is not available or supported on ROCm platforms, particularly for non-GFX9 devices.
  • Explicit FlashAttention Module Check: The get_vit_attn_backend method in vllm/platforms/rocm.py now explicitly checks for the presence of the flash_attn module using importlib.util.find_spec before attempting to use _Backend.FLASH_ATTN.
Changelog
  • vllm/attention/layer.py
    • Imported on_gfx9 from vllm.platforms.rocm to enable ROCm-specific platform checks.
    • Modified maybe_get_vit_flash_attn_backend to include new ROCm-specific logic for selecting ROCM_AITER_FA, FLASH_ATTN, or falling back to TORCH_SDPA based on on_gfx9() and check_upstream_fa_availability().
    • Added an assertion assert self._flash_attn_varlen_func is not None within the forward method when is_flash_attn_backend is true, ensuring the function is available before use.
  • vllm/platforms/rocm.py
    • Modified get_vit_attn_backend to explicitly check for the presence of the flash_attn module using find_spec("flash_attn") before returning _Backend.FLASH_ATTN for GFX9 devices.
Activity
  • tjtanaa offered to verify the changes.
  • JartX requested tjtanaa to also check another related pull request.
  • tjtanaa provided suggestions for modifying both vllm/attention/layer.py (to add torch_sdpa fallback) and vllm/platforms/rocm.py (to update default backend selection).
  • JartX initially disagreed with modifying rocm.py, stating the issue was in attention/layer.py, and explained why tjtanaa's initial suggestion for layer.py was problematic for on_gfx11 devices. JartX then updated the suggestion to work with RDNA3.
  • tjtanaa approved JartX's updated suggestions but reiterated the importance of updating rocm.py as it is the primary source for attention backend selection.
  • DarkLight1337 CC'd additional reviewers (ywang96, Isotr0py).
  • JartX agreed to incorporate the filter into rocm.py.
  • JartX confirmed that the updates to rocm.py were added.
  • gemini-code-assist[bot] provided an initial review comment, highlighting code duplication and suggesting a cleaner fix by preventing the upgrade logic on ROCm and relying on get_vit_attn_backend.
  • tjtanaa commented on the condition for _Backend.ROCM_AITER_FA, noting it's determined in model.py files.
  • tjtanaa clarified that the original condition was for ROCm on gfx9 and would attempt flash_attn with TORCH_SDPA as a fallback.
  • JartX asked tjtanaa for confirmation on whether to keep a direct _Backend.TORCH_SDPA, None return for on_gfx1x with a comment.
  • tjtanaa reiterated that the condition is applicable to ROCm on gfx9.

…solve conflict on layer.py attn_backend_override

Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX force-pushed the fixbug/torch_sdpa_rdna3_for_vl_models branch from 4d21e1e to 46ed9c6 Compare October 24, 2025 07:32
@JartX JartX requested a review from sighingnow as a code owner October 25, 2025 16:06
@mergify mergify bot added the qwen Related to Qwen models label Oct 25, 2025
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@DarkLight1337
Copy link
Member

Can you fix DCO?

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX force-pushed the fixbug/torch_sdpa_rdna3_for_vl_models branch from 64f4906 to 0dabbf1 Compare October 25, 2025 16:12
@JartX
Copy link
Contributor Author

JartX commented Oct 25, 2025

Can you fix DCO?

do it, many thanks @DarkLight1337 :D

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 25, 2025 16:21
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 25, 2025
Signed-off-by: JartX <sagformas@epdcenter.es>

precommit

Signed-off-by: JartX <sagformas@epdcenter.es>
auto-merge was automatically disabled October 25, 2025 16:25

Head branch was pushed to by a user without write access

@JartX JartX force-pushed the fixbug/torch_sdpa_rdna3_for_vl_models branch from 46284b7 to c42f2d1 Compare October 25, 2025 16:25
@JartX
Copy link
Contributor Author

JartX commented Oct 25, 2025

Sorry @DarkLight1337 the pre-commit was missing now all its okays, can you reenable the auto-merge, please? many thanks (L)

@JartX JartX changed the title [BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) [BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA Oct 25, 2025
@JartX
Copy link
Contributor Author

JartX commented Oct 25, 2025

@DarkLight1337 @tjtanaa What can we do with tests that fail?

@DarkLight1337
Copy link
Member

The test passes after retrying, so it should be good to merge

@DarkLight1337 DarkLight1337 merged commit 65d2cf9 into vllm-project:main Oct 26, 2025
55 checks passed
xuhaolei pushed a commit to ZJU-REAL/EasySteer-vllm-v1 that referenced this pull request Oct 27, 2025
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
@zhewenl
Copy link
Collaborator

zhewenl commented Oct 30, 2025

Unfortunately I think this PR breaks AMD CI (and also qwen2.5 vl):
tests/v1/entrypoints/openai/responses/test_image.py : with _Backend.FLASH_ATTN it did NOT set use_upstream_fa = True(code), so we got ImportError: cannot import name 'flash_attn_varlen_func' from 'vllm.vllm_flash_attn' (unknown location) (failure)

cc @DarkLight1337 , @JartX , @tjtanaa could you help take a look at proper fix? I think we should update the hacky logic to set use_upstream_fa in qwen2.5-vl, attempting fix in #27790

ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
… qwen3vl ROCm TORCH_SDPA (vllm-project#27190)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants