1010from lightllm .models .deepseek3_2 .triton_kernel .act_quant import act_quant
1111from lightllm .models .deepseek3_2 .mem_manager import Deepseek3_2MemoryManager
1212from lightllm .models .deepseek3_2 .triton_kernel .destindex_copy_indexer_ks import destindex_copy_indexer_ks
13- # from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits
13+ from lightllm .utils .log_utils import init_logger
14+
15+ logger = init_logger (__name__ )
1416
1517class NSAIndexerInfer (BaseLayerInfer ):
1618 def __init__ (self , layer_idx , network_config , mode = []):
@@ -66,70 +68,37 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
6668 q_fp8 , q_scale = act_quant (q , self .block_size , self .scale_fmt )
6769 k_fp8 , k_scale = act_quant (k , self .block_size , self .scale_fmt )
6870
69- self ._copy_ks_to_mem_cache (k_fp8 , k_scale , infer_state .mem_index , infer_state .mem_manager )
71+ destindex_copy_indexer_ks (
72+ k_fp8 .unsqueeze (1 ),
73+ k_scale .unsqueeze (1 ),
74+ infer_state .mem_index ,
75+ infer_state .indexer_ks_mem_manager .kv_buffer [self .layer_idx_ ]
76+ )
7077
7178 weights = layer_weight .weights_proj_ .mm (hidden_states ) * self .index_n_heads_scale
7279 weights = weights .unsqueeze (- 1 ) * q_scale
7380
74- ks_buffer = infer_state .mem_manager .indexer_ks_mem_manager .kv_buffer [self .layer_idx_ ]
75-
76- k_fp8_list = []
77- k_scale_list = []
78- ks_list = []
79- ke_list = []
80- offset = 0
81- for i in range (infer_state .batch_size ):
82- q_len = infer_state .b_q_seq_len [i ]
83- cache_len = infer_state .b_ready_cache_len [i ]
84- mem_indexes = infer_state .req_manager .req_to_token_indexs [infer_state .b_req_idx [i ], :cache_len + q_len ]
85- k_fp8 = ks_buffer [mem_indexes , 0 , :128 ].view (torch .float8_e4m3fn ).contiguous ()
86- k_scale = ks_buffer [mem_indexes , 0 , 128 :].view (torch .float32 ).contiguous ()
87- ks = torch .full ((q_len ,), offset , dtype = torch .int32 , device = "cuda" )
88- ke = ks + torch .arange (q_len , dtype = torch .int32 , device = "cuda" ) + 1
89- k_fp8_list .append (k_fp8 )
90- k_scale_list .append (k_scale )
91- ks_list .append (ks )
92- ke_list .append (ke )
93- offset += q_len
94-
95- k_fp8 = torch .cat (k_fp8_list , dim = 0 ).view (torch .float8_e4m3fn )
96- k_scale = torch .cat (k_scale_list , dim = 0 ).view (torch .float32 ).squeeze (- 1 )
97- kv_fp8 = (k_fp8 , k_scale )
98- ks = torch .cat (ks_list , dim = 0 )
99- ke = torch .cat (ke_list , dim = 0 )
100-
101- logits = deep_gemm .fp8_mqa_logits (
102- q_fp8 ,
103- kv_fp8 ,
104- weights .squeeze (- 1 ),
105- ks ,
106- ke ,
107- clean_logits = False ,
108- )
109-
110- return self .get_topk (logits , infer_state )
111-
112- def get_topk (self , logits , infer_state : Deepseek3_2FlashAttentionStateInfo ):
113- topk_indices_list = []
114- offset = 0
115-
116- for i in range (infer_state .batch_size ):
117- q_len = infer_state .b_q_seq_len [i ]
118- cache_len = infer_state .b_ready_cache_len [i ]
119- end_pos = q_len + cache_len
120- # Slice logits for this batch (both query and sequence dimensions)
121- batch_logits = logits [offset :offset + q_len , :end_pos ]
122- topk_indices = batch_logits .topk (min (self .index_topk , end_pos ), dim = - 1 )[1 ]
123- mem_indexes = infer_state .req_manager .req_to_token_indexs [infer_state .b_req_idx [i ], :cache_len + q_len ]
124- indices = torch .full ((q_len , self .index_topk ), - 1 , dtype = torch .int32 , device = "cuda" )
125- for j in range (q_len ):
126- indices [j , :topk_indices [j ].shape [0 ]] = mem_indexes [topk_indices [j ]]
127- topk_indices_list .append (indices )
128- offset += q_len
81+ # Use pre-computed indexing structures from infer_state
82+ mem_index = infer_state .mem_index
83+ ks = infer_state .ks
84+ ke = infer_state .ke
85+ lengths = infer_state .lengths
86+ page_table_1 = infer_state .page_table_size_1
12987
130- topk_indices_ = torch .cat (topk_indices_list , dim = 0 )
88+ # TODO
89+ k_fp8_ = infer_state .indexer_ks_mem_manager .kv_buffer [self .layer_idx_ ][mem_index , :, :128 ].view (torch .float8_e4m3fn ).squeeze (1 ).contiguous ()
90+ k_scale_ = infer_state .indexer_ks_mem_manager .kv_buffer [self .layer_idx_ ][mem_index , :, 128 :].view (torch .float32 )[:, 0 , 0 ].contiguous ()
13191
132- return topk_indices_
92+ logits = deep_gemm .fp8_mqa_logits (q_fp8 , (k_fp8_ , k_scale_ ), weights .squeeze (- 1 ), ks , ke )
93+
94+ # 返回 : [seq_q_len, topk] 无效的位置使用-1填充
95+ return fast_topk_transform_fused (
96+ score = logits , # [seq_len_q, seq_len_kv]
97+ lengths = lengths , # [seq_len_q]
98+ page_table_size_1 = page_table_1 , # [seq_len_q, max(lengths)] 无效的使用0填充
99+ cu_seqlens_q = infer_state .cu_seqlens_q , # [seq_len_q + 1]
100+ topk = self .index_topk ,
101+ )
133102
134103
135104 def get_k_float32_from_buffer (self , buffer : torch .Tensor ):
@@ -152,8 +121,9 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor:
152121 def _get_q_k_bf16 (self , hidden_states : torch .Tensor , q_lora : torch .Tensor ,
153122 infer_state : Deepseek3_2FlashAttentionStateInfo , layer_weight : NSAIndexerWeight ):
154123 q = layer_weight .wq_b_proj_ .mm (q_lora ).view (- 1 , self .index_n_heads , self .index_head_dim )
155-
156124 k = layer_weight .wk_proj_ .mm (hidden_states )
125+
126+ # TODO
157127 k = F .layer_norm (
158128 k .float (), (self .index_head_dim ,), layer_weight .k_norm_ .weight , layer_weight .k_norm_ .bias , self .eps
159129 ).type_as (k )
@@ -168,17 +138,3 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
168138 q = self ._rotate_activation (q )
169139 k = self ._rotate_activation (k )
170140 return q , k
171-
172- def _copy_ks_to_mem_cache (self , k_fp8 , k_scale , mem_index , mem_manager : Deepseek3_2MemoryManager ):
173- # k_fp8 : [seq_len, 128] torch.fp8_e4m3
174- # k_scale : [seq_len, 1] torch.float32
175- # mem_index : [seq_len] torch.int32
176- # buffer : [10000000, 1, 132] torch.uint8
177- buffer = mem_manager .indexer_ks_mem_manager .kv_buffer [self .layer_idx_ ]
178- destindex_copy_indexer_ks (
179- k_fp8 .unsqueeze (1 ), # Add head dimension: [seq_len, 1, 128]
180- k_scale .unsqueeze (1 ), # Add head dimension: [seq_len, 1, 1]
181- mem_index ,
182- buffer
183- )
184- return
0 commit comments