Skip to content

Commit 2bcf680

Browse files
committed
honor --mm_encoder_attn_backend when used (#27124)
Summary: Pull Request resolved: #27124 In #26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when `VLLM_ATTENTION_BACKEND` is set. This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior. Reviewed By: Prowindy Differential Revision: D84946967
1 parent a0003b5 commit 2bcf680

File tree

6 files changed

+9
-2
lines changed

6 files changed

+9
-2
lines changed

vllm/attention/layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Attention layer."""
44

55
from collections.abc import Callable
6-
from typing import cast
6+
from typing import cast, Optional
77

88
import torch
99
import torch.nn as nn
@@ -93,12 +93,13 @@ def check_upstream_fa_availability(dtype: torch.dtype):
9393

9494

9595
def maybe_get_vit_flash_attn_backend(
96-
attn_backend: _Backend, use_upstream_fa: bool
96+
attn_backend: _Backend, use_upstream_fa: bool, attn_backend_override: Optional[_Backend] = None
9797
) -> tuple[_Backend, Callable]:
9898
if (
9999
attn_backend != _Backend.FLASH_ATTN
100100
and attn_backend != _Backend.ROCM_AITER_FA
101101
and check_upstream_fa_availability(torch.get_default_dtype())
102+
and attn_backend_override is None
102103
):
103104
attn_backend = _Backend.FLASH_ATTN
104105
use_upstream_fa = True
@@ -499,6 +500,7 @@ def __init__(
499500
maybe_get_vit_flash_attn_backend(
500501
self.attn_backend,
501502
use_upstream_fa,
503+
attn_backend_override=attn_backend_override,
502504
)
503505
)
504506

vllm/model_executor/models/dots_ocr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(
299299
maybe_get_vit_flash_attn_backend(
300300
self.attn_backend,
301301
self.use_upstream_fa,
302+
attn_backend_override=attn_backend_override,
302303
)
303304
)
304305
if self.attn_backend not in {

vllm/model_executor/models/ernie45_vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def __init__(
206206
maybe_get_vit_flash_attn_backend(
207207
self.attn_backend,
208208
self.use_upstream_fa,
209+
attn_backend_override=attn_backend_override,
209210
)
210211
)
211212

vllm/model_executor/models/glm4_1v.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def __init__(
296296
maybe_get_vit_flash_attn_backend(
297297
self.attn_backend,
298298
self.use_upstream_fa,
299+
attn_backend_override=attn_backend_override,
299300
)
300301
)
301302

vllm/model_executor/models/qwen2_vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def __init__(
364364
maybe_get_vit_flash_attn_backend(
365365
self.attn_backend,
366366
self.use_upstream_fa,
367+
attn_backend_override=attn_backend_override,
367368
)
368369
)
369370

vllm/model_executor/models/siglip2navit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def __init__(
259259
maybe_get_vit_flash_attn_backend(
260260
self.attn_backend,
261261
self.use_upstream_fa,
262+
attn_backend_override=attn_backend_override,
262263
)
263264
)
264265

0 commit comments

Comments
 (0)