@@ -13,20 +13,14 @@ def __init__(self):
13
13
super ().__init__ ()
14
14
self .prefill_wrapper = None
15
15
self .decode_wrapper = None
16
+ self .flashinfer_extra_state = None
16
17
17
18
def init_some_extra_state (self , model , input_ids : torch .Tensor ):
18
19
super ().init_some_extra_state (model , input_ids )
20
+ self .flashinfer_extra_state = model .flashinfer_extra_state
19
21
20
22
if not self .is_prefill :
21
23
if enable_env_vars ("ENABLE_FLASHINFER_DECODE_MLA" ):
22
- self .tp_q_head_num = model .flashinfer_state .tp_q_head_num
23
- self .kv_lora_rank = model .flashinfer_state .kv_lora_rank
24
- self .qk_rope_head_dim = model .flashinfer_state .qk_rope_head_dim
25
- self .qk_nope_head_dim = model .flashinfer_state .qk_nope_head_dim
26
- self .softmax_scale = model .flashinfer_state .softmax_scale
27
- self .q_data_type = model .flashinfer_state .data_type
28
- self .kv_data_type = model .flashinfer_state .data_type
29
-
30
24
self .q_indptr = torch .arange (self .batch_size + 1 , dtype = torch .int32 ).to (input_ids .device )
31
25
self .kv_indices = torch .empty (
32
26
self .batch_size * model .flashinfer_state .max_seq_length , dtype = torch .int32
@@ -41,7 +35,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
41
35
)
42
36
if self .decode_wrapper is None :
43
37
self .decode_wrapper = flashinfer .mla .BatchMLAPagedAttentionWrapper (
44
- model . flashinfer_state .workspace_buffer ,
38
+ self . flashinfer_extra_state .workspace_buffer ,
45
39
use_cuda_graph = True ,
46
40
qo_indptr = self .q_indptr ,
47
41
kv_indices = self .kv_indices ,
@@ -53,23 +47,17 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
53
47
self .kv_starts ,
54
48
self .kv_indices ,
55
49
self .b_seq_len ,
56
- self .tp_q_head_num ,
57
- self .kv_lora_rank ,
58
- self .qk_rope_head_dim ,
50
+ self .flashinfer_extra_state . tp_q_head_num ,
51
+ self .flashinfer_extra_state . kv_lora_rank ,
52
+ self .flashinfer_extra_state . qk_rope_head_dim ,
59
53
1 ,
60
54
False , # causal
61
- self .softmax_scale ,
62
- self .q_data_type ,
63
- self .kv_data_type ,
55
+ self .flashinfer_extra_state . softmax_scale ,
56
+ self .flashinfer_extra_state . q_data_type ,
57
+ self .flashinfer_extra_state . kv_data_type ,
64
58
)
65
59
else :
66
60
if enable_env_vars ("ENABLE_FLASHINFER_PREFILLED" ):
67
- self .tp_q_head_num = model .flashinfer_state .tp_q_head_num
68
- self .qk_rope_head_dim = model .flashinfer_state .qk_rope_head_dim
69
- self .qk_nope_head_dim = model .flashinfer_state .qk_nope_head_dim
70
- self .softmax_scale = model .flashinfer_state .softmax_scale
71
- self .q_data_type = model .flashinfer_state .data_type
72
-
73
61
q_starts = torch .cat (
74
62
[self .b_start_loc , self .b_start_loc [- 1 :] + (self .b_seq_len - self .b_ready_cache_len )[- 1 :]], dim = 0
75
63
).int ()
@@ -78,18 +66,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
78
66
).int ()
79
67
if self .prefill_wrapper is None :
80
68
self .prefill_wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (
81
- model . flashinfer_state .workspace_buffer , "NHD"
69
+ self . flashinfer_extra_state .workspace_buffer , "NHD"
82
70
)
83
71
self .prefill_wrapper .plan (
84
72
qo_indptr = q_starts ,
85
73
kv_indptr = kv_starts ,
86
- num_qo_heads = self .tp_q_head_num ,
87
- num_kv_heads = self .tp_q_head_num ,
88
- head_dim_qk = self .qk_nope_head_dim + self .qk_rope_head_dim ,
89
- head_dim_vo = self .qk_nope_head_dim ,
90
- q_data_type = self .q_data_type ,
74
+ num_qo_heads = self .flashinfer_extra_state .tp_q_head_num ,
75
+ num_kv_heads = self .flashinfer_extra_state .tp_q_head_num ,
76
+ head_dim_qk = self .flashinfer_extra_state .qk_nope_head_dim
77
+ + self .flashinfer_extra_state .qk_rope_head_dim ,
78
+ head_dim_vo = self .flashinfer_extra_state .qk_nope_head_dim ,
79
+ q_data_type = self .flashinfer_extra_state .q_data_type ,
91
80
causal = True ,
92
- sm_scale = self .softmax_scale ,
81
+ sm_scale = self .flashinfer_extra_state . softmax_scale ,
93
82
)
94
83
return
95
84
@@ -101,13 +90,13 @@ def copy_for_cuda_graph(self, new_infer_state):
101
90
new_infer_state .kv_starts ,
102
91
new_infer_state .kv_indices ,
103
92
new_infer_state .b_seq_len ,
104
- new_infer_state .tp_q_head_num ,
105
- new_infer_state .kv_lora_rank ,
106
- new_infer_state .qk_rope_head_dim ,
93
+ new_infer_state .flashinfer_extra_state . tp_q_head_num ,
94
+ new_infer_state .flashinfer_extra_state . kv_lora_rank ,
95
+ new_infer_state .flashinfer_extra_state . qk_rope_head_dim ,
107
96
1 ,
108
97
False , # causal
109
- new_infer_state .softmax_scale ,
110
- new_infer_state .q_data_type ,
111
- new_infer_state .kv_data_type ,
98
+ new_infer_state .flashinfer_extra_state . softmax_scale ,
99
+ new_infer_state .flashinfer_extra_state . q_data_type ,
100
+ new_infer_state .flashinfer_extra_state . kv_data_type ,
112
101
)
113
102
return
0 commit comments