|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | +import numpy |
| 5 | + |
| 6 | + |
| 7 | +@triton.jit |
| 8 | +def _fwd_kernel_extract_indexer_ks( |
| 9 | + buffer_fp8, |
| 10 | + buffer_scale, |
| 11 | + mem_index, |
| 12 | + k_fp8_out, |
| 13 | + k_scale_out, |
| 14 | + stride_buffer_fp8_bs, |
| 15 | + stride_buffer_fp8_h, |
| 16 | + stride_buffer_fp8_d, |
| 17 | + stride_buffer_scale_bs, |
| 18 | + stride_buffer_scale_h, |
| 19 | + stride_buffer_scale_d, |
| 20 | + stride_k_fp8_out_bs, |
| 21 | + stride_k_fp8_out_d, |
| 22 | + stride_k_scale_out_bs, |
| 23 | + BLOCK_DMODEL: tl.constexpr, |
| 24 | +): |
| 25 | + cur_index = tl.program_id(0) |
| 26 | + |
| 27 | + # Load the memory index |
| 28 | + mem_idx = tl.load(mem_index + cur_index).to(tl.int64) |
| 29 | + |
| 30 | + # Load k_fp8 data from buffer_fp8[mem_idx, 0, :] |
| 31 | + offs_d = tl.arange(0, BLOCK_DMODEL) |
| 32 | + k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d |
| 33 | + k_fp8_data = tl.load(k_fp8_ptrs) |
| 34 | + |
| 35 | + # Load k_scale data from buffer_scale[mem_idx, 0, 0] |
| 36 | + k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d |
| 37 | + k_scale_data = tl.load(k_scale_ptr) |
| 38 | + |
| 39 | + # Store k_fp8 output |
| 40 | + k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d |
| 41 | + tl.store(k_fp8_out_ptrs, k_fp8_data) |
| 42 | + |
| 43 | + # Store k_scale output |
| 44 | + k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs |
| 45 | + tl.store(k_scale_out_ptr, k_scale_data) |
| 46 | + |
| 47 | + |
| 48 | +@torch.no_grad() |
| 49 | +def extract_indexer_ks(buffer, mem_index): |
| 50 | + """ |
| 51 | + Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel. |
| 52 | +
|
| 53 | + Args: |
| 54 | + buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8 |
| 55 | + mem_index: Indices tensor of shape [seq_len] with dtype int32/int64 |
| 56 | +
|
| 57 | + Returns: |
| 58 | + k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn |
| 59 | + k_scale: Tensor of shape [seq_len] with dtype float32 |
| 60 | + """ |
| 61 | + seq_len = mem_index.shape[0] |
| 62 | + assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" |
| 63 | + |
| 64 | + # Reinterpret buffer as the appropriate types for Triton |
| 65 | + buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) |
| 66 | + buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] |
| 67 | + |
| 68 | + # Prepare output tensors |
| 69 | + k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device) |
| 70 | + k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device) |
| 71 | + |
| 72 | + BLOCK_DMODEL = 128 |
| 73 | + grid = (seq_len,) |
| 74 | + num_warps = 1 |
| 75 | + |
| 76 | + _fwd_kernel_extract_indexer_ks[grid]( |
| 77 | + buffer_fp8, |
| 78 | + buffer_scale, |
| 79 | + mem_index, |
| 80 | + k_fp8_out, |
| 81 | + k_scale_out, |
| 82 | + buffer_fp8.stride(0), |
| 83 | + buffer_fp8.stride(1), |
| 84 | + buffer_fp8.stride(2), |
| 85 | + buffer_scale.stride(0), |
| 86 | + buffer_scale.stride(1), |
| 87 | + buffer_scale.stride(2), |
| 88 | + k_fp8_out.stride(0), |
| 89 | + k_fp8_out.stride(1), |
| 90 | + k_scale_out.stride(0), |
| 91 | + BLOCK_DMODEL=BLOCK_DMODEL, |
| 92 | + num_warps=num_warps, |
| 93 | + num_stages=1, |
| 94 | + ) |
| 95 | + |
| 96 | + return k_fp8_out, k_scale_out |
| 97 | + |
| 98 | + |
| 99 | +def test(): |
| 100 | + # Test parameters similar to the usage in nsa_indexer_layer_inder.py |
| 101 | + B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this) |
| 102 | + seq_len = 50 # number of tokens to extract |
| 103 | + dtype_fp8 = torch.float8_e4m3fn |
| 104 | + dtype_scale = torch.float32 |
| 105 | + |
| 106 | + # Create test buffer [total_tokens, heads, 132] as uint8 |
| 107 | + buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() |
| 108 | + |
| 109 | + # Fill with test data - simulate what destindex_copy_indexer_ks does |
| 110 | + test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() |
| 111 | + # Generate fp8 data by converting from float32 |
| 112 | + test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda() |
| 113 | + test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8) |
| 114 | + test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda() |
| 115 | + |
| 116 | + # Manually populate buffer as destindex_copy_indexer_ks would |
| 117 | + for i in range(seq_len): |
| 118 | + dest_idx = test_indices[i].item() |
| 119 | + # Store fp8 data |
| 120 | + buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8) |
| 121 | + # Store scale data (4 bytes) - need to convert float32 to bytes |
| 122 | + scale_bytes = test_k_scale[i].cpu().numpy().tobytes() |
| 123 | + scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8) |
| 124 | + buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device) |
| 125 | + |
| 126 | + # Call our extraction function |
| 127 | + extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices) |
| 128 | + |
| 129 | + # Verify results |
| 130 | + print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}") |
| 131 | + print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}") |
| 132 | + print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}") |
| 133 | + print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}") |
| 134 | + |
| 135 | + # Check if extraction matches (convert fp8 to float32 for comparison) |
| 136 | + # Use higher tolerance for fp8 due to quantization precision |
| 137 | + fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1) |
| 138 | + scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6) |
| 139 | + |
| 140 | + print(f"FP8 data matches: {fp8_match}") |
| 141 | + print(f"Scale data matches: {scale_match}") |
| 142 | + |
| 143 | + if fp8_match and scale_match: |
| 144 | + print("All tests passed!") |
| 145 | + else: |
| 146 | + print("Test failed!") |
| 147 | + if not fp8_match: |
| 148 | + print("First few fp8 values:") |
| 149 | + print(f"Original: {test_k_fp8_fp32[0, :5]}") |
| 150 | + print(f"Extracted: {extracted_fp8.float()[0, :5]}") |
| 151 | + if not scale_match: |
| 152 | + print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}") |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == "__main__": |
| 156 | + test() |
0 commit comments