11
11
AscendMLAImpl , AscendMLAMetadata ,
12
12
AscendMLAMetadataBuilder ,
13
13
AscendMLAPrefillMetadata )
14
- from vllm_ascend .torchair .utils import TorchairCommonAttentionMetadata
15
14
16
15
17
16
class TestAscendMLABackend (TestBase ):
@@ -188,8 +187,6 @@ def test_ascend_mla_metadata_builder_default(self):
188
187
mock_device = 'cpu'
189
188
190
189
ascend_config = MagicMock ()
191
- ascend_config .torchair_graph_config = MagicMock ()
192
- ascend_config .torchair_graph_config .enabled = True
193
190
with patch ("vllm_ascend.attention.mla_v1.get_ascend_config" ,
194
191
return_value = ascend_config ):
195
192
builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
@@ -199,44 +196,9 @@ def test_ascend_mla_metadata_builder_default(self):
199
196
self .assertEqual (
200
197
builder .chunked_prefill_enabled ,
201
198
mock_vllm_config .scheduler_config .chunked_prefill_enabled )
202
- self .assertEqual (builder .torchair_graph_enabled , True )
203
199
204
- @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
205
- def test_reorder_batch_with_torchair_graph (self , ascend_config ):
206
- mock_vllm_config = MagicMock ()
207
- mock_vllm_config .model_config .max_model_len = 1024
208
- mock_vllm_config .cache_config .block_size = 16
209
- mock_vllm_config .scheduler_config .max_num_seqs = 4
210
- mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
211
- mock_device = 'cpu'
212
- ascend_config .torchair_graph_config = MagicMock ()
213
- ascend_config .torchair_graph_config .enabled = True
214
-
215
- builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
216
-
217
- input_batch = MagicMock ()
218
- input_batch .req_ids = [0 , 1 , 2 , 3 ]
219
-
220
- scheduler_output = MagicMock ()
221
- scheduler_output .num_scheduled_tokens = {0 : 2 , 1 : 1 , 2 : 3 , 3 : 1 }
222
- scheduler_output .scheduled_spec_decode_tokens = {
223
- 0 : [1 ],
224
- 1 : [],
225
- 2 : [1 , 1 ],
226
- 3 : []
227
- }
228
-
229
- input_batch .swap_states = MagicMock ()
230
-
231
- modified = builder .reorder_batch (input_batch , scheduler_output )
232
-
233
- self .assertFalse (modified )
234
- input_batch .swap_states .assert_not_called ()
235
-
236
- def test_reorder_batch_without_torchair_graph (self ):
200
+ def test_reorder_batch (self ):
237
201
ascend_config = MagicMock ()
238
- ascend_config .torchair_graph_config = MagicMock ()
239
- ascend_config .torchair_graph_config .enabled = False
240
202
241
203
mock_vllm_config = MagicMock ()
242
204
mock_vllm_config .model_config .max_model_len = 1024
@@ -268,128 +230,6 @@ def test_reorder_batch_without_torchair_graph(self):
268
230
self .assertTrue (modified )
269
231
input_batch .swap_states .assert_called_once_with (1 , 2 )
270
232
271
- @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
272
- def test_get_graph_runner_block_tables_normal (self , mock_ascend_config ):
273
- ascend_config = MagicMock ()
274
- mock_ascend_config .return_value = ascend_config
275
- ascend_config .torchair_graph_config .enabled = False
276
- mock_vllm_config = MagicMock ()
277
- mock_vllm_config .model_config .max_model_len = 1024
278
- mock_vllm_config .cache_config .block_size = 16
279
- mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
280
- mock_device = 'cpu'
281
-
282
- builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
283
- block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
284
-
285
- result = builder ._get_graph_runner_block_tables (3 , block_tables )
286
- self .assertEqual (result .shape [0 ], 3 )
287
- self .assertEqual (result .shape [1 ], 64 )
288
- self .assertTrue (torch .equal (result [:, :10 ], block_tables ))
289
-
290
- @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
291
- def test_get_graph_runner_block_tables_truncated (self , mock_ascend_config ):
292
- ascend_config = MagicMock ()
293
- mock_ascend_config .return_value = ascend_config
294
- ascend_config .torchair_graph_config .enabled = False
295
- mock_vllm_config = MagicMock ()
296
- mock_vllm_config .model_config .max_model_len = 64
297
- mock_vllm_config .cache_config .block_size = 16
298
- mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
299
- mock_device = 'cpu'
300
-
301
- builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
302
- block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
303
-
304
- result = builder ._get_graph_runner_block_tables (3 , block_tables )
305
- self .assertEqual (result .shape [0 ], 3 )
306
- self .assertEqual (result .shape [1 ], 4 )
307
- self .assertTrue (torch .equal (result , block_tables [:, :4 ]))
308
-
309
- @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
310
- def test_get_graph_runner_block_tables_from_numpy (self ,
311
- mock_ascend_config ):
312
- ascend_config = MagicMock ()
313
- mock_ascend_config .return_value = ascend_config
314
- ascend_config .torchair_graph_config .enabled = False
315
- mock_vllm_config = MagicMock ()
316
- mock_vllm_config .model_config .max_model_len = 1024
317
- mock_vllm_config .cache_config .block_size = 16
318
- mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
319
- mock_device = 'cpu'
320
-
321
- builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
322
-
323
- block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
324
-
325
- result = builder ._get_graph_runner_block_tables (3 , block_tables )
326
-
327
- self .assertEqual (result .shape [0 ], 3 )
328
- self .assertEqual (result .shape [1 ], 64 )
329
- self .assertTrue (torch .equal (result [:, :10 ], block_tables ))
330
-
331
- @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
332
- def test_build_dummy (self , mock_ascend_config ):
333
- ascend_config = MagicMock ()
334
- mock_ascend_config .return_value = ascend_config
335
- ascend_config .torchair_graph_config .enabled = False
336
-
337
- mock_vllm_config = MagicMock ()
338
- mock_vllm_config .model_config .max_model_len = 1024
339
- mock_vllm_config .cache_config .block_size = 16
340
- mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
341
- mock_vllm_config .get_head_size .return_value = 64
342
- mock_vllm_config .model_config .dtype = torch .float16
343
- mock_device = 'cpu'
344
-
345
- builder = AscendMLAMetadataBuilder (mock_vllm_config ,
346
- mock_device ,
347
- metadata_cls = AscendMLAMetadata )
348
- builder .rope_dim = 64
349
-
350
- with patch .object (builder ,
351
- "_get_graph_runner_block_tables" ,
352
- side_effect = lambda x , y : y ):
353
- common_attn_metadata = TorchairCommonAttentionMetadata (
354
- num_reqs = 3 ,
355
- num_actual_tokens = 3 ,
356
- decode_token_per_req = 1 ,
357
- actual_seq_lengths_q = [0 , 1 , 2 ],
358
- attn_mask = torch .zeros ((1 , 1 ), dtype = torch .bool ),
359
- spec_attn_mask = torch .zeros ((1 , 1 ), dtype = torch .bool ),
360
- )
361
- metadata = builder .build_torchair_graph_dummy (common_attn_metadata )
362
-
363
- sin_golden = torch .ones (3 ,
364
- 1 ,
365
- 1 ,
366
- 64 ,
367
- dtype = torch .float16 ,
368
- device = mock_device )
369
- cos_golden = torch .ones (3 ,
370
- 1 ,
371
- 1 ,
372
- 64 ,
373
- dtype = torch .float16 ,
374
- device = mock_device )
375
-
376
- self .assertIsInstance (metadata , AscendMLAMetadata )
377
- self .assertEqual (metadata .num_input_tokens , 3 )
378
- self .assertEqual (metadata .num_actual_tokens , 3 )
379
- self .assertEqual (metadata .num_decodes , 1 )
380
- self .assertEqual (metadata .num_decode_tokens , 1 )
381
- self .assertEqual (metadata .num_prefills , 0 )
382
- self .assertEqual (metadata .attn_state , AscendAttentionState .DecodeOnly )
383
- self .assertIsNone (metadata .prefill )
384
- self .assertIsInstance (metadata .decode , AscendMLADecodeMetadata )
385
- self .assertEqual (metadata .block_tables .shape [0 ], 3 )
386
- self .assertEqual (metadata .block_tables .shape [1 ], 64 )
387
- self .assertEqual (metadata .seq_lens .shape [0 ], 3 )
388
- self .assertEqual (metadata .slot_mapping .shape [0 ], 3 )
389
- self .assertEqual (metadata .query_start_loc .shape [0 ], 3 )
390
- assert torch .equal (sin_golden , metadata .decode .sin )
391
- assert torch .equal (cos_golden , metadata .decode .cos )
392
-
393
233
394
234
class TestAscendMLAImpl (TestBase ):
395
235
@@ -401,8 +241,6 @@ class TestAscendMLAImpl(TestBase):
401
241
@patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
402
242
def setUp (self , ascend_config , vllm_config , mock_get_tp_size , mock_tp ):
403
243
mock_tp .world_size = 2
404
- ascend_config .torchair_graph_config .enabled = True
405
- ascend_config .torchair_graph_config .enable_kv_nz = False
406
244
speculative_config = MagicMock ()
407
245
speculative_config .num_speculative_tokens = 4
408
246
vllm_config .speculative_config = speculative_config
@@ -464,7 +302,6 @@ def test_init(self):
464
302
self .assertIsNotNone (self .impl .kv_a_layernorm )
465
303
self .assertEqual (self .impl .num_queries_per_kv , 32 )
466
304
self .assertEqual (self .impl .tp_size , 2 )
467
- self .assertTrue (self .impl .torchair_graph_enabled )
468
305
469
306
def test_v_up_proj_and_o_proj (self ):
470
307
batch_size = 4
@@ -580,102 +417,10 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
580
417
self .assertEqual (out .shape , prefix_out .shape )
581
418
self .assertEqual (lse .shape , prefix_lse .shape )
582
419
583
- @patch ("torch_npu.npu_kv_rmsnorm_rope_cache" )
584
- def test_exec_kv (self , mock_kv_cache ):
585
- batch_size = 2
586
- hidden = torch .randn (batch_size , 128 )
587
- cos = torch .randn (batch_size , 32 )
588
- sin = torch .randn (batch_size , 32 )
589
- kv_cache = (torch .randn (
590
- 4 , 8 , self .impl .kv_lora_rank + self .impl .qk_rope_head_dim ),
591
- torch .randn (
592
- 4 , 8 ,
593
- self .impl .kv_lora_rank + self .impl .qk_rope_head_dim ))
594
- slots = torch .arange (batch_size , dtype = torch .long )
595
-
596
- proj_out = torch .randn (
597
- batch_size , self .impl .num_kv_heads , 1 ,
598
- self .impl .kv_lora_rank + self .impl .qk_rope_head_dim )
599
- self .impl .kv_a_proj_with_mqa .return_value = (proj_out , )
600
-
601
- mock_kv_cache .return_value = (torch .randn (batch_size ,
602
- self .impl .num_kv_heads , 1 ,
603
- self .impl .qk_rope_head_dim ),
604
- torch .randn (batch_size ,
605
- self .impl .num_kv_heads , 1 ,
606
- self .impl .kv_lora_rank ),
607
- None , None )
608
-
609
- k_pe , k_nope , kv = self .impl .exec_kv (hidden , cos , sin , kv_cache , slots )
610
-
611
- self .impl .kv_a_proj_with_mqa .assert_called_once_with (hidden )
612
- mock_kv_cache .assert_called_once ()
613
- self .assertEqual (k_pe .shape , (batch_size , self .impl .num_kv_heads , 1 ,
614
- self .impl .qk_rope_head_dim ))
615
- self .assertEqual (
616
- k_nope .shape ,
617
- (batch_size , self .impl .num_kv_heads , 1 , self .impl .kv_lora_rank ))
618
- self .assertEqual (kv .shape ,
619
- (batch_size , self .impl .num_kv_heads , 1 ,
620
- self .impl .kv_lora_rank + self .impl .qk_rope_head_dim ))
621
-
622
- @patch ("torch_npu.npu_kv_rmsnorm_rope_cache" )
623
- def test_exec_kv_prefill (self , mock_kv ):
624
- B , N , S , H = 2 , self .impl .num_kv_heads , 1 , 128
625
- hidden_states = torch .randn (B , N , S , H )
626
- cos = torch .randn (B , S , 32 )
627
- sin = torch .randn (B , S , 32 )
628
- kv_cache = (
629
- torch .randn (100 , 8 ,
630
- self .impl .kv_lora_rank + self .impl .qk_rope_head_dim ),
631
- torch .randn (100 , 8 ,
632
- self .impl .kv_lora_rank + self .impl .qk_rope_head_dim ),
633
- )
634
-
635
- slots = torch .arange (B * S , dtype = torch .long )
636
-
637
- proj_out = torch .randn (
638
- B , N , S , self .impl .kv_lora_rank + self .impl .qk_rope_head_dim )
639
- self .impl .kv_a_proj_with_mqa .return_value = (proj_out , )
640
-
641
- mock_kv .return_value = (None , None ,
642
- torch .randn (B , self .impl .num_kv_heads , S ,
643
- self .impl .qk_rope_head_dim ),
644
- torch .randn (B , self .impl .num_kv_heads , S ,
645
- self .impl .kv_lora_rank ))
646
-
647
- k_pe , k_nope = self .impl .exec_kv_prefill (hidden_states , cos , sin ,
648
- kv_cache , slots )
649
-
650
- self .impl .kv_a_proj_with_mqa .assert_called_once_with (hidden_states )
651
- mock_kv .assert_called_once ()
652
-
653
- self .assertEqual (
654
- k_pe .shape ,
655
- (B , self .impl .num_kv_heads , S , self .impl .qk_rope_head_dim ))
656
- self .assertEqual (
657
- k_nope .shape ,
658
- (B , self .impl .num_kv_heads , S , self .impl .kv_lora_rank ))
659
-
660
- @patch ("torch_npu.npu_interleave_rope" )
661
- def test_rope_single (self , mock_rope ):
662
- B , N , D = 2 , 16 , 1024
663
- x = torch .randn (B , N , D )
664
- cos = torch .randn (B , N , 1 , D )
665
- sin = torch .randn (B , N , 1 , D )
666
- mock_rope .return_value = x .view (B , N , 1 , D )
667
- result = self .impl .rope_single (x , cos , sin )
668
- self .assertEqual (result .shape [0 ], B )
669
- self .assertEqual (result .shape [1 ], N )
670
- self .assertEqual (result .shape [2 ], D )
671
- mock_rope .assert_called_once ()
672
-
673
420
@patch ("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj" )
674
421
@patch ("torch_npu._npu_paged_attention_mla" )
675
422
def test_forward_decode_without_graph (self , mock_page_attention_mla ,
676
423
mock_up_proj ):
677
- self .impl .running_in_graph = False
678
- self .impl .running_chunkprefilll_with_torchair = False
679
424
num_tokens = 100
680
425
num_blocks = 256
681
426
block_size = 4
@@ -706,9 +451,6 @@ def test_forward_decode_without_graph(self, mock_page_attention_mla,
706
451
@patch ("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill" )
707
452
@patch ("torch_npu._npu_reshape_and_cache" )
708
453
def test_forward_without_graph (self , _ , mock_forward_prefill ):
709
- self .impl .running_in_graph = False
710
- self .impl .torchair_graph_enabled = False
711
-
712
454
num_tokens = 100
713
455
num_blocks = 256
714
456
block_size = 4
0 commit comments