diff --git a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py index 80310a74..b5e1ebcd 100644 --- a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py +++ b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py @@ -39,11 +39,19 @@ def get_cu_file_str( def get_insts(attention_variant): return "\n".join( [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( Params& params, cudaStream_t stream); -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( Params& params, cudaStream_t stream); """.format( diff --git a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py index e26a7389..ad53dc31 100644 --- a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py +++ b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py @@ -40,11 +40,19 @@ def get_cu_file_str( def get_insts(attention_variant): return "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( Params& params, cudaStream_t stream); -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( Params& params, cudaStream_t stream); """.format( diff --git a/csrc/aot_extension_utils.h b/csrc/aot_extension_utils.h index b701c289..9e8ad758 100644 --- a/csrc/aot_extension_utils.h +++ b/csrc/aot_extension_utils.h @@ -30,17 +30,6 @@ #define DISPATCH_mask_mode(expr, const_expr, ...) \ _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) -#define DISPATCH_BOOL(expr, const_expr, ...) \ - [&]() -> bool { \ - if (expr) { \ - constexpr bool const_expr = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool const_expr = false; \ - return __VA_ARGS__(); \ - } \ - }() - #define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \ [&]() -> bool { \ if (kv_dtype == q_dtype) { \ diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index 0d358e90..ea53eb12 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -29,14 +29,14 @@ namespace flashinfer { template + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ, + typename DTypeKV, typename DTypeO, typename IdType> cudaError_t BatchPrefillWithRaggedKVCacheDispatched( BatchPrefillRaggedParams& params, cudaStream_t stream); template + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ, + typename DTypeKV, typename DTypeO, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheDispatched( BatchPrefillPagedParams& params, cudaStream_t stream); @@ -47,9 +47,9 @@ using namespace flashinfer; std::vector BatchPrefillWithKVCacheSM90Plan( unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int64_t cuda_stream) { + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows, + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = @@ -61,12 +61,13 @@ std::vector BatchPrefillWithKVCacheSM90Plan( cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaError_t status = PrefillSM90Plan( - float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, - int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), - int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), - kv_indptr.data_ptr(), kv_len_arr.data_ptr(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + cudaError_t status = + PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), kv_len_arr.data_ptr(), total_num_rows, + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, causal, + enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TORCH_CHECK(status == cudaSuccess, "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); @@ -151,19 +152,23 @@ void BatchPrefillWithRaggedKVCacheSM90Run( GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { return DISPATCH_BOOL(use_swa, USE_SWA, [&] { - using AttentionVariant = - std::conditional_t; - cudaError_t status = - BatchPrefillWithRaggedKVCacheDispatched(params, stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", - cudaGetErrorString(status)); - return true; + return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched< + HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>( + params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); }); }); }); @@ -259,20 +264,23 @@ void BatchPrefillWithPagedKVCacheSM90Run( GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); params.kv_indices = static_cast(paged_kv_indices.data_ptr()); + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { return DISPATCH_BOOL(use_swa, USE_SWA, [&] { - using AttentionVariant = - std::conditional_t; - cudaError_t status = - BatchPrefillWithPagedKVCacheDispatched(params, stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", - cudaGetErrorString(status)); - return true; + return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = BatchPrefillWithPagedKVCacheDispatched< + HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>( + params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); }); }); }); diff --git a/csrc/flashinfer_ops_sm90.cu b/csrc/flashinfer_ops_sm90.cu index cc3ac869..39c02c50 100644 --- a/csrc/flashinfer_ops_sm90.cu +++ b/csrc/flashinfer_ops_sm90.cu @@ -33,9 +33,9 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q std::vector BatchPrefillWithKVCacheSM90Plan( unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int64_t cuda_stream); + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows, + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheSM90Run( unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, diff --git a/csrc/pytorch_extension_utils.h b/csrc/pytorch_extension_utils.h index 825d6431..dc830574 100644 --- a/csrc/pytorch_extension_utils.h +++ b/csrc/pytorch_extension_utils.h @@ -189,6 +189,17 @@ return __VA_ARGS__(); \ } +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", diff --git a/flashinfer/jit/batch_prefill_sm90_templ.py b/flashinfer/jit/batch_prefill_sm90_templ.py index 88bd2407..e36e8ded 100644 --- a/flashinfer/jit/batch_prefill_sm90_templ.py +++ b/flashinfer/jit/batch_prefill_sm90_templ.py @@ -44,13 +44,25 @@ def ragged_prefill_sm90_inst_templ(mask_mode: str) -> str: template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ + mask_mode - + r""", /*USE_SWA=*/true, AttentionVariant>( + + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( RaggedParams& params, cudaStream_t stream); template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ + mask_mode - + r""", /*USE_SWA=*/false, AttentionVariant>( + + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( + RaggedParams& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( + RaggedParams& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( RaggedParams& params, cudaStream_t stream); @@ -77,13 +89,25 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ + mask_mode - + r""", /*USE_SWA=*/true, AttentionVariant>( + + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( + PagedParams& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( + PagedParams& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( PagedParams& params, cudaStream_t stream); template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ + mask_mode - + r""", /*USE_SWA=*/false, AttentionVariant>( + + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( PagedParams& params, cudaStream_t stream); @@ -100,8 +124,8 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: std::vector BatchPrefillWithKVCacheSM90Plan( bool causal, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows, + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); @@ -115,8 +139,9 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<{{ dtype_idx }}>(), - kv_indptr.data_ptr<{{ dtype_idx }}>(), kv_len_arr.data_ptr<{{ dtype_idx }}>(), batch_size, num_qo_heads, - num_kv_heads, {{ head_dim }}, page_size, causal, enable_cuda_graph, sizeof({{dtype_o}}), stream); + kv_indptr.data_ptr<{{ dtype_idx }}>(), kv_len_arr.data_ptr<{{ dtype_idx }}>(), + total_num_rows, batch_size, num_qo_heads, num_kv_heads, {{ head_dim }}, page_size, + causal, enable_cuda_graph, sizeof({{dtype_o}}), stream); TORCH_CHECK(status == cudaSuccess, "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); @@ -141,7 +166,7 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: namespace flashinfer { -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( @@ -224,16 +249,21 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: params.head_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - using AttentionVariant = - std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; - cudaError_t status = - BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, - AttentionVariant>(params, stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", - cudaGetErrorString(status)); + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + using AttentionVariant = + std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; + cudaError_t status = + BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, + SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + + return true; + }); }); } """, @@ -255,7 +285,7 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: namespace flashinfer { -template cudaError_t BatchPrefillWithPagedKVCacheDispatched( @@ -352,16 +382,20 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); params.kv_indices = static_cast(paged_kv_indices.data_ptr()); + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - using AttentionVariant = - std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; - cudaError_t status = - BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, - AttentionVariant>(params, stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", - cudaGetErrorString(status)); + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + using AttentionVariant = + std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; + cudaError_t status = + BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, + SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); }); } """, @@ -370,8 +404,8 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: std::vector BatchPrefillWithKVCacheSM90Plan( bool causal, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows, + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheSM90Run( diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index fd8252e3..8dd45ed7 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1106,13 +1106,13 @@ def __init__( self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) if backend in ["fa3", "auto"]: - self._int_workspace_buffer = torch.empty( - (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - # NOTE(Zihao): assume maximum accumulate kv length is 16M + # NOTE(Zihao): assume maximum accumulate kv length is 4M self._vector_sparse_indices_buffer = torch.empty( - (16 * 1024 * 1024,), dtype=torch.int32, device=self.device + (4 * 1024 * 1024,), dtype=torch.int32, device=self.device ) # NOTE(Zihao): assume maximum batch size is 32768 self._vector_sparse_indptr_buffer = torch.empty( @@ -1121,10 +1121,6 @@ def __init__( self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) - else: - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, @@ -1474,6 +1470,7 @@ def plan( qo_indptr_host, vector_sparse_indptr_host, kv_lens_arr_host, + self._max_total_num_rows or total_num_rows, batch_size, num_qo_heads, num_kv_heads, @@ -1864,14 +1861,9 @@ def __init__( self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - if backend in ["fa3", "auto"]: - self._int_workspace_buffer = torch.empty( - (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - else: - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True ) @@ -2147,6 +2139,7 @@ def plan( qo_indptr_host, kv_indptr_host, kv_len_arr, + self._max_total_num_rows or total_num_rows, batch_size, num_qo_heads, num_kv_heads, diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 7732ce05..1e308d5e 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -130,13 +130,13 @@ def __init__( """ self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) if backend in ["fa3", "auto"]: - self._int_workspace_buffer = torch.empty( - (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - # NOTE(Zihao): assume maximum accumulate kv length is 16M + # NOTE(Zihao): assume maximum accumulate kv length is 4M self._vector_sparse_indices_buffer = torch.empty( - (16 * 1024 * 1024,), dtype=torch.int32, device=self.device + (4 * 1024 * 1024,), dtype=torch.int32, device=self.device ) # NOTE(Zihao): assume maximum batch size is 32768 self._vector_sparse_indptr_buffer = torch.empty( @@ -145,15 +145,9 @@ def __init__( self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) - else: - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), - dtype=torch.uint8, - device=float_workspace_buffer.device, - ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, - dtype=self._int_workspace_buffer.dtype, + dtype=torch.uint8, pin_memory=True, ) self._use_cuda_graph = False @@ -453,6 +447,7 @@ def plan( qo_indptr_host, vector_sparse_indptr_host, kv_lens_arr_host, + M, # total_num_rows num_blocks_row, # batch_size num_qo_heads, num_kv_heads, diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index 708f80f3..1a1f0027 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -291,7 +291,8 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched( return cudaSuccess; } -template +template cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( BatchPrefillPagedParams& params, @@ -303,7 +304,9 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( using CollectiveMainloop = SparseCollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; - using Scheduler = BatchPrefillTileScheduler; + using Scheduler = + std::conditional_t, + BatchPrefillPersistentTileScheduler>; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {params.q_ptr, @@ -324,12 +327,12 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( params.o_stride_h), // layout_O params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE }); - typename Scheduler::Arguments scheduler_args = { params.work_indptr, params.head_indices, params.qo_tile_indices, params.qo_indptr, params.kv_indptr, params.qo_lens, - params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), + params.num_qo_heads}; typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); // Get the ptr to kernel function. @@ -348,13 +351,15 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; dim3 block_dims(ctaSize); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( BatchPrefillRaggedParams& params, @@ -366,7 +371,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( using CollectiveMainloop = CollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; - using Scheduler = BatchPrefillTileScheduler; + using Scheduler = + std::conditional_t, + BatchPrefillPersistentTileScheduler>; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {params.q_ptr, get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.q_stride_n, @@ -392,7 +399,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( params.work_indptr, params.head_indices, params.qo_tile_indices, params.qo_indptr, params.kv_indptr, params.qo_lens, - params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), + params.num_qo_heads}; typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); // Get the ptr to kernel function. @@ -452,8 +460,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ, + typename DTypeKV, typename DTypeO, typename IdType> cudaError_t BatchPrefillWithRaggedKVCacheDispatched( BatchPrefillRaggedParams& params, cudaStream_t stream) { static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); @@ -466,27 +474,27 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else if constexpr (HEAD_DIM == 128) { BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else { // HEAD_DIM == 256; BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } cudaError_t status = cudaGetLastError(); return status; } template + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ, + typename DTypeKV, typename DTypeO, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheDispatched( BatchPrefillPagedParams& params, cudaStream_t stream) { static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); @@ -500,12 +508,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else if constexpr (HEAD_DIM == 128) { BatchPrefillWithPagedKVCacheKernelTraitsDispatched< AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else { // HEAD_DIM == 256; // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later @@ -513,7 +521,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } cudaError_t status = cudaGetLastError(); return status; diff --git a/include/flashinfer/attention/hopper/tile_scheduler.cuh b/include/flashinfer/attention/hopper/tile_scheduler.cuh index 39610271..4637d3cf 100644 --- a/include/flashinfer/attention/hopper/tile_scheduler.cuh +++ b/include/flashinfer/attention/hopper/tile_scheduler.cuh @@ -78,13 +78,14 @@ struct SingleTileScheduler { }; template -struct BatchPrefillTileScheduler { +struct BatchPrefillPersistentTileScheduler { public: // Host side kernel arguments struct Arguments { IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, *kv_lens; cutlass::FastDivmod group_size_fastdiv; + int num_qo_heads; // placeholder }; // Device side kernel params @@ -99,9 +100,7 @@ struct BatchPrefillTileScheduler { args.kv_indptr, args.qo_lens, args.kv_lens, args.group_size_fastdiv}; } - static dim3 get_grid_dim(Arguments const& args, int num_sm) { - return {132U}; // 132 - } + static dim3 get_grid_dim(Arguments const& args, int num_sm) { return {(unsigned)num_sm}; } struct WorkTileInfo { int q_tile_idx = 0; @@ -126,7 +125,7 @@ struct BatchPrefillTileScheduler { }; CUTLASS_DEVICE - BatchPrefillTileScheduler() {} + BatchPrefillPersistentTileScheduler() {} CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { @@ -143,7 +142,7 @@ struct BatchPrefillTileScheduler { params.kv_indptr[work_idx], params.qo_lens[work_idx], params.kv_lens[work_idx], - 0, + /*counter=*/0, ptr_begin, ptr_end}; } else { @@ -191,6 +190,117 @@ struct BatchPrefillTileScheduler { } }; +/*! + * \brief Tile scheduler that maps q/o head to blockIdx.y + */ +template +struct BatchPrefillTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, + *kv_lens; // head_indices is a placeholder + cutlass::FastDivmod group_size_fastdiv; + int num_qo_heads; + }; + + // Device side kernel params + struct Params { + IdType *work_indptr, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, *kv_lens; + cutlass::FastDivmod group_size_fastdiv; + int num_qo_heads; + }; + + static Params to_underlying_arguments(Arguments const& args) { + return {args.work_indptr, args.qo_tile_indices, args.qo_indptr, args.kv_indptr, + args.qo_lens, args.kv_lens, args.group_size_fastdiv, args.num_qo_heads}; + } + + static dim3 get_grid_dim(Arguments const& args, int num_sm) { + return {(unsigned)num_sm, (unsigned)args.num_qo_heads}; + } + + struct WorkTileInfo { + int q_tile_idx = 0; + int qo_head_idx = 0; + int kv_head_idx = 0; + int qo_indptr = 0; + int kv_indptr = 0; + int qo_len = 0; + int kv_len = 0; + int counter = 0; + int ptr_begin = 0; + int ptr_end = 0; + + CUTLASS_DEVICE + bool is_valid(Params const& params) const { return counter + ptr_begin < ptr_end; } + + CUTLASS_DEVICE + auto get_block_coord(Params const& params) const { + return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, + kv_indptr, qo_len, kv_len}; + } + }; + + CUTLASS_DEVICE + BatchPrefillTileScheduler() {} + + CUTLASS_DEVICE + WorkTileInfo get_initial_work(Params const& params) const { + int ptr_begin = params.work_indptr[blockIdx.x]; + int ptr_end = params.work_indptr[blockIdx.x + 1]; + if (ptr_begin < ptr_end) { + int work_idx = ptr_begin; + int qo_head_idx = blockIdx.y; + int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); + return {params.qo_tile_indices[work_idx], + /*qo_head_idx=*/qo_head_idx, + /*kv_head_idx=*/kv_head_idx, + params.qo_indptr[work_idx], + params.kv_indptr[work_idx], + params.qo_lens[work_idx], + params.kv_lens[work_idx], + /*counter=*/0, + ptr_begin, + ptr_end}; + } else { + return {-1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end}; + } + } + + CUTLASS_DEVICE + void init_consumer() const {} + + CUTLASS_DEVICE + void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, + WorkTileInfo const& current_work) const { + int work_idx = current_work.ptr_begin + current_work.counter + 1; + if (work_idx < current_work.ptr_end) { + return {params.qo_tile_indices[work_idx], current_work.qo_head_idx, + current_work.kv_head_idx, params.qo_indptr[work_idx], + params.kv_indptr[work_idx], params.qo_lens[work_idx], + params.kv_lens[work_idx], current_work.counter + 1, + current_work.ptr_begin, current_work.ptr_end}; + } else { + return {-1, + -1, + -1, + -1, + -1, + -1, + current_work.counter + 1, + current_work.ptr_begin, + current_work.ptr_end}; + } + } +}; + } // namespace flashinfer #endif // FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_ diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index f8023171..b83da70d 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -742,6 +742,7 @@ struct PrefillPlanSM90Info { int64_t kv_len_offset; int64_t head_indices_offset; int64_t work_indptr_offset; + bool same_schedule_for_all_heads; PrefillPlanSM90Info() : qo_tile_indices_offset(0), @@ -750,17 +751,20 @@ struct PrefillPlanSM90Info { qo_len_offset(0), kv_len_offset(0), head_indices_offset(0), - work_indptr_offset(0) {} + work_indptr_offset(0), + same_schedule_for_all_heads(false) {} // convert PrefillPlanSM90Info to std::vector std::vector ToVector() const { - return {qo_tile_indices_offset, qo_indptr_offset, kv_indptr_offset, qo_len_offset, - kv_len_offset, head_indices_offset, work_indptr_offset}; + return {qo_tile_indices_offset, qo_indptr_offset, + kv_indptr_offset, qo_len_offset, + kv_len_offset, head_indices_offset, + work_indptr_offset, same_schedule_for_all_heads}; } // From std::vector to PrefillPlanSM90Info void FromVector(const std::vector& vec) { - if (vec.size() != 7) { + if (vec.size() != 8) { std::ostringstream err_msg; err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be 8, but got " << vec.size(); FLASHINFER_ERROR(err_msg.str()); @@ -772,6 +776,7 @@ struct PrefillPlanSM90Info { kv_len_offset = vec[4]; head_indices_offset = vec[5]; work_indptr_offset = vec[6]; + same_schedule_for_all_heads = vec[7]; } }; @@ -780,9 +785,10 @@ cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_size_in_b void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, PrefillPlanSM90Info& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, uint32_t page_size, bool causal, - bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) { + uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, + bool causal, bool enable_cuda_graph, uint32_t sizeof_dtype_o, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -825,7 +831,11 @@ cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_size_in_b cta_kv_len(num_sm90_ctas, std::vector()), cta_head_indices(num_sm90_ctas, std::vector()); - for (int qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + int max_num_works_per_head = ceil_div(total_num_rows, cta_tile_q); + plan_info.same_schedule_for_all_heads = max_num_works_per_head > 4096; + + for (int qo_head_idx = 0; + qo_head_idx < (plan_info.same_schedule_for_all_heads ? 1 : num_qo_heads); ++qo_head_idx) { for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) { int num_qo_tiles = ceil_div(qo_len, cta_tile_q); for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { @@ -853,7 +863,7 @@ cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_size_in_b for (uint32_t i = 0; i < num_sm90_ctas; ++i) { work_indptr_vec[i + 1] = work_indptr_vec[i] + cta_qo_tile_indices[i].size(); } - IdType total_num_works = work_indptr_vec[num_sm90_ctas]; + int total_num_works = work_indptr_vec.back(); auto qo_tile_indices_vec = flatten(cta_qo_tile_indices, total_num_works); auto qo_indptr_vec = flatten(cta_qo_indptr, total_num_works); auto kv_indptr_vec = flatten(cta_kv_indptr, total_num_works); @@ -862,13 +872,16 @@ cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_size_in_b auto head_indices_vec = flatten(cta_head_indices, total_num_works); AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - const int max_total_num_works = 1048576; - if (total_num_works > max_total_num_works) { - std::ostringstream err_msg; - err_msg << "total_num_works " << total_num_works << " should be less than " - << max_total_num_works; - FLASHINFER_ERROR(err_msg.str()); + int max_total_num_works; + + if (enable_cuda_graph) { + max_total_num_works = plan_info.same_schedule_for_all_heads + ? max_num_works_per_head + : max_num_works_per_head * num_qo_heads; + } else { + max_total_num_works = total_num_works; } + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_tile_indices"); plan_info.qo_indptr_offset = int_allocator.aligned_alloc_offset( diff --git a/tests/test_hopper.py b/tests/test_hopper.py index 1fbad5ff..6a6e12a0 100644 --- a/tests/test_hopper.py +++ b/tests/test_hopper.py @@ -175,7 +175,8 @@ def test_batch_paged_prefill( kv_indptr = torch.arange( 0, batch_size * num_pages_per_request + 1, num_pages_per_request ).int() - kv_indices = torch.arange(0, batch_size * num_pages_per_request).int() + # NOTE(Zihao): pad 256 elements to avoid out-of-bound because we didn't check the boundary in the kernel + kv_indices = torch.arange(0, batch_size * num_pages_per_request + 256).int() last_page_len = torch.full((batch_size,), last_page_len, dtype=torch.int32) wrapper_sm80.plan(