@@ -975,26 +975,6 @@ static __global__ void flash_attn_tile(
975975 }
976976 }
977977
978- if (gridDim .y == 1 ) {
979- #pragma unroll
980- for (int jc0 = 0 ; jc0 < cpw; ++jc0) {
981- #ifdef FAST_FP16_AVAILABLE
982- const half2 KQ_sum_jc_inv = make_half2 (1 .0f /KQ_sum[jc0], 1 .0f /KQ_sum[jc0]);
983- #pragma unroll
984- for (int i = 0 ; i < (DVp/2 )/warp_size; ++i) {
985- VKQ[jc0*((DVp/2 )/warp_size) + i] *= KQ_sum_jc_inv;
986- }
987- #else
988- const float KQ_sum_jc_inv = 1 .0f /KQ_sum[jc0];
989- #pragma unroll
990- for (int i = 0 ; i < (DVp/2 )/warp_size; ++i) {
991- VKQ[jc0*((DVp/2 )/warp_size) + i].x *= KQ_sum_jc_inv;
992- VKQ[jc0*((DVp/2 )/warp_size) + i].y *= KQ_sum_jc_inv;
993- }
994- #endif // FAST_FP16_AVAILABLE
995- }
996- }
997-
998978 // Write back results:
999979#pragma unroll
1000980 for (int jc0 = 0 ; jc0 < cpw; ++jc0) {
@@ -1007,6 +987,8 @@ static __global__ void flash_attn_tile(
1007987 return ;
1008988 }
1009989
990+ const float scale = gridDim .y == 1 ? 1 .0f /KQ_sum[jc0] : 1 .0f ;
991+
1010992 const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim .y + blockIdx .y ;
1011993
1012994#ifdef FAST_FP16_AVAILABLE
@@ -1017,6 +999,8 @@ static __global__ void flash_attn_tile(
1017999#pragma unroll
10181000 for (int i1 = 0 ; i1 < cpy_ne_D; ++i1) {
10191001 tmp[i1] = __half22float2 (VKQ[jc0*((DVp/2 )/warp_size) + i0/warp_size + i1]);
1002+ tmp[i1].x *= scale;
1003+ tmp[i1].y *= scale;
10201004 }
10211005 if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx .x *cpy_ne_D < DV/2 ) {
10221006 ggml_cuda_memcpy_1<sizeof (tmp)>(&dst[j_dst_unrolled*DV + 2 *i0 + threadIdx .x *(2 *cpy_ne_D)], tmp);
@@ -1027,6 +1011,11 @@ static __global__ void flash_attn_tile(
10271011#pragma unroll
10281012 for (int i0 = 0 ; i0 < DVp; i0 += warp_size*cpy_ne_D) {
10291013 if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx .x *cpy_ne_D < DV) {
1014+ #pragma unroll
1015+ for (int i1 = 0 ; i1 < cpy_ne_D/2 ; ++i1) {
1016+ VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size) + i1].x *= scale;
1017+ VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size) + i1].y *= scale;
1018+ }
10301019 ggml_cuda_memcpy_1<cpy_ne_D*4 >(
10311020 &dst[j_dst_unrolled*DV + i0 + threadIdx .x *cpy_ne_D],
10321021 &VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size)]);
0 commit comments