-
Notifications
You must be signed in to change notification settings - Fork 558
Closed
Description
Thanks for creating these awesome kernels! I am trying to get flashinfer kernels to work with cuda graphs. But it appears that several parallelism decisions (block size, num_q_tiles, etc.) are made on the fly based on the input data in the forward function. This makes it difficult to capture flashinfer kernels in cuda graphs in a generic manner. I think one solution to the problem would be to introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism. Towards that, following are the items I have identified --
1. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- handle return lse?
2. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- paged_kv_t batch_size should not be on cpu side
3. BatchPrefillWithPagedKVCacheWrapperDispatched -- make cuda device function or get rid of it
4. BatchPrefillWithPagedKVCacheWrapperDispatched -- num_frags_x, num_qo_tiles, batch size need to be
5. BatchPrefillWithPagedKVCacheWrapperDispatched -- do not access handler state directly in the function
6. BatchPrefillWithPagedKVCacheDispatched -- make cuda device function
7. BatchPrefillWithPagedKVCacheDispatched -- put num_qo_tiles on device accessible memory
8. BatchPrefillWithPagedKVCacheDispatched -- Make validations gpu friendly
9. Batch size should be explicit input parameter not be based on length of indptr, so that inputs can be padded.
@yzh119 please let me know what would be the best way to proceed?
rkooo567, yzh119 and pengwu22
Metadata
Metadata
Assignees
Labels
No labels