Skip to content

Commit d4146fb

Browse files
authored
bugfix: fix the compilation issue of pip wheels (#115)
This PR fixes #113, which is because #69 changed the `BatchPrefillWithPagedKVCacheWrapperDispatched` signature, and `flashinfer_decl.h` was not updated accordingly. Also fixes some tiny format issues in #111.
1 parent 1306d11 commit d4146fb

File tree

16 files changed

+44
-30
lines changed

16 files changed

+44
-30
lines changed

include/flashinfer/wrapper.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
207207
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
208208
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
209209
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
210-
handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size,
211-
num_kv_heads, rope_scale, rope_theta, stream);
210+
handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr,
211+
/*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads,
212+
rope_scale, rope_theta, stream);
212213
})})})})})});
213214
return cudaSuccess;
214215
}

python/csrc/batch_prefill.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
216216
&handler_, static_cast<c_type*>(q.data_ptr()),
217217
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<c_type*>(k.data_ptr()),
218218
static_cast<c_type*>(v.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
219+
/*q_rope_position=*/nullptr, /*k_rope_pos_offset=*/nullptr,
219220
static_cast<c_type*>(o.data_ptr()),
220221
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
221222
num_kv_heads, rope_scale, rope_theta,

python/csrc/cascade.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
4444
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());
4545

4646
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v_a.scalar_type(), c_type, [&] {
47-
cudaError_t status =
48-
MergeState(static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
49-
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
50-
static_cast<c_type*>(v_merged.data_ptr()),
51-
static_cast<float*>(s_merged.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
47+
cudaError_t status = MergeState(
48+
static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
49+
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
50+
static_cast<c_type*>(v_merged.data_ptr()), static_cast<float*>(s_merged.data_ptr()),
51+
seq_len, num_heads, head_dim, torch_current_stream);
5252
TORCH_CHECK(status == cudaSuccess,
5353
"MergeState kernel launch failed: ", cudaGetErrorString(status));
5454
return true;
@@ -80,10 +80,10 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
8080
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
8181

8282
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] {
83-
cudaError_t status =
84-
MergeStateInPlace(static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
85-
static_cast<c_type*>(v_other.data_ptr()),
86-
static_cast<float*>(s_other.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
83+
cudaError_t status = MergeStateInPlace(
84+
static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
85+
static_cast<c_type*>(v_other.data_ptr()), static_cast<float*>(s_other.data_ptr()), seq_len,
86+
num_heads, head_dim, torch_current_stream);
8787
TORCH_CHECK(status == cudaSuccess,
8888
"MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status));
8989
return true;

python/csrc/flashinfer_decl.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,19 @@
2424
template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \
2525
PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \
2626
CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \
27+
int32_t* q_rope_position, \
2728
paged_kv_t<PageStorage::kIndices, LAYOUT, T, int32_t> paged_kv, T* o, \
2829
float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \
2930
}
3031

31-
#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
32-
LAYOUT, ROTARY_MODE) \
33-
namespace flashinfer { \
34-
template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
35-
GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \
36-
BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \
37-
T* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, float rope_scale, \
38-
float rope_theta, cudaStream_t stream); \
32+
#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
33+
LAYOUT, ROTARY_MODE) \
34+
namespace flashinfer { \
35+
template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
36+
GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \
37+
BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \
38+
int32_t* q_rope_position, int32_t* k_rope_pos_offset, T* o, float* lse, uint32_t batch_size, \
39+
uint32_t num_kv_heads, float rope_scale, float rope_theta, cudaStream_t stream); \
3940
}
4041

4142
#define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \
@@ -56,15 +57,15 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
5657
typename IdType>
5758
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
5859
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
59-
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
60-
const uint32_t num_kv_heads, const float rope_scale, const float rope_theta,
61-
cudaStream_t stream);
60+
IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
61+
const uint32_t batch_size, const uint32_t num_kv_heads, const float rope_scale,
62+
const float rope_theta, cudaStream_t stream);
6263

6364
template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
6465
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
6566
typename DTypeOut, typename IdType>
6667
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
67-
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
68+
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
6869
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
6970
float rope_scale, float rope_theta, cudaStream_t stream);
7071

python/csrc/page.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
7373
num_heads, page_size, head_dim, batch_size, static_cast<c_type*>(kv_data.data_ptr()),
7474
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
7575
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
76-
cudaError_t status = AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
77-
static_cast<c_type*>(append_value.data_ptr()),
78-
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
76+
cudaError_t status =
77+
AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
78+
static_cast<c_type*>(append_value.data_ptr()),
79+
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
7980
TORCH_CHECK(status == cudaSuccess,
8081
"AppendPagedKVCache failed with error: ", cudaGetErrorString(status));
8182
return true;

python/csrc/pytorch_extension_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
* limitations under the License.
1515
*/
1616
#pragma once
17-
#include <torch/extension.h>
1817
#include <c10/cuda/CUDAStream.h>
18+
#include <torch/extension.h>
1919

2020
#include "generated/dispatch.inc"
2121

python/flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from .decode import (
1718
single_decode_with_kv_cache,
1819
batch_decode_with_padded_kv_cache,

python/flashinfer/cascade.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import math
1718
from typing import Optional
1819
import torch

python/flashinfer/decode.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import math
1718
from typing import Optional, Union
1819
import torch
@@ -477,9 +478,9 @@ def begin_forward(
477478
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
478479
empty_data = torch.empty(
479480
0,
480-
dtype=getattr(torch, data_type)
481-
if isinstance(data_type, str)
482-
else data_type,
481+
dtype=(
482+
getattr(torch, data_type) if isinstance(data_type, str) else data_type
483+
),
483484
)
484485
self._wrapper.begin_forward(
485486
self._workspace_buffer,

python/flashinfer/page.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import torch
1718

1819
try:

0 commit comments

Comments
 (0)