2323from vllm .model_executor .layers .quantization .base_config import (
2424 QuantizationConfig )
2525from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
26+ from vllm .model_executor .models .vision import get_vit_attn_backend
2627from vllm .platforms import _Backend , current_platform
2728from vllm .utils import direct_register_custom_op
2829
@@ -64,6 +65,14 @@ def check_xformers_availability():
6465 return USE_XFORMERS_OPS
6566
6667
68+ def check_upstream_fa_availability (dtype : torch .dtype ):
69+ if dtype in (torch .float16 , torch .bfloat16 ) and current_platform .is_cuda (
70+ ) and current_platform .has_device_capability (80 ):
71+ from transformers .utils import is_flash_attn_2_available
72+ return is_flash_attn_2_available ()
73+ return False
74+
75+
6776class Attention (nn .Module , AttentionLayerBase ):
6877 """Attention layer.
6978
@@ -358,29 +367,55 @@ def __init__(
358367 f"divisible by num_kv_heads ({ self .num_kv_heads } )"
359368 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
360369
370+ # During model initialization, the default dtype is set as the model
371+ # weight and activation dtype.
361372 dtype = torch .get_default_dtype ()
362- attn_backend = get_attn_backend (head_size ,
363- dtype ,
364- kv_cache_dtype = None ,
365- block_size = 16 ,
366- is_attention_free = False )
367- backend = backend_name_to_enum (attn_backend .get_name ())
373+
374+ # Determine the attention backend
375+ backend = get_vit_attn_backend (head_size = head_size , dtype = dtype )
376+
377+ # Some auto-selected backends can be upgraded
378+ # to upstream flash attention if available.
379+ # If vllm native fa is selected, we use it directly.
380+ use_upstream_fa = False
381+ if backend != _Backend .FLASH_ATTN and check_upstream_fa_availability (
382+ dtype ):
383+ backend = _Backend .FLASH_ATTN
384+ use_upstream_fa = True
385+
368386 if current_platform .is_rocm ():
369387 # currently, only torch_sdpa is supported on rocm
370388 self .attn_backend = _Backend .TORCH_SDPA
371389 else :
390+
372391 self .attn_backend = backend if backend in {
373392 _Backend .TORCH_SDPA ,
374393 _Backend .TORCH_SDPA_VLLM_V1 ,
375394 _Backend .XFORMERS ,
376395 _Backend .PALLAS_VLLM_V1 ,
377396 _Backend .ROCM_AITER_FA ,
378- } else current_platform .get_vit_attn_backend ()
397+ _Backend .FLASH_ATTN ,
398+ _Backend .FLASH_ATTN_VLLM_V1 ,
399+ } else _Backend .TORCH_SDPA
379400
380401 if (self .attn_backend == _Backend .XFORMERS
381402 and not check_xformers_availability ()):
382403 self .attn_backend = _Backend .TORCH_SDPA
383404
405+ if self .attn_backend in {
406+ _Backend .FLASH_ATTN , _Backend .FLASH_ATTN_VLLM_V1
407+ }:
408+ if use_upstream_fa :
409+ from flash_attn import flash_attn_varlen_func
410+ self ._flash_attn_varlen_func = flash_attn_varlen_func
411+ else :
412+ from vllm .vllm_flash_attn import flash_attn_varlen_func
413+ self ._flash_attn_varlen_func = flash_attn_varlen_func
414+
415+ logger .info_once (
416+ f"MultiHeadAttention attn_backend: { self .attn_backend } , "
417+ f"use_upstream_fa: { use_upstream_fa } " )
418+
384419 def forward (
385420 self ,
386421 query : torch .Tensor ,
@@ -401,7 +436,31 @@ def forward(
401436 key = torch .repeat_interleave (key , num_repeat , dim = 2 )
402437 value = torch .repeat_interleave (value , num_repeat , dim = 2 )
403438
404- if self .attn_backend == _Backend .XFORMERS :
439+ if self .attn_backend in {
440+ _Backend .FLASH_ATTN ,
441+ _Backend .FLASH_ATTN_VLLM_V1 ,
442+ }:
443+
444+ cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
445+ step = q_len ,
446+ dtype = torch .int32 ,
447+ device = query .device )
448+ cu_seqlens_k = torch .arange (0 , (bsz + 1 ) * kv_len ,
449+ step = kv_len ,
450+ dtype = torch .int32 ,
451+ device = key .device )
452+
453+ out = self ._flash_attn_varlen_func (
454+ query .flatten (0 , 1 ),
455+ key .flatten (0 , 1 ),
456+ value .flatten (0 , 1 ),
457+ cu_seqlens_q = cu_seqlens_q ,
458+ cu_seqlens_k = cu_seqlens_k ,
459+ max_seqlen_q = q_len ,
460+ max_seqlen_k = kv_len ,
461+ softmax_scale = self .scale ,
462+ )
463+ elif self .attn_backend == _Backend .XFORMERS :
405464 from xformers import ops as xops
406465
407466 out = xops .memory_efficient_attention_forward (query ,
0 commit comments