-
Notifications
You must be signed in to change notification settings - Fork 163
/
bench_batch_decode.cu
188 lines (173 loc) · 9.48 KB
/
bench_batch_decode.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/device_vector.h>
#include <cstddef>
#include <cstdint>
#include <nvbench/nvbench.cuh>
#include <vector>
#include "flashinfer_ops.cuh"
#include "utils.h"
using utils::vec_bytes;
using namespace flashinfer;
constexpr QKVLayout kv_layout = QKVLayout::kNHD;
template <typename T, typename TKV>
void bench_flashinfer_batch_decode(nvbench::state& state) {
constexpr size_t head_dim = 128;
constexpr auto pos_encoding_mode = PosEncodingMode::kNone;
size_t seqlen = state.get_int64("seqlen");
size_t batch_size = state.get_int64("batch_size");
size_t page_size = state.get_int64("page_size");
size_t num_qo_heads = state.get_int64("num_qo_heads");
size_t num_kv_heads = state.get_int64("num_kv_heads");
// KV cache:
auto pages_per_seq = (seqlen + page_size - 1) / page_size;
auto num_pages = pages_per_seq * batch_size;
std::vector<int32_t> kv_indptr_host{0};
std::vector<int32_t> kv_indicies_host;
std::vector<int32_t> kv_last_page_len_host;
for (size_t i = 0; i < batch_size; ++i) {
for (size_t p = 0; p < pages_per_seq; ++p) {
kv_indicies_host.push_back(i * pages_per_seq + p);
}
kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq);
kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1);
}
thrust::device_vector<TKV> k_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<TKV> v_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<int32_t> kv_indptr(kv_indptr_host);
thrust::device_vector<int32_t> kv_indices(kv_indicies_host);
thrust::device_vector<int32_t> kv_last_page_len(kv_last_page_len_host);
paged_kv_t<TKV, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()),
thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()),
thrust::raw_pointer_cast(kv_last_page_len.data()));
// Allocate input data:
thrust::device_vector<T> q(batch_size * num_qo_heads * head_dim);
thrust::device_vector<T> o(batch_size * num_qo_heads * head_dim);
state.add_global_memory_reads<uint8_t>(
vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(TKV) +
vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len),
"Read");
state.add_global_memory_writes<uint8_t>(vec_bytes(o), "Write");
BatchDecodeHandler handler;
size_t float_workspace_size_in_bytes = 32 * 1024 * 1024;
thrust::device_vector<char> float_buffer(float_workspace_size_in_bytes);
size_t int_workspace_size_in_bytes = 8 * 1024 * 1024;
thrust::device_vector<char> int_buffer(int_workspace_size_in_bytes);
// begin forward
BatchDecodeHandlerPlan<T, TKV, T, int32_t>(
&handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads,
head_dim, page_size, pos_encoding_mode);
state.exec([&](nvbench::launch&) {
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<T, TKV, T, int32_t>(
&handler, thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv,
thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
}
});
}
template <typename T, typename TKV>
void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
constexpr size_t head_dim = 128;
constexpr auto pos_encoding_mode = PosEncodingMode::kNone;
size_t seqlen = state.get_int64("seqlen");
size_t batch_size = state.get_int64("batch_size");
size_t page_size = state.get_int64("page_size");
size_t num_qo_heads = state.get_int64("num_qo_heads");
size_t num_kv_heads = state.get_int64("num_kv_heads");
// KV cache:
auto pages_per_seq = (seqlen + page_size - 1) / page_size;
auto num_pages = pages_per_seq * batch_size;
std::vector<int32_t> kv_indptr_host{0};
std::vector<int32_t> kv_indicies_host;
std::vector<int32_t> kv_last_page_len_host;
for (size_t i = 0; i < batch_size; ++i) {
for (size_t p = 0; p < pages_per_seq; ++p) {
kv_indicies_host.push_back(i * pages_per_seq + p);
}
kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq);
kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1);
}
thrust::device_vector<TKV> k_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<TKV> v_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<int32_t> kv_indptr(kv_indptr_host);
thrust::device_vector<int32_t> kv_indices(kv_indicies_host);
thrust::device_vector<int32_t> kv_last_page_len(kv_last_page_len_host);
paged_kv_t<TKV, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()),
thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()),
thrust::raw_pointer_cast(kv_last_page_len.data()));
// Allocate input data:
thrust::device_vector<T> q(batch_size * num_qo_heads * head_dim);
thrust::device_vector<T> o(batch_size * num_qo_heads * head_dim);
std::vector<int32_t> qo_indptr_h{0};
for (uint32_t i = 0; i < batch_size; ++i) {
qo_indptr_h.push_back(qo_indptr_h.back() + 1);
}
thrust::device_vector<int32_t> qo_indptr_d(qo_indptr_h);
state.add_global_memory_reads<uint8_t>(
vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(TKV) +
vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len),
"Read");
state.add_global_memory_writes<uint8_t>(vec_bytes(o), "Write");
BatchPrefillHandler handler;
size_t float_workspace_size_in_bytes = 128 * 1024 * 1024;
thrust::device_vector<char> float_buffer(float_workspace_size_in_bytes);
size_t int_workspace_size_in_bytes = 8 * 1024 * 1024;
thrust::device_vector<char> int_buffer(int_workspace_size_in_bytes);
handler.Plan<T, int32_t>(
(void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
qo_indptr_h.data(), kv_indptr_host.data(), /*total_num_rows=*/batch_size, batch_size,
num_qo_heads, num_kv_heads, head_dim, page_size);
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, TKV, T, int32_t>(
&handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()),
/*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()),
/*lse=*/nullptr, num_qo_heads,
/*causal=*/false, pos_encoding_mode);
});
}
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#define BENCH_FLASHINFER_BATCH_DECODE(dtype, dtypeKV) \
auto bench_flashinfer_batch_decode_##dtype##_ = bench_flashinfer_batch_decode<dtype, dtypeKV>; \
NVBENCH_BENCH(bench_flashinfer_batch_decode_##dtype##_) \
.set_name("bench_flashinfer_batch_decode_" STR(dtype) STR(dtypeKV)) \
.add_int64_axis("seqlen", \
{32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \
.add_int64_axis("batch_size", {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024}) \
.add_int64_axis("page_size", {16}) \
.add_int64_axis("num_qo_heads", {32}) \
.add_int64_axis("num_kv_heads", {32, 4})
#define BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(dtype, dtypeKV) \
auto bench_flashinfer_batch_decode_with_prefill_##dtype##_ = \
bench_flashinfer_batch_decode_with_prefill<dtype, dtypeKV>; \
NVBENCH_BENCH(bench_flashinfer_batch_decode_with_prefill_##dtype##_) \
.set_name("bench_flashinfer_batch_decode_with_prefill_" STR(dtype) STR(dtypeKV)) \
.add_int64_axis("seqlen", \
{32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \
.add_int64_axis("batch_size", {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024}) \
.add_int64_axis("page_size", {16}) \
.add_int64_axis("num_qo_heads", {32}) \
.add_int64_axis("num_kv_heads", {32, 4})
BENCH_FLASHINFER_BATCH_DECODE(half, half);
BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(half, half);