Skip to content

Commit 1116237

Browse files
authored
perf: Optimize tensor conversions in C++ code to avoid unnecessary copies (#366)
Small tweak to avoid unnecessary copying by combining `to` calls. Discovered during profiling.
1 parent 264082e commit 1116237

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

python/csrc/batch_decode.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
154154
torch::Tensor o = torch::empty_like(q);
155155
torch::Tensor lse;
156156
if (return_lse) {
157-
lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32);
157+
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
158158
}
159159

160160
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");

python/csrc/batch_prefill.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
3131
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
3232
CHECK_DIM(1, qo_indptr);
3333
CHECK_DIM(1, workspace_buffer);
34-
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
35-
paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32);
34+
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
35+
paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
3636
auto device = workspace_buffer.device();
3737
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
3838
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
@@ -111,7 +111,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
111111
torch::Tensor o = torch::empty_like(q, q.options());
112112
torch::Tensor lse = torch::empty({0});
113113
if (return_lse) {
114-
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
114+
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
115115
}
116116
MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone;
117117
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
@@ -226,7 +226,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
226226
torch::Tensor o = torch::empty_like(q, q.options());
227227
torch::Tensor lse = torch::empty({0});
228228
if (return_lse) {
229-
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
229+
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
230230
}
231231
constexpr MaskMode MASK_MODE = MaskMode::kCustom;
232232
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
@@ -288,8 +288,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
288288
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
289289
CHECK_DIM(1, qo_indptr);
290290
CHECK_DIM(1, workspace_buffer);
291-
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
292-
kv_indptr = kv_indptr.to(torch::kCPU).to(torch::kInt32);
291+
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
292+
kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
293293
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
294294
auto device = workspace_buffer.device();
295295
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
@@ -354,7 +354,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
354354
torch::Tensor o = torch::empty_like(q, q.options());
355355
torch::Tensor lse = torch::empty({0});
356356
if (return_lse) {
357-
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
357+
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
358358
}
359359

360360
MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone;
@@ -452,7 +452,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
452452
torch::Tensor o = torch::empty_like(q, q.options());
453453
torch::Tensor lse = torch::empty({0});
454454
if (return_lse) {
455-
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
455+
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype((torch::kFloat32)));
456456
}
457457

458458
constexpr MaskMode MASK_MODE = MaskMode::kCustom;

0 commit comments

Comments
 (0)