@@ -88,11 +88,12 @@ static __global__ void flash_attn_ext_f16(
88
88
constexpr int kqar = sizeof (KQ_acc_t)/sizeof (half);
89
89
90
90
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
91
- const float * Q_f = (const float *) (Q + nb02* blockIdx .y + nb01*ic0);
92
- const half * K_h = (const half *) (K + nb12*(blockIdx .y / gqa_ratio));
93
- const half * V_h = (const half *) (V + nb22*(blockIdx .y / gqa_ratio)); // K and V have same shape
94
- const half * maskh = (const half *) mask + (nb31/sizeof (half))* ic0;
95
- const half2 * mask2 = (const half2 *) mask + (nb31/sizeof (half))*(ic0/2 );
91
+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
92
+ const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
93
+ const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
94
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
95
+ const half2 * mask2 = (const half2 *) maskh;
96
+ const float * sinksf = (const float *) sinks;
96
97
97
98
const int stride_Q = nb01 / sizeof (float );
98
99
const int stride_K = nb11 / sizeof (half);
@@ -385,7 +386,54 @@ static __global__ void flash_attn_ext_f16(
385
386
386
387
__syncthreads ();
387
388
}
389
+
390
+ // Apply attention sinks
391
+ if (sinksf && blockIdx .y == 0 ) {
392
+ const float sinkf = sinksf[head];
393
+ const half sinkh = __float2half (sinkf);
388
394
395
+ #pragma unroll
396
+ for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
397
+ const int j = j0 + threadIdx .y ;
398
+
399
+ if (std::is_same<KQ_acc_t, float >::value) {
400
+ float kqmax_new = fmaxf (KQ_max_f[j0/nwarps], sinkf);
401
+
402
+ const float KQ_max_scale = expf (KQ_max_f[j0/nwarps] - kqmax_new);
403
+ KQ_max_f[j0/nwarps] = kqmax_new;
404
+
405
+ KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf (sinkf - KQ_max_f[j0/nwarps]);
406
+
407
+ const half2 scale_h2 = make_half2 (KQ_max_scale, KQ_max_scale);
408
+ #pragma unroll
409
+ for (int i0 = 0 ; i0 < D/2 ; i0 += warp_size) {
410
+ const int i = i0 + threadIdx .x ;
411
+ if (i0 + warp_size > D/2 && i >= D/2 ) break ;
412
+ VKQ2[j*(D_padded/2 ) + i] *= scale_h2;
413
+ }
414
+ } else {
415
+ half kqmax_old = __low2half (KQ_max_h2[j0/nwarps]);
416
+ half kqmax_new = fmaxf (kqmax_old, sinkh);
417
+ KQ_max_h2[j0/nwarps] = __half2half2 (kqmax_new);
418
+
419
+ const half KQ_max_scale_h = hexp (kqmax_old - kqmax_new);
420
+ const half2 KQ_max_scale = __half2half2 (KQ_max_scale_h);
421
+
422
+ KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
423
+ const half val = hexp (sinkh - kqmax_new);
424
+ KQ_rowsum_h2[j0/nwarps].x = __hadd (KQ_rowsum_h2[j0/nwarps].x , val);
425
+
426
+ #pragma unroll
427
+ for (int i0 = 0 ; i0 < D/2 ; i0 += warp_size) {
428
+ const int i = i0 + threadIdx .x ;
429
+ if (i0 + warp_size > D/2 && i >= D/2 ) break ;
430
+ VKQ2[j*(D_padded/2 ) + i] *= KQ_max_scale;
431
+ }
432
+ }
433
+ }
434
+
435
+ __syncthreads ();
436
+ }
389
437
#pragma unroll
390
438
for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
391
439
const int j_VKQ = j0 + threadIdx .y ;
0 commit comments