11# SPDX-License-Identifier: Apache-2.0 
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
33"""Tests for v1 attention backends without GPUModelRunner dependency.""" 
4+ from  functools  import  partial 
5+ from  typing  import  Optional , Union 
46
57import  pytest 
68import  torch 
9+ from  torch .nn .attention .flex_attention  import  create_block_mask , flex_attention 
710
811from  tests .v1 .attention .utils  import  (BatchSpec , _Backend ,
912                                      create_common_attn_metadata ,
1013                                      create_standard_kv_cache_spec ,
1114                                      create_vllm_config ,
1215                                      get_attention_backend )
16+ from  vllm .config  import  ModelConfig 
17+ from  vllm .platforms  import  current_platform 
1318from  vllm .utils  import  STR_DTYPE_TO_TORCH_DTYPE , cdiv , is_torch_equal_or_newer 
1419from  vllm .v1 .attention .backends .utils  import  (CommonAttentionMetadata ,
1520                                              set_kv_cache_layout )
@@ -183,13 +188,19 @@ def __init__(self, device: torch.device):
183188        self ._v_scale_float  =  1.0 
184189
185190
186- def  run_attention_backend (backend : _Backend , kv_cache_spec : FullAttentionSpec ,
187-                           layer_names : list [str ], vllm_config ,
188-                           device : torch .device ,
189-                           common_attn_metadata : CommonAttentionMetadata ,
190-                           query : torch .Tensor , key : torch .Tensor ,
191-                           value : torch .Tensor ,
192-                           kv_cache : torch .Tensor ) ->  torch .Tensor :
191+ def  run_attention_backend (
192+     backend : _Backend ,
193+     kv_cache_spec : FullAttentionSpec ,
194+     layer_names : list [str ],
195+     vllm_config ,
196+     device : torch .device ,
197+     common_attn_metadata : CommonAttentionMetadata ,
198+     query : torch .Tensor ,
199+     key : torch .Tensor ,
200+     value : torch .Tensor ,
201+     kv_cache : torch .Tensor ,
202+     sliding_window : Optional [int ] =  None ,
203+ ) ->  torch .Tensor :
193204    """Run attention computation using the specified backend's AttentionImpl.""" 
194205
195206    # Handle special case for FLEX_ATTENTION_SLOW 
@@ -253,7 +264,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
253264        scale = scale ,
254265        num_kv_heads = num_kv_heads ,
255266        alibi_slopes = None ,
256-         sliding_window = None ,
267+         sliding_window = sliding_window ,
257268        kv_cache_dtype = "auto" ,
258269    )
259270
@@ -275,13 +286,16 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
275286    return  output 
276287
277288
278- @pytest .mark .parametrize ("batch_spec_name" , [ 
279-     "small_decode" , "small_prefill" , "mixed_small" , "medium_decode" , 
280-     "medium_prefill" , "mixed_medium" , "large_decode" , "large_prefill" , 
281-     "single_decode" , "single_prefill"  
282- ]) 
283- @pytest .mark .parametrize ("model" , ["meta-llama/Meta-Llama-3-8B" ]) 
284- def  test_backend_correctness (batch_spec_name : str , model : str ):
289+ def  _test_backend_correctness (
290+     batch_spec : BatchSpec ,
291+     model : str ,
292+     backend_to_test : list [Union [_Backend , str ]],
293+     mask_mod ,
294+     * ,
295+     block_size : int  =  16 ,
296+     atol : float  =  1e-2 ,
297+     rtol : float  =  1e-2 ,
298+ ):
285299    """ 
286300    Test that all backends produce similar outputs to a reference implementation 
287301    using torch.nn.functional.scaled_dot_product_attention. 
@@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
297311       simulated paged KV cache. 
298312    5. Comparing the vLLM backend's output to the ground-truth SDPA output. 
299313    """ 
300-     batch_spec   =   BATCH_SPECS [ batch_spec_name ] 
314+     current_platform . seed_everything ( 42 ) 
301315    vllm_config  =  create_vllm_config (model_name = model ,
302316                                     max_model_len = max (batch_spec .seq_lens ),
317+                                      block_size = block_size ,
303318                                     num_gpu_blocks = 8192 )
304319    device  =  torch .device ("cuda:0" )
305320
@@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
314329    num_kv_heads  =  vllm_config .model_config .get_num_kv_heads (
315330        vllm_config .parallel_config )
316331    head_size  =  vllm_config .model_config .get_head_size ()
332+     sliding_window  =  vllm_config .model_config .get_sliding_window ()
317333    dtype  =  _convert_dtype_to_torch (vllm_config .model_config .dtype )
318334    block_size  =  vllm_config .cache_config .block_size 
319335    scale  =  1.0  /  (head_size ** 0.5 )
@@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
361377        # Create causal mask: query token i attends to positions 0 to 
362378        #  (context_len + i) 
363379        kv_len  =  s_len 
364-         offset  =  context_len 
365-         attn_mask  =  torch .full ((q_len , kv_len ),
366-                                float ('-inf' ),
367-                                device = device ,
368-                                dtype = dtype )
369-         for  i  in  range (q_len ):
370-             attn_mask [i , :offset  +  i  +  1 ] =  0.0 
371- 
372-         sdpa_out_i  =  torch .nn .functional .scaled_dot_product_attention (
373-             q_sdpa_in ,
374-             k_sdpa_in ,
375-             v_sdpa_in ,
376-             attn_mask = attn_mask ,
377-             scale = scale ,
378-             enable_gqa = True )
379-         # Convert back to (L, H, D) 
380+ 
381+         final_mask_mod  =  partial (mask_mod , context_len = context_len )
382+         block_mask  =  create_block_mask (final_mask_mod ,
383+                                        B = None ,
384+                                        H = None ,
385+                                        Q_LEN = q_len ,
386+                                        KV_LEN = kv_len ,
387+                                        device = device )
388+         sdpa_out_i  =  flex_attention (q_sdpa_in ,
389+                                     k_sdpa_in ,
390+                                     v_sdpa_in ,
391+                                     block_mask = block_mask ,
392+                                     scale = scale ,
393+                                     enable_gqa = True )
394+ 
380395        all_sdpa_outputs .append (sdpa_out_i .transpose (1 , 2 ).squeeze (0 ))
381396
382397        # Inputs for vLLM backends are just the new tokens 
@@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
412427    # 4. Run vLLM backends and compare 
413428    # Note: flex_attention has known Triton kernel compatibility issues 
414429    # with test infrastructures 
415-     for  backend_name  in  BACKENDS_TO_TEST :
430+     for  backend_name  in  backend_to_test :
416431        # FlashAttentionm + FlexAttention: 
417432        #   [2, num_blocks, block_size, num_kv_heads, head_size] 
418433        # FlashInfer: 
@@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
427442                2 , 3 ).contiguous ().transpose (2 , 3 )
428443            set_kv_cache_layout ("HND" )
429444
430-         backend_output  =  run_attention_backend (backend_name , kv_cache_spec ,
431-                                                ["placeholder" ], vllm_config ,
432-                                                device , common_attn_metadata ,
433-                                                query_vllm , key_vllm ,
434-                                                value_vllm ,
435-                                                kv_cache_for_backend )
445+         backend_output  =  run_attention_backend (
446+             backend_name ,
447+             kv_cache_spec ,
448+             ["placeholder" ],
449+             vllm_config ,
450+             device ,
451+             common_attn_metadata ,
452+             query_vllm ,
453+             key_vllm ,
454+             value_vllm ,
455+             kv_cache_for_backend ,
456+             sliding_window = sliding_window ,
457+         )
436458
437459        # Check shape and dtype consistency 
438460        assert  backend_output .shape  ==  sdpa_output .shape , (
@@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
446468            f"[{ backend_name }  )
447469
448470        # Check numerical similarity 
449-         rtol  =  1e-2 
450-         atol  =  5e-3 
451- 
452-         max_diff  =  torch .max (torch .abs (backend_output  -  sdpa_output )).item ()
453-         max_rel_diff  =  torch .max (
454-             torch .abs (backend_output  -  sdpa_output ) / 
455-             torch .abs (sdpa_output )).item ()
456-         all_close  =  torch .allclose (backend_output ,
471+         def  error_msg (msg : str , backend_name : str ):
472+             return  (f"[{ backend_name }  
473+                     f"{ msg }  )
474+ 
475+         torch .testing .assert_close (backend_output ,
457476                                   sdpa_output ,
458477                                   rtol = rtol ,
459-                                    atol = atol )
478+                                    atol = atol ,
479+                                    msg = partial (error_msg ,
480+                                                backend_name = backend_name ))
460481
461-         assert  all_close , (
462-             f"[{ backend_name }  
463-             f"Max diff: { max_diff :.6f} { max_rel_diff :.6f}  )
482+ 
483+ @pytest .mark .parametrize ("batch_spec_name" , [ 
484+     "small_decode" , "small_prefill" , "mixed_small" , "medium_decode" , 
485+     "medium_prefill" , "mixed_medium" , "large_decode" , "large_prefill" , 
486+     "single_decode" , "single_prefill"  
487+ ]) 
488+ @pytest .mark .parametrize ("model" , ["meta-llama/Meta-Llama-3-8B" ]) 
489+ def  test_causal_backend_correctness (batch_spec_name : str , model : str ):
490+     """Test backend's correctness with causal attention.""" 
491+ 
492+     def  causal_mask_mod (
493+         b : torch .Tensor ,
494+         h : torch .Tensor ,
495+         q_idx : torch .Tensor ,
496+         kv_idx : torch .Tensor ,
497+         * ,
498+         context_len : int ,
499+     ):
500+         return  (q_idx  +  context_len ) >=  kv_idx 
501+ 
502+     batch_spec  =  BATCH_SPECS [batch_spec_name ]
503+     LARGE_BLOCK_BACKENDS  =  ([_Backend .FLEX_ATTENTION ]
504+                             if  is_torch_equal_or_newer ("2.9.0.dev0" ) else  [])
505+     SMALL_BLOCK_BACKENDS  =  [
506+         x  for  x  in  BACKENDS_TO_TEST  if  x  not  in LARGE_BLOCK_BACKENDS 
507+     ]
508+     _test_backend_correctness (batch_spec , model , SMALL_BLOCK_BACKENDS ,
509+                               causal_mask_mod )
510+ 
511+     # Fast FlexAttention needs to run with block_size=128 
512+     if  LARGE_BLOCK_BACKENDS :
513+         _test_backend_correctness (batch_spec ,
514+                                   model ,
515+                                   LARGE_BLOCK_BACKENDS ,
516+                                   causal_mask_mod ,
517+                                   block_size = 128 )
518+ 
519+ 
520+ SLIDING_WINDOW_BACKENDS_TO_TEST  =  [
521+     _Backend .FLASH_ATTN_VLLM_V1 , _Backend .FLEX_ATTENTION ,
522+     _Backend .TRITON_ATTN_VLLM_V1 , "FLEX_ATTENTION_SLOW" 
523+ ]
524+ 
525+ 
526+ @pytest .mark .parametrize ("batch_spec_name" , [ 
527+     "small_decode" , "small_prefill" , "mixed_medium" , "large_decode" , 
528+     "large_prefill"  
529+ ]) 
530+ @pytest .mark .parametrize ("model" , ["microsoft/Phi-tiny-MoE-instruct" ]) 
531+ def  test_sliding_window_backend_correctness (batch_spec_name : str , model : str ):
532+     """Test backend's correctness with sliding window attention.""" 
533+ 
534+     def  sliding_window_mask_mod (
535+         b : torch .Tensor ,
536+         h : torch .Tensor ,
537+         q_idx : torch .Tensor ,
538+         kv_idx : torch .Tensor ,
539+         * ,
540+         context_len : int ,
541+         sliding_window : int ,
542+     ):
543+         causal_mask  =  q_idx  +  context_len  >=  kv_idx 
544+         window_mask  =  q_idx  +  context_len  -  kv_idx  <  sliding_window 
545+         return  causal_mask  &  window_mask 
546+ 
547+     batch_spec  =  BATCH_SPECS [batch_spec_name ]
548+     model_config  =  ModelConfig (model = model ,
549+                                max_model_len = max (batch_spec .seq_lens ))
550+     sliding_window  =  model_config .get_sliding_window ()
551+     sliding_window_mask_mod_fn  =  partial (sliding_window_mask_mod ,
552+                                          sliding_window = sliding_window )
553+ 
554+     LARGE_BLOCK_BACKENDS  =  ([_Backend .FLEX_ATTENTION ]
555+                             if  is_torch_equal_or_newer ("2.9.0.dev0" ) else  [])
556+     SMALL_BLOCK_BACKENDS  =  [
557+         x  for  x  in  SLIDING_WINDOW_BACKENDS_TO_TEST 
558+         if  x  not  in LARGE_BLOCK_BACKENDS 
559+     ]
560+     _test_backend_correctness (batch_spec , model , SMALL_BLOCK_BACKENDS ,
561+                               sliding_window_mask_mod_fn )
562+ 
563+     # Fast FlexAttention needs to run with block_size=128 
564+     if  LARGE_BLOCK_BACKENDS :
565+         _test_backend_correctness (batch_spec ,
566+                                   model ,
567+                                   LARGE_BLOCK_BACKENDS ,
568+                                   sliding_window_mask_mod_fn ,
569+                                   block_size = 128 )
0 commit comments