Skip to content

Commit d8d99d1

Browse files
CUDA: fix numerical issue in tile FA kernel
1 parent d4d465b commit d8d99d1

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -975,26 +975,6 @@ static __global__ void flash_attn_tile(
975975
}
976976
}
977977

978-
if (gridDim.y == 1) {
979-
#pragma unroll
980-
for (int jc0 = 0; jc0 < cpw; ++jc0) {
981-
#ifdef FAST_FP16_AVAILABLE
982-
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
983-
#pragma unroll
984-
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
985-
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
986-
}
987-
#else
988-
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
989-
#pragma unroll
990-
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
991-
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
992-
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
993-
}
994-
#endif // FAST_FP16_AVAILABLE
995-
}
996-
}
997-
998978
// Write back results:
999979
#pragma unroll
1000980
for (int jc0 = 0; jc0 < cpw; ++jc0) {
@@ -1007,6 +987,8 @@ static __global__ void flash_attn_tile(
1007987
return;
1008988
}
1009989

990+
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
991+
1010992
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
1011993

1012994
#ifdef FAST_FP16_AVAILABLE
@@ -1017,6 +999,8 @@ static __global__ void flash_attn_tile(
1017999
#pragma unroll
10181000
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
10191001
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1002+
tmp[i1].x *= scale;
1003+
tmp[i1].y *= scale;
10201004
}
10211005
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
10221006
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@@ -1027,6 +1011,11 @@ static __global__ void flash_attn_tile(
10271011
#pragma unroll
10281012
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
10291013
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1014+
#pragma unroll
1015+
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1016+
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1017+
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1018+
}
10301019
ggml_cuda_memcpy_1<cpy_ne_D*4>(
10311020
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
10321021
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);

0 commit comments

Comments
 (0)