Skip to content

block_multihead_attention support V100 GQA #68104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions paddle/phi/kernels/fusion/gpu/block_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -4038,10 +4038,10 @@ __global__ void fusedQKV_transpose_split_kernel(T *q_buf,
const int seq_len,
const int pre_cache_length,
const int token_num,
const int head_num,
const int q_head_num,
const int kv_head_num,
const int size_per_head) {
const int32_t hidden_size = head_num * size_per_head;
const int32_t fused_hidden_size = 3 * hidden_size;
const int fused_hidden_size = (q_head_num + 2 * kv_head_num) * size_per_head;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
Expand All @@ -4059,22 +4059,33 @@ __global__ void fusedQKV_transpose_split_kernel(T *q_buf,
if (seq_lens[target_batch_id] == 0) continue;
const int32_t seq_id = ori_token_idx % seq_len;

const int32_t qkv_id = bias_idx / hidden_size;
const int32_t head_id = (linear_index % hidden_size) / size_per_head;
const int32_t head_id = bias_idx / size_per_head;
const int32_t size_id = linear_index % size_per_head;

const int tmp_max_len_this_time =
max_len_this_time + (qkv_id == 0 ? 0 : pre_cache_length);
const int tmp_seq_id = qkv_id == 0 ? seq_id : seq_id + pre_cache_length;
const int write_idx =
target_batch_id * head_num * tmp_max_len_this_time * size_per_head +
head_id * tmp_max_len_this_time * size_per_head +
tmp_seq_id * size_per_head + size_id;
if (qkv_id == 0) {
max_len_this_time + (head_id < q_head_num ? 0 : pre_cache_length);
const int tmp_seq_id =
head_id < q_head_num ? seq_id : seq_id + pre_cache_length;

if (head_id < q_head_num) {
const int write_idx =
target_batch_id * q_head_num * tmp_max_len_this_time * size_per_head +
head_id * tmp_max_len_this_time * size_per_head +
tmp_seq_id * size_per_head + size_id;
phi::Store<T, VecSize>(src_vec, &q_buf[write_idx]);
} else if (qkv_id == 1) {
} else if (head_id < q_head_num + kv_head_num) {
const int write_idx =
target_batch_id * kv_head_num * tmp_max_len_this_time *
size_per_head +
(head_id - q_head_num) * tmp_max_len_this_time * size_per_head +
tmp_seq_id * size_per_head + size_id;
phi::Store<T, VecSize>(src_vec, &k_buf[write_idx]);
} else {
const int write_idx = target_batch_id * kv_head_num *
tmp_max_len_this_time * size_per_head +
(head_id - q_head_num - kv_head_num) *
tmp_max_len_this_time * size_per_head +
tmp_seq_id * size_per_head + size_id;
phi::Store<T, VecSize>(src_vec, &v_buf[write_idx]);
}
}
Expand All @@ -4093,12 +4104,14 @@ void qkv_transpose_split(
const int *seq_lens,
const int token_num,
const int batch_size,
const int head_num,
const int q_head_num,
const int kv_head_num,
const int max_len_this_time,
const int seq_len,
const int pre_cache_length,
const int size_per_head) {
int32_t elem_cnt = token_num * head_num * size_per_head * 3;
int32_t elem_cnt = token_num * (q_head_num + kv_head_num * 2) * size_per_head;

constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(size_per_head % PackSize,
0,
Expand All @@ -4123,11 +4136,12 @@ void qkv_transpose_split(
seq_len,
pre_cache_length,
token_num,
head_num,
q_head_num,
kv_head_num,
size_per_head);
if (pre_key_cache) {
// stage 2: write pre_cache to kv_buf
elem_cnt = batch_size * head_num * pre_cache_length * size_per_head * 2;
elem_cnt = batch_size * q_head_num * pre_cache_length * size_per_head * 2;
pack_num = elem_cnt / PackSize;
GetNumBlocks(pack_num, &grid_size);
write_pre_cahe_to_kv_buffer<T, PackSize>
Expand All @@ -4138,7 +4152,7 @@ void qkv_transpose_split(
seq_lens,
batch_size,
pre_cache_length,
head_num,
kv_head_num,
size_per_head,
max_len_this_time,
elem_cnt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,7 @@ void DispatchWithDtype(
// Reshape fmha_buf back (to 2-D), to not affect following codes.
fmha_buf.Resize(fmha_shape);
} else {
// NOTE: not support gqa
if (sm < 80 && !use_pre_cache) {
if (q_num_head != kv_num_head) {
PADDLE_THROW(common::errors::Unimplemented(
"Only supported MHA on Volta/Turing(sm < 80) now."));
}
qkv_transpose_split<T>(
dev_ctx,
q_trans.data<T>(),
Expand All @@ -589,6 +584,7 @@ void DispatchWithDtype(
token_num,
bsz,
q_num_head,
kv_num_head,
max_enc_len_this_time_data,
max_seq_len,
pre_cache_length,
Expand All @@ -607,6 +603,7 @@ void DispatchWithDtype(
token_num,
bsz,
q_num_head,
kv_num_head,
max_enc_len_this_time_data,
max_seq_len,
pre_cache_length,
Expand Down
16 changes: 1 addition & 15 deletions test/legacy_test/test_block_multihead_attention_gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
create_attn_mask,
get_cuda_version,
get_padding_offset,
is_sm_supported,
remove_padding,
)

Expand All @@ -34,21 +35,6 @@
np.random.seed(2024)


is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
and paddle.device.cuda.get_device_capability()[1] >= 0
)

is_sm9x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 9
and paddle.device.cuda.get_device_capability()[1] >= 0
)

is_sm_supported = is_sm9x or is_sm8x


def naive_attention_impl(
query,
key,
Expand Down