Skip to content

Commit 7aadc0d

Browse files
authored
Add dtype checks for q-kv tensors (#280)
Right now, we require q and kv tensors to have the same dtype, but that is not enforced, which can lead to cryptic memory errors in case of a misconfiguration. This PR adds a check to ensure that we prevent mismatched dtypes.
1 parent 1092e7e commit 7aadc0d

File tree

4 files changed

+13
-0
lines changed

4 files changed

+13
-0
lines changed

python/csrc/batch_decode.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
3333
CHECK_SHAPE(k_padded, v_padded);
3434
CHECK_EQ(q.size(0), k_padded.size(0));
3535
CHECK_EQ(q.size(2), k_padded.size(3));
36+
CHECK_EQ(q.scalar_type(), k_padded.scalar_type());
37+
CHECK_EQ(q.scalar_type(), v_padded.scalar_type());
3638
unsigned int batch_size = q.size(0);
3739
unsigned int num_qo_heads = q.size(1);
3840
unsigned int head_dim = q.size(2);
@@ -206,6 +208,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
206208
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
207209
CHECK_DIM(1, paged_kv_indptr); // (B+1,)
208210
CHECK_DIM(1, paged_kv_indices); // (nnz,)
211+
CHECK_EQ(q.scalar_type(), paged_kv_data.scalar_type());
209212
// (num_max_pages, 2, H_kv, page_size, head_dim) for HND
210213
// (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
211214
CHECK_DIM(5, paged_kv_data);

python/csrc/batch_prefill.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
7070
CHECK_DIM(1, paged_kv_indptr); // (B + 1,)
7171
CHECK_DIM(1, paged_kv_indices); // (nnz_kv,)
7272
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
73+
CHECK_EQ(q.scalar_type(), paged_kv_data.scalar_type());
7374
int64_t batch_size = qo_indptr.size(0) - 1;
7475
int64_t nnz_qo = q.size(0);
7576
int64_t num_qo_heads = q.size(1);
@@ -173,6 +174,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
173174
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
174175
CHECK_DIM(1, custom_mask); // (nnz_qk,)
175176
CHECK_DIM(1, qk_indptr); // (B + 1,)
177+
CHECK_EQ(q.scalar_type(), paged_kv_data.scalar_type());
176178
int64_t batch_size = qo_indptr.size(0) - 1;
177179
int64_t nnz_qo = q.size(0);
178180
int64_t num_qo_heads = q.size(1);
@@ -299,6 +301,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
299301
CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
300302
CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
301303
CHECK_DIM(1, kv_indptr); // (B + 1,)
304+
CHECK_EQ(q.scalar_type(), k.scalar_type());
305+
CHECK_EQ(q.scalar_type(), v.scalar_type());
302306
int64_t batch_size = qo_indptr.size(0) - 1;
303307
int64_t nnz_qo = q.size(0);
304308
int64_t num_qo_heads = q.size(1);
@@ -382,6 +386,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
382386
CHECK_DIM(1, kv_indptr); // (B + 1,)
383387
CHECK_DIM(1, custom_mask); // (nnz_qk,)
384388
CHECK_DIM(1, qk_indptr); // (B + 1,)
389+
CHECK_EQ(q.scalar_type(), k.scalar_type());
390+
CHECK_EQ(q.scalar_type(), v.scalar_type());
385391
int64_t batch_size = qo_indptr.size(0) - 1;
386392
int64_t nnz_qo = q.size(0);
387393
int64_t num_qo_heads = q.size(1);

python/csrc/single_decode.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
3232
CHECK_DIM(3, v);
3333
CHECK_SHAPE(k, v);
3434
CHECK_EQ(q.size(1), k.size(2));
35+
CHECK_EQ(q.scalar_type(), k.scalar_type());
36+
CHECK_EQ(q.scalar_type(), v.scalar_type());
3537
unsigned int num_qo_heads = q.size(0);
3638
unsigned int head_dim = q.size(1);
3739
unsigned int kv_len, num_kv_heads;

python/csrc/single_prefill.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
3232
CHECK_DIM(3, v);
3333
CHECK_SHAPE(k, v);
3434
CHECK_EQ(q.size(2), k.size(2));
35+
CHECK_EQ(q.scalar_type(), k.scalar_type());
36+
CHECK_EQ(q.scalar_type(), v.scalar_type());
3537
unsigned int head_dim = q.size(2);
3638
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
3739
QKVLayout kv_layout = static_cast<QKVLayout>(layout);

0 commit comments

Comments
 (0)