@@ -65,7 +65,7 @@ def check_xformers_availability():
6565 return USE_XFORMERS_OPS
6666
6767
68- def check_upstream_fa_availability (dtype : torch .dtype ):
68+ def check_upstream_fa_availability (dtype : torch .dtype ) -> bool :
6969 if (
7070 dtype in (torch .float16 , torch .bfloat16 )
7171 and current_platform .is_cuda ()
@@ -80,26 +80,40 @@ def check_upstream_fa_availability(dtype: torch.dtype):
8080 return find_spec ("flash_attn" ) is not None
8181 return False
8282
83+ def is_fa_backend (backend : _Backend ) -> bool :
84+ return backend in {_Backend .FLASH_ATTN , _Backend .ROCM_AITER_FA }
8385
8486def maybe_get_vit_flash_attn_backend (
85- attn_backend : _Backend , use_upstream_fa : bool
86- ) -> tuple [_Backend , Callable ]:
87- if (
88- attn_backend != _Backend .FLASH_ATTN
89- and attn_backend != _Backend .ROCM_AITER_FA
90- and check_upstream_fa_availability (torch .get_default_dtype ())
91- ):
87+ attn_backend : _Backend ,
88+ try_switch_to_fa : bool = False ,
89+ force_upstream_fa : bool = False ) -> tuple [_Backend , Callable ]:
90+
91+ upstream_fa_available = check_upstream_fa_availability (torch .get_default_dtype ())
92+ if force_upstream_fa :
93+ assert upstream_fa_available , \
94+ "Upstream FlashAttn is not available."
95+
96+ use_upstream_fa = force_upstream_fa
97+ if try_switch_to_fa and not is_fa_backend (attn_backend ) and upstream_fa_available :
9298 attn_backend = _Backend .FLASH_ATTN
99+ logger .info_once ("maybe_get_vit_flash_attn_backend: " , \
100+ "auto-switching to upstream FlashAttn." )
93101 use_upstream_fa = True
94-
95- if current_platform .is_rocm () and attn_backend == _Backend .FLASH_ATTN :
102+
103+ if current_platform .is_rocm () and \
104+ attn_backend == _Backend .FLASH_ATTN :
105+ # Always upstream on ROCM.
106+ logger .info_once ("maybe_get_vit_flash_attn_backend: " , \
107+ "ROCM backend is now FLASH_ATTN, forcing upstream FA." )
96108 use_upstream_fa = True
97-
98- if attn_backend in { _Backend . FLASH_ATTN , _Backend . ROCM_AITER_FA } :
109+
110+ if is_fa_backend ( attn_backend ) :
99111 if attn_backend == _Backend .ROCM_AITER_FA :
100112 from aiter import flash_attn_varlen_func
101113 else :
102114 if use_upstream_fa :
115+ assert upstream_fa_available , \
116+ "Upstream FlashAttn is not available."
103117 from flash_attn import flash_attn_varlen_func
104118 else :
105119 from vllm .vllm_flash_attn import flash_attn_varlen_func
@@ -108,7 +122,6 @@ def maybe_get_vit_flash_attn_backend(
108122
109123 return attn_backend , flash_attn_varlen_func
110124
111-
112125class Attention (nn .Module , AttentionLayerBase ):
113126 """Attention layer.
114127
@@ -428,11 +441,6 @@ def __init__(
428441 # Determine the attention backend
429442 backend = get_vit_attn_backend (head_size = head_size , dtype = dtype )
430443
431- # Some auto-selected backends can be upgraded
432- # to upstream flash attention if available.
433- # If vllm native fa is selected, we use it directly.
434- use_upstream_fa = False
435-
436444 if current_platform .is_xpu ():
437445 # currently, only torch_sdpa is supported on xpu
438446 self .attn_backend = _Backend .TORCH_SDPA
@@ -450,30 +458,19 @@ def __init__(
450458 else _Backend .TORCH_SDPA
451459 )
452460
453- self .attn_backend , self ._flash_attn_varlen_func = (
454- maybe_get_vit_flash_attn_backend (
461+ self .attn_backend , self ._flash_attn_varlen_func \
462+ = maybe_get_vit_flash_attn_backend (
455463 self .attn_backend ,
456- use_upstream_fa ,
464+ try_switch_to_fa = False ,
457465 )
458- )
459466
460467 if self .attn_backend == _Backend .XFORMERS and not check_xformers_availability ():
461468 self .attn_backend = _Backend .TORCH_SDPA
462469
463- self .is_flash_attn_backend = self .attn_backend in {
464- _Backend .FLASH_ATTN ,
465- _Backend .ROCM_AITER_FA ,
466- }
467-
468- # this condition is just to make sure that the
469- # use_upstream_fa in the log is correct
470- if current_platform .is_rocm () and self .attn_backend == _Backend .FLASH_ATTN :
471- use_upstream_fa = True
470+ self .is_flash_attn_backend = is_fa_backend (self .attn_backend )
472471
473472 logger .info_once (
474- f"MultiHeadAttention attn_backend: { self .attn_backend } , "
475- f"use_upstream_fa: { use_upstream_fa } "
476- )
473+ f"MultiHeadAttention attn_backend: { self .attn_backend } " )
477474
478475 def forward (
479476 self ,
0 commit comments