Skip to content

Commit f71ef6b

Browse files
committed
CUDA: add attention sinks for tile and wmma
Port of ggml-org/llama.cpp#15178
1 parent 8ae6cc3 commit f71ef6b

File tree

4 files changed

+115
-15
lines changed

4 files changed

+115
-15
lines changed

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ static __global__ void flash_attn_tile_ext_f16(
6464
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
6565

6666
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
67-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
68-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
69-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
70-
const half * maskh = (const half *) mask + ne11*ic0;
67+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
68+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
69+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
70+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
71+
const float * sinksf = (const float *) (sinks);
7172

7273
const int stride_KV2 = nb11 / sizeof(half2);
7374

@@ -257,6 +258,31 @@ static __global__ void flash_attn_tile_ext_f16(
257258
__syncthreads();
258259
}
259260

261+
//Attention sink: adjust running max and sum once per head
262+
if (sinksf && blockIdx.y == 0) {
263+
const half sink = __float2half(sinksf[head]);
264+
265+
#pragma unroll
266+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
267+
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
268+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
269+
270+
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
271+
kqmax[j0/nwarps] = kqmax_new_j;
272+
273+
const half val = hexp(sink - kqmax[j0/nwarps]);
274+
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
275+
if (threadIdx.x == 0) {
276+
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
277+
}
278+
279+
#pragma unroll
280+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
281+
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
282+
}
283+
}
284+
}
285+
260286
#pragma unroll
261287
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
262288
const int j_VKQ = j_VKQ_0 + threadIdx.y;

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ static __global__ void flash_attn_tile_ext_f32(
6464
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
6565

6666
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
67-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
68-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
69-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
70-
const half * maskh = (const half *) mask + ne11*ic0;
67+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
68+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
69+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
70+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
71+
const float * sinksf = (const float *) (sinks);
7172

7273
const int stride_K2 = nb11 / sizeof(half2);
7374
const int stride_V2 = nb12 / sizeof(half2);
@@ -262,6 +263,32 @@ static __global__ void flash_attn_tile_ext_f32(
262263
__syncthreads();
263264
}
264265

266+
//Attention sink: adjust running max and sum once per head
267+
if (sinksf && blockIdx.y == 0) {
268+
const float sink = sinksf[head];
269+
270+
#pragma unroll
271+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
272+
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
273+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
274+
275+
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
276+
kqmax[j0/nwarps] = kqmax_new_j;
277+
278+
const float val = expf(sink - kqmax[j0/nwarps]);
279+
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
280+
if (threadIdx.x == 0) {
281+
kqsum[j0/nwarps] += val;
282+
}
283+
284+
#pragma unroll
285+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
286+
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
287+
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
288+
}
289+
}
290+
}
291+
265292
#pragma unroll
266293
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
267294
const int j_VKQ = j_VKQ_0 + threadIdx.y;

ggml/src/ggml-cuda/fattn-wmma-f16.cuh

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,12 @@ static __global__ void flash_attn_ext_f16(
8888
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
8989

9090
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
91-
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
92-
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
93-
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
94-
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
95-
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
91+
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
92+
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
93+
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
94+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
95+
const half2 * mask2 = (const half2 *) maskh;
96+
const float * sinksf = (const float *) sinks;
9697

9798
const int stride_Q = nb01 / sizeof(float);
9899
const int stride_K = nb11 / sizeof(half);
@@ -385,7 +386,54 @@ static __global__ void flash_attn_ext_f16(
385386

386387
__syncthreads();
387388
}
389+
390+
// Apply attention sinks
391+
if (sinksf && blockIdx.y == 0) {
392+
const float sinkf = sinksf[head];
393+
const half sinkh = __float2half(sinkf);
388394

395+
#pragma unroll
396+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
397+
const int j = j0 + threadIdx.y;
398+
399+
if (std::is_same<KQ_acc_t, float>::value) {
400+
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
401+
402+
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
403+
KQ_max_f[j0/nwarps] = kqmax_new;
404+
405+
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
406+
407+
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
408+
#pragma unroll
409+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
410+
const int i = i0 + threadIdx.x;
411+
if (i0 + warp_size > D/2 && i >= D/2) break;
412+
VKQ2[j*(D_padded/2) + i] *= scale_h2;
413+
}
414+
} else {
415+
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
416+
half kqmax_new = fmaxf(kqmax_old, sinkh);
417+
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
418+
419+
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
420+
const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
421+
422+
KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
423+
const half val = hexp(sinkh - kqmax_new);
424+
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
425+
426+
#pragma unroll
427+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
428+
const int i = i0 + threadIdx.x;
429+
if (i0 + warp_size > D/2 && i >= D/2) break;
430+
VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
431+
}
432+
}
433+
}
434+
435+
__syncthreads();
436+
}
389437
#pragma unroll
390438
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
391439
const int j_VKQ = j0 + threadIdx.y;

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
461461
const ggml_tensor * K = dst->src[1];
462462
const ggml_tensor * V = dst->src[2];
463463
const ggml_tensor * mask = dst->src[3];
464-
const ggml_tensor * sinks = dst->src[4];
465464

466465
ggml_cuda_set_device(ctx.device);
467466
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
468467
const int32_t precision = KQV->op_params[3];
469468

470469
// On AMD the tile kernels perform poorly, use the vec kernel instead:
471-
if (cc >= CC_OFFSET_AMD || (sinks && !fp16_mma_available(cc))) {
470+
if (cc >= CC_OFFSET_AMD) {
472471
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
473472
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
474473
} else {

0 commit comments

Comments
 (0)