File tree Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -105,8 +105,8 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
105105 static_cast <int32_t *>(paged_kv_indices.data_ptr ()),
106106 static_cast <int32_t *>(paged_kv_indptr.data_ptr ()),
107107 static_cast <int32_t *>(paged_kv_last_page_len.data_ptr ()));
108- return DISPATCH_group_size (num_qo_heads / num_kv_heads, [&] {
109- return DISPATCH_head_dim (head_dim, [&] {
108+ bool success = DISPATCH_group_size (num_qo_heads / num_kv_heads, [&] {
109+ bool success = DISPATCH_head_dim (head_dim, [&] {
110110 DISPATCH_CAUSAL (causal, CAUSAL, {
111111 DISPATCH_ALLOW_FP16_QK_REDUCTION (allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
112112 DISPATCH_POS_ENCODING_MODE (PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, {
@@ -127,7 +127,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
127127 });
128128 return true ;
129129 });
130+ TORCH_CHECK (success, " BatchPrefillWithPagedKVCache failed to dispatch head_dim " , head_dim);
131+ return success;
130132 });
133+ TORCH_CHECK (success, " BatchPrefillWithPagedKVCache failed to dispatch group_size " ,
134+ num_qo_heads / num_kv_heads);
131135 });
132136 return true ;
133137 });
Original file line number Diff line number Diff line change @@ -55,8 +55,8 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
5555 }
5656
5757 bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE (q.scalar_type (), c_type, [&] {
58- return DISPATCH_group_size (num_qo_heads / num_kv_heads, [&] {
59- return DISPATCH_head_dim (head_dim, [&] {
58+ bool success = DISPATCH_group_size (num_qo_heads / num_kv_heads, [&] {
59+ bool success = DISPATCH_head_dim (head_dim, [&] {
6060 DISPATCH_CAUSAL (causal, CAUSAL, {
6161 DISPATCH_LAYOUT (kv_layout, KV_LAYOUT, {
6262 DISPATCH_ALLOW_FP16_QK_REDUCTION (allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
@@ -80,7 +80,15 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
8080 });
8181 return true ;
8282 });
83+ TORCH_CHECK (success,
84+ " SinglePrefillWithKVCache kernel launch failed, error: unknown head_dim " ,
85+ head_dim);
86+ return success;
8387 });
88+ TORCH_CHECK (success,
89+ " SinglePrefillWithKVCache kernel launch failed, error: unknown group_size " ,
90+ num_qo_heads / num_kv_heads);
91+ return success;
8492 });
8593
8694 TORCH_CHECK (success, " SinglePrefillWithKVCache kernel launch failed, error: unknown dtype" );
You can’t perform that action at this time.
0 commit comments