1717#define FLASHINFER_HANDLER_CUH_
1818
1919#include < algorithm>
20+ #include < cstddef>
2021#include < memory>
2122#include < unordered_map>
2223#include < vector>
@@ -101,7 +102,7 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
101102cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched (
102103 uint32_t & tmp_size, uint32_t & max_grid_size, uint32_t & max_num_pages_per_batch,
103104 uint32_t & new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
104- const uint32_t page_size, cudaStream_t stream) {
105+ const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
105106 constexpr uint32_t vec_size = std::max (16UL / sizeof (DTypeIn), HEAD_DIM / 32UL );
106107 constexpr uint32_t num_stages_smem = 2U ;
107108 constexpr uint32_t bdx = HEAD_DIM / vec_size;
@@ -126,8 +127,10 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
126127 FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (
127128 &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
128129 max_grid_size = num_blocks_per_sm * num_sm;
129- if (batch_size * num_kv_heads >= max_grid_size) {
130+ if (batch_size * num_kv_heads >= max_grid_size && !enable_cuda_graph ) {
130131 // do not use partition-kv kernel
132+ // TODO(Zihao): if enable_cuda_graph, we should always use partition-kv kernel
133+ // so that only one kernel will be captured in the graph.
131134 tmp_size = 0 ;
132135 new_batch_size = batch_size;
133136 } else {
@@ -299,39 +302,42 @@ class BatchDecodeHandler {
299302 DTypeOut, IdType>;
300303 FLASHINFER_CUDA_CALL (work_estimation_func (tmp_size, max_grid_size, max_num_pages_per_batch,
301304 new_batch_size, batch_size, indptr, num_qo_heads,
302- page_size, stream_));
305+ page_size,
306+ /* enable_cuda_graph=*/ false , stream_));
303307 batch_size_after_partition_ = new_batch_size;
304308 if (tmp_size > 0 ) {
305309 AlignedAlloactor allocator (buffer, workspace_size_in_bytes);
306310 float_buffer_ = allocator.aligned_alloc <void *>(tmp_size, 16 );
307311 new_indptr_ =
308312 allocator.aligned_alloc <void *>((batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
309- void * new_indptr_h_ = host_buffer_ ;
313+ void * new_indptr_h_ = page_locked_buffer_ ;
310314 new_last_page_len_ =
311315 allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
312316 void * new_last_page_len_h_ =
313- (char *)host_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
317+ (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
314318 chunk_indptr_ =
315319 allocator.aligned_alloc <void *>((batch_size_before_partition_ + 1 ) * sizeof (IdType), 16 );
316- void * chunk_indptr_h_ = (char *)host_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
320+ void * chunk_indptr_h_ =
321+ (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
317322 batch_idx_map_ =
318323 allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
319- void * batch_idx_map_h_ = (char *)host_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
324+ void * batch_idx_map_h_ =
325+ (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
320326 chunk_start_pos_ =
321327 allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
322328 void * chunk_start_pos_h_ =
323- (char *)host_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
329+ (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
324330 seq_lengths_before_partition_ =
325331 allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
326332 void * seq_lengths_before_partition_h_ =
327- (char *)host_buffer_ + ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
333+ (char *)page_locked_buffer_ + ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
328334 size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
329335 FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
330336 max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
331337 (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
332338 (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
333- (IdType*)seq_lengths_before_partition_h_, new_indptr_, host_buffer_, num_bytes_to_copy ,
334- stream_));
339+ (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_ ,
340+ num_bytes_to_copy, stream_));
335341 }
336342 forward_started_ = true ;
337343 return cudaSuccess;
@@ -353,6 +359,11 @@ class BatchDecodeHandler {
353359
354360 bool IsForwardStarted () const { return forward_started_; }
355361
362+ void UpdatePageLockedBufferSize (size_t max_workspace_size_in_bytes) {
363+ cudaFreeHost (page_locked_buffer_);
364+ cudaMallocHost (&page_locked_buffer_, max_workspace_size_in_bytes);
365+ }
366+
356367 uint32_t GetBatchSizeBeforePartition () const { return batch_size_before_partition_; }
357368
358369 uint32_t GetBatchSizeAfterPartition () const { return batch_size_after_partition_; }
@@ -372,17 +383,19 @@ class BatchDecodeHandler {
372383 seq_lengths_before_partition_(nullptr ),
373384 forward_started_(false ),
374385 stream_(nullptr ) {
375- cudaMallocHost (&host_buffer_ , max_workspace_size_in_bytes);
386+ cudaMallocHost (&page_locked_buffer_ , max_workspace_size_in_bytes);
376387 }
377388 ~BatchDecodeHandler () {
378389 EndForward ();
379- cudaFreeHost (host_buffer_ );
390+ cudaFreeHost (page_locked_buffer_ );
380391 }
381392
382- private:
393+ virtual bool IsCUDAGraphMode () const { return false ; }
394+
395+ protected:
383396 uint32_t batch_size_before_partition_;
384397 uint32_t batch_size_after_partition_;
385- void * host_buffer_ ;
398+ void * page_locked_buffer_ ;
386399 void * float_buffer_;
387400 void * new_indptr_;
388401 void * new_last_page_len_;
@@ -394,6 +407,86 @@ class BatchDecodeHandler {
394407 cudaStream_t stream_;
395408};
396409
410+ class CUDAGraphBatchDecodeHandler : public BatchDecodeHandler {
411+ public:
412+ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
413+ PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
414+ cudaError_t CUDAGraphBeginForwardDispatched (void * buffer, size_t workspace_size_in_bytes,
415+ IdType* indptr, IdType* last_page_len,
416+ uint32_t batch_size, uint32_t num_qo_heads,
417+ uint32_t page_size) {
418+ batch_size_before_partition_ = batch_size;
419+ uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
420+ auto work_estimation_func =
421+ BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
422+ kv_layout, POS_ENCODING_MODE, DTypeIn,
423+ DTypeOut, IdType>;
424+ FLASHINFER_CUDA_CALL (work_estimation_func (tmp_size, max_grid_size, max_num_pages_per_batch,
425+ new_batch_size, batch_size, indptr, num_qo_heads,
426+ page_size,
427+ /* enable_cuda_graph=*/ true , stream_));
428+ // NOTE(Zihao): max_batch_size_after_partition_ is determined in handler initialization.
429+ // the value should not be changed during the lifetime of the handler.
430+ // So it should be compatible with CUDAGraph which requires fixed pointer.
431+ batch_size_after_partition_ = new_batch_size;
432+ size_t max_tmp_size = num_qo_heads * max_batch_size_after_partition_ *
433+ (HEAD_DIM * sizeof (DTypeOut) + 2 * sizeof (float ));
434+ AlignedAlloactor allocator (buffer, workspace_size_in_bytes);
435+ float_buffer_ = allocator.aligned_alloc <void *>(max_tmp_size, 16 );
436+ new_indptr_ =
437+ allocator.aligned_alloc <void *>((max_batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
438+
439+ void * new_indptr_h_ = page_locked_buffer_;
440+ new_last_page_len_ =
441+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
442+ void * new_last_page_len_h_ =
443+ (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
444+ chunk_indptr_ =
445+ allocator.aligned_alloc <void *>((max_batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
446+ void * chunk_indptr_h_ =
447+ (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
448+ batch_idx_map_ =
449+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
450+ void * batch_idx_map_h_ =
451+ (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
452+ chunk_start_pos_ =
453+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
454+ void * chunk_start_pos_h_ =
455+ (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
456+ seq_lengths_before_partition_ =
457+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
458+ void * seq_lengths_before_partition_h_ =
459+ (char *)page_locked_buffer_ + ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
460+
461+ size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
462+ FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
463+ max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
464+ (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
465+ (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
466+ (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_,
467+ num_bytes_to_copy, stream_));
468+ forward_started_ = true ;
469+ return cudaSuccess;
470+ }
471+ CUDAGraphBatchDecodeHandler (size_t max_batch_size) {
472+ int dev_id = 0 , num_sm = 0 , max_thread_blocks_per_sm = 0 ;
473+ cudaGetDevice (&dev_id);
474+ cudaDeviceGetAttribute (&num_sm, cudaDevAttrMultiProcessorCount, dev_id);
475+ cudaDeviceGetAttribute (&max_thread_blocks_per_sm, cudaDevAttrMaxBlocksPerMultiprocessor,
476+ dev_id);
477+ max_batch_size_after_partition_ =
478+ std::max<size_t >(max_thread_blocks_per_sm * num_sm, max_batch_size);
479+ std::cout << max_thread_blocks_per_sm * num_sm << " " << max_batch_size << std::endl;
480+ size_t max_workspace_size_in_bytes =
481+ 6 * (sizeof (uint64_t ) * (max_batch_size_after_partition_ + 1 ) + 16 );
482+ cudaMallocHost (&page_locked_buffer_, max_workspace_size_in_bytes);
483+ }
484+ bool IsCUDAGraphMode () const override { return true ; }
485+
486+ private:
487+ uint32_t max_batch_size_after_partition_;
488+ };
489+
397490class BatchPrefillHandler {
398491 public:
399492 template <typename IdType>
@@ -412,6 +505,11 @@ class BatchPrefillHandler {
412505
413506 bool IsForwardStarted () const { return request_indices_ != nullptr ; }
414507
508+ void UpdatePageLockedBufferSize (size_t max_workspace_size_in_bytes) {
509+ cudaFreeHost (page_locked_buffer_);
510+ cudaMallocHost (&page_locked_buffer_, max_workspace_size_in_bytes);
511+ }
512+
415513 template <typename IdType>
416514 cudaError_t BeginForward (void * buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
417515 uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
@@ -429,14 +527,15 @@ class BatchPrefillHandler {
429527 AlignedAlloactor allocator (buffer, workspace_size_in_bytes);
430528 request_indices_ =
431529 allocator.aligned_alloc <void *>(sizeof (IdType) * request_indices_vec.size (), 16 );
432- void * request_indices_h_ = host_buffer_ ;
530+ void * request_indices_h_ = page_locked_buffer_ ;
433531 tile_indices_ = allocator.aligned_alloc <void *>(sizeof (IdType) * tile_indices_vec.size (), 16 );
434- void * tile_indices_h_ = (char *)host_buffer_ + ((char *)tile_indices_ - (char *)request_indices_);
532+ void * tile_indices_h_ =
533+ (char *)page_locked_buffer_ + ((char *)tile_indices_ - (char *)request_indices_);
435534 std::copy (request_indices_vec.begin (), request_indices_vec.end (), (IdType*)request_indices_h_);
436535 std::copy (tile_indices_vec.begin (), tile_indices_vec.end (), (IdType*)tile_indices_h_);
437536 size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)request_indices_;
438537
439- FLASHINFER_CUDA_CALL (cudaMemcpyAsync (request_indices_, host_buffer_ , num_bytes_to_copy,
538+ FLASHINFER_CUDA_CALL (cudaMemcpyAsync (request_indices_, page_locked_buffer_ , num_bytes_to_copy,
440539 cudaMemcpyHostToDevice, stream_));
441540
442541 return cudaSuccess;
@@ -462,15 +561,15 @@ class BatchPrefillHandler {
462561 num_qo_tiles_(0U ),
463562 forward_started_(false ),
464563 stream_(nullptr ) {
465- cudaMallocHost (&host_buffer_ , max_workspace_size_in_bytes);
564+ cudaMallocHost (&page_locked_buffer_ , max_workspace_size_in_bytes);
466565 }
467566 ~BatchPrefillHandler () {
468567 EndForward ();
469- cudaFreeHost (host_buffer_ );
568+ cudaFreeHost (page_locked_buffer_ );
470569 }
471570
472571 private:
473- void * host_buffer_ ;
572+ void * page_locked_buffer_ ;
474573 void * request_indices_;
475574 void * tile_indices_;
476575 uint32_t num_frags_x_;
0 commit comments