Skip to content

Commit 2d4e437

Browse files
committed
[2/N][refactor] torchair deepseek mla backend refactor
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 3fb80ee commit 2d4e437

File tree

7 files changed

+2195
-750
lines changed

7 files changed

+2195
-750
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 1 addition & 259 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
AscendMLAImpl, AscendMLAMetadata,
1212
AscendMLAMetadataBuilder,
1313
AscendMLAPrefillMetadata)
14-
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
1514

1615

1716
class TestAscendMLABackend(TestBase):
@@ -188,8 +187,6 @@ def test_ascend_mla_metadata_builder_default(self):
188187
mock_device = 'cpu'
189188

190189
ascend_config = MagicMock()
191-
ascend_config.torchair_graph_config = MagicMock()
192-
ascend_config.torchair_graph_config.enabled = True
193190
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
194191
return_value=ascend_config):
195192
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
@@ -199,44 +196,9 @@ def test_ascend_mla_metadata_builder_default(self):
199196
self.assertEqual(
200197
builder.chunked_prefill_enabled,
201198
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
202-
self.assertEqual(builder.torchair_graph_enabled, True)
203199

204-
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
205-
def test_reorder_batch_with_torchair_graph(self, ascend_config):
206-
mock_vllm_config = MagicMock()
207-
mock_vllm_config.model_config.max_model_len = 1024
208-
mock_vllm_config.cache_config.block_size = 16
209-
mock_vllm_config.scheduler_config.max_num_seqs = 4
210-
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
211-
mock_device = 'cpu'
212-
ascend_config.torchair_graph_config = MagicMock()
213-
ascend_config.torchair_graph_config.enabled = True
214-
215-
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
216-
217-
input_batch = MagicMock()
218-
input_batch.req_ids = [0, 1, 2, 3]
219-
220-
scheduler_output = MagicMock()
221-
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
222-
scheduler_output.scheduled_spec_decode_tokens = {
223-
0: [1],
224-
1: [],
225-
2: [1, 1],
226-
3: []
227-
}
228-
229-
input_batch.swap_states = MagicMock()
230-
231-
modified = builder.reorder_batch(input_batch, scheduler_output)
232-
233-
self.assertFalse(modified)
234-
input_batch.swap_states.assert_not_called()
235-
236-
def test_reorder_batch_without_torchair_graph(self):
200+
def test_reorder_batch(self):
237201
ascend_config = MagicMock()
238-
ascend_config.torchair_graph_config = MagicMock()
239-
ascend_config.torchair_graph_config.enabled = False
240202

241203
mock_vllm_config = MagicMock()
242204
mock_vllm_config.model_config.max_model_len = 1024
@@ -268,128 +230,6 @@ def test_reorder_batch_without_torchair_graph(self):
268230
self.assertTrue(modified)
269231
input_batch.swap_states.assert_called_once_with(1, 2)
270232

271-
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
272-
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
273-
ascend_config = MagicMock()
274-
mock_ascend_config.return_value = ascend_config
275-
ascend_config.torchair_graph_config.enabled = False
276-
mock_vllm_config = MagicMock()
277-
mock_vllm_config.model_config.max_model_len = 1024
278-
mock_vllm_config.cache_config.block_size = 16
279-
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
280-
mock_device = 'cpu'
281-
282-
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
283-
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
284-
285-
result = builder._get_graph_runner_block_tables(3, block_tables)
286-
self.assertEqual(result.shape[0], 3)
287-
self.assertEqual(result.shape[1], 64)
288-
self.assertTrue(torch.equal(result[:, :10], block_tables))
289-
290-
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
291-
def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
292-
ascend_config = MagicMock()
293-
mock_ascend_config.return_value = ascend_config
294-
ascend_config.torchair_graph_config.enabled = False
295-
mock_vllm_config = MagicMock()
296-
mock_vllm_config.model_config.max_model_len = 64
297-
mock_vllm_config.cache_config.block_size = 16
298-
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
299-
mock_device = 'cpu'
300-
301-
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
302-
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
303-
304-
result = builder._get_graph_runner_block_tables(3, block_tables)
305-
self.assertEqual(result.shape[0], 3)
306-
self.assertEqual(result.shape[1], 4)
307-
self.assertTrue(torch.equal(result, block_tables[:, :4]))
308-
309-
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
310-
def test_get_graph_runner_block_tables_from_numpy(self,
311-
mock_ascend_config):
312-
ascend_config = MagicMock()
313-
mock_ascend_config.return_value = ascend_config
314-
ascend_config.torchair_graph_config.enabled = False
315-
mock_vllm_config = MagicMock()
316-
mock_vllm_config.model_config.max_model_len = 1024
317-
mock_vllm_config.cache_config.block_size = 16
318-
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
319-
mock_device = 'cpu'
320-
321-
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
322-
323-
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
324-
325-
result = builder._get_graph_runner_block_tables(3, block_tables)
326-
327-
self.assertEqual(result.shape[0], 3)
328-
self.assertEqual(result.shape[1], 64)
329-
self.assertTrue(torch.equal(result[:, :10], block_tables))
330-
331-
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
332-
def test_build_dummy(self, mock_ascend_config):
333-
ascend_config = MagicMock()
334-
mock_ascend_config.return_value = ascend_config
335-
ascend_config.torchair_graph_config.enabled = False
336-
337-
mock_vllm_config = MagicMock()
338-
mock_vllm_config.model_config.max_model_len = 1024
339-
mock_vllm_config.cache_config.block_size = 16
340-
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
341-
mock_vllm_config.get_head_size.return_value = 64
342-
mock_vllm_config.model_config.dtype = torch.float16
343-
mock_device = 'cpu'
344-
345-
builder = AscendMLAMetadataBuilder(mock_vllm_config,
346-
mock_device,
347-
metadata_cls=AscendMLAMetadata)
348-
builder.rope_dim = 64
349-
350-
with patch.object(builder,
351-
"_get_graph_runner_block_tables",
352-
side_effect=lambda x, y: y):
353-
common_attn_metadata = TorchairCommonAttentionMetadata(
354-
num_reqs=3,
355-
num_actual_tokens=3,
356-
decode_token_per_req=1,
357-
actual_seq_lengths_q=[0, 1, 2],
358-
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
359-
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
360-
)
361-
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
362-
363-
sin_golden = torch.ones(3,
364-
1,
365-
1,
366-
64,
367-
dtype=torch.float16,
368-
device=mock_device)
369-
cos_golden = torch.ones(3,
370-
1,
371-
1,
372-
64,
373-
dtype=torch.float16,
374-
device=mock_device)
375-
376-
self.assertIsInstance(metadata, AscendMLAMetadata)
377-
self.assertEqual(metadata.num_input_tokens, 3)
378-
self.assertEqual(metadata.num_actual_tokens, 3)
379-
self.assertEqual(metadata.num_decodes, 1)
380-
self.assertEqual(metadata.num_decode_tokens, 1)
381-
self.assertEqual(metadata.num_prefills, 0)
382-
self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly)
383-
self.assertIsNone(metadata.prefill)
384-
self.assertIsInstance(metadata.decode, AscendMLADecodeMetadata)
385-
self.assertEqual(metadata.block_tables.shape[0], 3)
386-
self.assertEqual(metadata.block_tables.shape[1], 64)
387-
self.assertEqual(metadata.seq_lens.shape[0], 3)
388-
self.assertEqual(metadata.slot_mapping.shape[0], 3)
389-
self.assertEqual(metadata.query_start_loc.shape[0], 3)
390-
assert torch.equal(sin_golden, metadata.decode.sin)
391-
assert torch.equal(cos_golden, metadata.decode.cos)
392-
393233

394234
class TestAscendMLAImpl(TestBase):
395235

@@ -401,8 +241,6 @@ class TestAscendMLAImpl(TestBase):
401241
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
402242
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
403243
mock_tp.world_size = 2
404-
ascend_config.torchair_graph_config.enabled = True
405-
ascend_config.torchair_graph_config.enable_kv_nz = False
406244
speculative_config = MagicMock()
407245
speculative_config.num_speculative_tokens = 4
408246
vllm_config.speculative_config = speculative_config
@@ -464,7 +302,6 @@ def test_init(self):
464302
self.assertIsNotNone(self.impl.kv_a_layernorm)
465303
self.assertEqual(self.impl.num_queries_per_kv, 32)
466304
self.assertEqual(self.impl.tp_size, 2)
467-
self.assertTrue(self.impl.torchair_graph_enabled)
468305

469306
def test_v_up_proj_and_o_proj(self):
470307
batch_size = 4
@@ -580,102 +417,10 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
580417
self.assertEqual(out.shape, prefix_out.shape)
581418
self.assertEqual(lse.shape, prefix_lse.shape)
582419

583-
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
584-
def test_exec_kv(self, mock_kv_cache):
585-
batch_size = 2
586-
hidden = torch.randn(batch_size, 128)
587-
cos = torch.randn(batch_size, 32)
588-
sin = torch.randn(batch_size, 32)
589-
kv_cache = (torch.randn(
590-
4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
591-
torch.randn(
592-
4, 8,
593-
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
594-
slots = torch.arange(batch_size, dtype=torch.long)
595-
596-
proj_out = torch.randn(
597-
batch_size, self.impl.num_kv_heads, 1,
598-
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
599-
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
600-
601-
mock_kv_cache.return_value = (torch.randn(batch_size,
602-
self.impl.num_kv_heads, 1,
603-
self.impl.qk_rope_head_dim),
604-
torch.randn(batch_size,
605-
self.impl.num_kv_heads, 1,
606-
self.impl.kv_lora_rank),
607-
None, None)
608-
609-
k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots)
610-
611-
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden)
612-
mock_kv_cache.assert_called_once()
613-
self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1,
614-
self.impl.qk_rope_head_dim))
615-
self.assertEqual(
616-
k_nope.shape,
617-
(batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank))
618-
self.assertEqual(kv.shape,
619-
(batch_size, self.impl.num_kv_heads, 1,
620-
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
621-
622-
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
623-
def test_exec_kv_prefill(self, mock_kv):
624-
B, N, S, H = 2, self.impl.num_kv_heads, 1, 128
625-
hidden_states = torch.randn(B, N, S, H)
626-
cos = torch.randn(B, S, 32)
627-
sin = torch.randn(B, S, 32)
628-
kv_cache = (
629-
torch.randn(100, 8,
630-
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
631-
torch.randn(100, 8,
632-
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
633-
)
634-
635-
slots = torch.arange(B * S, dtype=torch.long)
636-
637-
proj_out = torch.randn(
638-
B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
639-
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
640-
641-
mock_kv.return_value = (None, None,
642-
torch.randn(B, self.impl.num_kv_heads, S,
643-
self.impl.qk_rope_head_dim),
644-
torch.randn(B, self.impl.num_kv_heads, S,
645-
self.impl.kv_lora_rank))
646-
647-
k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin,
648-
kv_cache, slots)
649-
650-
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states)
651-
mock_kv.assert_called_once()
652-
653-
self.assertEqual(
654-
k_pe.shape,
655-
(B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim))
656-
self.assertEqual(
657-
k_nope.shape,
658-
(B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank))
659-
660-
@patch("torch_npu.npu_interleave_rope")
661-
def test_rope_single(self, mock_rope):
662-
B, N, D = 2, 16, 1024
663-
x = torch.randn(B, N, D)
664-
cos = torch.randn(B, N, 1, D)
665-
sin = torch.randn(B, N, 1, D)
666-
mock_rope.return_value = x.view(B, N, 1, D)
667-
result = self.impl.rope_single(x, cos, sin)
668-
self.assertEqual(result.shape[0], B)
669-
self.assertEqual(result.shape[1], N)
670-
self.assertEqual(result.shape[2], D)
671-
mock_rope.assert_called_once()
672-
673420
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj")
674421
@patch("torch_npu._npu_paged_attention_mla")
675422
def test_forward_decode_without_graph(self, mock_page_attention_mla,
676423
mock_up_proj):
677-
self.impl.running_in_graph = False
678-
self.impl.running_chunkprefilll_with_torchair = False
679424
num_tokens = 100
680425
num_blocks = 256
681426
block_size = 4
@@ -706,9 +451,6 @@ def test_forward_decode_without_graph(self, mock_page_attention_mla,
706451
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
707452
@patch("torch_npu._npu_reshape_and_cache")
708453
def test_forward_without_graph(self, _, mock_forward_prefill):
709-
self.impl.running_in_graph = False
710-
self.impl.torchair_graph_enabled = False
711-
712454
num_tokens = 100
713455
num_blocks = 256
714456
block_size = 4

tests/ut/test_platform.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,27 @@ def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
425425
self.assertEqual(result,
426426
"vllm_ascend.attention.mla_v1.AscendMLABackend")
427427

428+
@patch('vllm_ascend.platform.get_ascend_config')
429+
def test_get_attn_backend_cls_use_v1_mla_and_torchair(
430+
self, mock_get_ascend_config):
431+
mock_config = MagicMock()
432+
mock_config.torchair_graph_config.enabled = True
433+
434+
mock_get_ascend_config.return_value = mock_config
435+
436+
result = self.platform.get_attn_backend_cls(
437+
selected_backend="ascend",
438+
head_size=64,
439+
dtype="float16",
440+
kv_cache_dtype="float16",
441+
block_size=64,
442+
use_v1=True,
443+
use_mla=True,
444+
)
445+
self.assertEqual(
446+
result,
447+
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend")
448+
428449
@patch('vllm_ascend.platform.get_ascend_config')
429450
def test_get_attn_backend_cls_use_v1_and_torchair(self,
430451
mock_get_ascend_config):

0 commit comments

Comments
 (0)