@@ -141,32 +141,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
141141 return DISPATCH_kv_layout (kv_layout_, KV_LAYOUT, [&] {
142142 return DISPATCH_pos_encoding_mode (
143143 PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
144- if (handler_->IsCUDAGraphMode ()) {
145- // NOTE(Zihao): use runtime dispatch because template function is not virtual
146- auto cuda_graph_handler_ =
147- dynamic_cast <CUDAGraphBatchDecodeHandler*>(handler_.get ());
148- cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched <
149- GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
150- c_type, nv_half, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
151- workspace_size_in_bytes,
152- static_cast <int32_t *>(indptr.data_ptr ()),
153- static_cast <int32_t *>(last_page_len.data_ptr ()),
154- batch_size, num_qo_heads, page_size);
155- TORCH_CHECK (status == cudaSuccess,
156- " BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error " ,
157- cudaGetErrorString (status));
158- } else {
159- cudaError_t status = handler_->BeginForwardDispatched <
160- GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
161- c_type, nv_half, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
162- workspace_size_in_bytes,
163- static_cast <int32_t *>(indptr.data_ptr ()),
164- static_cast <int32_t *>(last_page_len.data_ptr ()),
165- batch_size, num_qo_heads, page_size);
166- TORCH_CHECK (status == cudaSuccess,
167- " BatchDecodeWithPagedKVCache failed with error " ,
168- cudaGetErrorString (status));
169- }
144+ cudaError_t status =
145+ handler_->BeginForwardDispatched <GROUP_SIZE, HEAD_DIM, PageStorage::kIndices ,
146+ KV_LAYOUT, POS_ENCODING_MODE, c_type,
147+ nv_half, int32_t >(
148+ static_cast <void *>(workspace_buffer.data_ptr ()), workspace_size_in_bytes,
149+ static_cast <int32_t *>(indptr.data_ptr ()),
150+ static_cast <int32_t *>(last_page_len.data_ptr ()), batch_size, num_qo_heads,
151+ page_size);
152+ TORCH_CHECK (status == cudaSuccess,
153+ " BatchDecodeWithPagedKVCache failed with error " ,
154+ cudaGetErrorString (status));
170155 return true ;
171156 });
172157 });
@@ -180,32 +165,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
180165 return DISPATCH_kv_layout (kv_layout_, KV_LAYOUT, [&] {
181166 return DISPATCH_pos_encoding_mode (
182167 PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
183- if (handler_->IsCUDAGraphMode ()) {
184- // NOTE(Zihao): use runtime dispatch because template function is not virtual
185- auto cuda_graph_handler_ =
186- dynamic_cast <CUDAGraphBatchDecodeHandler*>(handler_.get ());
187- auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched <
188- GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
189- c_type, c_type, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
190- workspace_size_in_bytes,
191- static_cast <int32_t *>(indptr.data_ptr ()),
192- static_cast <int32_t *>(last_page_len.data_ptr ()),
193- batch_size, num_qo_heads, page_size);
194- TORCH_CHECK (status == cudaSuccess,
195- " BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error " ,
196- cudaGetErrorString (status));
197- } else {
198- cudaError_t status = handler_->BeginForwardDispatched <
199- GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
200- c_type, c_type, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
201- workspace_size_in_bytes,
202- static_cast <int32_t *>(indptr.data_ptr ()),
203- static_cast <int32_t *>(last_page_len.data_ptr ()),
204- batch_size, num_qo_heads, page_size);
205- TORCH_CHECK (status == cudaSuccess,
206- " BatchDecodeWithPagedKVCache failed with error " ,
207- cudaGetErrorString (status));
208- }
168+ cudaError_t status =
169+ handler_->BeginForwardDispatched <GROUP_SIZE, HEAD_DIM, PageStorage::kIndices ,
170+ KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
171+ int32_t >(
172+ static_cast <void *>(workspace_buffer.data_ptr ()), workspace_size_in_bytes,
173+ static_cast <int32_t *>(indptr.data_ptr ()),
174+ static_cast <int32_t *>(last_page_len.data_ptr ()), batch_size, num_qo_heads,
175+ page_size);
176+ TORCH_CHECK (status == cudaSuccess,
177+ " BatchDecodeWithPagedKVCache failed with error " ,
178+ cudaGetErrorString (status));
209179 return true ;
210180 });
211181 });
0 commit comments