@@ -328,10 +328,13 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q
328328
329329void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward (
330330 int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size,
331- int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim) {
331+ int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream ) {
332332 CHECK_EQ (workspace_buffer->ndim , 1 ) << " The workspace buffer must be a 1-D tensor" ;
333333 size_t workspace_size_in_bytes = workspace_buffer->shape [0 ] * workspace_buffer->dtype .bits / 8 ;
334334 CHECK (handler_idx < max_num_handlers) << " The handler id must be less than " << max_num_handlers;
335+ cudaStream_t original_stream = batch_prefill_paged_kv_handlers[handler_idx].GetCUDAStream ();
336+ batch_prefill_paged_kv_handlers[handler_idx].SetCUDAStream (
337+ static_cast <cudaStream_t>(copy_stream));
335338 DISPATCH_TVM_CUDA_IDTYPE (qo_indptr->dtype , dtype_idx, {
336339 cudaError_t status = batch_prefill_paged_kv_handlers[handler_idx].BeginForward (
337340 static_cast <void *>(workspace_buffer->data ), workspace_size_in_bytes,
@@ -340,6 +343,7 @@ void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward(
340343 LOG (FATAL) << " FlashInfer prefill BeginForward error " << cudaGetErrorString (status);
341344 }
342345 });
346+ batch_prefill_paged_kv_handlers[handler_idx].SetCUDAStream (original_stream);
343347}
344348
345349void _FlashInferAttentionPrefillWithPagedKVCacheEndForward (int64_t handler_idx) {
@@ -456,7 +460,7 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_
456460void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward (
457461 int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* page_table_indptr,
458462 DLTensor* last_page_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
459- int64_t page_size, int64_t pos_encoding_mode) {
463+ int64_t page_size, int64_t pos_encoding_mode, TVMStreamHandle copy_stream ) {
460464 CHECK_EQ (workspace_buffer->ndim , 1 ) << " The workspace buffer must be a 1-D tensor" ;
461465 size_t workspace_size_in_bytes = workspace_buffer->shape [0 ] * workspace_buffer->dtype .bits / 8 ;
462466 CHECK_LT (handler_idx, max_num_handlers)
@@ -467,6 +471,8 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward(
467471 // leave a parameter for the input data type.
468472 using dtype_in = half;
469473 const uint32_t batch_size = page_table_indptr->shape [0 ] - 1 ;
474+ cudaStream_t original_stream = batch_decode_handlers[handler_idx].GetCUDAStream ();
475+ batch_decode_handlers[handler_idx].SetCUDAStream (static_cast <cudaStream_t>(copy_stream));
470476 DISPATCH_TVM_CUDA_IDTYPE (page_table_indptr->dtype , dtype_idx, {
471477 cudaError_t status =
472478 batch_decode_handlers[handler_idx]
@@ -479,6 +485,7 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward(
479485 LOG (FATAL) << " FlashInfer decode BeginForward error " << cudaGetErrorString (status);
480486 }
481487 });
488+ batch_decode_handlers[handler_idx].SetCUDAStream (original_stream);
482489}
483490
484491void _FlashInferAttentionDecodeWithPagedKVCacheEndForward (int64_t handler_id) {
@@ -606,9 +613,11 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(
606613
607614void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward (
608615 DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size, int64_t num_qo_heads,
609- int64_t num_kv_heads, int64_t head_dim) {
616+ int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream ) {
610617 CHECK_EQ (workspace_buffer->ndim , 1 ) << " The workspace buffer must be a 1-D tensor" ;
611618 size_t workspace_size_in_bytes = workspace_buffer->shape [0 ] * workspace_buffer->dtype .bits / 8 ;
619+ cudaStream_t original_stream = batch_prefill_ragged_kv_handler.GetCUDAStream ();
620+ batch_prefill_ragged_kv_handler.SetCUDAStream (static_cast <cudaStream_t>(copy_stream));
612621
613622 DISPATCH_TVM_CUDA_IDTYPE (qo_indptr->dtype , dtype_idx, {
614623 cudaError_t status = batch_prefill_ragged_kv_handler.BeginForward (
@@ -619,6 +628,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward(
619628 << cudaGetErrorString (status);
620629 }
621630 });
631+ batch_prefill_ragged_kv_handler.SetCUDAStream (original_stream);
622632}
623633
624634void _FlashInferAttentionPrefillWithRaggedKVCacheEndForward () {
0 commit comments