Skip to content

Commit edae8c1

Browse files
author
niushengxiao
committed
feat: add flashinfer prefilled operator in the attention module
1 parent 3c51248 commit edae8c1

File tree

5 files changed

+136
-76
lines changed

5 files changed

+136
-76
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class TpPartBaseModel:
3636
infer_state_class = InferStateInfo
3737

3838
def __init__(self, kvargs):
39+
self.infer_state = self.infer_state_class()
3940
self.run_mode = kvargs["run_mode"]
4041
self.tp_rank_ = kvargs["tp_rank"]
4142
self.world_size_ = kvargs["world_size"]
@@ -330,7 +331,9 @@ def _decode(
330331
b_seq_len,
331332
multimodal_params,
332333
):
333-
infer_state = self.infer_state_class()
334+
infer_state = self.infer_state
335+
if self.graph is None or self.graph.need_capture(batch_size):
336+
infer_state = self.infer_state_class()
334337
infer_state.is_prefill = False
335338
infer_state.batch_size = batch_size
336339
infer_state.total_token_num = total_token_num

lightllm/models/deepseek2/infer_struct.py

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6-
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
7-
import flashinfer
86

97

108
class Deepseek2InferStateInfo(LlamaInferStateInfo):
119
def __init__(self):
1210
super().__init__()
1311
self.kv_starts = None
12+
self.prefill_wrapper = None
13+
self.decode_wrapper = None
1414
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
15+
self.enable_flashinfer_prefilled = os.getenv("ENABLE_FLASHINFER_PREFILLED", "False").upper() in [
16+
"ON",
17+
"TRUE",
18+
"1",
19+
]
1520
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
1621
"ON",
1722
"TRUE",
@@ -20,12 +25,24 @@ def __init__(self):
2025

2126
def init_some_extra_state(self, model, input_ids: torch.Tensor):
2227
super().init_some_extra_state(model, input_ids)
23-
# 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
28+
2429
if not self.is_prefill:
2530
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
2631
self.total_token_num_tensor = torch.sum(self.b_seq_len)
2732
if self.enable_flashinfer_decode_mla:
28-
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(input_ids.device)
33+
import flashinfer
34+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
35+
36+
self.tp_q_head_num = (
37+
model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
38+
)
39+
self.kv_lora_rank = model.kv_lora_rank
40+
self.qk_rope_head_dim = model.qk_rope_head_dim
41+
self.qk_nope_head_dim = model.qk_nope_head_dim
42+
self.softmax_scale = model.softmax_scale
43+
self.q_data_type = model.data_type
44+
self.kv_data_type = model.data_type
45+
2946
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
3047
self.kv_indices = torch.empty(self.batch_size * model.max_seq_length, dtype=torch.int32).to(
3148
input_ids.device
@@ -38,38 +55,63 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3855
self.max_len_in_batch,
3956
self.kv_indices,
4057
)
41-
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
42-
self.workspace_buffer,
43-
backend="fa2",
44-
use_cuda_graph=True,
45-
qo_indptr=self.q_indptr,
46-
kv_indices=self.kv_indices,
47-
kv_indptr=self.kv_starts,
48-
kv_len_arr=self.b_seq_len,
58+
if self.decode_wrapper is None:
59+
self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
60+
model.workspace_buffer,
61+
use_cuda_graph=True,
62+
qo_indptr=self.q_indptr,
63+
kv_indices=self.kv_indices,
64+
kv_indptr=self.kv_starts,
65+
kv_len_arr=self.b_seq_len,
66+
)
67+
self.decode_wrapper.plan(
68+
self.q_indptr,
69+
self.kv_starts,
70+
self.kv_indices,
71+
self.b_seq_len,
72+
self.tp_q_head_num,
73+
self.kv_lora_rank,
74+
self.qk_rope_head_dim,
75+
1,
76+
False, # causal
77+
self.softmax_scale,
78+
self.q_data_type,
79+
self.kv_data_type,
80+
)
81+
else:
82+
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len
83+
if self.enable_flashinfer_prefilled:
84+
import flashinfer
85+
86+
self.tp_q_head_num = (
87+
model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
4988
)
50-
self.head_num = model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
51-
self.kv_lora_rank = model.kv_lora_rank
5289
self.qk_rope_head_dim = model.qk_rope_head_dim
90+
self.qk_nope_head_dim = model.qk_nope_head_dim
5391
self.softmax_scale = model.softmax_scale
5492
self.q_data_type = model.data_type
55-
self.kv_data_type = model.data_type
56-
self.wrapper.plan(
57-
self.q_indptr,
58-
self.kv_starts,
59-
self.kv_indices,
60-
self.b_seq_len,
61-
self.head_num,
62-
self.kv_lora_rank,
63-
self.qk_rope_head_dim,
64-
1,
65-
False, # causal
66-
self.softmax_scale,
67-
self.q_data_type,
68-
self.kv_data_type,
69-
)
7093

71-
if self.is_prefill:
72-
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len
94+
q_starts = torch.cat(
95+
[self.b_start_loc, self.b_start_loc[-1:] + (self.b_seq_len - self.b_ready_cache_len)[-1:]], dim=0
96+
).int()
97+
kv_starts = torch.cat(
98+
[self.b_kv_start_loc, self.b_kv_start_loc[-1:] + self.b_seq_len[-1:]], dim=0
99+
).int()
100+
if self.prefill_wrapper is None:
101+
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
102+
model.workspace_buffer, "NHD"
103+
)
104+
self.prefill_wrapper.plan(
105+
qo_indptr=q_starts,
106+
kv_indptr=kv_starts,
107+
num_qo_heads=self.tp_q_head_num,
108+
num_kv_heads=self.tp_q_head_num,
109+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
110+
head_dim_vo=self.qk_nope_head_dim,
111+
q_data_type=self.q_data_type,
112+
causal=True,
113+
sm_scale=self.softmax_scale,
114+
)
73115

74116
if self.enable_dp:
75117
rank = dist.get_rank()
@@ -89,19 +131,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
89131

90132
def copy_for_cuda_graph(self, new_infer_state):
91133
super().copy_for_cuda_graph(new_infer_state)
92-
if self.enable_flashinfer_decode_mla:
93-
self.wrapper.plan(
94-
self.q_indptr,
95-
self.kv_starts,
96-
self.kv_indices,
97-
self.b_seq_len,
98-
self.head_num,
99-
self.kv_lora_rank,
100-
self.qk_rope_head_dim,
134+
if self.enable_flashinfer_decode_mla and not self.is_prefill:
135+
self.decode_wrapper.plan(
136+
new_infer_state.q_indptr,
137+
new_infer_state.kv_starts,
138+
new_infer_state.kv_indices,
139+
new_infer_state.b_seq_len,
140+
new_infer_state.tp_q_head_num,
141+
new_infer_state.kv_lora_rank,
142+
new_infer_state.qk_rope_head_dim,
101143
1,
102144
False, # causal
103-
self.softmax_scale,
104-
self.q_data_type,
105-
self.kv_data_type,
145+
new_infer_state.softmax_scale,
146+
new_infer_state.q_data_type,
147+
new_infer_state.kv_data_type,
106148
)
107149
return

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6969
self.num_heads = network_config["num_attention_heads"]
7070
self.num_kv_heads = network_config["num_key_value_heads"]
7171
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
72+
self.enable_flashinfer_prefilled = os.getenv("ENABLE_FLASHINFER_PREFILLED", "False").upper() in [
73+
"ON",
74+
"TRUE",
75+
"1",
76+
]
7277
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
7378
"ON",
7479
"TRUE",
@@ -220,22 +225,28 @@ def _context_attention_kernel_with_CC(
220225
out=None,
221226
) -> torch.Tensor:
222227
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
223-
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
224-
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
225-
context_attention_fwd_with_v(
226-
q_nope,
227-
q_rope,
228-
k_nope,
229-
k_rope,
230-
v,
231-
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
232-
infer_state.b_start_loc,
233-
infer_state.b_kv_start_loc,
234-
infer_state.b_seq_len,
235-
infer_state.b_ready_cache_len,
236-
infer_state.max_len_in_batch,
237-
self.softmax_scale,
228+
o_tensor = (
229+
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
238230
)
231+
if self.enable_flashinfer_prefilled:
232+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
233+
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
234+
else:
235+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
236+
context_attention_fwd_with_v(
237+
q_nope,
238+
q_rope,
239+
k_nope,
240+
k_rope,
241+
v,
242+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
243+
infer_state.b_start_loc,
244+
infer_state.b_kv_start_loc,
245+
infer_state.b_seq_len,
246+
infer_state.b_ready_cache_len,
247+
infer_state.max_len_in_batch,
248+
self.softmax_scale,
249+
)
239250
return o_tensor
240251

241252
def _context_attention_kernel_with_CC_fp8(
@@ -249,20 +260,24 @@ def _context_attention_kernel_with_CC_fp8(
249260
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
250261
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
251262
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
252-
context_attention_fwd_with_v(
253-
q_nope,
254-
q_rope,
255-
k_nope,
256-
k_rope,
257-
v,
258-
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
259-
infer_state.b_start_loc,
260-
infer_state.b_kv_start_loc,
261-
infer_state.b_seq_len,
262-
infer_state.b_ready_cache_len,
263-
infer_state.max_len_in_batch,
264-
self.softmax_scale,
265-
)
263+
if self.enable_flashinfer_prefilled:
264+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
265+
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
266+
else:
267+
context_attention_fwd_with_v(
268+
q_nope,
269+
q_rope,
270+
k_nope,
271+
k_rope,
272+
v,
273+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
274+
infer_state.b_start_loc,
275+
infer_state.b_kv_start_loc,
276+
infer_state.b_seq_len,
277+
infer_state.b_ready_cache_len,
278+
infer_state.max_len_in_batch,
279+
self.softmax_scale,
280+
)
266281
return o_tensor
267282

268283
def _context_attention_kernel_origin(
@@ -378,7 +393,7 @@ def _token_gqa_decode_attention_flashdecoding(
378393
)
379394
return o_tensor
380395
elif self.enable_flashinfer_decode_mla:
381-
infer_state.wrapper.run(
396+
infer_state.decode_wrapper.run(
382397
q_nope,
383398
q_rope,
384399
kv[:, :, : -self.qk_rope_head_dim],

lightllm/models/deepseek2/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _init_some_value(self):
4040
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
4141
self.tp_q_head_num_ = self.config["num_attention_heads"] // self.world_size_
4242
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
43+
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(self.tp_rank_)
4344
if self.config["rope_scaling"] is not None:
4445
rope_scaling = self.config["rope_scaling"]
4546
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0)

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ def _fwd_kernel_calcu_index_and_block_seq(
179179
req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX)
180180
b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX
181181
b_start_loc = torch.arange(Z).cuda().int() * N_CTX
182-
b_start_loc[0] = 0
183-
b_req_idx = torch.arange(Z).cuda().int()
182+
b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda()
184183
kv_starts = torch.cat([b_start_loc, b_start_loc[-1:] + b_seq_len[-1:]], dim=0)
185184

186185
o = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda")

0 commit comments

Comments
 (0)