Skip to content

Commit 425edb2

Browse files
committed
run like deepseek v3
1 parent ee7100c commit 425edb2

File tree

6 files changed

+225
-90
lines changed

6 files changed

+225
-90
lines changed

lightllm/models/deepseek2/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, model):
5050
self.softmax_scale = self.softmax_scale * mscale * mscale
5151

5252

53-
@ModelRegistry(["deepseek_v2", "deepseek_v3"])
53+
@ModelRegistry(["deepseek_v2", "deepseek_v3", "deepseek_v32"])
5454
class Deepseek2TpPartModel(LlamaTpPartModel):
5555
# weight class
5656
transformer_weight_class = Deepseek2TransformerLayerWeight

lightllm/models/deepseek3_2/infer_struct.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo
3+
from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager
34

45
class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo):
56

@@ -15,6 +16,9 @@ def __init__(self):
1516

1617
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1718
super().init_some_extra_state(model, input_ids)
19+
assert isinstance(self.mem_manager, Deepseek3_2MemoryManager)
20+
self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager
21+
1822
# Ensure b_ready_cache_len is set for both prefill and decode modes
1923
if self.is_prefill:
2024
# b_ready_cache_len is already set in basemodel.py for prefill
@@ -24,9 +28,42 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2428
# since b_q_seq_len represents the new tokens being processed
2529
if self.b_ready_cache_len is None:
2630
self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len
27-
28-
self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk)
31+
32+
self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk)
2933
assert self.nsa_cache_seqlens.dtype == torch.int32
3034
self.nsa_cu_seqlens_k = torch.nn.functional.pad(
3135
torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0)
32-
)
36+
)
37+
38+
# Pre-compute NSA indexer indexing structures
39+
self._init_nsa_indexing_structures()
40+
41+
def _init_nsa_indexing_structures(self):
42+
"""Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer"""
43+
mem_index_list = []
44+
ks_list = []
45+
ke_list = []
46+
lengths_list = []
47+
offset = 0
48+
num_seq_len = self.b_req_idx.shape[0]
49+
self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda')
50+
51+
for i in range(num_seq_len):
52+
seq_len = self.b_seq_len[i]
53+
q_seq_len = self.b_q_seq_len[i]
54+
mem_index = self.req_manager.req_to_token_indexs[i, :seq_len]
55+
mem_index_list.append(mem_index)
56+
self.page_table_size_1[i, :seq_len] = mem_index
57+
ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset
58+
ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1
59+
ks_list.append(ks)
60+
ke_list.append(ke)
61+
lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda'))
62+
offset += seq_len
63+
64+
self.mem_index = torch.cat(mem_index_list, dim=0)
65+
# ks : [seq_len_q] 标志kv的起始位置
66+
# ke : [seq_len_q] 标志kv的结束位置
67+
self.ks = torch.cat(ks_list, dim=0)
68+
self.ke = torch.cat(ke_list, dim=0)
69+
self.lengths = torch.cat(lengths_list, dim=0)

lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py

Lines changed: 30 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant
1111
from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager
1212
from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks
13-
# from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits
13+
from lightllm.utils.log_utils import init_logger
14+
15+
logger = init_logger(__name__)
1416

1517
class NSAIndexerInfer(BaseLayerInfer):
1618
def __init__(self, layer_idx, network_config, mode=[]):
@@ -66,70 +68,37 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
6668
q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt)
6769
k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt)
6870

69-
self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager)
71+
destindex_copy_indexer_ks(
72+
k_fp8.unsqueeze(1),
73+
k_scale.unsqueeze(1),
74+
infer_state.mem_index,
75+
infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_]
76+
)
7077

7178
weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale
7279
weights = weights.unsqueeze(-1) * q_scale
7380

74-
ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_]
75-
76-
k_fp8_list = []
77-
k_scale_list = []
78-
ks_list = []
79-
ke_list = []
80-
offset = 0
81-
for i in range(infer_state.batch_size):
82-
q_len = infer_state.b_q_seq_len[i]
83-
cache_len = infer_state.b_ready_cache_len[i]
84-
mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len]
85-
k_fp8 = ks_buffer[mem_indexes, 0, :128].view(torch.float8_e4m3fn).contiguous()
86-
k_scale = ks_buffer[mem_indexes, 0, 128:].view(torch.float32).contiguous()
87-
ks = torch.full((q_len,), offset, dtype=torch.int32, device="cuda")
88-
ke = ks + torch.arange(q_len, dtype=torch.int32, device="cuda") + 1
89-
k_fp8_list.append(k_fp8)
90-
k_scale_list.append(k_scale)
91-
ks_list.append(ks)
92-
ke_list.append(ke)
93-
offset += q_len
94-
95-
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
96-
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
97-
kv_fp8 = (k_fp8, k_scale)
98-
ks = torch.cat(ks_list, dim=0)
99-
ke = torch.cat(ke_list, dim=0)
100-
101-
logits = deep_gemm.fp8_mqa_logits(
102-
q_fp8,
103-
kv_fp8,
104-
weights.squeeze(-1),
105-
ks,
106-
ke,
107-
clean_logits=False,
108-
)
109-
110-
return self.get_topk(logits, infer_state)
111-
112-
def get_topk(self, logits, infer_state: Deepseek3_2FlashAttentionStateInfo):
113-
topk_indices_list = []
114-
offset = 0
115-
116-
for i in range(infer_state.batch_size):
117-
q_len = infer_state.b_q_seq_len[i]
118-
cache_len = infer_state.b_ready_cache_len[i]
119-
end_pos = q_len + cache_len
120-
# Slice logits for this batch (both query and sequence dimensions)
121-
batch_logits = logits[offset:offset + q_len, :end_pos]
122-
topk_indices = batch_logits.topk(min(self.index_topk, end_pos), dim=-1)[1]
123-
mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len]
124-
indices = torch.full((q_len, self.index_topk), -1, dtype=torch.int32, device="cuda")
125-
for j in range(q_len):
126-
indices[j, :topk_indices[j].shape[0]] = mem_indexes[topk_indices[j]]
127-
topk_indices_list.append(indices)
128-
offset += q_len
81+
# Use pre-computed indexing structures from infer_state
82+
mem_index = infer_state.mem_index
83+
ks = infer_state.ks
84+
ke = infer_state.ke
85+
lengths = infer_state.lengths
86+
page_table_1 = infer_state.page_table_size_1
12987

130-
topk_indices_ = torch.cat(topk_indices_list, dim=0)
88+
# TODO
89+
k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous()
90+
k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous()
13191

132-
return topk_indices_
92+
logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke)
93+
94+
# 返回 : [seq_q_len, topk] 无效的位置使用-1填充
95+
return fast_topk_transform_fused(
96+
score=logits, # [seq_len_q, seq_len_kv]
97+
lengths=lengths, # [seq_len_q]
98+
page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充
99+
cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1]
100+
topk=self.index_topk,
101+
)
133102

134103

135104
def get_k_float32_from_buffer(self, buffer: torch.Tensor):
@@ -152,8 +121,9 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor:
152121
def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
153122
infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight):
154123
q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim)
155-
156124
k = layer_weight.wk_proj_.mm(hidden_states)
125+
126+
# TODO
157127
k = F.layer_norm(
158128
k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps
159129
).type_as(k)
@@ -168,17 +138,3 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
168138
q = self._rotate_activation(q)
169139
k = self._rotate_activation(k)
170140
return q, k
171-
172-
def _copy_ks_to_mem_cache(self, k_fp8, k_scale, mem_index, mem_manager: Deepseek3_2MemoryManager):
173-
# k_fp8 : [seq_len, 128] torch.fp8_e4m3
174-
# k_scale : [seq_len, 1] torch.float32
175-
# mem_index : [seq_len] torch.int32
176-
# buffer : [10000000, 1, 132] torch.uint8
177-
buffer = mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_]
178-
destindex_copy_indexer_ks(
179-
k_fp8.unsqueeze(1), # Add head dimension: [seq_len, 1, 128]
180-
k_scale.unsqueeze(1), # Add head dimension: [seq_len, 1, 1]
181-
mem_index,
182-
buffer
183-
)
184-
return

lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,9 @@ def _nsa_context_attention_kernel(
8282
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
8383
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
8484
q_all = torch.cat([q_nope, q_rope], dim=-1)
85-
8685
mla_out, _, _ = flash_mla_sparse_fwd(
87-
q=q_all,
88-
kv=infer_state.mem_manager.kv_buffer[self.layer_num_],
86+
q=q_all, # [seq_len_q, q_num_head, qk_dim]
87+
kv=infer_state.mem_manager.kv_buffer[self.layer_num_], # [size, 1, qk_dim]
8988
indices=self.topk_indices.unsqueeze(1),
9089
sm_scale=self.softmax_scale,
9190
d_v=self.kv_lora_rank,
@@ -100,15 +99,16 @@ def _nsa_token_attention_kernel(
10099
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
101100
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim)
102101
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank)
102+
103103
o_tensor = flash_attn_with_kvcache(
104-
q=q_rope,
105-
k_cache=k_rope,
106-
v_cache=kv_nope,
107-
qv=q_nope,
108-
page_table=self.topk_indices,
109-
cache_seqlens=infer_state.nsa_cache_seqlens,
110-
cu_seqlens_q=infer_state.cu_seqlens_q,
111-
cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k,
104+
q=q_rope, # (q_seqlen, nheads, qk_headdim)
105+
k_cache=k_rope, # (kv_size, 1, 1, qk_head_dim)
106+
v_cache=kv_nope, # (kv_size, 1, 1, kv_lora_rank)
107+
qv=q_nope, # (q_seqlen, nheads, kv_lora_rank)
108+
page_table=self.topk_indices, # (q_seqlen, max_seq_len)
109+
cache_seqlens=infer_state.nsa_cache_seqlens, # (q_seqlen) # 表示当前kv长度,用于读取page_table.
110+
cu_seqlens_q=infer_state.cu_seqlens_q, # (batch_size+1) [0,1]
111+
cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, #(batch_size+1) [0,9]
112112
max_seqlen_q=infer_state.max_q_seq_len,
113113
softmax_scale=self.softmax_scale,
114114
causal=True,

lightllm/models/deepseek3_2/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from lightllm.utils.envs_utils import get_env_start_args
66
from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo
77
from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager
8-
@ModelRegistry(["deepseek_v32"])
8+
# @ModelRegistry(["deepseek_v32"])
99
class Deepseek3_2TpPartModel(Deepseek2TpPartModel):
1010
# weight class
1111
transformer_weight_class = Deepseek3_2TransformerLayerWeight
@@ -21,6 +21,9 @@ def __init__(self, kvargs):
2121
self.index_topk = self.config["index_topk"]
2222
return
2323

24+
def _init_inferstate_cls(self):
25+
self.infer_state_class = Deepseek3_2FlashAttentionStateInfo
26+
2427
def _init_mem_manager(self):
2528
manager_class = Deepseek3_2MemoryManager
2629
if "triton_fp8kv" in self.mode:

0 commit comments

Comments
 (0)