3
3
import numpy as np
4
4
import torch .distributed as dist
5
5
from lightllm .models .llama .infer_struct import LlamaInferStateInfo
6
- from lightllm .models .deepseek2 .triton_kernel .repack_kv_index import repack_kv_index
7
- import flashinfer
8
6
9
7
10
8
class Deepseek2InferStateInfo (LlamaInferStateInfo ):
11
9
def __init__ (self ):
12
10
super ().__init__ ()
13
11
self .kv_starts = None
12
+ self .prefill_wrapper = None
13
+ self .decode_wrapper = None
14
14
self .enable_dp = os .getenv ("ENABLE_DP" , "0" ).upper () in ["ON" , "TRUE" , "1" ]
15
+ self .enable_flashinfer_prefilled = os .getenv ("ENABLE_FLASHINFER_PREFILLED" , "False" ).upper () in [
16
+ "ON" ,
17
+ "TRUE" ,
18
+ "1" ,
19
+ ]
15
20
self .enable_flashinfer_decode_mla = os .getenv ("ENABLE_FLASHINFER_DECODE_MLA" , "False" ).upper () in [
16
21
"ON" ,
17
22
"TRUE" ,
@@ -20,12 +25,24 @@ def __init__(self):
20
25
21
26
def init_some_extra_state (self , model , input_ids : torch .Tensor ):
22
27
super ().init_some_extra_state (model , input_ids )
23
- # 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
28
+
24
29
if not self .is_prefill :
25
30
self .kv_starts = torch .cat ([self .b_start_loc , self .b_start_loc [- 1 :] + self .b_seq_len [- 1 :]], dim = 0 )
26
31
self .total_token_num_tensor = torch .sum (self .b_seq_len )
27
32
if self .enable_flashinfer_decode_mla :
28
- self .workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 ).to (input_ids .device )
33
+ import flashinfer
34
+ from lightllm .models .deepseek2 .triton_kernel .repack_kv_index import repack_kv_index
35
+
36
+ self .tp_q_head_num = (
37
+ model .tp_q_head_num_ * model .world_size_ if self .enable_dp else model .tp_q_head_num_
38
+ )
39
+ self .kv_lora_rank = model .kv_lora_rank
40
+ self .qk_rope_head_dim = model .qk_rope_head_dim
41
+ self .qk_nope_head_dim = model .qk_nope_head_dim
42
+ self .softmax_scale = model .softmax_scale
43
+ self .q_data_type = model .data_type
44
+ self .kv_data_type = model .data_type
45
+
29
46
self .q_indptr = torch .arange (self .batch_size + 1 , dtype = torch .int32 ).to (input_ids .device )
30
47
self .kv_indices = torch .empty (self .batch_size * model .max_seq_length , dtype = torch .int32 ).to (
31
48
input_ids .device
@@ -38,38 +55,63 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
38
55
self .max_len_in_batch ,
39
56
self .kv_indices ,
40
57
)
41
- self .wrapper = flashinfer .mla .BatchMLAPagedAttentionWrapper (
42
- self .workspace_buffer ,
43
- backend = "fa2" ,
44
- use_cuda_graph = True ,
45
- qo_indptr = self .q_indptr ,
46
- kv_indices = self .kv_indices ,
47
- kv_indptr = self .kv_starts ,
48
- kv_len_arr = self .b_seq_len ,
58
+ if self .decode_wrapper is None :
59
+ self .decode_wrapper = flashinfer .mla .BatchMLAPagedAttentionWrapper (
60
+ model .workspace_buffer ,
61
+ use_cuda_graph = True ,
62
+ qo_indptr = self .q_indptr ,
63
+ kv_indices = self .kv_indices ,
64
+ kv_indptr = self .kv_starts ,
65
+ kv_len_arr = self .b_seq_len ,
66
+ )
67
+ self .decode_wrapper .plan (
68
+ self .q_indptr ,
69
+ self .kv_starts ,
70
+ self .kv_indices ,
71
+ self .b_seq_len ,
72
+ self .tp_q_head_num ,
73
+ self .kv_lora_rank ,
74
+ self .qk_rope_head_dim ,
75
+ 1 ,
76
+ False , # causal
77
+ self .softmax_scale ,
78
+ self .q_data_type ,
79
+ self .kv_data_type ,
80
+ )
81
+ else :
82
+ self .b_kv_start_loc = self .b_seq_len .cumsum (dim = 0 ) - self .b_seq_len
83
+ if self .enable_flashinfer_prefilled :
84
+ import flashinfer
85
+
86
+ self .tp_q_head_num = (
87
+ model .tp_q_head_num_ * model .world_size_ if self .enable_dp else model .tp_q_head_num_
49
88
)
50
- self .head_num = model .tp_q_head_num_ * model .world_size_ if self .enable_dp else model .tp_q_head_num_
51
- self .kv_lora_rank = model .kv_lora_rank
52
89
self .qk_rope_head_dim = model .qk_rope_head_dim
90
+ self .qk_nope_head_dim = model .qk_nope_head_dim
53
91
self .softmax_scale = model .softmax_scale
54
92
self .q_data_type = model .data_type
55
- self .kv_data_type = model .data_type
56
- self .wrapper .plan (
57
- self .q_indptr ,
58
- self .kv_starts ,
59
- self .kv_indices ,
60
- self .b_seq_len ,
61
- self .head_num ,
62
- self .kv_lora_rank ,
63
- self .qk_rope_head_dim ,
64
- 1 ,
65
- False , # causal
66
- self .softmax_scale ,
67
- self .q_data_type ,
68
- self .kv_data_type ,
69
- )
70
93
71
- if self .is_prefill :
72
- self .b_kv_start_loc = self .b_seq_len .cumsum (dim = 0 ) - self .b_seq_len
94
+ q_starts = torch .cat (
95
+ [self .b_start_loc , self .b_start_loc [- 1 :] + (self .b_seq_len - self .b_ready_cache_len )[- 1 :]], dim = 0
96
+ ).int ()
97
+ kv_starts = torch .cat (
98
+ [self .b_kv_start_loc , self .b_kv_start_loc [- 1 :] + self .b_seq_len [- 1 :]], dim = 0
99
+ ).int ()
100
+ if self .prefill_wrapper is None :
101
+ self .prefill_wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (
102
+ model .workspace_buffer , "NHD"
103
+ )
104
+ self .prefill_wrapper .plan (
105
+ qo_indptr = q_starts ,
106
+ kv_indptr = kv_starts ,
107
+ num_qo_heads = self .tp_q_head_num ,
108
+ num_kv_heads = self .tp_q_head_num ,
109
+ head_dim_qk = self .qk_nope_head_dim + self .qk_rope_head_dim ,
110
+ head_dim_vo = self .qk_nope_head_dim ,
111
+ q_data_type = self .q_data_type ,
112
+ causal = True ,
113
+ sm_scale = self .softmax_scale ,
114
+ )
73
115
74
116
if self .enable_dp :
75
117
rank = dist .get_rank ()
@@ -89,19 +131,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
89
131
90
132
def copy_for_cuda_graph (self , new_infer_state ):
91
133
super ().copy_for_cuda_graph (new_infer_state )
92
- if self .enable_flashinfer_decode_mla :
93
- self .wrapper .plan (
94
- self .q_indptr ,
95
- self .kv_starts ,
96
- self .kv_indices ,
97
- self .b_seq_len ,
98
- self . head_num ,
99
- self .kv_lora_rank ,
100
- self .qk_rope_head_dim ,
134
+ if self .enable_flashinfer_decode_mla and not self . is_prefill :
135
+ self .decode_wrapper .plan (
136
+ new_infer_state .q_indptr ,
137
+ new_infer_state .kv_starts ,
138
+ new_infer_state .kv_indices ,
139
+ new_infer_state .b_seq_len ,
140
+ new_infer_state . tp_q_head_num ,
141
+ new_infer_state .kv_lora_rank ,
142
+ new_infer_state .qk_rope_head_dim ,
101
143
1 ,
102
144
False , # causal
103
- self .softmax_scale ,
104
- self .q_data_type ,
105
- self .kv_data_type ,
145
+ new_infer_state .softmax_scale ,
146
+ new_infer_state .q_data_type ,
147
+ new_infer_state .kv_data_type ,
106
148
)
107
149
return
0 commit comments