Skip to content

Commit ed20304

Browse files
authored
bugfix: fix pybind class bindings (#255)
Previously we bind a factory method as the init function for C++ Class in Pybind, which is returned by value instead of reference. The destructive function of handlers in #253 will be triggered twice, and it leads to segmentation faults. This PR bypass the factory method and initializes the C++ classes directly.
1 parent 426266c commit ed20304

File tree

2 files changed

+9
-18
lines changed

2 files changed

+9
-18
lines changed

python/csrc/flashinfer_ops.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3737
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
3838
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
3939
"BatchDecodeWithPagedKVCachePyTorchWrapper")
40-
.def(py::init(&BatchDecodeWithPagedKVCachePyTorchWrapper::Create))
40+
.def(py::init<unsigned int>())
4141
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
4242
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
4343
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
4444
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
4545
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
46-
.def(py::init(&BatchPrefillWithPagedKVCachePyTorchWrapper::Create))
46+
.def(py::init<unsigned int>())
4747
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
4848
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
4949
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward);
5050
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
5151
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
52-
.def(py::init(&BatchPrefillWithRaggedKVCachePyTorchWrapper::Create))
52+
.def(py::init<unsigned int>())
5353
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
5454
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
5555
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward);

python/csrc/flashinfer_ops.h

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);
6565

6666
class BatchDecodeWithPagedKVCachePyTorchWrapper {
6767
public:
68-
static BatchDecodeWithPagedKVCachePyTorchWrapper Create(unsigned int layout) {
69-
return BatchDecodeWithPagedKVCachePyTorchWrapper(layout);
70-
}
7168
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
7269
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
7370
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
@@ -78,19 +75,16 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
7875
torch::Tensor paged_kv_last_page_len,
7976
unsigned int pos_encoding_mode, float sm_scale,
8077
float rope_scale, float rope_theta, bool return_lse);
81-
82-
private:
8378
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout)
8479
: kv_layout_(flashinfer::QKVLayout(layout)) {}
80+
81+
private:
8582
flashinfer::BatchDecodeHandler handler_;
8683
flashinfer::QKVLayout kv_layout_;
8784
};
8885

8986
class BatchPrefillWithPagedKVCachePyTorchWrapper {
9087
public:
91-
static BatchPrefillWithPagedKVCachePyTorchWrapper Create(unsigned int layout) {
92-
return BatchPrefillWithPagedKVCachePyTorchWrapper(layout);
93-
}
9488
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
9589
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
9690
unsigned int head_dim);
@@ -102,19 +96,16 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
10296
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
10397
float sm_scale, float rope_scale, float rope_theta,
10498
bool return_lse);
105-
106-
private:
10799
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout)
108100
: kv_layout_(flashinfer::QKVLayout(layout)) {}
101+
102+
private:
109103
flashinfer::BatchPrefillHandler handler_;
110104
flashinfer::QKVLayout kv_layout_;
111105
};
112106

113107
class BatchPrefillWithRaggedKVCachePyTorchWrapper {
114108
public:
115-
static BatchPrefillWithRaggedKVCachePyTorchWrapper Create(unsigned int layout) {
116-
return BatchPrefillWithRaggedKVCachePyTorchWrapper(layout);
117-
}
118109
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
119110
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
120111
unsigned int head_dim);
@@ -124,10 +115,10 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
124115
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
125116
float sm_scale, float rope_scale, float rope_theta,
126117
bool return_lse);
127-
128-
private:
129118
BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout)
130119
: kv_layout_(flashinfer::QKVLayout(layout)) {}
120+
121+
private:
131122
flashinfer::BatchPrefillHandler handler_;
132123
flashinfer::QKVLayout kv_layout_;
133124
};

0 commit comments

Comments
 (0)