Skip to content

Commit 6f5545a

Browse files
hiworldwzjniushengxiao
authored andcommitted
fix
1 parent 9af2b0c commit 6f5545a

File tree

4 files changed

+31
-43
lines changed

4 files changed

+31
-43
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class TpPartBaseModel:
3636
infer_state_class = InferStateInfo
3737

3838
def __init__(self, kvargs):
39-
self.infer_state = self.infer_state_class()
4039
self.run_mode = kvargs["run_mode"]
4140
self.tp_rank_ = kvargs["tp_rank"]
4241
self.world_size_ = kvargs["world_size"]
@@ -331,9 +330,7 @@ def _decode(
331330
b_seq_len,
332331
multimodal_params,
333332
):
334-
infer_state = self.infer_state
335-
if self.graph is None or self.graph.need_capture(batch_size) or infer_state.is_prefill:
336-
infer_state = self.infer_state_class()
333+
infer_state = self.infer_state_class()
337334
infer_state.is_prefill = False
338335
infer_state.batch_size = batch_size
339336
infer_state.total_token_num = total_token_num

lightllm/models/deepseek2/flashinfer_struct.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,14 @@ def __init__(self):
1313
super().__init__()
1414
self.prefill_wrapper = None
1515
self.decode_wrapper = None
16+
self.flashinfer_extra_state = None
1617

1718
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1819
super().init_some_extra_state(model, input_ids)
20+
self.flashinfer_extra_state = model.flashinfer_extra_state
1921

2022
if not self.is_prefill:
2123
if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA"):
22-
self.tp_q_head_num = model.flashinfer_state.tp_q_head_num
23-
self.kv_lora_rank = model.flashinfer_state.kv_lora_rank
24-
self.qk_rope_head_dim = model.flashinfer_state.qk_rope_head_dim
25-
self.qk_nope_head_dim = model.flashinfer_state.qk_nope_head_dim
26-
self.softmax_scale = model.flashinfer_state.softmax_scale
27-
self.q_data_type = model.flashinfer_state.data_type
28-
self.kv_data_type = model.flashinfer_state.data_type
29-
3024
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
3125
self.kv_indices = torch.empty(
3226
self.batch_size * model.flashinfer_state.max_seq_length, dtype=torch.int32
@@ -41,7 +35,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
4135
)
4236
if self.decode_wrapper is None:
4337
self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
44-
model.flashinfer_state.workspace_buffer,
38+
self.flashinfer_extra_state.workspace_buffer,
4539
use_cuda_graph=True,
4640
qo_indptr=self.q_indptr,
4741
kv_indices=self.kv_indices,
@@ -53,23 +47,17 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
5347
self.kv_starts,
5448
self.kv_indices,
5549
self.b_seq_len,
56-
self.tp_q_head_num,
57-
self.kv_lora_rank,
58-
self.qk_rope_head_dim,
50+
self.flashinfer_extra_state.tp_q_head_num,
51+
self.flashinfer_extra_state.kv_lora_rank,
52+
self.flashinfer_extra_state.qk_rope_head_dim,
5953
1,
6054
False, # causal
61-
self.softmax_scale,
62-
self.q_data_type,
63-
self.kv_data_type,
55+
self.flashinfer_extra_state.softmax_scale,
56+
self.flashinfer_extra_state.q_data_type,
57+
self.flashinfer_extra_state.kv_data_type,
6458
)
6559
else:
6660
if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"):
67-
self.tp_q_head_num = model.flashinfer_state.tp_q_head_num
68-
self.qk_rope_head_dim = model.flashinfer_state.qk_rope_head_dim
69-
self.qk_nope_head_dim = model.flashinfer_state.qk_nope_head_dim
70-
self.softmax_scale = model.flashinfer_state.softmax_scale
71-
self.q_data_type = model.flashinfer_state.data_type
72-
7361
q_starts = torch.cat(
7462
[self.b_start_loc, self.b_start_loc[-1:] + (self.b_seq_len - self.b_ready_cache_len)[-1:]], dim=0
7563
).int()
@@ -78,18 +66,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
7866
).int()
7967
if self.prefill_wrapper is None:
8068
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
81-
model.flashinfer_state.workspace_buffer, "NHD"
69+
self.flashinfer_extra_state.workspace_buffer, "NHD"
8270
)
8371
self.prefill_wrapper.plan(
8472
qo_indptr=q_starts,
8573
kv_indptr=kv_starts,
86-
num_qo_heads=self.tp_q_head_num,
87-
num_kv_heads=self.tp_q_head_num,
88-
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
89-
head_dim_vo=self.qk_nope_head_dim,
90-
q_data_type=self.q_data_type,
74+
num_qo_heads=self.flashinfer_extra_state.tp_q_head_num,
75+
num_kv_heads=self.flashinfer_extra_state.tp_q_head_num,
76+
head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim
77+
+ self.flashinfer_extra_state.qk_rope_head_dim,
78+
head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim,
79+
q_data_type=self.flashinfer_extra_state.q_data_type,
9180
causal=True,
92-
sm_scale=self.softmax_scale,
81+
sm_scale=self.flashinfer_extra_state.softmax_scale,
9382
)
9483
return
9584

@@ -101,13 +90,13 @@ def copy_for_cuda_graph(self, new_infer_state):
10190
new_infer_state.kv_starts,
10291
new_infer_state.kv_indices,
10392
new_infer_state.b_seq_len,
104-
new_infer_state.tp_q_head_num,
105-
new_infer_state.kv_lora_rank,
106-
new_infer_state.qk_rope_head_dim,
93+
new_infer_state.flashinfer_extra_state.tp_q_head_num,
94+
new_infer_state.flashinfer_extra_state.kv_lora_rank,
95+
new_infer_state.flashinfer_extra_state.qk_rope_head_dim,
10796
1,
10897
False, # causal
109-
new_infer_state.softmax_scale,
110-
new_infer_state.q_data_type,
111-
new_infer_state.kv_data_type,
98+
new_infer_state.flashinfer_extra_state.softmax_scale,
99+
new_infer_state.flashinfer_extra_state.q_data_type,
100+
new_infer_state.flashinfer_extra_state.kv_data_type,
112101
)
113102
return

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
2222
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
2323
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
24+
from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo
2425
from functools import partial
2526
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2627
import os
@@ -224,7 +225,7 @@ def _context_attention_flashinfer_kernel_with_CC(
224225
self,
225226
q: torch.Tensor,
226227
kv,
227-
infer_state: Deepseek2InferStateInfo,
228+
infer_state: Deepseek2FlashInferStateInfo,
228229
layer_weight: Deepseek2TransformerLayerWeight,
229230
out=None,
230231
) -> torch.Tensor:
@@ -240,7 +241,7 @@ def _context_attention_flashinfer_kernel_with_CC_fp8(
240241
self,
241242
q: torch.Tensor,
242243
kv,
243-
infer_state: Deepseek2InferStateInfo,
244+
infer_state: Deepseek2FlashInferStateInfo,
244245
layer_weight: Deepseek2TransformerLayerWeight,
245246
out=None,
246247
) -> torch.Tensor:
@@ -393,7 +394,7 @@ def _context_attention_kernel_origin_fp8(
393394
return o_tensor
394395

395396
def _token_gqa_decode_attention_flashinfer(
396-
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
397+
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
397398
):
398399
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
399400
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)

lightllm/models/deepseek2/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def __init__(self, model):
2323
self.qk_nope_head_dim = model.qk_nope_head_dim
2424
self.qk_rope_head_dim = model.qk_rope_head_dim
2525
self.kv_lora_rank = model.kv_lora_rank
26-
self.data_type = model.data_type
26+
self.q_data_type = model.data_type
27+
self.kv_data_type = model.data_type
2728
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(model.tp_rank_)
2829
self.max_seq_length = model.max_seq_length
2930
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
@@ -65,7 +66,7 @@ def _init_some_value(self):
6566
self.kv_lora_rank = self.config["kv_lora_rank"]
6667
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
6768
if self.enable_flashinfer:
68-
self.flashinfer_state = FlashInferStateExtraInfo(self)
69+
self.flashinfer_extra_state = FlashInferStateExtraInfo(self)
6970

7071
def _init_custom(self):
7172
self._init_to_get_yarn_rotary()

0 commit comments

Comments
 (0)