@@ -31,8 +31,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
3131 CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
3232 CHECK_DIM (1 , qo_indptr);
3333 CHECK_DIM (1 , workspace_buffer);
34- qo_indptr = qo_indptr.to (torch::kCPU ). to (torch::kInt32 );
35- paged_kv_indptr = paged_kv_indptr.to (torch::kCPU ). to (torch::kInt32 );
34+ qo_indptr = qo_indptr.to (torch::dtype (torch:: kInt32 ). device (torch::kCPU ) );
35+ paged_kv_indptr = paged_kv_indptr.to (torch::dtype (torch:: kInt32 ). device (torch::kCPU ) );
3636 auto device = workspace_buffer.device ();
3737 size_t workspace_size_in_bytes = workspace_buffer.size (0 ) * workspace_buffer.element_size ();
3838 cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
@@ -111,7 +111,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
111111 torch::Tensor o = torch::empty_like (q, q.options ());
112112 torch::Tensor lse = torch::empty ({0 });
113113 if (return_lse) {
114- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ()). to (torch::kFloat32 );
114+ lse = torch::empty ({nnz_qo, num_qo_heads}, q.options (). dtype (torch::kFloat32 ) );
115115 }
116116 MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone ;
117117 TORCH_CHECK (logits_soft_cap >= 0 .f , " logits_soft_cap must be non-negative" );
@@ -226,7 +226,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
226226 torch::Tensor o = torch::empty_like (q, q.options ());
227227 torch::Tensor lse = torch::empty ({0 });
228228 if (return_lse) {
229- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ()). to (torch::kFloat32 );
229+ lse = torch::empty ({nnz_qo, num_qo_heads}, q.options (). dtype (torch::kFloat32 ) );
230230 }
231231 constexpr MaskMode MASK_MODE = MaskMode::kCustom ;
232232 TORCH_CHECK (logits_soft_cap >= 0 .f , " logits_soft_cap must be non-negative" );
@@ -288,8 +288,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
288288 CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
289289 CHECK_DIM (1 , qo_indptr);
290290 CHECK_DIM (1 , workspace_buffer);
291- qo_indptr = qo_indptr.to (torch::kCPU ). to (torch::kInt32 );
292- kv_indptr = kv_indptr.to (torch::kCPU ). to (torch::kInt32 );
291+ qo_indptr = qo_indptr.to (torch::dtype (torch:: kInt32 ). device (torch::kCPU ) );
292+ kv_indptr = kv_indptr.to (torch::dtype (torch:: kInt32 ). device (torch::kCPU ) );
293293 size_t workspace_size_in_bytes = workspace_buffer.size (0 ) * workspace_buffer.element_size ();
294294 auto device = workspace_buffer.device ();
295295 cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
@@ -354,7 +354,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
354354 torch::Tensor o = torch::empty_like (q, q.options ());
355355 torch::Tensor lse = torch::empty ({0 });
356356 if (return_lse) {
357- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ()). to (torch::kFloat32 );
357+ lse = torch::empty ({nnz_qo, num_qo_heads}, q.options (). dtype (torch::kFloat32 ) );
358358 }
359359
360360 MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone ;
@@ -452,7 +452,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
452452 torch::Tensor o = torch::empty_like (q, q.options ());
453453 torch::Tensor lse = torch::empty ({0 });
454454 if (return_lse) {
455- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ()). to ( torch::kFloat32 );
455+ lse = torch::empty ({nnz_qo, num_qo_heads}, q.options (). dtype (( torch::kFloat32 )) );
456456 }
457457
458458 constexpr MaskMode MASK_MODE = MaskMode::kCustom ;
0 commit comments