Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
319c660
Split prefills and decode
LucasWilkinson Oct 9, 2025
e28bda0
workspace refactor
LucasWilkinson Oct 10, 2025
36168f4
wip
LucasWilkinson Oct 10, 2025
4a22155
cleanup
LucasWilkinson Oct 10, 2025
4399ec8
clean-up
LucasWilkinson Oct 11, 2025
2eceb39
booting
LucasWilkinson Oct 11, 2025
bb7ba88
Merge remote-tracking branch 'origin/main' into lwilkinson/split-pref…
LucasWilkinson Oct 11, 2025
337698c
cleanup
LucasWilkinson Oct 13, 2025
cd5edd1
Merge remote-tracking branch 'origin/main' into lwilkinson/split-pref…
LucasWilkinson Oct 13, 2025
5330f95
cleanup
LucasWilkinson Oct 13, 2025
376eeb9
cleanup
LucasWilkinson Oct 13, 2025
0fba98b
cleanup
LucasWilkinson Oct 13, 2025
3bda203
cleanup
LucasWilkinson Oct 13, 2025
d4d4522
cleanup
LucasWilkinson Oct 14, 2025
05b957b
cleanup
LucasWilkinson Oct 14, 2025
a4e28d5
fix
LucasWilkinson Oct 14, 2025
9e4c72c
Merge remote-tracking branch 'origin/main' into lwilkinson/split-pref…
LucasWilkinson Oct 14, 2025
e4957f3
wip
LucasWilkinson Oct 16, 2025
2105618
Merge remote-tracking branch 'origin/main' into lwilkinson/split-pref…
LucasWilkinson Oct 16, 2025
aee40b9
cleanup
LucasWilkinson Oct 16, 2025
5904207
fix
LucasWilkinson Oct 17, 2025
5bf86a9
fix server boot
LucasWilkinson Oct 17, 2025
6a6cee8
fix precommit
LucasWilkinson Oct 17, 2025
cb922d1
cleanup
LucasWilkinson Oct 17, 2025
0fbee26
add back kernels
LucasWilkinson Oct 17, 2025
2ccdf18
add back missing code
LucasWilkinson Oct 17, 2025
b44cf4c
cleanup
LucasWilkinson Oct 17, 2025
cc27394
rename
LucasWilkinson Oct 17, 2025
4b38ab2
Merge remote-tracking branch 'origin/main' into lwilkinson/split-pref…
LucasWilkinson Oct 18, 2025
cbcc9ad
Update vllm/v1/attention/backends/mla/flashmla_sparse.py
LucasWilkinson Oct 20, 2025
ac4b00d
Merge branch 'lwilkinson/split-prefills-and-decode' of https://github…
LucasWilkinson Oct 20, 2025
f1fe1fd
resvere MoE workspaces
LucasWilkinson Oct 23, 2025
ed47577
cleanup
LucasWilkinson Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion csrc/cache.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <torch/all.h>
#include <c10/util/Optional.h>

#include <map>
#include <vector>
Expand Down Expand Up @@ -71,4 +72,12 @@ void cp_gather_indexer_k_quant_cache(
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]

torch::Tensor convert_req_index_to_global_index_and_upconvert_prefills(
torch::Tensor req_id, torch::Tensor block_table,
torch::Tensor token_indices, int64_t block_size,
const std::optional<torch::Tensor>& prefill_mask,
const std::optional<torch::Tensor>& prefill_seen,
const std::optional<torch::Tensor>& prefill_bf16_workspace,
const std::optional<torch::Tensor>& kv_cache);
277 changes: 275 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/Optional.h>

#include "cuda_utils.h"
#include "cuda_compat.h"
Expand Down Expand Up @@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format

const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
Expand Down Expand Up @@ -1238,7 +1240,278 @@ void indexer_k_quant_and_cache(
CALL_INDEXER_K_QUANT_AND_CACHE);
}

// Macro to dispatch the kernel based on the data amount.
namespace vllm {

// Device function to cooperatively upconvert a single token from fp8 to bf16
// Requires blockDim.x >= 576 for optimal parallelism (fp8 + rope in parallel)
__device__ void upconvert_single_token(
const uint8_t* __restrict__ src_cache,
__nv_bfloat16* __restrict__ dst_workspace, int32_t token_index,
int64_t block_stride, int64_t entry_stride, int32_t block_size,
int32_t head_dim) {
const int64_t block_idx = token_index / block_size;
const int64_t block_offset = token_index % block_size;
const uint8_t* token_ptr =
src_cache + block_idx * block_stride + block_offset * entry_stride;
__nv_bfloat16* dst_ptr =
dst_workspace + token_index * static_cast<int64_t>(head_dim);

const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr =
reinterpret_cast<const float*>(token_ptr + 512); // 4 tiles of 128
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);

const int tid = threadIdx.x;

// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
// Threads 0-511: handle fp8 dequantization
// Threads 512-575: handle rope copy
// Threads 576+: idle

if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}
// Threads 576-1023 are idle during upconvert
}

// Fused kernel: convert per-request indices to global slots and upconvert
// unique prefill tokens
__global__ void convert_req_index_to_global_index_and_upconvert_prefills_kernel(
const int32_t* __restrict__ req_id, // [num_tokens]
const int32_t* __restrict__ block_table, // [num_requests,
// max_num_blocks_per_req]
const int32_t* __restrict__ token_indices, // [num_tokens, NUM_TOPK_TOKENS]
int32_t* __restrict__ out, // [num_tokens, NUM_TOPK_TOKENS]
const int32_t* __restrict__ prefill_mask, // [num_tokens] or nullptr
int32_t* __restrict__ prefill_seen, // [prefill_seen_size] or nullptr
__nv_bfloat16* __restrict__ prefill_bf16_workspace, // [num_slots,
// head_dim] or nullptr
const uint8_t* __restrict__ kv_cache, // [num_blocks, block_size, 656] or
// nullptr
int num_topk_tokens, int block_size, int max_num_blocks_per_req,
int bt_stride0, int bt_stride1, int ti_stride0, int ti_stride1,
int out_stride0, int out_stride1, int prefill_seen_size,
int64_t kv_block_stride, int64_t kv_entry_stride, int32_t head_dim,
bool has_prefill) {
const int token_id = blockIdx.x;
const int tid = threadIdx.x;

// Shared memory for batching upconvert operations
__shared__ int32_t tokens_to_upconvert[1024]; // One slot per thread
__shared__ int32_t num_tokens_to_upconvert;

// Outer loop over topk_indices - process in waves
for (int indice_id = tid; indice_id < num_topk_tokens;
indice_id += blockDim.x) {
// Initialize shared counter for this iteration
if (threadIdx.x == 0) {
num_tokens_to_upconvert = 0;
}
__syncthreads();

// Load request id for this token
const int req = req_id[token_id];

// Load token index
const int ti_offset = token_id * ti_stride0 + indice_id * ti_stride1;
const int tok = token_indices[ti_offset];

// Check if token is invalid
bool is_invalid = tok < 0;

// Compute block id and in-block offset
const int block_id = tok / block_size;
const int inblock_off = tok % block_size;

// Guard block_table access
const bool valid_block = block_id < max_num_blocks_per_req;
int base = 0;
if (valid_block) {
const int bt_offset = req * bt_stride0 + block_id * bt_stride1;
base = block_table[bt_offset];
}
is_invalid = is_invalid || !valid_block;

// Compute output value
const int out_val = is_invalid ? -1 : (base * block_size + inblock_off);

// Store result
const int out_offset = token_id * out_stride0 + indice_id * out_stride1;
out[out_offset] = out_val;

// Handle prefill unique tracking - queue tokens for upconversion
if (has_prefill && prefill_mask != nullptr && !is_invalid) {
const int is_prefill = prefill_mask[token_id];
if (is_prefill != 0 && out_val >= 0 && out_val < prefill_seen_size) {
// Optimistic coherent read from L2 to skip atomics when already seen
int seen = __ldcg(prefill_seen + out_val);

if (!seen) {
// Try to acquire the lock using atomic CAS
seen = atomicCAS(prefill_seen + out_val, 0, 1);

if (!seen) {
// We won the race - queue this token for upconversion
int idx = atomicAdd(&num_tokens_to_upconvert, 1);
tokens_to_upconvert[idx] = out_val;
}
}
}
}

__syncthreads();

// Cooperatively upconvert all queued tokens
if (num_tokens_to_upconvert > 0 && kv_cache != nullptr &&
prefill_bf16_workspace != nullptr) {
for (int i = 0; i < num_tokens_to_upconvert; i++) {
int32_t token_index = tokens_to_upconvert[i];
// All threads cooperate on upconverting this token
upconvert_single_token(kv_cache, prefill_bf16_workspace, token_index,
kv_block_stride, kv_entry_stride, block_size,
head_dim);
// CRITICAL: Sync after each token to prevent races
__syncthreads();
}
}

// CRITICAL: Sync before next outer loop iteration to prevent thread 0
// from resetting the counter while other threads are still reading it
__syncthreads();
}
}

} // namespace vllm

// Host function to launch the fused convert + upconvert kernel
torch::Tensor convert_req_index_to_global_index_and_upconvert_prefills(
torch::Tensor req_id, // int32 [num_tokens]
torch::Tensor block_table, // int32 [num_requests, max_num_blocks_per_req]
torch::Tensor token_indices, // int32 [num_tokens, NUM_TOPK_TOKENS]
int64_t block_size, // KV cache block size
const std::optional<torch::Tensor>& prefill_mask, // int32 [num_tokens]
const std::optional<torch::Tensor>&
prefill_seen, // int32 [prefill_seen_size]
const std::optional<torch::Tensor>&
prefill_bf16_workspace, // bf16 [num_slots, head_dim]
const std::optional<torch::Tensor>&
kv_cache // uint8 [num_blocks, block_size, 656]
) {
constexpr int THREADS_PER_BLOCK = 1024;
constexpr int MIN_THREADS_FOR_UPCONVERT = 576; // 512 fp8 + 64 rope
static_assert(
THREADS_PER_BLOCK >= MIN_THREADS_FOR_UPCONVERT,
"Need at least 576 threads for parallel fp8 dequant + rope copy");

// Validate input tensors
TORCH_CHECK(req_id.is_cuda(), "req_id must be a CUDA tensor");
TORCH_CHECK(block_table.is_cuda(), "block_table must be a CUDA tensor");
TORCH_CHECK(token_indices.is_cuda(), "token_indices must be a CUDA tensor");
TORCH_CHECK(req_id.dtype() == torch::kInt32, "req_id must be int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(token_indices.dtype() == torch::kInt32,
"token_indices must be int32");

// Ensure contiguous
req_id = req_id.contiguous();
block_table = block_table.contiguous();
token_indices = token_indices.contiguous();

// Extract dimensions
const int num_tokens = req_id.size(0);
const int num_topk_tokens = token_indices.size(1);
const int max_num_blocks_per_req = block_table.size(1);

// Create output tensor
auto out = torch::empty_like(token_indices);

// Extract strides
const int bt_stride0 = block_table.stride(0);
const int bt_stride1 = block_table.stride(1);
const int ti_stride0 = token_indices.stride(0);
const int ti_stride1 = token_indices.stride(1);
const int out_stride0 = out.stride(0);
const int out_stride1 = out.stride(1);

// Handle optional prefill tensors
bool has_prefill = prefill_mask.has_value();
const int32_t* prefill_mask_ptr = nullptr;
int32_t* prefill_seen_ptr = nullptr;
__nv_bfloat16* prefill_bf16_workspace_ptr = nullptr;
const uint8_t* kv_cache_ptr = nullptr;
int prefill_seen_size = 0;
int64_t kv_block_stride = 0;
int64_t kv_entry_stride = 0;
int32_t head_dim = 0;

if (has_prefill) {
TORCH_CHECK(
prefill_mask.has_value() && prefill_seen.has_value() &&
prefill_bf16_workspace.has_value() && kv_cache.has_value(),
"All prefill tensors must be provided together for fused kernel");

auto& pfm = prefill_mask.value();
auto& pfs = prefill_seen.value();
auto& pbw = prefill_bf16_workspace.value();
auto& kvc = kv_cache.value();

TORCH_CHECK(pfm.is_cuda(), "prefill_mask must be a CUDA tensor");
TORCH_CHECK(pfs.is_cuda(), "prefill_seen must be a CUDA tensor");
TORCH_CHECK(pbw.is_cuda(), "prefill_bf16_workspace must be a CUDA tensor");
TORCH_CHECK(kvc.is_cuda(), "kv_cache must be a CUDA tensor");
TORCH_CHECK(pfm.is_contiguous(), "prefill_mask must be contiguous");
TORCH_CHECK(pfs.is_contiguous(), "prefill_seen must be contiguous");
TORCH_CHECK(pbw.is_contiguous(),
"prefill_bf16_workspace must be contiguous");
TORCH_CHECK(kvc.is_contiguous(), "kv_cache must be contiguous");
TORCH_CHECK(pbw.dtype() == torch::kBFloat16,
"prefill_bf16_workspace must be bfloat16");
TORCH_CHECK(kvc.dtype() == torch::kUInt8, "kv_cache must be uint8");

prefill_mask_ptr = pfm.data_ptr<int32_t>();
prefill_seen_ptr = pfs.data_ptr<int32_t>();
prefill_bf16_workspace_ptr =
reinterpret_cast<__nv_bfloat16*>(pbw.data_ptr());
kv_cache_ptr = kvc.data_ptr<uint8_t>();
prefill_seen_size = pfs.size(0);
kv_block_stride = kvc.stride(0);
kv_entry_stride = kvc.stride(1);
head_dim = pbw.size(1);
}

// Get CUDA stream
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Launch kernel with 1024 threads per block
dim3 grid(num_tokens);
dim3 block(THREADS_PER_BLOCK);

vllm::convert_req_index_to_global_index_and_upconvert_prefills_kernel<<<
grid, block, 0, stream>>>(
req_id.data_ptr<int32_t>(), block_table.data_ptr<int32_t>(),
token_indices.data_ptr<int32_t>(), out.data_ptr<int32_t>(),
prefill_mask_ptr, prefill_seen_ptr, prefill_bf16_workspace_ptr,
kv_cache_ptr, num_topk_tokens, block_size, max_num_blocks_per_req,
bt_stride0, bt_stride1, ti_stride0, ti_stride1, out_stride0, out_stride1,
prefill_seen_size, kv_block_stride, kv_entry_stride, head_dim,
has_prefill);

return out;
}

// Macro to dispatch the kernel based on the data type.
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);

torch::Tensor convert_req_index_to_global_index(
torch::Tensor req_id, torch::Tensor block_table,
torch::Tensor token_indices, int64_t block_size,
const std::optional<torch::Tensor>& prefill_mask,
const std::optional<torch::Tensor>& prefill_seen,
const std::optional<torch::Tensor>& prefill_bf16_workspace,
const std::optional<torch::Tensor>& kv_cache);

void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
Expand Down
14 changes: 14 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);

ops.def(
"convert_req_index_to_global_index_and_upconvert_prefills("
" Tensor req_id,"
" Tensor block_table,"
" Tensor token_indices,"
" int block_size,"
" Tensor? prefill_mask,"
" Tensor? prefill_seen,"
" Tensor? prefill_bf16_workspace,"
" Tensor? kv_cache) -> Tensor");
ops.impl("convert_req_index_to_global_index_and_upconvert_prefills",
torch::kCUDA,
&convert_req_index_to_global_index_and_upconvert_prefills);

ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
Expand Down
Loading