@@ -191,6 +191,86 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(
191191 }
192192}
193193
194+ template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
195+ typename IdType>
196+ __global__ void BatchQKApplyRotaryKernel (DType* __restrict__ q, DType* __restrict__ k,
197+ DType* __restrict__ q_rope, DType* __restrict__ k_rope,
198+ IdType* __restrict__ indptr, IdType* __restrict__ offsets,
199+ uint32_t batch_size, uint32_t num_qo_heads,
200+ uint32_t num_kv_heads, size_t q_stride_n,
201+ size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
202+ float smooth_a, float smooth_b, float rope_rcp_scale,
203+ float rope_rcp_theta) {
204+ uint32_t bx = blockIdx .x , tx = threadIdx .x , ty = threadIdx .y ;
205+ const uint32_t bdy = blockDim .y ;
206+ vec_t <float , vec_size> freq;
207+ #pragma unroll
208+ for (uint32_t i = 0 ; i < vec_size; ++i) {
209+ if constexpr (interleave) {
210+ freq[i] = __powf (rope_rcp_theta, float (2 * ((tx * vec_size + i) / 2 )) / float (head_dim));
211+ } else {
212+ freq[i] = __powf (rope_rcp_theta,
213+ float (2 * ((tx * vec_size + i) % (head_dim / 2 ))) / float (head_dim));
214+ }
215+
216+ float smooth = freq[i] * smooth_a + smooth_b;
217+ smooth = max (0 .0f , min (1 .0f , smooth)); // clamp to [0, 1]
218+ freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
219+ }
220+
221+ if (bx < batch_size * num_qo_heads) {
222+ // apply rotary to q
223+ const uint32_t batch_idx = bx / num_qo_heads;
224+ const uint32_t qo_head_idx = bx % num_qo_heads;
225+ const uint32_t seq_len = indptr[batch_idx + 1 ] - indptr[batch_idx];
226+ const uint32_t offset = offsets[batch_idx];
227+ #pragma unroll 2
228+ for (uint32_t i = 0 ; i < (seq_len + bdy - 1 ) / bdy; ++i) {
229+ vec_t <float , vec_size> q_vec;
230+ if (i * bdy + ty < seq_len) {
231+ DType* q_ptr = q + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0 ,
232+ q_stride_n, q_stride_h);
233+ DType* q_rope_ptr =
234+ q_rope + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0 ,
235+ /* q_stride_n=*/ num_qo_heads * head_dim,
236+ /* q_stride_h=*/ head_dim);
237+ if constexpr (interleave) {
238+ q_vec =
239+ vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
240+ } else {
241+ q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
242+ }
243+ q_vec.cast_store (q_rope_ptr + tx * vec_size);
244+ }
245+ }
246+ } else {
247+ // apply rotary to k
248+ uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads;
249+ uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads;
250+ const uint32_t seq_len = indptr[batch_idx + 1 ] - indptr[batch_idx];
251+ const uint32_t offset = offsets[batch_idx];
252+ #pragma unroll 2
253+ for (uint32_t i = 0 ; i < (seq_len + bdy - 1 ) / bdy; ++i) {
254+ vec_t <float , vec_size> k_vec;
255+ if (i * bdy + ty < seq_len) {
256+ DType* k_ptr = k + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0 ,
257+ k_stride_n, k_stride_h);
258+ DType* k_rope_ptr =
259+ k_rope + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0 ,
260+ /* kv_stride_n=*/ num_kv_heads * head_dim,
261+ /* kv_stride_h=*/ head_dim);
262+ if constexpr (interleave) {
263+ k_vec =
264+ vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
265+ } else {
266+ k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
267+ }
268+ k_vec.cast_store (k_rope_ptr + +tx * vec_size);
269+ }
270+ }
271+ }
272+ }
273+
194274#define DISPATCH_INTERLEAVE (interleave, INTERLEAVE, ...) \
195275 if (interleave) { \
196276 const bool INTERLEAVE = true ; \
@@ -289,6 +369,100 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace(
289369 return cudaSuccess;
290370}
291371
372+ template <typename DType, typename IdType>
373+ cudaError_t BatchQKApplyRotary (DType* __restrict__ q, DType* __restrict__ k,
374+ DType* __restrict__ q_rope, DType* __restrict__ k_rope,
375+ IdType* __restrict__ indptr, IdType* __restrict__ offsets,
376+ uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
377+ uint32_t head_dim, size_t q_stride_n, size_t q_stride_h,
378+ size_t k_stride_n, size_t k_stride_h, bool interleave,
379+ float rope_scale, float rope_theta, cudaStream_t stream = nullptr ) {
380+ float rope_rcp_scale = 1 .0f / rope_scale;
381+ float rope_rcp_theta = 1 .0f / rope_theta;
382+ float smooth_a = 0 .f ;
383+ float smooth_b = 0 .f ;
384+
385+ DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
386+ DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
387+ constexpr uint32_t vec_size = std::max (16 / sizeof (DType), HEAD_DIM / 32 );
388+ constexpr uint32_t bdx = HEAD_DIM / vec_size;
389+ uint32_t num_threads = std::max (128U , bdx);
390+ uint32_t bdy = num_threads / bdx;
391+ dim3 nblks (batch_size * (num_qo_heads + num_kv_heads));
392+ dim3 nthrs (bdx, bdy);
393+ auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
394+ void * args[] = {(void *)&q,
395+ (void *)&k,
396+ (void *)&q_rope,
397+ (void *)&k_rope,
398+ (void *)&indptr,
399+ (void *)&offsets,
400+ (void *)&batch_size,
401+ (void *)&num_qo_heads,
402+ (void *)&num_kv_heads,
403+ (void *)&q_stride_n,
404+ (void *)&q_stride_h,
405+ (void *)&k_stride_n,
406+ (void *)&k_stride_h,
407+ (void *)&smooth_a,
408+ (void *)&smooth_b,
409+ (void *)&rope_rcp_scale,
410+ (void *)&rope_rcp_theta};
411+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, 0 , stream));
412+ });
413+ });
414+
415+ return cudaSuccess;
416+ }
417+
418+ template <typename DType, typename IdType>
419+ cudaError_t BatchQKApplyLlama31Rotary (DType* __restrict__ q, DType* __restrict__ k,
420+ DType* __restrict__ q_rope, DType* __restrict__ k_rope,
421+ IdType* __restrict__ indptr, IdType* __restrict__ offsets,
422+ uint32_t batch_size, uint32_t num_qo_heads,
423+ uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n,
424+ size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
425+ bool interleave, float rope_scale, float rope_theta,
426+ float low_freq_factor, float high_freq_factor,
427+ float old_context_length, cudaStream_t stream = nullptr ) {
428+ float rope_rcp_scale = 1 .0f / rope_scale;
429+ float rope_rcp_theta = 1 .0f / rope_theta;
430+ float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
431+ float smooth_b = -1 .0f / (high_freq_factor / low_freq_factor - 1 .0f );
432+
433+ DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
434+ DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
435+ constexpr uint32_t vec_size = std::max (16 / sizeof (DType), HEAD_DIM / 32 );
436+ constexpr uint32_t bdx = HEAD_DIM / vec_size;
437+ uint32_t num_threads = std::max (128U , bdx);
438+ uint32_t bdy = num_threads / bdx;
439+ dim3 nblks (batch_size * (num_qo_heads + num_kv_heads));
440+ dim3 nthrs (bdx, bdy);
441+ auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
442+ void * args[] = {(void *)&q,
443+ (void *)&k,
444+ (void *)&q_rope,
445+ (void *)&k_rope,
446+ (void *)&indptr,
447+ (void *)&offsets,
448+ (void *)&batch_size,
449+ (void *)&num_qo_heads,
450+ (void *)&num_kv_heads,
451+ (void *)&q_stride_n,
452+ (void *)&q_stride_h,
453+ (void *)&k_stride_n,
454+ (void *)&k_stride_h,
455+ (void *)&smooth_a,
456+ (void *)&smooth_b,
457+ (void *)&rope_rcp_scale,
458+ (void *)&rope_rcp_theta};
459+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, 0 , stream));
460+ });
461+ });
462+
463+ return cudaSuccess;
464+ }
465+
292466} // namespace flashinfer
293467
294468#endif // FLASHINFER_POS_ENC_CUH_
0 commit comments