1717 * limitations under the License.
1818 */
1919
20- #include < torch/extension .h>
20+ #include < torch/all .h>
2121#include < ATen/cuda/CUDAContext.h>
2222#include < c10/cuda/CUDAGuard.h>
2323#include < algorithm>
@@ -808,16 +808,17 @@ void paged_attention_v1(
808808 torch::Tensor&
809809 key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
810810 torch::Tensor&
811- value_cache, // [num_blocks, num_heads, head_size, block_size]
812- int num_kv_heads, // [num_heads]
813- float scale,
811+ value_cache, // [num_blocks, num_heads, head_size, block_size]
812+ int64_t num_kv_heads, // [num_heads]
813+ double scale,
814814 torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
815815 torch::Tensor& seq_lens, // [num_seqs]
816- int block_size, int max_seq_len,
816+ int64_t block_size, int64_t max_seq_len,
817817 const c10::optional<torch::Tensor>& alibi_slopes,
818- const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
819- const int blocksparse_local_blocks, const int blocksparse_vert_stride,
820- const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
818+ const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
819+ const int64_t blocksparse_local_blocks,
820+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
821+ const int64_t blocksparse_head_sliding_step) {
821822 const bool is_block_sparse = (blocksparse_vert_stride > 1 );
822823
823824 DISPATCH_BY_KV_CACHE_DTYPE (query.dtype (), kv_cache_dtype,
@@ -972,16 +973,17 @@ void paged_attention_v2(
972973 torch::Tensor&
973974 key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
974975 torch::Tensor&
975- value_cache, // [num_blocks, num_heads, head_size, block_size]
976- int num_kv_heads, // [num_heads]
977- float scale,
976+ value_cache, // [num_blocks, num_heads, head_size, block_size]
977+ int64_t num_kv_heads, // [num_heads]
978+ double scale,
978979 torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
979980 torch::Tensor& seq_lens, // [num_seqs]
980- int block_size, int max_seq_len,
981+ int64_t block_size, int64_t max_seq_len,
981982 const c10::optional<torch::Tensor>& alibi_slopes,
982- const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
983- const int blocksparse_local_blocks, const int blocksparse_vert_stride,
984- const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
983+ const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
984+ const int64_t blocksparse_local_blocks,
985+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
986+ const int64_t blocksparse_head_sliding_step) {
985987 const bool is_block_sparse = (blocksparse_vert_stride > 1 );
986988 DISPATCH_BY_KV_CACHE_DTYPE (query.dtype (), kv_cache_dtype,
987989 CALL_V2_LAUNCHER_BLOCK_SIZE)
@@ -990,4 +992,4 @@ void paged_attention_v2(
990992#undef WARP_SIZE
991993#undef MAX
992994#undef MIN
993- #undef DIVIDE_ROUND_UP
995+ #undef DIVIDE_ROUND_UP
0 commit comments