@@ -65,9 +65,6 @@ torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);
6565
6666class BatchDecodeWithPagedKVCachePyTorchWrapper {
6767 public:
68- static BatchDecodeWithPagedKVCachePyTorchWrapper Create (unsigned int layout) {
69- return BatchDecodeWithPagedKVCachePyTorchWrapper (layout);
70- }
7168 void BeginForward (torch::Tensor workspace_buffer, torch::Tensor indptr,
7269 torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
7370 unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
@@ -78,19 +75,16 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
7875 torch::Tensor paged_kv_last_page_len,
7976 unsigned int pos_encoding_mode, float sm_scale,
8077 float rope_scale, float rope_theta, bool return_lse);
81-
82- private:
8378 BatchDecodeWithPagedKVCachePyTorchWrapper (unsigned int layout)
8479 : kv_layout_(flashinfer::QKVLayout(layout)) {}
80+
81+ private:
8582 flashinfer::BatchDecodeHandler handler_;
8683 flashinfer::QKVLayout kv_layout_;
8784};
8885
8986class BatchPrefillWithPagedKVCachePyTorchWrapper {
9087 public:
91- static BatchPrefillWithPagedKVCachePyTorchWrapper Create (unsigned int layout) {
92- return BatchPrefillWithPagedKVCachePyTorchWrapper (layout);
93- }
9488 void BeginForward (torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
9589 unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
9690 unsigned int head_dim);
@@ -102,19 +96,16 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
10296 unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
10397 float sm_scale, float rope_scale, float rope_theta,
10498 bool return_lse);
105-
106- private:
10799 BatchPrefillWithPagedKVCachePyTorchWrapper (unsigned int layout)
108100 : kv_layout_(flashinfer::QKVLayout(layout)) {}
101+
102+ private:
109103 flashinfer::BatchPrefillHandler handler_;
110104 flashinfer::QKVLayout kv_layout_;
111105};
112106
113107class BatchPrefillWithRaggedKVCachePyTorchWrapper {
114108 public:
115- static BatchPrefillWithRaggedKVCachePyTorchWrapper Create (unsigned int layout) {
116- return BatchPrefillWithRaggedKVCachePyTorchWrapper (layout);
117- }
118109 void BeginForward (torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
119110 unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
120111 unsigned int head_dim);
@@ -124,10 +115,10 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
124115 unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
125116 float sm_scale, float rope_scale, float rope_theta,
126117 bool return_lse);
127-
128- private:
129118 BatchPrefillWithRaggedKVCachePyTorchWrapper (unsigned int layout)
130119 : kv_layout_(flashinfer::QKVLayout(layout)) {}
120+
121+ private:
131122 flashinfer::BatchPrefillHandler handler_;
132123 flashinfer::QKVLayout kv_layout_;
133124};
0 commit comments