Skip to content

Commit

Permalink
fix a problem
Browse files Browse the repository at this point in the history
  • Loading branch information
yuguo-Jack committed Jun 12, 2024
1 parent 15e6938 commit 0e35e8f
Showing 1 changed file with 67 additions and 33 deletions.
100 changes: 67 additions & 33 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -256,32 +256,23 @@ void FlashAttnUnpaddedGradBaseKernel(
}

#ifdef PADDLE_WITH_HIP
std::initializer_list<int64_t> dk_dv_input_shape = {
total_k, num_heads, head_size};
#endif

std::initializer_list<int64_t> dk_dv_shape = {total_k, num_heads, head_size};
#else
std::initializer_list<int64_t> dk_dv_shape = {
total_k, num_heads_k, num_heads / num_heads_k, head_size};
#endif

DenseTensor *kdk = dk, *kdv = dv;
DenseTensor dk_tmp;
if (!dk || !is_mha) {
#ifdef PADDLE_WITH_HIP
dk_tmp.Resize(dk_dv_input_shape);
#else
dk_tmp.Resize(dk_dv_shape);
#endif
ctx.template Alloc<T>(&dk_tmp);
kdk = &dk_tmp;
}

DenseTensor dv_tmp;
if (!dv || !is_mha) {
#ifdef PADDLE_WITH_HIP
dv_tmp.Resize(dk_dv_input_shape);
#else
dv_tmp.Resize(dk_dv_shape);
#endif
ctx.template Alloc<T>(&dv_tmp);
kdv = &dv_tmp;
}
Expand Down Expand Up @@ -424,9 +415,19 @@ void FlashAttnUnpaddedGradBaseKernel(
#ifdef PADDLE_WITH_HIP
if (dk->meta().is_contiguous())
phi::SumKernel<T, Context>(
ctx, dk_tmp.Resize(dk_dv_shape), {2}, dk->type(), false, dk);
ctx,
dk_tmp.Resize(
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
{2},
dk->type(),
false,
dk);
else
kvReduceForGQA<T, Context>(ctx, dk_tmp.Resize(dk_dv_shape), dk);
kvReduceForGQA<T, Context>(
ctx,
dk_tmp.Resize(
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
dk);
#else
if (dk->meta().is_contiguous())
phi::SumKernel<T, Context>(ctx, dk_tmp, {2}, dk->type(), false, dk);
Expand All @@ -438,9 +439,19 @@ void FlashAttnUnpaddedGradBaseKernel(
#ifdef PADDLE_WITH_HIP
if (dv->meta().is_contiguous())
phi::SumKernel<T, Context>(
ctx, dv_tmp.Resize(dk_dv_shape), {2}, dv->type(), false, dv);
ctx,
dv_tmp.Resize(
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
{2},
dv->type(),
false,
dv);
else
kvReduceForGQA<T, Context>(ctx, dv_tmp.Resize(dk_dv_shape), dv);
kvReduceForGQA<T, Context>(
ctx,
dv_tmp.Resize(
{total_k, num_heads_k, num_heads / num_heads_k, head_size}),
dv);
#else
if (dv->meta().is_contiguous())
phi::SumKernel<T, Context>(ctx, dv_tmp, {2}, dv->type(), false, dv);
Expand Down Expand Up @@ -644,12 +655,13 @@ void FlashAttnGradBaseKernel(
bool is_mha = (num_heads == num_heads_k);

#ifdef PADDLE_WITH_HIP
std::initializer_list<int64_t> dk_dv_input_shape = {
std::initializer_list<int64_t> dk_dv_shape = {
batch_size, seqlen_k, num_heads, head_size};
#endif

#else
std::initializer_list<int64_t> dk_dv_shape = {
batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
#endif

DenseTensor* kdq = dq;
DenseTensor dq_tmp;
if (!dq) {
Expand All @@ -661,22 +673,14 @@ void FlashAttnGradBaseKernel(
DenseTensor *kdk = dk, *kdv = dv;
DenseTensor dk_tmp;
if (!dk || !is_mha) {
#ifdef PADDLE_WITH_HIP
dk_tmp.Resize(dk_dv_input_shape);
#else
dk_tmp.Resize(dk_dv_shape);
#endif
ctx.template Alloc<T>(&dk_tmp);
kdk = &dk_tmp;
}

DenseTensor dv_tmp;
if (!dv || !is_mha) {
#ifdef PADDLE_WITH_HIP
dv_tmp.Resize(dk_dv_input_shape);
#else
dv_tmp.Resize(dk_dv_shape);
#endif
ctx.template Alloc<T>(&dv_tmp);
kdv = &dv_tmp;
}
Expand Down Expand Up @@ -832,10 +836,25 @@ void FlashAttnGradBaseKernel(
if (dk) {
#ifdef PADDLE_WITH_HIP
if (dk->meta().is_contiguous())
phi::SumKernel<T, Context>(
ctx, dk_tmp.Resize(dk_dv_shape), {3}, dk->type(), false, dk);
phi::SumKernel<T, Context>(ctx,
dk_tmp.Resize({batch_size,
seqlen_k,
num_heads_k,
num_heads / num_heads_k,
head_size}),
{3},
dk->type(),
false,
dk);
else
kvReduceBatchedForGQA<T, Context>(ctx, dk_tmp.Resize(dk_dv_shape), dk);
kvReduceBatchedForGQA<T, Context>(
ctx,
dk_tmp.Resize({batch_size,
seqlen_k,
num_heads_k,
num_heads / num_heads_k,
head_size}),
dk);
#else
if (dk->meta().is_contiguous())
phi::SumKernel<T, Context>(ctx, dk_tmp, {3}, dk->type(), false, dk);
Expand All @@ -847,10 +866,25 @@ void FlashAttnGradBaseKernel(
if (dv) {
#ifdef PADDLE_WITH_HIP
if (dv->meta().is_contiguous())
phi::SumKernel<T, Context>(
ctx, dv_tmp.Resize(dk_dv_shape), {3}, dv->type(), false, dv);
phi::SumKernel<T, Context>(ctx,
dv_tmp.Resize({batch_size,
seqlen_k,
num_heads_k,
num_heads / num_heads_k,
head_size}),
{3},
dv->type(),
false,
dv);
else
kvReduceBatchedForGQA<T, Context>(ctx, dv_tmp.Resize(dk_dv_shape), dv);
kvReduceBatchedForGQA<T, Context>(
ctx,
dv_tmp.Resize({batch_size,
seqlen_k,
num_heads_k,
num_heads / num_heads_k,
head_size}),
dv);
#else
if (dv->meta().is_contiguous())
phi::SumKernel<T, Context>(ctx, dv_tmp, {3}, dv->type(), false, dv);
Expand Down

0 comments on commit 0e35e8f

Please sign in to comment.