Skip to content

Make flashinfer kernels cuda graphs friendly #187

@AgrawalAmey

Description

@AgrawalAmey

Thanks for creating these awesome kernels! I am trying to get flashinfer kernels to work with cuda graphs. But it appears that several parallelism decisions (block size, num_q_tiles, etc.) are made on the fly based on the input data in the forward function. This makes it difficult to capture flashinfer kernels in cuda graphs in a generic manner. I think one solution to the problem would be to introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism. Towards that, following are the items I have identified --

1. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- handle return lse?
2. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- paged_kv_t batch_size should not be on cpu side
3. BatchPrefillWithPagedKVCacheWrapperDispatched -- make cuda device function or get rid of it
4. BatchPrefillWithPagedKVCacheWrapperDispatched -- num_frags_x, num_qo_tiles, batch size need to be 
5. BatchPrefillWithPagedKVCacheWrapperDispatched -- do not access handler state directly in the function
6. BatchPrefillWithPagedKVCacheDispatched -- make cuda device function
7. BatchPrefillWithPagedKVCacheDispatched -- put num_qo_tiles on device accessible memory
8. BatchPrefillWithPagedKVCacheDispatched -- Make validations gpu friendly
9. Batch size should be explicit input parameter not be based on length of indptr, so that inputs can be padded.

@yzh119 please let me know what would be the best way to proceed?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions