Skip to content

Commit 0372acc

Browse files
authored
feat: enable head_dim=256 for attention kernels (#132)
As mentioned in #130 , the kernels for `head_dim=256` are not compiled by default, this PR expose these attention kernels to pip wheels and adds unittests/benchmarks for `head_dim=256`.
1 parent a346b27 commit 0372acc

19 files changed

+415
-297
lines changed

include/flashinfer/handler.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ class BatchPrefillHandler {
187187

188188
template <typename IdType>
189189
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
190-
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads) {
190+
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
191+
uint32_t head_dim) {
191192
if (num_qo_heads % num_kv_heads != 0) {
192193
std::ostringstream err_msg;
193194
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
@@ -197,7 +198,7 @@ class BatchPrefillHandler {
197198
uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
198199
std::vector<IdType> request_indices_h, tile_indices_h;
199200
std::tie(num_frags_x_, num_qo_tiles_, request_indices_h, tile_indices_h) =
200-
split_qo_indptr(qo_indptr, batch_size, gqa_group_size, stream_);
201+
split_qo_indptr(qo_indptr, batch_size, gqa_group_size, head_dim, stream_);
201202
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
202203
request_indices_ =
203204
allocator.aligned_alloc<void*>(sizeof(IdType) * request_indices_h.size(), 16);

include/flashinfer/permuted_smem.cuh

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,25 @@ struct smem_t {
6363
template <uint32_t step_size>
6464
static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset,
6565
uint32_t step_idx) {
66+
static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, "Unsupported step size");
6667
if constexpr (step_size == 2) {
6768
return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8;
6869
} else if constexpr (step_size == 4) {
6970
return (offset ^ 0x4) + (step_idx % 2 == 1) * 8;
70-
} else if constexpr (step_size % 8 == 0) {
71-
return offset + step_size;
7271
} else {
73-
// Note(Zihao): not implemented yet.
74-
return 0;
72+
// step_size % 8 == 0
73+
return offset + step_size;
7574
}
7675
}
7776

7877
template <uint32_t step_size, uint32_t row_stride>
7978
static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset) {
79+
static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size");
8080
if constexpr (step_size == 4) {
8181
return (offset ^ 0x4) + step_size * row_stride;
82-
} else if constexpr (step_size % 8 == 0) {
83-
return offset + step_size * row_stride;
8482
} else {
85-
// NOTE(Zihao): not implemented yet.
86-
return 0;
83+
// step_size % 8 == 0
84+
return offset + step_size * row_stride;
8785
}
8886
}
8987

include/flashinfer/prefill.cuh

Lines changed: 263 additions & 169 deletions
Large diffs are not rendered by default.

include/flashinfer/utils.cuh

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -81,40 +81,49 @@
8181
__VA_ARGS__ \
8282
}
8383

84-
#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
85-
if (num_frags_x == 1) { \
86-
constexpr size_t NUM_FRAGS_X = 1; \
87-
__VA_ARGS__ \
88-
} else if (num_frags_x == 2) { \
89-
constexpr size_t NUM_FRAGS_X = 2; \
90-
__VA_ARGS__ \
91-
} else { \
92-
std::cerr << "Unsupported num_frags_x: " << num_frags_x << std::endl; \
84+
#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
85+
if (num_frags_x == 1) { \
86+
constexpr size_t NUM_FRAGS_X = 1; \
87+
__VA_ARGS__ \
88+
} else if (num_frags_x == 2) { \
89+
constexpr size_t NUM_FRAGS_X = 2; \
90+
__VA_ARGS__ \
91+
} else { \
92+
std::ostringstream err_msg; \
93+
err_msg << "Unsupported num_frags_x: " << num_frags_x; \
94+
throw std::invalid_argument(err_msg.str()); \
9395
}
9496

95-
#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
96-
if (max_frags_z == 4) { \
97-
constexpr size_t NUM_FRAGS_Z = 4; \
98-
__VA_ARGS__ \
99-
} else if (max_frags_z == 2) { \
100-
constexpr size_t NUM_FRAGS_Z = 2; \
101-
__VA_ARGS__ \
102-
} else { \
103-
std::cerr << "Unsupported max_frags_z: " << max_frags_z << std::endl; \
97+
#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
98+
if (max_frags_z >= 4) { \
99+
constexpr size_t NUM_FRAGS_Z = 4; \
100+
__VA_ARGS__ \
101+
} else if (max_frags_z >= 2) { \
102+
constexpr size_t NUM_FRAGS_Z = 2; \
103+
__VA_ARGS__ \
104+
} else if (max_frags_z >= 1) { \
105+
constexpr size_t NUM_FRAGS_Z = 1; \
106+
__VA_ARGS__ \
107+
} else { \
108+
std::ostringstream err_msg; \
109+
err_msg << "Unsupported max_frags_z: " << max_frags_z; \
110+
throw std::invalid_argument(err_msg.str()); \
104111
}
105112

106-
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
107-
if (group_size == 1) { \
108-
constexpr size_t GROUP_SIZE = 1; \
109-
__VA_ARGS__ \
110-
} else if (group_size == 4) { \
111-
constexpr size_t GROUP_SIZE = 4; \
112-
__VA_ARGS__ \
113-
} else if (group_size == 8) { \
114-
constexpr size_t GROUP_SIZE = 8; \
115-
__VA_ARGS__ \
116-
} else { \
117-
std::cerr << "Unsupported group_size: " << group_size << std::endl; \
113+
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
114+
if (group_size == 1) { \
115+
constexpr size_t GROUP_SIZE = 1; \
116+
__VA_ARGS__ \
117+
} else if (group_size == 4) { \
118+
constexpr size_t GROUP_SIZE = 4; \
119+
__VA_ARGS__ \
120+
} else if (group_size == 8) { \
121+
constexpr size_t GROUP_SIZE = 8; \
122+
__VA_ARGS__ \
123+
} else { \
124+
std::ostringstream err_msg; \
125+
err_msg << "Unsupported group_size: " << group_size; \
126+
throw std::invalid_argument(err_msg.str()); \
118127
}
119128

120129
#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \
@@ -169,25 +178,6 @@
169178
} \
170179
}
171180

172-
#define DISPATCH_HEAD_DIM_PREFILL(head_dim, HEAD_DIM, ...) \
173-
switch (head_dim) { \
174-
case 64: { \
175-
constexpr size_t HEAD_DIM = 64; \
176-
__VA_ARGS__ \
177-
break; \
178-
} \
179-
case 128: { \
180-
constexpr size_t HEAD_DIM = 128; \
181-
__VA_ARGS__ \
182-
break; \
183-
} \
184-
default: { \
185-
std::ostringstream err_msg; \
186-
err_msg << "Unsupported head_dim: " << head_dim; \
187-
throw std::invalid_argument(err_msg.str()); \
188-
} \
189-
}
190-
191181
#define DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \
192182
switch (rotary_mode) { \
193183
case RotaryMode::kNone: { \
@@ -222,7 +212,7 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
222212

223213
template <typename IdType>
224214
std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_indptr(
225-
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size,
215+
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim,
226216
cudaStream_t stream = nullptr) {
227217
constexpr uint32_t num_warps = 4;
228218
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices, tile_indices;
@@ -235,7 +225,7 @@ std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_in
235225

236226
const uint32_t total_q_len = qo_indptr_h[batch_size];
237227
const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size;
238-
const uint32_t num_frags_x = avg_len_greater_than_64 ? 2 : 1;
228+
const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1;
239229
const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16;
240230
uint32_t num_qo_tiles = 0;
241231

include/flashinfer/wrapper.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
122122
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
123123
typename IdType>
124124
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
125-
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
125+
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
126126
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
127127
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
128128
bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4,
@@ -142,8 +142,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
142142
return BatchPrefillWithPagedKVCacheWrapperDispatched<
143143
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
144144
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
145-
handler, q, qo_indptr, paged_kv, o, lse, rope_scale, rope_theta,
146-
stream);
145+
handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, rope_scale,
146+
rope_theta, stream);
147147
})})})})});
148148
return cudaSuccess;
149149
}

python/csrc/batch_prefill.cu

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919

2020
using namespace flashinfer;
2121

22-
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer,
23-
torch::Tensor qo_indptr,
24-
unsigned int batch_size,
25-
unsigned int num_qo_heads,
26-
unsigned int num_kv_heads) {
22+
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
23+
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
24+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
2725
// NOTE(Zihao): not necessary to be a CUDA tensor
2826
CHECK_CONTIGUOUS(qo_indptr);
2927
CHECK_CONTIGUOUS(workspace_buffer);
@@ -37,9 +35,10 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor work
3735
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
3836
handler_.SetCUDAStream(torch_current_stream);
3937

40-
cudaError_t status = handler_.BeginForward(
41-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
42-
static_cast<int32_t*>(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads);
38+
cudaError_t status =
39+
handler_.BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
40+
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
41+
batch_size, num_qo_heads, num_kv_heads, head_dim);
4342
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
4443
cudaGetErrorString(status));
4544
}
@@ -140,11 +139,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
140139
}
141140
}
142141

143-
void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer,
144-
torch::Tensor qo_indptr,
145-
unsigned int batch_size,
146-
unsigned int num_qo_heads,
147-
unsigned int num_kv_heads) {
142+
void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
143+
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
144+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
148145
// NOTE(Zihao): not necessary to be a CUDA tensor
149146
CHECK_CONTIGUOUS(qo_indptr);
150147
CHECK_CONTIGUOUS(workspace_buffer);
@@ -158,9 +155,10 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor wor
158155
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
159156
handler_.SetCUDAStream(torch_current_stream);
160157

161-
cudaError_t status = handler_.BeginForward(
162-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
163-
static_cast<int32_t*>(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads);
158+
cudaError_t status =
159+
handler_.BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
160+
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
161+
batch_size, num_qo_heads, num_kv_heads, head_dim);
164162
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
165163
cudaGetErrorString(status));
166164
}

python/csrc/flashinfer_ops.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
7979
return BatchPrefillWithPagedKVCachePyTorchWrapper(layout);
8080
}
8181
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
82-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
82+
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
83+
unsigned int head_dim);
8384
void EndForward();
8485
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
8586
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
@@ -101,7 +102,8 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
101102
return BatchPrefillWithRaggedKVCachePyTorchWrapper(layout);
102103
}
103104
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
104-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
105+
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
106+
unsigned int head_dim);
105107
void EndForward();
106108
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
107109
torch::Tensor v, torch::Tensor kv_indptr, bool causal,

python/flashinfer/cascade.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,8 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
578578
... paged_kv_indices,
579579
... paged_kv_last_page_len,
580580
... num_qo_heads,
581-
... num_kv_heads
581+
... num_kv_heads,
582+
... head_dim,
582583
... )
583584
>>> outputs = []
584585
>>> for i in range(num_layers):
@@ -641,6 +642,7 @@ def begin_forward(
641642
paged_kv_last_page_len: torch.Tensor,
642643
num_qo_heads: int,
643644
num_kv_heads: int,
645+
head_dim: int,
644646
):
645647
r"""Create auxiliary data structures for shared-prefix batch prefill/append
646648
attention for multiple forward calls within the same prefill/append step.
@@ -660,6 +662,8 @@ def begin_forward(
660662
The number of query/output heads.
661663
num_kv_heads : int
662664
The number of key/value heads.
665+
head_dim : int
666+
The dimension of the heads.
663667
664668
Notes
665669
-----
@@ -679,6 +683,7 @@ def begin_forward(
679683
paged_kv_last_page_len,
680684
num_qo_heads,
681685
num_kv_heads,
686+
head_dim,
682687
)
683688

684689
def end_forward(self):

python/flashinfer/prefill.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ class BatchPrefillWithPagedKVCacheWrapper:
298298
... paged_kv_indices,
299299
... paged_kv_last_page_len,
300300
... num_qo_heads,
301-
... num_kv_heads
301+
... num_kv_heads,
302+
... head_dim
302303
... )
303304
>>> outputs = []
304305
>>> for i in range(num_layers):
@@ -365,6 +366,7 @@ def begin_forward(
365366
paged_kv_last_page_len: torch.Tensor,
366367
num_qo_heads: int,
367368
num_kv_heads: int,
369+
head_dim: int,
368370
):
369371
r"""Create auxiliary data structures for batch prefill/append attention for
370372
multiple forward calls within the same prefill/append step.
@@ -384,6 +386,8 @@ def begin_forward(
384386
The number of query/output heads.
385387
num_kv_heads : int
386388
The number of key/value heads.
389+
head_dim : int
390+
The dimension of the heads.
387391
388392
Notes
389393
-----
@@ -401,7 +405,12 @@ def begin_forward(
401405
self._paged_kv_indices = paged_kv_indices
402406
self._paged_kv_last_page_len = paged_kv_last_page_len
403407
self._wrapper.begin_forward(
404-
self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads
408+
self._workspace_buffer,
409+
qo_indptr,
410+
batch_size,
411+
num_qo_heads,
412+
num_kv_heads,
413+
head_dim,
405414
)
406415

407416
def end_forward(self):
@@ -571,7 +580,8 @@ class BatchPrefillWithRaggedKVCacheWrapper:
571580
... qo_indptr,
572581
... kv_indptr,
573582
... num_qo_heads,
574-
... num_kv_heads
583+
... num_kv_heads,
584+
... head_dim
575585
... )
576586
>>> outputs = []
577587
>>> for i in range(num_layers):
@@ -635,6 +645,7 @@ def begin_forward(
635645
kv_indptr: torch.Tensor,
636646
num_qo_heads: int,
637647
num_kv_heads: int,
648+
head_dim: int,
638649
):
639650
r"""Create auxiliary data structures for batch prefill/append attention for
640651
multiple forward calls within the same prefill/append step.
@@ -649,6 +660,8 @@ def begin_forward(
649660
The number of query/output heads.
650661
num_kv_heads : int
651662
The number of key/value heads.
663+
head_dim : int
664+
The dimension of the heads.
652665
653666
Notes
654667
-----
@@ -664,7 +677,12 @@ def begin_forward(
664677
self._qo_indptr = qo_indptr
665678
self._kv_indptr = kv_indptr
666679
self._wrapper.begin_forward(
667-
self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads
680+
self._workspace_buffer,
681+
qo_indptr,
682+
batch_size,
683+
num_qo_heads,
684+
num_kv_heads,
685+
head_dim,
668686
)
669687

670688
def end_forward(self):

0 commit comments

Comments
 (0)