Skip to content

[Core] Support sparse KV cache framework #5752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 59 additions & 15 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ __device__ void paged_attention_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const std::string& sparse_cache_type,
float* __restrict__ attention_scores) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
Expand Down Expand Up @@ -303,6 +305,14 @@ __device__ void paged_attention_kernel(
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);

if (attention_scores != nullptr && !mask &&
physical_block_number != 0) {
attention_scores[seq_idx * BLOCK_SIZE * max_num_blocks_per_seq *
num_heads +
(token_idx - start_token_idx) * num_heads +
head_idx] = logits[token_idx - start_token_idx];
}
}
}
}
Expand Down Expand Up @@ -515,15 +525,17 @@ __global__ void paged_attention_v1_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const std::string& sparse_cache_type,
float* __restrict__ attention_scores) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, sparse_cache_type, attention_scores);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
Expand Down Expand Up @@ -551,14 +563,16 @@ __global__ void paged_attention_v2_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const std::string& sparse_cache_type,
float* __restrict__ attention_scores) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, sparse_cache_type, attention_scores);
}

// Grid: (num_heads, num_seqs).
Expand All @@ -573,7 +587,8 @@ __global__ void paged_attention_v2_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int max_num_partitions, const std::string& sparse_cache_type,
float* __restrict__ attention_scores) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
Expand Down Expand Up @@ -684,7 +699,8 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, sparse_cache_type, \
attention_scores_ptr);

// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
Expand All @@ -697,7 +713,8 @@ void paged_attention_v1_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const std::string& sparse_cache_type, torch::Tensor& attention_scores) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -722,6 +739,17 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();

torch::Device cache_device = key_cache.device();
TORCH_CHECK(cache_device.is_cuda());
torch::Device cpu_device = attention_scores.device();
TORCH_CHECK(cpu_device.is_cpu());
torch::Tensor attention_scores_tensor = attention_scores.to(cache_device);
float* attention_scores_ptr = nullptr;
if (attention_scores_tensor.size(0) > 0) {
attention_scores_ptr =
reinterpret_cast<float*>(attention_scores_tensor.data_ptr());
}

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
Expand Down Expand Up @@ -764,6 +792,7 @@ void paged_attention_v1_launcher(
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
attention_scores.copy_(attention_scores_tensor.to(cpu_device));
}

#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
Expand All @@ -772,7 +801,8 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
sparse_cache_type, attention_scores);

#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
Expand Down Expand Up @@ -818,7 +848,8 @@ void paged_attention_v1(
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const std::string& sparse_cache_type, torch::Tensor& attention_scores) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);

DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
Expand All @@ -835,12 +866,13 @@ void paged_attention_v1(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step, \
sparse_cache_type, attention_scores_ptr); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
max_num_partitions, sparse_cache_type, attention_scores_ptr);

template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
Expand All @@ -853,7 +885,8 @@ void paged_attention_v2_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const std::string& sparse_cache_type, torch::Tensor& attention_scores) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand Down Expand Up @@ -881,6 +914,14 @@ void paged_attention_v2_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();

torch::Device cache_device = key_cache.device();
TORCH_CHECK(cache_device.is_cuda());
torch::Device cpu_device = attention_scores.device();
TORCH_CHECK(cpu_device.is_cpu());
torch::Tensor attention_scores_tensor = attention_scores.to(cache_device);
float* attention_scores_ptr =
reinterpret_cast<float*>(attention_scores_tensor.data_ptr());

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
Expand Down Expand Up @@ -925,6 +966,7 @@ void paged_attention_v2_launcher(
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
attention_scores.copy_(attention_scores_tensor.to(cpu_device));
}

#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
Expand All @@ -933,7 +975,8 @@ void paged_attention_v2_launcher(
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
sparse_cache_type, attention_scores);

#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
Expand Down Expand Up @@ -983,7 +1026,8 @@ void paged_attention_v2(
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const std::string& sparse_cache_type, torch::Tensor& attention_scores) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
Expand Down
9 changes: 9 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);

void sparse_cache_copy(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping_src_tensor,
const torch::Tensor& block_mapping_dst_tensor,
const torch::Tensor& selection_index_src_tensor,
const torch::Tensor& selection_index_dst_tensor,
const int64_t num_heads, const int64_t head_size,
const int64_t block_size);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
156 changes: 156 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,162 @@ void reshape_and_cache_flash(

namespace vllm {

// Grid: (num_layers, num_seqs, block_size * num_blocks)
template <typename scalar_t>
__global__ void sparse_cache_copy_kernel(
int64_t* key_cache_ptrs, int64_t* value_cache_ptrs,
const int64_t* __restrict__ block_mapping_src,
const int64_t* __restrict__ block_mapping_dst,
const int64_t* __restrict__ selection_index_src,
const int64_t* __restrict__ selection_index_dst, const int num_heads,
const int head_size, const int x) {
const int layer_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int selected_pairs_idx = blockIdx.z;

const int num_layers = gridDim.x;
const int num_seqs = gridDim.y;
const int block_size_times_num = gridDim.z;
const int block_size = 16;

const int num_selected_index = layer_idx * num_seqs * block_size_times_num +
seq_idx * block_size_times_num +
selected_pairs_idx;

const int64_t src_token_idx = selection_index_src[num_selected_index];
const int64_t tgt_token_idx = selection_index_dst[num_selected_index];
if (src_token_idx < 0 ||
src_token_idx > num_layers * num_seqs * block_size_times_num) {
return;
}

scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);

for (int i = threadIdx.x; i < num_heads * head_size; i += blockDim.x) {
const int64_t src_token_layer_idx =
src_token_idx % (num_seqs * block_size_times_num);
const int64_t tgt_token_layer_idx =
tgt_token_idx % (num_seqs * block_size_times_num);

const int64_t block_idx =
block_mapping_src[src_token_layer_idx / block_size];
const int64_t block_offset = src_token_layer_idx % block_size;

const int64_t tgt_block_idx =
block_mapping_dst[tgt_token_layer_idx / block_size];
const int64_t tgt_block_offset = tgt_token_layer_idx % block_size;

if (block_idx == -1 || tgt_block_idx == -1) {
continue;
}

const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;

const int64_t src_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t src_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key_cache[src_key_idx];
scalar_t tgt_value = value_cache[src_value_idx];

const int64_t tgt_key_idx =
tgt_block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
tgt_block_offset * x + x_offset;
const int64_t tgt_value_idx =
tgt_block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
tgt_block_offset;
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
}
}

} // namespace vllm

void sparse_cache_copy(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping_src_tensor,
const torch::Tensor& block_mapping_dst_tensor,
const torch::Tensor& selection_index_src_tensor,
const torch::Tensor& selection_index_dst_tensor,
const int64_t num_heads, const int64_t head_size,
const int64_t block_size) {
const int num_seqs = block_mapping_src_tensor.size(0);
const int total_num_blocks = block_mapping_src_tensor.size(1);

const int num_layers = key_caches.size();
const int x = 8;

TORCH_CHECK(num_layers == value_caches.size());
TORCH_CHECK(selection_index_src_tensor.size(0) ==
num_layers * block_size * num_seqs * total_num_blocks);
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());

// Create data structures for the kernel.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}

torch::Tensor flat_block_mapping_src_tensor =
torch::flatten(block_mapping_src_tensor);
torch::Tensor flat_block_mapping_dst_tensor =
torch::flatten(block_mapping_dst_tensor);

// Move the data structures to the GPU.
torch::Tensor key_cache_ptrs_tensor =
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor block_mapping_src_tensor_cuda =
flat_block_mapping_src_tensor.clone().to(cache_device);
torch::Tensor block_mapping_dst_tensor_cuda =
flat_block_mapping_dst_tensor.clone().to(cache_device);
torch::Tensor selection_index_src_tensor_cuda =
selection_index_src_tensor.clone().to(cache_device);
torch::Tensor selection_index_dst_tensor_cuda =
selection_index_dst_tensor.clone().to(cache_device);

// Launch the kernel.
dim3 grid(num_layers, num_seqs, block_size * total_num_blocks);
dim3 block(std::min(num_heads * head_size, int64_t(64)));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "sparse_cache_copy_kernel", ([&] {
vllm::sparse_cache_copy_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_src_tensor_cuda.data_ptr<int64_t>(),
block_mapping_dst_tensor_cuda.data_ptr<int64_t>(),
selection_index_src_tensor_cuda.data_ptr<int64_t>(),
selection_index_dst_tensor_cuda.data_ptr<int64_t>(), num_heads,
head_size, x);
}));
}

namespace vllm {

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
Expand Down
Loading