1717 * limitations under the License.
1818 */
1919
20- #include < torch/all .h>
20+ #include < torch/extension .h>
2121#include < ATen/cuda/CUDAContext.h>
2222#include < c10/cuda/CUDAGuard.h>
2323#include < algorithm>
@@ -808,17 +808,16 @@ void paged_attention_v1(
808808 torch::Tensor&
809809 key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
810810 torch::Tensor&
811- value_cache, // [num_blocks, num_heads, head_size, block_size]
812- int64_t num_kv_heads, // [num_heads]
813- double scale,
811+ value_cache, // [num_blocks, num_heads, head_size, block_size]
812+ int num_kv_heads, // [num_heads]
813+ float scale,
814814 torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
815815 torch::Tensor& seq_lens, // [num_seqs]
816- int64_t block_size, int64_t max_seq_len,
816+ int block_size, int max_seq_len,
817817 const c10::optional<torch::Tensor>& alibi_slopes,
818- const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
819- const int64_t blocksparse_local_blocks,
820- const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
821- const int64_t blocksparse_head_sliding_step) {
818+ const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
819+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
820+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
822821 const bool is_block_sparse = (blocksparse_vert_stride > 1 );
823822
824823 DISPATCH_BY_KV_CACHE_DTYPE (query.dtype (), kv_cache_dtype,
@@ -973,17 +972,16 @@ void paged_attention_v2(
973972 torch::Tensor&
974973 key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
975974 torch::Tensor&
976- value_cache, // [num_blocks, num_heads, head_size, block_size]
977- int64_t num_kv_heads, // [num_heads]
978- double scale,
975+ value_cache, // [num_blocks, num_heads, head_size, block_size]
976+ int num_kv_heads, // [num_heads]
977+ float scale,
979978 torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
980979 torch::Tensor& seq_lens, // [num_seqs]
981- int64_t block_size, int64_t max_seq_len,
980+ int block_size, int max_seq_len,
982981 const c10::optional<torch::Tensor>& alibi_slopes,
983- const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
984- const int64_t blocksparse_local_blocks,
985- const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
986- const int64_t blocksparse_head_sliding_step) {
982+ const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
983+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
984+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
987985 const bool is_block_sparse = (blocksparse_vert_stride > 1 );
988986 DISPATCH_BY_KV_CACHE_DTYPE (query.dtype (), kv_cache_dtype,
989987 CALL_V2_LAUNCHER_BLOCK_SIZE)
@@ -992,4 +990,4 @@ void paged_attention_v2(
992990#undef WARP_SIZE
993991#undef MAX
994992#undef MIN
995- #undef DIVIDE_ROUND_UP
993+ #undef DIVIDE_ROUND_UP
0 commit comments