@@ -71,14 +71,14 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
7171 return plan_info.ToVector ();
7272}
7373
74- std::vector< torch::Tensor> BatchPrefillWithRaggedKVCacheRun (
74+ torch::Tensor BatchPrefillWithRaggedKVCacheRun (
7575 unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
7676 torch::Tensor int_workspace_buffer, std::vector<int64_t > plan_info_vec, torch::Tensor q,
7777 torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
7878 std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
7979 torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
8080 int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
81- bool return_lse ) {
81+ std::optional<torch::Tensor> maybe_lse ) {
8282 PrefillPlanInfo plan_info;
8383 plan_info.FromVector (plan_info_vec);
8484 QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -98,10 +98,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
9898 auto device = float_workspace_buffer.device ();
9999 cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
100100 auto o = torch::empty_like (q, q.options ());
101- int64_t nnz_qo = q.size (0 );
102- torch::Tensor lse = torch::empty ({0 });
103- if (return_lse) {
104- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ().dtype (torch::kFloat32 ));
101+ if (maybe_lse) {
102+ const auto & lse = *maybe_lse;
103+ TORCH_CHECK (lse.size (0 ) == q.size (0 ), lse.size (0 ), q.size (0 ));
104+ TORCH_CHECK (lse.size (1 ) == q.size (1 ), lse.size (1 ), q.size (1 ));
105+ TORCH_CHECK (lse.dtype () == torch::kFloat32 , " lse must be float32" );
105106 }
106107
107108 void * float_buffer_ptr = float_workspace_buffer.data_ptr ();
@@ -140,7 +141,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
140141 : nullptr ,
141142 /* q_offset=*/ nullptr ,
142143 /* k_rope_pos_offset=*/ nullptr , static_cast <DTypeO*>(o.data_ptr ()),
143- /* lse=*/ return_lse ? static_cast <float *>(lse. data_ptr ()) : nullptr ,
144+ /* lse=*/ (maybe_lse ? static_cast <float *>(maybe_lse-> data_ptr ()) : nullptr ) ,
144145 /* alibi_slopes=*/ nullptr , num_qo_heads, num_kv_heads, q_stride_n, q_stride_h,
145146 kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale,
146147 rope_theta);
@@ -187,22 +188,18 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
187188 });
188189 });
189190
190- if (return_lse) {
191- return {o, lse};
192- } else {
193- return {o};
194- }
191+ return o;
195192}
196193
197- std::vector< torch::Tensor> BatchPrefillWithPagedKVCacheRun (
194+ torch::Tensor BatchPrefillWithPagedKVCacheRun (
198195 unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
199196 torch::Tensor int_workspace_buffer, std::vector<int64_t > plan_info_vec, torch::Tensor q,
200197 torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
201198 std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
202199 torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
203200 torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
204201 unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
205- float rope_scale, float rope_theta, bool return_lse ) {
202+ float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse ) {
206203 PrefillPlanInfo plan_info;
207204 plan_info.FromVector (plan_info_vec);
208205 QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -221,10 +218,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
221218
222219 cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
223220 auto o = torch::empty_like (q, q.options ());
224- int64_t nnz_qo = q.size (0 );
225- torch::Tensor lse = torch::empty ({0 });
226- if (return_lse) {
227- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ().dtype (torch::kFloat32 ));
221+ if (maybe_lse) {
222+ const auto & lse = *maybe_lse;
223+ TORCH_CHECK (lse.size (0 ) == q.size (0 ), lse.size (0 ), q.size (0 ));
224+ TORCH_CHECK (lse.size (1 ) == q.size (1 ), lse.size (1 ), q.size (1 ));
225+ TORCH_CHECK (lse.dtype () == torch::kFloat32 , " lse must be float32" );
228226 }
229227
230228 void * float_buffer_ptr = static_cast <void *>(float_workspace_buffer.data_ptr ());
@@ -277,7 +275,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
277275 maybe_qk_indptr.has_value () ? static_cast <IdType*>(maybe_qk_indptr->data_ptr ())
278276 : nullptr ,
279277 /* q_offset=*/ nullptr , static_cast <DTypeO*>(o.data_ptr ()),
280- /* lse=*/ return_lse ? static_cast <float *>(lse. data_ptr ()) : nullptr ,
278+ /* lse=*/ (maybe_lse ? static_cast <float *>(maybe_lse-> data_ptr ()) : nullptr ) ,
281279 /* alibi_slopes=*/ nullptr , num_qo_heads, q_stride_n, q_stride_h, window_left,
282280 logits_soft_cap, sm_scale, rope_scale, rope_theta);
283281
@@ -323,9 +321,5 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
323321 });
324322 });
325323
326- if (return_lse) {
327- return {o, lse};
328- } else {
329- return {o};
330- }
324+ return o;
331325}
0 commit comments