@@ -57,7 +57,7 @@ static inline __device__ uint16_t extractBinIdx(float x) {
5757template <int kNumThreadsPerBlock = 512 >
5858static __global__ void topKPerRow (const float * logits, const int * rowStarts,
5959 const int * rowEnds, int * outIndices,
60- float * outLogits, int stride0, int stride1) {
60+ int stride0, int stride1) {
6161 // The number of bins in the histogram.
6262 static constexpr int kNumBins = 512 ;
6363
@@ -103,8 +103,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
103103 __shared__ int smemHistogram[kNumBins ];
104104 // Shared memory to store the selected indices.
105105 __shared__ int smemIndices[kTopK ];
106- // Shared memory to store the selected logits.
107- __shared__ float smemLogits[kTopK ];
108106 // Shared memory to store the threshold bin.
109107 __shared__ int smemThresholdBinIdx[1 ];
110108 // Shared memory counter to register the candidates for the final phase.
@@ -124,13 +122,10 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
124122 rowIt += kNumThreadsPerBlock ) {
125123 int idx = rowStart + rowIt;
126124 outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
127- outLogits[rowIdx * kTopK + rowIt] =
128- logits[rowIdx * stride0 + idx * stride1];
129125 }
130126 for (int rowIt = rowLen + threadIdx .x ; rowIt < kTopK ;
131127 rowIt += kNumThreadsPerBlock ) {
132128 outIndices[rowIdx * kTopK + rowIt] = -1 ;
133- outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
134129 }
135130 return ;
136131 }
@@ -201,7 +196,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
201196 uint16_t idx = extractBinIdx (logit);
202197 if (idx < thresholdBinIdx) {
203198 int dstIdx = atomicAdd (&smemHistogram[idx], 1 );
204- smemLogits[dstIdx] = logit;
205199 smemIndices[dstIdx] = rowIt;
206200 } else if (idx == thresholdBinIdx) {
207201 int dstIdx = atomicAdd (&smemFinalDstIdx[0 ], 1 );
@@ -250,36 +244,19 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts,
250244 int srcIdx = ii * kNumThreadsPerBlock + threadIdx .x ;
251245 int dstIdx = baseIdx + srcIdx;
252246 if (dstIdx < kTopK ) {
253- smemLogits[dstIdx] = finalLogits[ii];
254247 smemIndices[dstIdx] = finalIndices[ii];
255248 }
256249 }
257250
258251 // Make sure the data is in shared memory.
259252 __syncthreads ();
260253
261- // The topK logits.
262- float topKLogits[kNumTopKItemsPerThread ];
263- // The topK indices.
264- int topKIndices[kNumTopKItemsPerThread ];
265-
266- // Load from shared memory.
267- #pragma unroll
268- for (int ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
269- topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx .x ];
270- topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx .x ];
271- }
272-
273- // Sort the elements.
274- TopKSort (smemFinal.topKSort )
275- .SortDescendingBlockedToStriped (topKLogits, topKIndices);
276-
277254// Store to global memory.
278255#pragma unroll
279256 for (int ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
280257 int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx .x ;
281- outIndices[offset] = topKIndices[ii] - rowStart;
282- outLogits[offset] = topKLogits [ii] ;
258+ outIndices[offset] =
259+ smemIndices [ii * kNumThreadsPerBlock + threadIdx . x ] - rowStart ;
283260 }
284261}
285262
@@ -328,8 +305,7 @@ void apply_repetition_penalties_(
328305
329306void top_k_per_row (const torch::Tensor& logits, const torch::Tensor& rowStarts,
330307 const torch::Tensor& rowEnds, torch::Tensor& indices,
331- torch::Tensor& values, int64_t numRows, int64_t stride0,
332- int64_t stride1) {
308+ int64_t numRows, int64_t stride0, int64_t stride1) {
333309 // Compute the results on the device.
334310 constexpr int kNumThreadsPerBlock = 512 ;
335311 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -338,6 +314,5 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
338314 <<<numRows, kNumThreadsPerBlock , 0 , stream>>> (
339315 logits.data_ptr <float >(), rowStarts.data_ptr <int >(),
340316 rowEnds.data_ptr <int >(), indices.data_ptr <int >(),
341- values.data_ptr <float >(), static_cast <int >(stride0),
342- static_cast <int >(stride1));
317+ static_cast <int >(stride0), static_cast <int >(stride1));
343318}
0 commit comments