Skip to content

Commit 80e9452

Browse files
authored
[Deepseek v3.2] Optimize top_k_per_row (#26763)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
1 parent c3a2c6a commit 80e9452

File tree

5 files changed

+13
-49
lines changed

5 files changed

+13
-49
lines changed

csrc/ops.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ void apply_repetition_penalties_(torch::Tensor& logits,
9999

100100
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
101101
const torch::Tensor& rowEnds, torch::Tensor& indices,
102-
torch::Tensor& values, int64_t numRows, int64_t stride0,
103-
int64_t stride1);
102+
int64_t numRows, int64_t stride0, int64_t stride1);
104103

105104
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
106105
torch::Tensor& weight, torch::Tensor& scale,

csrc/sampler.cu

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ static inline __device__ uint16_t extractBinIdx(float x) {
5757
template <int kNumThreadsPerBlock = 512>
5858
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
5959
const int* rowEnds, int* outIndices,
60-
float* outLogits, int stride0, int stride1) {
60+
int stride0, int stride1) {
6161
// The number of bins in the histogram.
6262
static constexpr int kNumBins = 512;
6363

@@ -103,8 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
103103
__shared__ int smemHistogram[kNumBins];
104104
// Shared memory to store the selected indices.
105105
__shared__ int smemIndices[kTopK];
106-
// Shared memory to store the selected logits.
107-
__shared__ float smemLogits[kTopK];
108106
// Shared memory to store the threshold bin.
109107
__shared__ int smemThresholdBinIdx[1];
110108
// Shared memory counter to register the candidates for the final phase.
@@ -124,13 +122,10 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
124122
rowIt += kNumThreadsPerBlock) {
125123
int idx = rowStart + rowIt;
126124
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
127-
outLogits[rowIdx * kTopK + rowIt] =
128-
logits[rowIdx * stride0 + idx * stride1];
129125
}
130126
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
131127
rowIt += kNumThreadsPerBlock) {
132128
outIndices[rowIdx * kTopK + rowIt] = -1;
133-
outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
134129
}
135130
return;
136131
}
@@ -201,7 +196,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
201196
uint16_t idx = extractBinIdx(logit);
202197
if (idx < thresholdBinIdx) {
203198
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
204-
smemLogits[dstIdx] = logit;
205199
smemIndices[dstIdx] = rowIt;
206200
} else if (idx == thresholdBinIdx) {
207201
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
@@ -250,36 +244,19 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
250244
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
251245
int dstIdx = baseIdx + srcIdx;
252246
if (dstIdx < kTopK) {
253-
smemLogits[dstIdx] = finalLogits[ii];
254247
smemIndices[dstIdx] = finalIndices[ii];
255248
}
256249
}
257250

258251
// Make sure the data is in shared memory.
259252
__syncthreads();
260253

261-
// The topK logits.
262-
float topKLogits[kNumTopKItemsPerThread];
263-
// The topK indices.
264-
int topKIndices[kNumTopKItemsPerThread];
265-
266-
// Load from shared memory.
267-
#pragma unroll
268-
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
269-
topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x];
270-
topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x];
271-
}
272-
273-
// Sort the elements.
274-
TopKSort(smemFinal.topKSort)
275-
.SortDescendingBlockedToStriped(topKLogits, topKIndices);
276-
277254
// Store to global memory.
278255
#pragma unroll
279256
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
280257
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
281-
outIndices[offset] = topKIndices[ii] - rowStart;
282-
outLogits[offset] = topKLogits[ii];
258+
outIndices[offset] =
259+
smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart;
283260
}
284261
}
285262

@@ -328,8 +305,7 @@ void apply_repetition_penalties_(
328305

329306
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
330307
const torch::Tensor& rowEnds, torch::Tensor& indices,
331-
torch::Tensor& values, int64_t numRows, int64_t stride0,
332-
int64_t stride1) {
308+
int64_t numRows, int64_t stride0, int64_t stride1) {
333309
// Compute the results on the device.
334310
constexpr int kNumThreadsPerBlock = 512;
335311
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -338,6 +314,5 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
338314
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
339315
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
340316
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
341-
values.data_ptr<float>(), static_cast<int>(stride0),
342-
static_cast<int>(stride1));
317+
static_cast<int>(stride0), static_cast<int>(stride1));
343318
}

csrc/torch_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
185185
// Optimized top-k per row operation
186186
ops.def(
187187
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
188-
"Tensor! indices, Tensor! values, int numRows, int stride0, "
188+
"Tensor! indices, int numRows, int stride0, "
189189
"int stride1) -> ()");
190190
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
191191

tests/kernels/test_top_k_per_row.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ def create_row_boundaries(
3939

4040

4141
def compare_top_k_results(
42+
logits: torch.Tensor,
4243
cuda_indices: torch.Tensor,
43-
cuda_values: torch.Tensor,
4444
torch_indices: torch.Tensor,
45-
torch_values: torch.Tensor,
4645
row_starts: torch.Tensor,
4746
row_ends: torch.Tensor,
4847
top_k: int,
@@ -70,8 +69,9 @@ def compare_top_k_results(
7069
continue
7170

7271
# Any difference in elements, compare the values
73-
cuda_row_values = cuda_values[row_idx][:num_valid].cpu()
74-
torch_row_values = torch_values[row_idx][:num_valid].cpu()
72+
logits_row = logits[row_idx]
73+
cuda_row_values = [logits_row[i] for i in cuda_row_indices]
74+
torch_row_values = [logits_row[i] for i in torch_row_indices]
7575

7676
cuda_only_values, torch_only_values = [], []
7777
for idx in cuda_set - torch_set:
@@ -115,28 +115,26 @@ def test_top_k_per_row(
115115

116116
# Create output tensors
117117
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
118-
values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda")
119118

120119
# Run CUDA implementation
121120
torch.ops._C.top_k_per_row(
122121
logits,
123122
row_starts,
124123
row_ends,
125124
indices,
126-
values,
127125
num_rows,
128126
logits.stride(0),
129127
logits.stride(1),
130128
)
131129

132130
# Run reference implementation
133-
torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)
131+
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
134132
mask_lo = torch_indices >= 0
135133
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
136134
mask = mask_lo & mask_hi
137135
torch_indices = torch_indices.masked_fill(~mask, -1)
138136

139137
# Compare results
140138
assert compare_top_k_results(
141-
indices, values, torch_indices, torch_values, row_starts, row_ends, top_k
139+
logits, indices, torch_indices, row_starts, row_ends, top_k
142140
), "CUDA top_k_per_row results don't match torch.topk"

vllm/model_executor/models/deepseek_v2.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -577,15 +577,11 @@ def sparse_attn_indexer(
577577
topk_indices = torch.empty(
578578
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
579579
)
580-
topk_values = torch.empty(
581-
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
582-
)
583580
torch.ops._C.top_k_per_row(
584581
logits,
585582
chunk.cu_seqlen_ks,
586583
chunk.cu_seqlen_ke,
587584
topk_indices,
588-
topk_values,
589585
num_rows,
590586
logits.stride(0),
591587
logits.stride(1),
@@ -642,15 +638,11 @@ def sparse_attn_indexer(
642638
topk_indices = torch.empty(
643639
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
644640
)
645-
topk_values = torch.empty(
646-
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
647-
)
648641
torch.ops._C.top_k_per_row(
649642
logits,
650643
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
651644
index_end_pos.to(dtype=torch.int32, device=logits.device),
652645
topk_indices,
653-
topk_values,
654646
num_rows,
655647
logits.stride(0),
656648
logits.stride(1),

0 commit comments

Comments
 (0)