Skip to content

Commit

Permalink
[Performance] Using user-allocated workspace for batch decode/prefill…
Browse files Browse the repository at this point in the history
… handlers (#88)

To avoid the overhead of allocating/destroy memory per step.
  • Loading branch information
yzh119 authored Jan 26, 2024
1 parent 9f49803 commit 51b88d2
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 96 deletions.
127 changes: 69 additions & 58 deletions include/flashinfer/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@

namespace flashinfer {

struct AlignedAlloactor {
void* ptr;
size_t space;
AlignedAlloactor(void* buf, size_t space) : ptr(buf), space(space) {}
template <typename T>
T* aligned_alloc(size_t size, size_t alignment) {
if (std::align(alignment, size, ptr, space)) {
T* result = reinterpret_cast<T*>(ptr);
ptr = (char*)ptr + size;
space -= size;
return result;
} else {
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
}
return nullptr;
}
};

class BatchDecodeHandler {
public:
template <typename DType>
Expand All @@ -36,57 +54,35 @@ class BatchDecodeHandler {
}
template <typename IdType>
IdType* GetNewIndPtr() const {
return (IdType*)int_buffer_;
return (IdType*)new_indptr_;
}
template <typename IdType>
IdType* GetNewLastPageLen() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + batch_size_after_partition_ + 1;
} else {
return nullptr;
}
return (IdType*)new_last_page_len_;
}
template <typename IdType>
IdType* GetChunkIndPtr() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ + 1;
} else {
return nullptr;
}
return (IdType*)chunk_indptr_;
}
template <typename IdType>
IdType* GetBatchIdxMap() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ +
batch_size_before_partition_ + 2;
} else {
return nullptr;
}
return (IdType*)batch_idx_map_;
}
template <typename IdType>
IdType* GetChunkStartPos() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 3 * batch_size_after_partition_ +
batch_size_before_partition_ + 2;
} else {
return nullptr;
}
return (IdType*)chunk_start_pos_;
}
template <typename IdType>
IdType* GetSeqLengthsBeforePartition() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 4 * batch_size_after_partition_ +
batch_size_before_partition_ + 2;
} else {
return nullptr;
}
return (IdType*)seq_lengths_before_partition_;
}

template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BeginForward(IdType* indptr, IdType* last_page_len, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, RotaryMode rotary_mode) {
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
RotaryMode rotary_mode) {
batch_size_before_partition_ = batch_size;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
Expand All @@ -97,10 +93,20 @@ class BatchDecodeHandler {
num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
batch_size_after_partition_ = new_batch_size;
if (tmp_size > 0) {
FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, tmp_size, stream_));
FLASHINFER_CUDA_CALL(cudaMallocAsync(
&int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2),
stream_));
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
float_buffer_ = allocator.aligned_alloc<void*>(tmp_size, 16);
new_indptr_ =
allocator.aligned_alloc<void*>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
new_last_page_len_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
chunk_indptr_ =
allocator.aligned_alloc<void*>((batch_size_before_partition_ + 1) * sizeof(IdType), 16);
batch_idx_map_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
chunk_start_pos_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetChunkIndPtr<IdType>(),
Expand All @@ -115,14 +121,13 @@ class BatchDecodeHandler {
forward_started_ = false;
batch_size_before_partition_ = 0;
batch_size_after_partition_ = 0;
if (float_buffer_ != nullptr) {
FLASHINFER_CUDA_CALL(cudaFreeAsync(float_buffer_, stream_));
float_buffer_ = nullptr;
}
if (int_buffer_ != nullptr) {
FLASHINFER_CUDA_CALL(cudaFreeAsync(int_buffer_, stream_));
int_buffer_ = nullptr;
}
float_buffer_ = nullptr;
new_indptr_ = nullptr;
new_last_page_len_ = nullptr;
chunk_indptr_ = nullptr;
batch_idx_map_ = nullptr;
chunk_start_pos_ = nullptr;
seq_lengths_before_partition_ = nullptr;
return cudaSuccess;
}

Expand All @@ -139,7 +144,12 @@ class BatchDecodeHandler {
BatchDecodeHandler()
: batch_size_after_partition_(0U),
float_buffer_(nullptr),
int_buffer_(nullptr),
new_indptr_(nullptr),
new_last_page_len_(nullptr),
chunk_indptr_(nullptr),
batch_idx_map_(nullptr),
chunk_start_pos_(nullptr),
seq_lengths_before_partition_(nullptr),
forward_started_(false),
stream_(nullptr) {}
~BatchDecodeHandler() { EndForward(); }
Expand All @@ -148,7 +158,12 @@ class BatchDecodeHandler {
uint32_t batch_size_before_partition_;
uint32_t batch_size_after_partition_;
void* float_buffer_;
void* int_buffer_;
void* new_indptr_;
void* new_last_page_len_;
void* chunk_indptr_;
void* batch_idx_map_;
void* chunk_start_pos_;
void* seq_lengths_before_partition_;
bool forward_started_;
cudaStream_t stream_;
};
Expand All @@ -172,8 +187,8 @@ class BatchPrefillHandler {
bool IsForwardStarted() const { return request_indices_ != nullptr; }

template <typename IdType>
cudaError_t BeginForward(IdType* qo_indptr, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads) {
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
Expand All @@ -184,8 +199,10 @@ class BatchPrefillHandler {
std::vector<IdType> request_indices_h, tile_indices_h;
std::tie(num_frags_x_, num_qo_tiles_, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr, batch_size, gqa_group_size, stream_);
FLASHINFER_CUDA_CALL(cudaMalloc(&request_indices_, sizeof(IdType) * request_indices_h.size()));
FLASHINFER_CUDA_CALL(cudaMalloc(&tile_indices_, sizeof(IdType) * tile_indices_h.size()));
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
request_indices_ =
allocator.aligned_alloc<void*>(sizeof(IdType) * request_indices_h.size(), 16);
tile_indices_ = allocator.aligned_alloc<void*>(sizeof(IdType) * tile_indices_h.size(), 16);
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, request_indices_h.data(),
sizeof(IdType) * request_indices_h.size(),
cudaMemcpyHostToDevice, stream_));
Expand All @@ -199,14 +216,8 @@ class BatchPrefillHandler {
forward_started_ = false;
num_frags_x_ = 0U;
num_qo_tiles_ = 0U;
if (request_indices_ != nullptr) {
FLASHINFER_CUDA_CALL(cudaFreeAsync(request_indices_, stream_));
request_indices_ = nullptr;
}
if (tile_indices_ != nullptr) {
FLASHINFER_CUDA_CALL(cudaFreeAsync(tile_indices_, stream_));
tile_indices_ = nullptr;
}
request_indices_ = nullptr;
tile_indices_ = nullptr;
return cudaSuccess;
}

Expand Down
11 changes: 8 additions & 3 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,26 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache_return_lse(
}

void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned int page_size, unsigned int rotary_mode, torch::Tensor empty_data) {
torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, unsigned int rotary_mode,
torch::Tensor empty_data) {
// NOTE(zihao): not necessary to be CUDA tensor
CHECK_CONTIGUOUS(indptr);
CHECK_CONTIGUOUS(last_page_len);
CHECK_CONTIGUOUS(workspace_buffer);
CHECK_DIM(1, indptr);
CHECK_DIM(1, last_page_len);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
SWITCH_LAYOUT(kv_layout_, KV_LAYOUT, {
cudaError_t status =
handler_.BeginForward<PageStorage::kIndices, KV_LAYOUT, c_type, c_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, RotaryMode(rotary_mode));
Expand Down
12 changes: 9 additions & 3 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,25 @@ torch::Tensor batch_prefill_with_paged_kv_cache(
return o;
}

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor qo_indptr,
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer,
torch::Tensor qo_indptr,
unsigned int batch_size,
unsigned int num_qo_heads,
unsigned int num_kv_heads) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
CHECK_EQ(num_qo_heads % num_kv_heads, 0);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, workspace_buffer);

// TODO(Zihao): support dispatching to different index data types.
CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();

cudaError_t status = handler_.BeginForward(static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads);
cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
Expand Down
11 changes: 6 additions & 5 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
static BatchDecodeWithPagedKVCachePyTorchWrapper Create(unsigned int layout) {
return BatchDecodeWithPagedKVCachePyTorchWrapper(layout);
}
void BeginForward(torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned int page_size, unsigned int rotary_mode, torch::Tensor empty_data);
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int rotary_mode, torch::Tensor empty_data);
void EndForward();
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor paged_kv_data,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
Expand All @@ -78,8 +79,8 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
static BatchPrefillWithPagedKVCachePyTorchWrapper Create(unsigned int layout) {
return BatchPrefillWithPagedKVCachePyTorchWrapper(layout);
}
void BeginForward(torch::Tensor qo_indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads);
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
void EndForward();
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
Expand Down
17 changes: 14 additions & 3 deletions python/flashinfer/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,17 @@ class BatchDecodeWithPagedKVCacheWrapper:
the lifecycle of these data structures.
"""

def __init__(self, kv_layout: str = "NHD"):
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
_check_kv_layout(kv_layout)
self.kv_layout = kv_layout
self.workspace_buffer = workspace_buffer
self._wrapper = _kernels.BatchDecodeWithPagedKVCachePyTorchWrapper(
getattr(TensorLayout, kv_layout)
)

def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
self.workspace_buffer = workspace_buffer

def begin_forward(
self,
indptr: torch.Tensor,
Expand All @@ -553,6 +557,7 @@ def begin_forward(
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
empty_data = torch.empty(0, dtype=getattr(torch, data_type))
self._wrapper.begin_forward(
self.workspace_buffer,
indptr,
last_page_len,
batch_size,
Expand Down Expand Up @@ -630,21 +635,27 @@ def forward_return_lse(
class BatchPrefillWithPagedKVCacheWrapper:
r"""Wrapper class of batch_prefill_with_paged_kv_cache kernel."""

def __init__(self, kv_layout: str = "NHD"):
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
_check_kv_layout(kv_layout)
self.kv_layout = kv_layout
self.workspace_buffer = workspace_buffer
self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper(
getattr(TensorLayout, kv_layout)
)

def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
self.workspace_buffer = workspace_buffer

def begin_forward(
self,
qo_indptr: torch.Tensor,
batch_size: int,
num_qo_heads: int,
num_kv_heads: int,
):
self._wrapper.begin_forward(qo_indptr, batch_size, num_qo_heads, num_kv_heads)
self._wrapper.begin_forward(
self.workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads
)

def end_forward(self):
self._wrapper.end_forward()
Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test_batch_decode_with_paged_kv_cache(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32
).to(0)

wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(kv_layout)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)
wrapper.begin_forward(
kv_indptr,
kv_last_page_len,
Expand Down
5 changes: 4 additions & 1 deletion python/tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def test_batch_prefill_with_paged_kv_cache(
).to(0)

if use_wrapper:
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(kv_layout)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
wrapper.begin_forward(q_indptr, batch_size, num_qo_heads, num_kv_heads)
o = wrapper.forward(
q, q_indptr, kv_data, kv_indptr, kv_indices, kv_last_page_len
Expand Down
9 changes: 8 additions & 1 deletion src/bench_batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ void bench_flashinfer_batch_decode(nvbench::state& state) {
BatchDecodeHandler handler;

if (cooperative) {
size_t workspace_size_in_bytes = 32 * 1024 * 1024;
thrust::device_vector<char> buffer(workspace_size_in_bytes);
// begin forward
handler.BeginForward<PageStorage::kIndices, kv_layout, T, T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes,
kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads,
head_dim, page_size, rotary_mode);
state.exec([&](nvbench::launch&) {
Expand Down Expand Up @@ -144,7 +147,11 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
"Read");
state.add_global_memory_writes<uint8_t>(vec_bytes(o), "Write");
BatchPrefillHandler handler;
handler.BeginForward(qo_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads);
size_t workspace_size_in_bytes = 32 * 1024 * 1024;
thrust::device_vector<char> buffer(workspace_size_in_bytes);

handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes,
qo_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads);

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper(
Expand Down
Loading

0 comments on commit 51b88d2

Please sign in to comment.