@@ -105,40 +105,71 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
105105}
106106
107107std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward (
108- torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
109- torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
110- unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale,
111- float rope_theta, bool return_lse) {
108+ torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
109+ std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
110+ torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
111+ torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap,
112+ float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
112113 CHECK_INPUT (q);
113- CHECK_INPUT (paged_kv_data);
114+ bool paged_kv_defined = paged_kv_cache.has_value ();
115+ if (paged_kv_defined) {
116+ CHECK_INPUT (paged_kv_cache.value ());
117+ } else {
118+ CHECK_INPUT (paged_k_cache.value ());
119+ CHECK_INPUT (paged_v_cache.value ());
120+ }
114121 CHECK_INPUT (paged_kv_indptr);
115122 CHECK_INPUT (paged_kv_indices);
116123 CHECK_INPUT (paged_kv_last_page_len);
117124 auto device = q.device ();
118- CHECK_EQ (paged_kv_data.device (), device);
125+ if (paged_kv_defined) {
126+ CHECK_EQ (paged_kv_cache->device (), device);
127+ } else {
128+ CHECK_EQ (paged_k_cache->device (), device);
129+ CHECK_EQ (paged_v_cache->device (), device);
130+ }
119131 CHECK_EQ (paged_kv_indices.device (), device);
120132 CHECK_EQ (paged_kv_indptr.device (), device);
121133 CHECK_EQ (paged_kv_last_page_len.device (), device);
122134 CHECK_DIM (3 , q); // (B, H_qo, D)
123135 CHECK_DIM (1 , paged_kv_last_page_len); // (B,)
124136 CHECK_DIM (1 , paged_kv_indptr); // (B+1,)
125137 CHECK_DIM (1 , paged_kv_indices); // (nnz,)
126- // (num_max_pages, 2, H_kv, page_size, head_dim) for HND
127- // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
128- CHECK_DIM (5 , paged_kv_data);
138+ if (paged_kv_defined) {
139+ // (num_max_pages, 2, H_kv, page_size, head_dim) for HND
140+ // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
141+ CHECK_DIM (5 , paged_kv_cache.value ());
142+ } else {
143+ // (num_max_pages, H_kv, page_size, head_dim) for HND
144+ // (num_max_pages, page_size, H_kv, head_dim) for NHD
145+ CHECK_DIM (4 , paged_k_cache.value ());
146+ CHECK_DIM (4 , paged_v_cache.value ());
147+ }
129148 int64_t batch_size = q.size (0 );
130149 int64_t num_qo_heads = q.size (1 );
131150 int64_t head_dim = q.size (2 );
132151 int64_t num_kv_heads, page_size;
133- if (kv_layout_ == QKVLayout::kHND ) {
134- num_kv_heads = paged_kv_data.size (2 );
135- page_size = paged_kv_data.size (3 );
152+ if (paged_kv_defined) {
153+ CHECK_EQ (paged_kv_cache->size (1 ), 2 );
154+ CHECK_EQ (paged_kv_cache->size (4 ), head_dim);
155+ if (kv_layout_ == QKVLayout::kHND ) {
156+ num_kv_heads = paged_kv_cache->size (2 );
157+ page_size = paged_kv_cache->size (3 );
158+ } else {
159+ page_size = paged_kv_cache->size (2 );
160+ num_kv_heads = paged_kv_cache->size (3 );
161+ }
136162 } else {
137- page_size = paged_kv_data.size (2 );
138- num_kv_heads = paged_kv_data.size (3 );
163+ CHECK_EQ (paged_k_cache->size (3 ), head_dim);
164+ CHECK_EQ (paged_v_cache->size (3 ), head_dim);
165+ if (kv_layout_ == QKVLayout::kHND ) {
166+ num_kv_heads = paged_k_cache->size (1 );
167+ page_size = paged_k_cache->size (2 );
168+ } else {
169+ page_size = paged_k_cache->size (1 );
170+ num_kv_heads = paged_k_cache->size (2 );
171+ }
139172 }
140- CHECK_EQ (paged_kv_data.size (1 ), 2 );
141- CHECK_EQ (paged_kv_data.size (4 ), head_dim);
142173 CHECK_GE (paged_kv_indptr.size (0 ), batch_size + 1 );
143174 CHECK_GE (paged_kv_last_page_len.size (0 ), batch_size);
144175 // TODO(Zihao): support dispatching to different data types
@@ -159,7 +190,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
159190 logits_soft_cap > 0 .f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone ;
160191
161192 auto q_scalar_type = q.scalar_type ();
162- auto kv_scalar_type = paged_kv_data.scalar_type ();
193+ auto kv_scalar_type =
194+ paged_kv_defined ? paged_kv_cache->scalar_type () : paged_k_cache->scalar_type ();
163195
164196 if (q_scalar_type == kv_scalar_type) {
165197 DISPATCH_PYTORCH_DTYPE_TO_CTYPE (q_scalar_type, qkv_type, [&] {
@@ -169,7 +201,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
169201 PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
170202 paged_kv_t <PageStorage::kIndices , qkv_type, int32_t > paged_kv (
171203 num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
172- static_cast <qkv_type*>(paged_kv_data.data_ptr ()),
204+ static_cast <qkv_type*>(paged_kv_cache.has_value () ? paged_kv_cache->data_ptr ()
205+ : nullptr ),
206+ static_cast <qkv_type*>(paged_k_cache.has_value () ? paged_k_cache->data_ptr ()
207+ : nullptr ),
208+ static_cast <qkv_type*>(paged_v_cache.has_value () ? paged_v_cache->data_ptr ()
209+ : nullptr ),
173210 static_cast <int32_t *>(paged_kv_indices.data_ptr ()),
174211 static_cast <int32_t *>(paged_kv_indptr.data_ptr ()),
175212 static_cast <int32_t *>(paged_kv_last_page_len.data_ptr ()));
@@ -197,7 +234,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
197234 PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
198235 paged_kv_t <PageStorage::kIndices , kv_type, int32_t > paged_kv (
199236 num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
200- static_cast <kv_type*>(paged_kv_data.data_ptr ()),
237+ static_cast <kv_type*>(paged_kv_cache.has_value () ? paged_kv_cache->data_ptr ()
238+ : nullptr ),
239+ static_cast <kv_type*>(paged_k_cache.has_value () ? paged_k_cache->data_ptr ()
240+ : nullptr ),
241+ static_cast <kv_type*>(paged_v_cache.has_value () ? paged_v_cache->data_ptr ()
242+ : nullptr ),
201243 static_cast <int32_t *>(paged_kv_indices.data_ptr ()),
202244 static_cast <int32_t *>(paged_kv_indptr.data_ptr ()),
203245 static_cast <int32_t *>(paged_kv_last_page_len.data_ptr ()));
0 commit comments