Skip to content

Commit 24cc583

Browse files
authored
feat: support cuda graph for batched multi-query(prefill/append) attention (#277)
#275 is not complete, this pr pushes the remaining changes.
1 parent 081a4c5 commit 24cc583

File tree

11 files changed

+587
-575
lines changed

11 files changed

+587
-575
lines changed

include/flashinfer/attention/decode.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include "../utils.cuh"
3737
#include "../vec_dtypes.cuh"
3838
#include "cascade.cuh"
39-
#include "handler.cuh"
4039
#include "state.cuh"
4140

4241
namespace flashinfer {

include/flashinfer/attention/handler.cuh

Lines changed: 107 additions & 102 deletions
Large diffs are not rendered by default.

include/flashinfer/attention/prefill.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@
3535
#include "../pos_enc.cuh"
3636
#include "../utils.cuh"
3737
#include "cascade.cuh"
38-
#include "handler.cuh"
3938
#include "mask.cuh"
40-
#include "state.cuh"
4139

4240
namespace flashinfer {
4341

include/flashinfer/sampling.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,6 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
679679
auto& temp_storage = reinterpret_cast<
680680
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
681681

682-
bool rejected = false;
683682
uint32_t pos = 0;
684683
for (pos = 0; pos < num_speculative_tokens; ++pos) {
685684
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos];

python/csrc/batch_decode.cu

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -141,32 +141,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
141141
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
142142
return DISPATCH_pos_encoding_mode(
143143
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
144-
if (handler_->IsCUDAGraphMode()) {
145-
// NOTE(Zihao): use runtime dispatch because template function is not virtual
146-
auto cuda_graph_handler_ =
147-
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
148-
cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
149-
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
150-
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
151-
workspace_size_in_bytes,
152-
static_cast<int32_t*>(indptr.data_ptr()),
153-
static_cast<int32_t*>(last_page_len.data_ptr()),
154-
batch_size, num_qo_heads, page_size);
155-
TORCH_CHECK(status == cudaSuccess,
156-
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
157-
cudaGetErrorString(status));
158-
} else {
159-
cudaError_t status = handler_->BeginForwardDispatched<
160-
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
161-
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
162-
workspace_size_in_bytes,
163-
static_cast<int32_t*>(indptr.data_ptr()),
164-
static_cast<int32_t*>(last_page_len.data_ptr()),
165-
batch_size, num_qo_heads, page_size);
166-
TORCH_CHECK(status == cudaSuccess,
167-
"BatchDecodeWithPagedKVCache failed with error ",
168-
cudaGetErrorString(status));
169-
}
144+
cudaError_t status =
145+
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
146+
KV_LAYOUT, POS_ENCODING_MODE, c_type,
147+
nv_half, int32_t>(
148+
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
149+
static_cast<int32_t*>(indptr.data_ptr()),
150+
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
151+
page_size);
152+
TORCH_CHECK(status == cudaSuccess,
153+
"BatchDecodeWithPagedKVCache failed with error ",
154+
cudaGetErrorString(status));
170155
return true;
171156
});
172157
});
@@ -180,32 +165,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
180165
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
181166
return DISPATCH_pos_encoding_mode(
182167
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
183-
if (handler_->IsCUDAGraphMode()) {
184-
// NOTE(Zihao): use runtime dispatch because template function is not virtual
185-
auto cuda_graph_handler_ =
186-
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
187-
auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
188-
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
189-
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
190-
workspace_size_in_bytes,
191-
static_cast<int32_t*>(indptr.data_ptr()),
192-
static_cast<int32_t*>(last_page_len.data_ptr()),
193-
batch_size, num_qo_heads, page_size);
194-
TORCH_CHECK(status == cudaSuccess,
195-
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
196-
cudaGetErrorString(status));
197-
} else {
198-
cudaError_t status = handler_->BeginForwardDispatched<
199-
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
200-
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
201-
workspace_size_in_bytes,
202-
static_cast<int32_t*>(indptr.data_ptr()),
203-
static_cast<int32_t*>(last_page_len.data_ptr()),
204-
batch_size, num_qo_heads, page_size);
205-
TORCH_CHECK(status == cudaSuccess,
206-
"BatchDecodeWithPagedKVCache failed with error ",
207-
cudaGetErrorString(status));
208-
}
168+
cudaError_t status =
169+
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
170+
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
171+
int32_t>(
172+
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
173+
static_cast<int32_t*>(indptr.data_ptr()),
174+
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
175+
page_size);
176+
TORCH_CHECK(status == cudaSuccess,
177+
"BatchDecodeWithPagedKVCache failed with error ",
178+
cudaGetErrorString(status));
209179
return true;
210180
});
211181
});

python/csrc/flashinfer_ops.cu

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,34 +44,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4444
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
4545
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
4646
"BatchDecodeWithPagedKVCachePyTorchWrapper")
47-
.def(py::init<unsigned int, unsigned int>())
47+
.def(py::init<unsigned int, unsigned int, unsigned int, bool>())
4848
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
4949
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
50+
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
5051
.def("update_page_locked_buffer_size",
5152
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
5253
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
53-
py::class_<CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper>(
54-
m, "CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper")
55-
.def(py::init<unsigned int, unsigned int>())
56-
.def("begin_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
57-
.def("end_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
58-
.def("update_page_locked_buffer_size",
59-
&CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
60-
.def("forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
6154
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
6255
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
63-
.def(py::init<unsigned int, unsigned int>())
56+
.def(py::init<unsigned int, unsigned int, bool>())
6457
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
6558
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
59+
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
6660
.def("update_page_locked_buffer_size",
6761
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
6862
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
6963
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
7064
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
7165
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
72-
.def(py::init<unsigned int, unsigned int>())
66+
.def(py::init<unsigned int, unsigned int, bool>())
7367
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
7468
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
69+
.def("is_cuda_graph_enabled",
70+
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
7571
.def("update_page_locked_buffer_size",
7672
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
7773
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)

python/csrc/flashinfer_ops.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020
#include <flashinfer/layout.cuh>
2121
#include <memory>
2222

23-
// namespace flashinfer {
24-
// class BatchPrefillHandler;
25-
// class BatchDecodeHandler;
26-
// } // namespace flashinfer
27-
2823
torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v,
2924
torch::Tensor tmp, unsigned int pos_encoding_mode,
3025
unsigned int layout, float sm_scale, float rope_scale,
@@ -84,6 +79,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
8479
unsigned int pos_encoding_mode, torch::Tensor empty_data);
8580
void EndForward();
8681
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
82+
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
8783
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor paged_kv_data,
8884
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
8985
torch::Tensor paged_kv_last_page_len,
@@ -93,31 +89,24 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
9389
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
9490
: handler_(handler_ptr), kv_layout_(kv_layout) {}
9591
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
96-
unsigned int max_workspace_size_in_bytes)
92+
unsigned int max_workspace_size_in_bytes,
93+
unsigned int max_batch_size, bool enable_cuda_graph)
9794
: kv_layout_(flashinfer::QKVLayout(layout)),
98-
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}
95+
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(
96+
max_workspace_size_in_bytes, max_batch_size, enable_cuda_graph)) {}
9997

10098
protected:
10199
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
102100
flashinfer::QKVLayout kv_layout_;
103101
};
104102

105-
class CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper
106-
: public BatchDecodeWithPagedKVCachePyTorchWrapper {
107-
public:
108-
CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
109-
unsigned int max_batch_size)
110-
: BatchDecodeWithPagedKVCachePyTorchWrapper(
111-
std::make_shared<flashinfer::CUDAGraphBatchDecodeHandler>(max_batch_size),
112-
flashinfer::QKVLayout(layout)) {}
113-
};
114-
115103
class BatchPrefillWithPagedKVCachePyTorchWrapper {
116104
public:
117105
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
118106
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
119107
unsigned int head_dim);
120108
void EndForward();
109+
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
121110
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
122111
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
123112
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
@@ -133,9 +122,11 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
133122
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale,
134123
float rope_scale, float rope_theta, bool return_lse);
135124
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout,
136-
unsigned int max_workspace_size_in_bytes)
125+
unsigned int max_workspace_size_in_bytes,
126+
bool enable_cuda_graph)
137127
: kv_layout_(flashinfer::QKVLayout(layout)),
138-
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}
128+
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
129+
enable_cuda_graph)) {}
139130

140131
private:
141132
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
@@ -148,6 +139,7 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
148139
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
149140
unsigned int head_dim);
150141
void EndForward();
142+
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
151143
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
152144
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
153145
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
@@ -162,9 +154,11 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
162154
bool allow_fp16_qk_reduction, float sm_scale,
163155
float rope_scale, float rope_theta, bool return_lse);
164156
BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout,
165-
unsigned int max_workspace_size_in_bytes)
157+
unsigned int max_workspace_size_in_bytes,
158+
bool enable_cuda_graph)
166159
: kv_layout_(flashinfer::QKVLayout(layout)),
167-
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}
160+
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
161+
enable_cuda_graph)) {}
168162

169163
private:
170164
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;

0 commit comments

Comments
 (0)