Skip to content

Commit d1a773d

Browse files
committed
fix
1 parent 425edb2 commit d1a773d

File tree

2 files changed

+160
-8
lines changed

2 files changed

+160
-8
lines changed

lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant
1111
from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager
1212
from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks
13+
from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks
14+
from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward
1315
from lightllm.utils.log_utils import init_logger
1416

1517
logger = init_logger(__name__)
@@ -78,16 +80,13 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
7880
weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale
7981
weights = weights.unsqueeze(-1) * q_scale
8082

81-
# Use pre-computed indexing structures from infer_state
8283
mem_index = infer_state.mem_index
8384
ks = infer_state.ks
8485
ke = infer_state.ke
8586
lengths = infer_state.lengths
8687
page_table_1 = infer_state.page_table_size_1
8788

88-
# TODO
89-
k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous()
90-
k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous()
89+
k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index)
9190

9291
logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke)
9392

@@ -123,10 +122,7 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
123122
q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim)
124123
k = layer_weight.wk_proj_.mm(hidden_states)
125124

126-
# TODO
127-
k = F.layer_norm(
128-
k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps
129-
).type_as(k)
125+
k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps)
130126

131127
rotary_emb_fwd(
132128
q[:, :, : self.qk_rope_head_dim],
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
import numpy
5+
6+
7+
@triton.jit
8+
def _fwd_kernel_extract_indexer_ks(
9+
buffer_fp8,
10+
buffer_scale,
11+
mem_index,
12+
k_fp8_out,
13+
k_scale_out,
14+
stride_buffer_fp8_bs,
15+
stride_buffer_fp8_h,
16+
stride_buffer_fp8_d,
17+
stride_buffer_scale_bs,
18+
stride_buffer_scale_h,
19+
stride_buffer_scale_d,
20+
stride_k_fp8_out_bs,
21+
stride_k_fp8_out_d,
22+
stride_k_scale_out_bs,
23+
BLOCK_DMODEL: tl.constexpr,
24+
):
25+
cur_index = tl.program_id(0)
26+
27+
# Load the memory index
28+
mem_idx = tl.load(mem_index + cur_index).to(tl.int64)
29+
30+
# Load k_fp8 data from buffer_fp8[mem_idx, 0, :]
31+
offs_d = tl.arange(0, BLOCK_DMODEL)
32+
k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d
33+
k_fp8_data = tl.load(k_fp8_ptrs)
34+
35+
# Load k_scale data from buffer_scale[mem_idx, 0, 0]
36+
k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d
37+
k_scale_data = tl.load(k_scale_ptr)
38+
39+
# Store k_fp8 output
40+
k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d
41+
tl.store(k_fp8_out_ptrs, k_fp8_data)
42+
43+
# Store k_scale output
44+
k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs
45+
tl.store(k_scale_out_ptr, k_scale_data)
46+
47+
48+
@torch.no_grad()
49+
def extract_indexer_ks(buffer, mem_index):
50+
"""
51+
Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel.
52+
53+
Args:
54+
buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8
55+
mem_index: Indices tensor of shape [seq_len] with dtype int32/int64
56+
57+
Returns:
58+
k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn
59+
k_scale: Tensor of shape [seq_len] with dtype float32
60+
"""
61+
seq_len = mem_index.shape[0]
62+
assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}"
63+
64+
# Reinterpret buffer as the appropriate types for Triton
65+
buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn)
66+
buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1]
67+
68+
# Prepare output tensors
69+
k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device)
70+
k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device)
71+
72+
BLOCK_DMODEL = 128
73+
grid = (seq_len,)
74+
num_warps = 1
75+
76+
_fwd_kernel_extract_indexer_ks[grid](
77+
buffer_fp8,
78+
buffer_scale,
79+
mem_index,
80+
k_fp8_out,
81+
k_scale_out,
82+
buffer_fp8.stride(0),
83+
buffer_fp8.stride(1),
84+
buffer_fp8.stride(2),
85+
buffer_scale.stride(0),
86+
buffer_scale.stride(1),
87+
buffer_scale.stride(2),
88+
k_fp8_out.stride(0),
89+
k_fp8_out.stride(1),
90+
k_scale_out.stride(0),
91+
BLOCK_DMODEL=BLOCK_DMODEL,
92+
num_warps=num_warps,
93+
num_stages=1,
94+
)
95+
96+
return k_fp8_out, k_scale_out
97+
98+
99+
def test():
100+
# Test parameters similar to the usage in nsa_indexer_layer_inder.py
101+
B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this)
102+
seq_len = 50 # number of tokens to extract
103+
dtype_fp8 = torch.float8_e4m3fn
104+
dtype_scale = torch.float32
105+
106+
# Create test buffer [total_tokens, heads, 132] as uint8
107+
buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda()
108+
109+
# Fill with test data - simulate what destindex_copy_indexer_ks does
110+
test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda()
111+
# Generate fp8 data by converting from float32
112+
test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda()
113+
test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8)
114+
test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda()
115+
116+
# Manually populate buffer as destindex_copy_indexer_ks would
117+
for i in range(seq_len):
118+
dest_idx = test_indices[i].item()
119+
# Store fp8 data
120+
buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8)
121+
# Store scale data (4 bytes) - need to convert float32 to bytes
122+
scale_bytes = test_k_scale[i].cpu().numpy().tobytes()
123+
scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8)
124+
buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device)
125+
126+
# Call our extraction function
127+
extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices)
128+
129+
# Verify results
130+
print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}")
131+
print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}")
132+
print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}")
133+
print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}")
134+
135+
# Check if extraction matches (convert fp8 to float32 for comparison)
136+
# Use higher tolerance for fp8 due to quantization precision
137+
fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1)
138+
scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6)
139+
140+
print(f"FP8 data matches: {fp8_match}")
141+
print(f"Scale data matches: {scale_match}")
142+
143+
if fp8_match and scale_match:
144+
print("All tests passed!")
145+
else:
146+
print("Test failed!")
147+
if not fp8_match:
148+
print("First few fp8 values:")
149+
print(f"Original: {test_k_fp8_fp32[0, :5]}")
150+
print(f"Extracted: {extracted_fp8.float()[0, :5]}")
151+
if not scale_match:
152+
print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}")
153+
154+
155+
if __name__ == "__main__":
156+
test()

0 commit comments

Comments
 (0)