@@ -8,14 +8,11 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
88 if (GGML_CUDA_CC_IS_AMD (cc)) {
99 switch (D) {
1010 case 64 :
11- return 64 ;
11+ return ncols <= 16 ? 32 : 64 ;
1212 case 128 :
13+ return ncols <= 16 ? 64 : warp_size;
1314 case 256 :
14- if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
15- return ncols <= 16 ? 64 : 32 ;
16- } else {
17- return 64 ;
18- }
15+ return 64 ;
1916 default :
2017 GGML_ABORT (" fatal error" );
2118 return -1 ;
@@ -44,26 +41,17 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
4441 GGML_ABORT (" fatal error" );
4542 return -1 ;
4643 }
47- GGML_UNUSED (warp_size);
4844}
4945
5046static constexpr __device__ int fattn_tile_get_kq_stride_device (int D, int ncols, int warp_size) {
5147#ifdef GGML_USE_HIP
5248 switch (D) {
5349 case 64 :
54- return 64 ;
50+ return ncols <= 16 ? 32 : 64 ;
5551 case 128 :
56- #if defined(GCN) || defined(CDNA)
57- return ncols <= 16 ? 64 : 32 ;
58- #else
59- return 64 ;
60- #endif // defined(GCN) || defined(CDNA)
52+ return ncols <= 16 ? 64 : warp_size;
6153 case 256 :
62- #if defined(GCN) || defined(CDNA)
63- return ncols <= 16 ? 64 : 32 ;
64- #else
6554 return 64 ;
66- #endif // defined(GCN) || defined(CDNA)
6755 default :
6856 return -1 ;
6957 }
@@ -100,17 +88,9 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
10088 case 64 :
10189 return 64 ;
10290 case 128 :
103- #if defined(GCN) || defined(CDNA)
104- return ncols <= 16 ? 64 : 128 ;
105- #else
106- return 64 ;
107- #endif // defined(GCN) || defined(CDNA)
91+ return ncols <= 16 ? 2 *warp_size : 128 ;
10892 case 256 :
109- #if defined(GCN) || defined(CDNA)
110- return ncols <= 16 ? 64 : 128 ;
111- #else
112- return ncols <= 16 ? 64 : 256 ;
113- #endif // defined(GCN) || defined(CDNA)
93+ return ncols <= 16 ? 128 : 2 *warp_size;
11494 default :
11595 return -1 ;
11696 }
@@ -216,21 +196,14 @@ static __global__ void flash_attn_tile(
216196
217197 const float slope = get_alibi_slope (max_bias, head, n_head_log2, m0, m1);
218198
219- #if defined(GGML_USE_HIP)
220- constexpr int cpy_nb = 16 ;
221- #else
222- constexpr int cpy_nb = 8 ;
223- #endif // defined(GGML_USE_HIP) && defined(GCN)
224- constexpr int cpy_ne = cpy_nb / 4 ;
225-
226199 __shared__ float KQ[ncols][kq_stride];
227200#ifdef FAST_FP16_AVAILABLE
228201 __shared__ half2 Q_tmp[ncols][D/2 ];
229- __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne )]; // Padded to avoid memory bank conflicts.
202+ __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1 )]; // Padded to avoid memory bank conflicts.
230203 half2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
231204#else
232205 __shared__ float Q_tmp[ncols][D];
233- __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne )]; // Padded to avoid memory bank conflicts.
206+ __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1 )]; // Padded to avoid memory bank conflicts.
234207 float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
235208 float2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
236209#endif // FAST_FP16_AVAILABLE
@@ -283,11 +256,11 @@ static __global__ void flash_attn_tile(
283256 for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += warp_size) {
284257 const half2 tmp_h2 = K_h2[int64_t (k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx .x ];
285258#ifdef FAST_FP16_AVAILABLE
286- KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
259+ KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
287260#else
288261 const float2 tmp_f2 = __half22float2 (tmp_h2);
289- KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
290- KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
262+ KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
263+ KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
291264#endif // FAST_FP16_AVAILABLE
292265 }
293266 }
@@ -296,45 +269,42 @@ static __global__ void flash_attn_tile(
296269
297270#ifdef FAST_FP16_AVAILABLE
298271#pragma unroll
299- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += cpy_ne ) {
300- half2 K_k[kq_stride/warp_size][cpy_ne] ;
301- half2 Q_k[ncols/nwarps][cpy_ne] ;
272+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; ++k_KQ_1 ) {
273+ half2 K_k[kq_stride/warp_size];
274+ half2 Q_k[ncols/nwarps];
302275#else
303276#pragma unroll
304- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne ) {
305- float K_k[kq_stride/warp_size][cpy_ne] ;
306- float Q_k[ncols/nwarps][cpy_ne] ;
277+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; ++k_KQ_1 ) {
278+ float K_k[kq_stride/warp_size];
279+ float Q_k[ncols/nwarps];
307280#endif // FAST_FP16_AVAILABLE
308281
309282#pragma unroll
310283 for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
311284 const int i_KQ = i_KQ_0 + threadIdx .x ;
312285
313286#ifdef FAST_FP16_AVAILABLE
314- ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1]) ;
287+ K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1];
315288#else
316- ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne ) + k_KQ_1]) ;
289+ K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1 ) + k_KQ_1];
317290#endif // FAST_FP16_AVAILABLE
318291 }
319292#pragma unroll
320293 for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
321294 const int j_KQ = j_KQ_0 + threadIdx .y ;
322295
323296#ifdef FAST_FP16_AVAILABLE
324- ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]) ;
297+ Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
325298#else
326- ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]) ;
299+ Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
327300#endif // FAST_FP16_AVAILABLE
328301 }
329302
330303#pragma unroll
331304 for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
332305#pragma unroll
333306 for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
334- #pragma unroll
335- for (int k = 0 ; k < cpy_ne; ++k) {
336- ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
337- }
307+ ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
338308 }
339309 }
340310 }
@@ -375,54 +345,14 @@ static __global__ void flash_attn_tile(
375345 kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
376346
377347 float kqsum_add = 0 .0f ;
378- if (kq_stride % (4 *warp_size) == 0 && cpy_ne % 4 == 0 ) {
379348#pragma unroll
380- for (int i0 = 0 ; i0 < kq_stride; i0 += 4 * warp_size) {
381- const int i = i0 + 4 * threadIdx .x ;
349+ for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
350+ const int i = i0 + threadIdx .x ;
382351
383- float4 val = *(const float4 *) &KQ[j][i];
384- val.x = expf (val.x - kqmax[j0/nwarps]);
385- val.y = expf (val.y - kqmax[j0/nwarps]);
386- val.z = expf (val.z - kqmax[j0/nwarps]);
387- val.w = expf (val.w - kqmax[j0/nwarps]);
388- kqsum_add += val.x + val.y + val.z + val.w ;
389-
390- #ifdef FAST_FP16_AVAILABLE
391- const half2 tmp[2 ] = {make_half2 (val.x , val.y ), make_half2 (val.z , val.w )};
392- ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
393- #else
394- ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
395- #endif // FAST_FP16_AVAILABLE
396- }
397- } else if (kq_stride % (2 *warp_size) == 0 && cpy_ne % 2 == 0 ) {
398- #pragma unroll
399- for (int i0 = 0 ; i0 < kq_stride; i0 += 2 *warp_size) {
400- const int i = i0 + 2 *threadIdx .x ;
401-
402- float2 val = *(const float2 *) &KQ[j][i];
403- val.x = expf (val.x - kqmax[j0/nwarps]);
404- val.y = expf (val.y - kqmax[j0/nwarps]);
405- kqsum_add += val.x + val.y ;
406- #ifdef FAST_FP16_AVAILABLE
407- const half2 tmp = make_half2 (val.x , val.y );
408- ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
409- #else
410- ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
411- #endif // FAST_FP16_AVAILABLE
412- }
413- } else {
414- for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
415- const int i = i0 + threadIdx .x ;
416-
417- const float diff = KQ[j][i] - kqmax[j0/nwarps];
418- const float val = expf (diff);
419- kqsum_add += val;
420- #ifdef FAST_FP16_AVAILABLE
421- ((half *) KQ[j])[i] = val;
422- #else
423- KQ[j][i] = val;
424- #endif // FAST_FP16_AVAILABLE
425- }
352+ const float diff = KQ[j][i] - kqmax[j0/nwarps];
353+ const float val = expf (diff);
354+ kqsum_add += val;
355+ KQ[j][i] = val;
426356 }
427357 kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
428358
@@ -489,7 +419,8 @@ static __global__ void flash_attn_tile(
489419 const int j = j0 + threadIdx .y ;
490420
491421#ifdef FAST_FP16_AVAILABLE
492- KQ_k[j0/nwarps] = __half2half2 (((const half *)KQ[j])[k0 + k1]);
422+ const float tmp = KQ[j][k0 + k1];
423+ KQ_k[j0/nwarps] = make_half2 (tmp, tmp);
493424#else
494425 KQ_k[j0/nwarps] = KQ[j][k0 + k1];
495426#endif // FAST_FP16_AVAILABLE
0 commit comments