Skip to content

Commit 203c634

Browse files
wwl2755gc-fu
authored andcommitted
[Multi Modal] Add FA3 in VIT (vllm-project#24347)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
1 parent 29f17c0 commit 203c634

File tree

13 files changed

+247
-66
lines changed

13 files changed

+247
-66
lines changed

tests/entrypoints/openai/test_vision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
],
3535
[
3636
"The image shows a Venn diagram with three over",
37-
"The image shows a Venn diagram with three intersect",
37+
"This image shows a Venn diagram with three over",
3838
],
3939
[
4040
"This image displays a gradient of colors ranging from",
41-
"The image displays a gradient of colors ranging from",
41+
"This image displays a gradient of colors forming a spectrum",
4242
],
4343
]
4444

tests/kernels/attention/test_mha_attn.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
3636
torch.set_default_dtype(torch.float16)
3737

3838
if device == "cpu":
39-
with patch("vllm.attention.selector.current_platform",
40-
CpuPlatform()), \
41-
patch("vllm.platforms.current_platform", CpuPlatform()):
39+
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
40+
patch("vllm.model_executor.models.vision.current_platform",
41+
CpuPlatform()):
4242
attn = MultiHeadAttention(16, 64, scale=1)
43-
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
43+
assert attn.attn_backend == _Backend.TORCH_SDPA
4444
elif device == "hip":
45-
with patch("vllm.attention.selector.current_platform",
46-
RocmPlatform()), \
47-
patch("vllm.platforms.current_platform", RocmPlatform()), \
48-
patch("vllm.attention.layer.current_platform", RocmPlatform()):
45+
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
46+
patch("vllm.model_executor.models.vision.current_platform",
47+
RocmPlatform()):
4948
attn = MultiHeadAttention(16, 64, scale=1)
5049
assert attn.attn_backend == _Backend.TORCH_SDPA
5150
else:
52-
with patch("vllm.attention.selector.current_platform",
53-
CudaPlatform()), \
54-
patch("vllm.platforms.current_platform", CudaPlatform()):
51+
# Test CUDA with head_size=64 (divisible by 32)
52+
# - should use vLLM's FlashAttention
53+
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
54+
patch("vllm.model_executor.models.vision.current_platform",
55+
CudaPlatform()):
5556
attn = MultiHeadAttention(16, 64, scale=1)
56-
assert attn.attn_backend == _Backend.XFORMERS
57+
assert attn.attn_backend == _Backend.FLASH_ATTN
5758

58-
with patch("vllm.attention.selector.current_platform",
59+
# Test CUDA with head_size=72 (not divisible by 32)
60+
# - with upstream FA not available
61+
# - should use xformers
62+
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
63+
patch("vllm.model_executor.models.vision.current_platform",
5964
CudaPlatform()), \
60-
patch("vllm.platforms.current_platform", CudaPlatform()):
65+
patch("vllm.attention.layer.check_upstream_fa_availability",
66+
return_value=False):
6167
attn = MultiHeadAttention(16, 72, scale=1)
6268
assert attn.attn_backend == _Backend.XFORMERS
6369

70+
# Test CUDA with head_size=72 (not divisible by 32)
71+
# - with upstream FA available
72+
# - should use upstream FA
73+
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
74+
patch("vllm.model_executor.models.vision.current_platform",
75+
CudaPlatform()), \
76+
patch("vllm.attention.layer.check_upstream_fa_availability",
77+
return_value=True), \
78+
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
79+
{
80+
'flash_attn_varlen_func': lambda *args, **kwargs: None
81+
})()}):
82+
attn = MultiHeadAttention(16, 72, scale=1)
83+
assert attn.attn_backend == _Backend.FLASH_ATTN
84+
6485

6586
def ref_attention(
6687
query: torch.Tensor,

vllm/attention/layer.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig)
2525
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26+
from vllm.model_executor.models.vision import get_vit_attn_backend
2627
from vllm.platforms import _Backend, current_platform
2728
from vllm.utils import direct_register_custom_op
2829

@@ -64,6 +65,14 @@ def check_xformers_availability():
6465
return USE_XFORMERS_OPS
6566

6667

68+
def check_upstream_fa_availability(dtype: torch.dtype):
69+
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
70+
) and current_platform.has_device_capability(80):
71+
from transformers.utils import is_flash_attn_2_available
72+
return is_flash_attn_2_available()
73+
return False
74+
75+
6776
class Attention(nn.Module, AttentionLayerBase):
6877
"""Attention layer.
6978
@@ -358,29 +367,55 @@ def __init__(
358367
f"divisible by num_kv_heads ({self.num_kv_heads})"
359368
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
360369

370+
# During model initialization, the default dtype is set as the model
371+
# weight and activation dtype.
361372
dtype = torch.get_default_dtype()
362-
attn_backend = get_attn_backend(head_size,
363-
dtype,
364-
kv_cache_dtype=None,
365-
block_size=16,
366-
is_attention_free=False)
367-
backend = backend_name_to_enum(attn_backend.get_name())
373+
374+
# Determine the attention backend
375+
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
376+
377+
# Some auto-selected backends can be upgraded
378+
# to upstream flash attention if available.
379+
# If vllm native fa is selected, we use it directly.
380+
use_upstream_fa = False
381+
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
382+
dtype):
383+
backend = _Backend.FLASH_ATTN
384+
use_upstream_fa = True
385+
368386
if current_platform.is_rocm():
369387
# currently, only torch_sdpa is supported on rocm
370388
self.attn_backend = _Backend.TORCH_SDPA
371389
else:
390+
372391
self.attn_backend = backend if backend in {
373392
_Backend.TORCH_SDPA,
374393
_Backend.TORCH_SDPA_VLLM_V1,
375394
_Backend.XFORMERS,
376395
_Backend.PALLAS_VLLM_V1,
377396
_Backend.ROCM_AITER_FA,
378-
} else current_platform.get_vit_attn_backend()
397+
_Backend.FLASH_ATTN,
398+
_Backend.FLASH_ATTN_VLLM_V1,
399+
} else _Backend.TORCH_SDPA
379400

380401
if (self.attn_backend == _Backend.XFORMERS
381402
and not check_xformers_availability()):
382403
self.attn_backend = _Backend.TORCH_SDPA
383404

405+
if self.attn_backend in {
406+
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
407+
}:
408+
if use_upstream_fa:
409+
from flash_attn import flash_attn_varlen_func
410+
self._flash_attn_varlen_func = flash_attn_varlen_func
411+
else:
412+
from vllm.vllm_flash_attn import flash_attn_varlen_func
413+
self._flash_attn_varlen_func = flash_attn_varlen_func
414+
415+
logger.info_once(
416+
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
417+
f"use_upstream_fa: {use_upstream_fa}")
418+
384419
def forward(
385420
self,
386421
query: torch.Tensor,
@@ -401,7 +436,31 @@ def forward(
401436
key = torch.repeat_interleave(key, num_repeat, dim=2)
402437
value = torch.repeat_interleave(value, num_repeat, dim=2)
403438

404-
if self.attn_backend == _Backend.XFORMERS:
439+
if self.attn_backend in {
440+
_Backend.FLASH_ATTN,
441+
_Backend.FLASH_ATTN_VLLM_V1,
442+
}:
443+
444+
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
445+
step=q_len,
446+
dtype=torch.int32,
447+
device=query.device)
448+
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
449+
step=kv_len,
450+
dtype=torch.int32,
451+
device=key.device)
452+
453+
out = self._flash_attn_varlen_func(
454+
query.flatten(0, 1),
455+
key.flatten(0, 1),
456+
value.flatten(0, 1),
457+
cu_seqlens_q=cu_seqlens_q,
458+
cu_seqlens_k=cu_seqlens_k,
459+
max_seqlen_q=q_len,
460+
max_seqlen_k=kv_len,
461+
softmax_scale=self.scale,
462+
)
463+
elif self.attn_backend == _Backend.XFORMERS:
405464
from xformers import ops as xops
406465

407466
out = xops.memory_efficient_attention_forward(query,

vllm/model_executor/models/ernie45_vl.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from einops import rearrange, repeat
3535
from transformers import BatchFeature
3636

37+
from vllm.attention.layer import check_upstream_fa_availability
3738
from vllm.config import VllmConfig
3839
from vllm.distributed import parallel_state
3940
from vllm.distributed import utils as dist_utils
@@ -170,7 +171,16 @@ def __init__(
170171
prefix=f"{prefix}.proj")
171172

172173
# Detect attention implementation.
173-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
174+
self.attn_backend = get_vit_attn_backend(
175+
head_size=self.hidden_size_per_attention_head,
176+
dtype=torch.get_default_dtype())
177+
178+
self.use_upstream_fa = False
179+
if self.attn_backend != _Backend.FLASH_ATTN and \
180+
check_upstream_fa_availability(torch.get_default_dtype()):
181+
self.attn_backend = _Backend.FLASH_ATTN
182+
self.use_upstream_fa = True
183+
174184
if self.attn_backend not in {
175185
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
176186
_Backend.ROCM_AITER_FA
@@ -233,7 +243,10 @@ def forward(
233243
if self.attn_backend == _Backend.ROCM_AITER_FA:
234244
from aiter import flash_attn_varlen_func
235245
else:
236-
from flash_attn import flash_attn_varlen_func
246+
if self.use_upstream_fa:
247+
from flash_attn import flash_attn_varlen_func
248+
else:
249+
from vllm.vllm_flash_attn import flash_attn_varlen_func
237250

238251
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
239252

@@ -457,7 +470,11 @@ def __init__(
457470
), "vit's config.hidden must be equal to config.embed_dim"
458471
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
459472

460-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
473+
self.attn_backend = get_vit_attn_backend(
474+
head_size=head_dim, dtype=torch.get_default_dtype())
475+
if self.attn_backend != _Backend.FLASH_ATTN and \
476+
check_upstream_fa_availability(torch.get_default_dtype()):
477+
self.attn_backend = _Backend.FLASH_ATTN
461478

462479
@property
463480
def dtype(self) -> torch.dtype:

vllm/model_executor/models/glm4_1v.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
Glm4vVideoProcessor)
4545
from transformers.video_utils import VideoMetadata
4646

47+
from vllm.attention.layer import check_upstream_fa_availability
4748
from vllm.config import VllmConfig
4849
from vllm.distributed import (get_tensor_model_parallel_world_size,
4950
parallel_state)
@@ -260,7 +261,15 @@ def __init__(
260261
)
261262

262263
# Detect attention implementation.
263-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
264+
self.attn_backend = get_vit_attn_backend(
265+
head_size=self.hidden_size_per_attention_head,
266+
dtype=torch.get_default_dtype())
267+
self.use_upstream_fa = False
268+
if self.attn_backend != _Backend.FLASH_ATTN and \
269+
check_upstream_fa_availability(torch.get_default_dtype()):
270+
self.attn_backend = _Backend.FLASH_ATTN
271+
self.use_upstream_fa = True
272+
264273
if self.attn_backend not in {
265274
_Backend.FLASH_ATTN,
266275
_Backend.TORCH_SDPA,
@@ -310,7 +319,10 @@ def forward(
310319
if self.attn_backend == _Backend.FLASH_ATTN:
311320
# from vllm_flash_attn.flash_attn_interface import (
312321
# flash_attn_varlen_func)
313-
from flash_attn import flash_attn_varlen_func
322+
if self.use_upstream_fa:
323+
from flash_attn import flash_attn_varlen_func
324+
else:
325+
from vllm.vllm_flash_attn import flash_attn_varlen_func
314326

315327
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
316328

@@ -715,7 +727,11 @@ def __init__(
715727
self.post_layernorm = RMSNorm(vision_config.hidden_size,
716728
eps=vision_config.rms_norm_eps)
717729

718-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
730+
self.attn_backend = get_vit_attn_backend(
731+
head_size=head_dim, dtype=torch.get_default_dtype())
732+
if self.attn_backend != _Backend.FLASH_ATTN and \
733+
check_upstream_fa_availability(torch.get_default_dtype()):
734+
self.attn_backend = _Backend.FLASH_ATTN
719735

720736
@property
721737
def dtype(self) -> torch.dtype:

vllm/model_executor/models/keye.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
BaseModelOutputWithPooling)
1818
from transformers.utils import torch_int
1919

20+
from vllm.attention.layer import check_upstream_fa_availability
2021
from vllm.config import VllmConfig
2122
from vllm.distributed import get_tensor_model_parallel_world_size
2223
from vllm.logger import init_logger
@@ -374,7 +375,16 @@ def __init__(
374375
)
375376

376377
# Detect attention implementation.
377-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
378+
self.attn_backend = get_vit_attn_backend(
379+
head_size=self.head_dim, dtype=torch.get_default_dtype())
380+
381+
self.use_upstream_fa = False
382+
if self.attn_backend != _Backend.FLASH_ATTN and \
383+
check_upstream_fa_availability(
384+
torch.get_default_dtype()):
385+
self.attn_backend = _Backend.FLASH_ATTN
386+
self.use_upstream_fa = True
387+
378388
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
379389
raise RuntimeError(
380390
f"Keye-VL does not support {self.attn_backend} backend now.")
@@ -428,7 +438,10 @@ def forward(
428438
)
429439

430440
if self.attn_backend == _Backend.FLASH_ATTN:
431-
from flash_attn import flash_attn_varlen_func
441+
if self.use_upstream_fa:
442+
from flash_attn import flash_attn_varlen_func
443+
else:
444+
from vllm.vllm_flash_attn import flash_attn_varlen_func
432445

433446
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
434447

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
3939
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
4040

41+
from vllm.attention.layer import check_upstream_fa_availability
4142
from vllm.config import VllmConfig
4243
from vllm.distributed import parallel_state
4344
from vllm.distributed import utils as dist_utils
@@ -298,7 +299,16 @@ def __init__(
298299
disable_tp=use_data_parallel)
299300

300301
# Detect attention implementation.
301-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
302+
self.attn_backend = get_vit_attn_backend(
303+
head_size=self.hidden_size_per_attention_head,
304+
dtype=torch.get_default_dtype())
305+
self.use_upstream_fa = False
306+
if self.attn_backend != _Backend.FLASH_ATTN and \
307+
check_upstream_fa_availability(
308+
torch.get_default_dtype()):
309+
self.attn_backend = _Backend.FLASH_ATTN
310+
self.use_upstream_fa = True
311+
302312
if self.attn_backend not in {
303313
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
304314
_Backend.ROCM_AITER_FA, _Backend.IPEX
@@ -359,7 +369,10 @@ def forward(
359369
if self.attn_backend == _Backend.ROCM_AITER_FA:
360370
from aiter import flash_attn_varlen_func
361371
else:
362-
from flash_attn import flash_attn_varlen_func
372+
if self.use_upstream_fa:
373+
from flash_attn import flash_attn_varlen_func
374+
else:
375+
from vllm.vllm_flash_attn import flash_attn_varlen_func
363376

364377
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
365378

@@ -660,7 +673,12 @@ def __init__(
660673
prefix=f"{prefix}.merger",
661674
use_data_parallel=use_data_parallel,
662675
)
663-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
676+
self.attn_backend = get_vit_attn_backend(
677+
head_size=head_dim, dtype=torch.get_default_dtype())
678+
if self.attn_backend != _Backend.FLASH_ATTN and \
679+
check_upstream_fa_availability(
680+
torch.get_default_dtype()):
681+
self.attn_backend = _Backend.FLASH_ATTN
664682

665683
@property
666684
def dtype(self) -> torch.dtype:

0 commit comments

Comments
 (0)