@@ -210,22 +210,19 @@ def __init__(
210210 self .scale = scale
211211 self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212212
213- assert self .num_heads % self .num_kv_heads == 0
214- self .num_queries_per_kv = self .num_heads // self .num_kv_heads
215-
216213 dtype = torch .get_default_dtype ()
217214 attn_backend = get_attn_backend (head_size ,
218215 dtype ,
219216 kv_cache_dtype = None ,
220217 block_size = 16 ,
221218 is_attention_free = False )
222219 backend = backend_name_to_enum (attn_backend .get_name ())
220+ if backend in {_Backend .FLASH_ATTN , _Backend .FLASH_ATTN_VLLM_V1 }:
221+ backend = _Backend .XFORMERS
223222
224223 self .attn_backend = backend if backend in {
225224 _Backend .TORCH_SDPA ,
226225 _Backend .XFORMERS ,
227- _Backend .FLASH_ATTN ,
228- _Backend .FLASH_ATTN_VLLM_V1 ,
229226 } else _Backend .TORCH_SDPA
230227
231228 def forward (
@@ -235,45 +232,15 @@ def forward(
235232 value : torch .Tensor ,
236233 ) -> torch .Tensor :
237234 """Input shape: batch_size x seq_len x hidden_size"""
235+ # TODO(Isotr0py): Use existing backend implementations and support FA3
238236 bsz , q_len , _ = query .size ()
239237 kv_len = key .size (1 )
240238
241239 query = query .view (bsz , q_len , self .num_heads , self .head_size )
242240 key = key .view (bsz , kv_len , self .num_kv_heads , self .head_size )
243241 value = value .view (bsz , kv_len , self .num_kv_heads , self .head_size )
244242
245- if (num_repeat := self .num_queries_per_kv ) > 1 :
246- # Handle MQA and GQA
247- key = torch .repeat_interleave (key , num_repeat , dim = 2 )
248- value = torch .repeat_interleave (value , num_repeat , dim = 2 )
249-
250- if self .attn_backend in {
251- _Backend .FLASH_ATTN ,
252- _Backend .FLASH_ATTN_VLLM_V1 ,
253- }:
254- from vllm .vllm_flash_attn import flash_attn_varlen_func
255-
256- cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
257- step = q_len ,
258- dtype = torch .int32 ,
259- device = query .device )
260- cu_seqlens_k = torch .arange (0 , (bsz + 1 ) * kv_len ,
261- step = kv_len ,
262- dtype = torch .int32 ,
263- device = key .device )
264-
265- out = flash_attn_varlen_func (
266- query .flatten (0 , 1 ),
267- key .flatten (0 , 1 ),
268- value .flatten (0 , 1 ),
269- cu_seqlens_q = cu_seqlens_q ,
270- cu_seqlens_k = cu_seqlens_k ,
271- max_seqlen_q = q_len ,
272- max_seqlen_k = kv_len ,
273- softmax_scale = self .scale ,
274- )
275- out = out .reshape (bsz , q_len , - 1 )
276- elif self .attn_backend == _Backend .XFORMERS :
243+ if self .attn_backend == _Backend .XFORMERS :
277244 from xformers import ops as xops
278245
279246 out = xops .memory_efficient_attention_forward (query ,
0 commit comments