@@ -44,6 +44,245 @@ __global__ void apply_repetition_penalties_kernel(
4444 }
4545}
4646
47+ static inline __device__ uint16_t extractBinIdx (float x) {
48+ union {
49+ __half h;
50+ uint16_t u16 ;
51+ } tmp;
52+ tmp.h = __float2half_rn (x);
53+ tmp.u16 = (x < 0 .f ) ? (~tmp.u16 & 0xffff ) : (tmp.u16 | 0x8000 );
54+ return 511 - (tmp.u16 >> 7 );
55+ }
56+
57+ template <int kNumThreadsPerBlock = 512 >
58+ static __global__ void topKPerRow (const float * logits, const int * rowStarts,
59+ const int * rowEnds, int * outIndices,
60+ float * outLogits, int stride0, int stride1) {
61+ // The number of bins in the histogram.
62+ static constexpr int kNumBins = 512 ;
63+
64+ // The top-k width.
65+ static constexpr int kTopK = 2048 ;
66+ // The number of elements per thread for the final top-k sort.
67+ static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock ;
68+ // The class to sort the elements during the final top-k sort.
69+ using TopKSort = cub::BlockRadixSort<float , kNumThreadsPerBlock ,
70+ kNumTopKItemsPerThread , int >;
71+
72+ // The number of slots for the final pass.
73+ static constexpr int kNumFinalItems = 3072 ;
74+ // The number of elements per thread for the final sort.
75+ static constexpr int kNumFinalItemsPerThread =
76+ kNumFinalItems / kNumThreadsPerBlock ;
77+ // The class to sort the elements during the final pass.
78+ using FinalSort = cub::BlockRadixSort<float , kNumThreadsPerBlock ,
79+ kNumFinalItemsPerThread , int >;
80+
81+ // The class to compute the inclusive prefix-sum over the histogram.
82+ using Scan = cub::BlockScan<int , kNumThreadsPerBlock >;
83+
84+ // Shared memory to compute the block scan.
85+ __shared__ typename Scan::TempStorage smemScan;
86+
87+ // The structure to store the final items (for the final pass).
88+ struct FinalItems {
89+ // Shared memory to store the indices for the final pass.
90+ int indices[kNumFinalItems ];
91+ // Shared memory to store the logits for the final pass.
92+ float logits[kNumFinalItems ];
93+ };
94+
95+ // Shared memory to compute the block sort.
96+ __shared__ union {
97+ FinalItems items;
98+ typename FinalSort::TempStorage finalSort;
99+ typename TopKSort::TempStorage topKSort;
100+ } smemFinal;
101+
102+ // Shared memory to store the histogram.
103+ __shared__ int smemHistogram[kNumBins ];
104+ // Shared memory to store the selected indices.
105+ __shared__ int smemIndices[kTopK ];
106+ // Shared memory to store the selected logits.
107+ __shared__ float smemLogits[kTopK ];
108+ // Shared memory to store the threshold bin.
109+ __shared__ int smemThresholdBinIdx[1 ];
110+ // Shared memory counter to register the candidates for the final phase.
111+ __shared__ int smemFinalDstIdx[1 ];
112+
113+ // The row computed by this block.
114+ int rowIdx = blockIdx .x ;
115+ // The range of logits within the row.
116+ int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
117+ // The length of the row.
118+ int rowLen = rowEnd - rowStart;
119+
120+ // Shortcut if the length of the row is smaller than Top-K. Indices are not
121+ // sorted by their corresponding logit.
122+ if (rowLen <= kTopK ) {
123+ for (int rowIt = threadIdx .x ; rowIt < rowLen;
124+ rowIt += kNumThreadsPerBlock ) {
125+ int idx = rowStart + rowIt;
126+ outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
127+ outLogits[rowIdx * kTopK + rowIt] =
128+ logits[rowIdx * stride0 + idx * stride1];
129+ }
130+ for (int rowIt = rowLen + threadIdx .x ; rowIt < kTopK ;
131+ rowIt += kNumThreadsPerBlock ) {
132+ outIndices[rowIdx * kTopK + rowIt] = -1 ;
133+ outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
134+ }
135+ return ;
136+ }
137+
138+ // Clear the histogram.
139+ if (threadIdx .x < kNumBins ) {
140+ smemHistogram[threadIdx .x ] = 0 ;
141+ }
142+
143+ // Make sure the histogram is ready.
144+ __syncthreads ();
145+
146+ // Fetch elements one-by-one.
147+ for (int rowIt = rowStart + threadIdx .x ; rowIt < rowEnd;
148+ rowIt += kNumThreadsPerBlock ) {
149+ uint16_t idx = extractBinIdx (logits[rowIdx * stride0 + rowIt * stride1]);
150+ atomicAdd (&smemHistogram[idx], 1 );
151+ }
152+
153+ // Make sure the histogram is ready.
154+ __syncthreads ();
155+
156+ // Read the values from SMEM.
157+ int binCount{0 };
158+ if (threadIdx .x < kNumBins ) {
159+ binCount = smemHistogram[threadIdx .x ];
160+ }
161+
162+ // Make sure each thread has read its value.
163+ __syncthreads ();
164+
165+ // Compute the prefix sum.
166+ int prefixSum{0 }, totalSum{0 };
167+ Scan (smemScan).ExclusiveSum (binCount, prefixSum, totalSum);
168+
169+ // Update the histogram with the prefix sums.
170+ if (threadIdx .x < kNumBins ) {
171+ smemHistogram[threadIdx .x ] = prefixSum;
172+ }
173+
174+ // Make sure the data is in shared memory.
175+ __syncthreads ();
176+
177+ // Find the last valid bin.
178+ if (threadIdx .x < kNumBins ) {
179+ int nextPrefixSum =
180+ threadIdx .x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx .x + 1 ];
181+ if (prefixSum < kTopK && nextPrefixSum >= kTopK ) {
182+ smemThresholdBinIdx[0 ] = threadIdx .x ;
183+ }
184+ }
185+
186+ // Clear the counter to store the items for the final phase.
187+ if (threadIdx .x == 0 ) {
188+ smemFinalDstIdx[0 ] = 0 ;
189+ }
190+
191+ // Make sure the data is in shared memory.
192+ __syncthreads ();
193+
194+ // The threshold bin.
195+ int thresholdBinIdx = smemThresholdBinIdx[0 ];
196+
197+ // Fetch elements one-by-one and populate the shared memory buffers.
198+ for (int rowIt = rowStart + threadIdx .x ; rowIt < rowEnd;
199+ rowIt += kNumThreadsPerBlock ) {
200+ float logit = logits[rowIdx * stride0 + rowIt * stride1];
201+ uint16_t idx = extractBinIdx (logit);
202+ if (idx < thresholdBinIdx) {
203+ int dstIdx = atomicAdd (&smemHistogram[idx], 1 );
204+ smemLogits[dstIdx] = logit;
205+ smemIndices[dstIdx] = rowIt;
206+ } else if (idx == thresholdBinIdx) {
207+ int dstIdx = atomicAdd (&smemFinalDstIdx[0 ], 1 );
208+ if (dstIdx < kNumFinalItems ) {
209+ smemFinal.items .logits [dstIdx] = logit;
210+ smemFinal.items .indices [dstIdx] = rowIt;
211+ }
212+ }
213+ }
214+
215+ // Make sure the elements are in shared memory.
216+ __syncthreads ();
217+
218+ // The logits of the elements to be sorted in the final pass.
219+ float finalLogits[kNumFinalItemsPerThread ];
220+ // The indices of the elements to be sorted in the final pass.
221+ int finalIndices[kNumFinalItemsPerThread ];
222+
223+ // Init.
224+ #pragma unroll
225+ for (int ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
226+ finalLogits[ii] = -FLT_MAX;
227+ }
228+
229+ // Read the elements from SMEM.
230+ #pragma unroll
231+ for (int ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
232+ int srcIdx = ii * kNumThreadsPerBlock + threadIdx .x ;
233+ if (srcIdx < smemFinalDstIdx[0 ]) {
234+ finalLogits[ii] = smemFinal.items .logits [srcIdx];
235+ finalIndices[ii] = smemFinal.items .indices [srcIdx];
236+ }
237+ }
238+
239+ // Make sure the shared memory has been read.
240+ __syncthreads ();
241+
242+ // Sort the elements.
243+ FinalSort (smemFinal.finalSort )
244+ .SortDescendingBlockedToStriped (finalLogits, finalIndices);
245+
246+ // Copy the data back to the shared memory storage.
247+ int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1 ] : 0 ;
248+ #pragma unroll
249+ for (int ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
250+ int srcIdx = ii * kNumThreadsPerBlock + threadIdx .x ;
251+ int dstIdx = baseIdx + srcIdx;
252+ if (dstIdx < kTopK ) {
253+ smemLogits[dstIdx] = finalLogits[ii];
254+ smemIndices[dstIdx] = finalIndices[ii];
255+ }
256+ }
257+
258+ // Make sure the data is in shared memory.
259+ __syncthreads ();
260+
261+ // The topK logits.
262+ float topKLogits[kNumTopKItemsPerThread ];
263+ // The topK indices.
264+ int topKIndices[kNumTopKItemsPerThread ];
265+
266+ // Load from shared memory.
267+ #pragma unroll
268+ for (int ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
269+ topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx .x ];
270+ topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx .x ];
271+ }
272+
273+ // Sort the elements.
274+ TopKSort (smemFinal.topKSort )
275+ .SortDescendingBlockedToStriped (topKLogits, topKIndices);
276+
277+ // Store to global memory.
278+ #pragma unroll
279+ for (int ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
280+ int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx .x ;
281+ outIndices[offset] = topKIndices[ii] - rowStart;
282+ outLogits[offset] = topKLogits[ii];
283+ }
284+ }
285+
47286} // namespace vllm
48287
49288void apply_repetition_penalties_ (
@@ -85,4 +324,20 @@ void apply_repetition_penalties_(
85324 repetition_penalties.data_ptr <scalar_t >(), num_seqs, vocab_size,
86325 tile_size);
87326 });
88- }
327+ }
328+
329+ void top_k_per_row (const torch::Tensor& logits, const torch::Tensor& rowStarts,
330+ const torch::Tensor& rowEnds, torch::Tensor& indices,
331+ torch::Tensor& values, int64_t numRows, int64_t stride0,
332+ int64_t stride1) {
333+ // Compute the results on the device.
334+ constexpr int kNumThreadsPerBlock = 512 ;
335+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
336+
337+ vllm::topKPerRow<kNumThreadsPerBlock >
338+ <<<numRows, kNumThreadsPerBlock , 0 , stream>>> (
339+ logits.data_ptr <float >(), rowStarts.data_ptr <int >(),
340+ rowEnds.data_ptr <int >(), indices.data_ptr <int >(),
341+ values.data_ptr <float >(), static_cast <int >(stride0),
342+ static_cast <int >(stride1));
343+ }
0 commit comments