Skip to content

Commit d3984d9

Browse files
jeffbolznvqnixsynapse
authored andcommitted
vulkan: support softmax/FA batch and broadcast (ggml-org#14449)
1 parent a522cda commit d3984d9

File tree

5 files changed

+121
-124
lines changed

5 files changed

+121
-124
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ struct vk_flash_attn_push_constants {
633633
uint32_t nev2;
634634
uint32_t nev3;
635635
uint32_t nem1;
636+
uint32_t nem2;
636637

637638
uint32_t nb01;
638639
uint32_t nb02;
@@ -643,7 +644,6 @@ struct vk_flash_attn_push_constants {
643644
uint32_t nb21;
644645
uint32_t nb22;
645646
uint32_t nb23;
646-
uint32_t nb31;
647647

648648
float scale;
649649
float max_bias;
@@ -658,6 +658,7 @@ struct vk_flash_attn_push_constants {
658658
uint32_t split_kv;
659659
uint32_t k_num;
660660
};
661+
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
661662

662663
struct vk_op_push_constants {
663664
uint32_t KX;
@@ -756,6 +757,14 @@ struct vk_op_rope_push_constants {
756757
struct vk_op_soft_max_push_constants {
757758
uint32_t KX;
758759
uint32_t KY;
760+
uint32_t ne00;
761+
uint32_t ne01;
762+
uint32_t ne02;
763+
uint32_t ne12;
764+
uint32_t ne13;
765+
uint32_t nb11;
766+
uint32_t nb12;
767+
uint32_t nb13;
759768
float scale;
760769
float max_bias;
761770
float m0;
@@ -6040,7 +6049,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
60406049
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
60416050

60426051
const uint32_t nem1 = mask ? mask->ne[1] : 0;
6043-
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
6052+
const uint32_t nem2 = mask ? mask->ne[2] : 0;
60446053

60456054
const uint32_t D = neq0;
60466055
uint32_t N = neq1;
@@ -6203,7 +6212,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62036212
// Try to use split_k when KV is large enough to be worth the overhead
62046213
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
62056214
// Try to run two workgroups per SM.
6206-
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6215+
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
62076216
if (split_k > 1) {
62086217
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
62096218
// of "align", so recompute split_k based on that.
@@ -6213,9 +6222,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62136222
}
62146223
}
62156224

6216-
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6217-
// and the per-row m and L values (ne1 rows).
6218-
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6225+
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
6226+
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6227+
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
62196228
if (split_k_size > ctx->device->max_memory_allocation_size) {
62206229
GGML_ABORT("Requested preallocation size is too large");
62216230
}
@@ -6307,11 +6316,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63076316
(uint32_t)neq2, (uint32_t)neq3,
63086317
(uint32_t)nek2, (uint32_t)nek3,
63096318
(uint32_t)nev2, (uint32_t)nev3,
6310-
nem1,
6319+
nem1, nem2,
63116320
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
63126321
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
63136322
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6314-
nbm1,
63156323
scale, max_bias, logit_softcap,
63166324
mask != nullptr, n_head_log2, m0, m1,
63176325
gqa_ratio, split_kv, split_k };
@@ -6334,13 +6342,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63346342
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
63356343

63366344
ggml_vk_sync_buffers(subctx);
6337-
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
6345+
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
63386346
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
63396347
{
63406348
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
63416349
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
63426350
},
6343-
pc2, { (uint32_t)ne1, 1, 1 });
6351+
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
63446352
} else {
63456353
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
63466354
{
@@ -7666,7 +7674,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
76667674
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
76677675
const uint32_t nrows_y = (uint32_t)src0->ne[1];
76687676

7669-
const uint32_t n_head_kv = nrows_x/nrows_y;
7677+
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7678+
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7679+
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7680+
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7681+
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7682+
7683+
const uint32_t n_head_kv = src0->ne[2];
76707684
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
76717685

76727686
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7675,6 +7689,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
76757689
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
76767690
ncols,
76777691
src1 != nullptr ? nrows_y : (uint32_t)0,
7692+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7693+
ne12, ne13,
7694+
nb11, nb12, nb13,
76787695
scale, max_bias,
76797696
m0, m1,
76807697
n_head_log2,
@@ -10248,11 +10265,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1024810265
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1024910266
return false;
1025010267
}
10251-
// TODO: support broadcast
10252-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
10253-
if (op->src[0]->ne[3] != 1) {
10254-
return false;
10255-
}
1025610268
// It's straightforward to support different K/V dequant, but would
1025710269
// significantly increase the number of pipelines
1025810270
if (op->src[1]->type != op->src[2]->type) {
@@ -10413,13 +10425,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1041310425
case GGML_OP_DIAG_MASK_INF:
1041410426
return true;
1041510427
case GGML_OP_SOFT_MAX:
10416-
// TODO: support batching
10417-
if (op->src[0]->ne[3] != 1) {
10418-
return false;
10419-
}
10420-
// TODO: support broadcast
10421-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
10422-
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
1042310428
case GGML_OP_SOFT_MAX_BACK:
1042410429
case GGML_OP_ARGSORT:
1042510430
case GGML_OP_SUM:

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
#include "types.comp"
1212
#include "flash_attn_base.comp"
1313

14-
const uint32_t HSK_per_thread = HSK / D_split;
15-
const uint32_t HSV_per_thread = HSV / D_split;
14+
const uint32_t D_per_thread = D / D_split;
1615

1716
const uint32_t cols_per_iter = WorkGroupSize / D_split;
1817
const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -30,7 +29,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
3029
// Rows index by Q's dimension 2, and the first N rows are valid.
3130
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
3231
{
33-
uint32_t offset = (iq2 + r) * HSV + c;
32+
uint32_t offset = (iq2 + r) * D + c;
3433
data_o[o_offset + offset] = D_TYPE(elem);
3534
return elem;
3635
}
@@ -39,7 +38,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
3938
shared vec4 tmpshv4[WorkGroupSize];
4039

4140
shared float masksh[Bc][Br];
42-
shared vec4 Qf[Br][HSK / 4];
41+
shared vec4 Qf[Br][D / 4];
4342

4443
void main() {
4544
#ifdef NEEDS_INIT_IQ_SHMEM
@@ -54,18 +53,18 @@ void main() {
5453

5554
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
5655

57-
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
58-
uint32_t d = (idx + tid) % (HSK / 4);
59-
uint32_t r = (idx + tid) / (HSK / 4);
60-
if (r < Br && d < HSK / 4 &&
56+
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
57+
uint32_t d = (idx + tid) % (D / 4);
58+
uint32_t r = (idx + tid) / (D / 4);
59+
if (r < Br && d < D / 4 &&
6160
i * Br + r < N) {
6261
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
6362
}
6463
}
6564
barrier();
6665

67-
vec4 Of[Br][HSV_per_thread / 4];
68-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
66+
vec4 Of[Br][D_per_thread / 4];
67+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
6968
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
7069
Of[r][d] = vec4(0.0);
7170
}
@@ -101,8 +100,8 @@ void main() {
101100
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
102101
#endif
103102
uint32_t m_offset = 0;
104-
if (p.nem2 != 1 || p.nem3 != 1) {
105-
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
103+
if (p.nem2 != 1) {
104+
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
106105
}
107106

108107
[[dont_unroll]]
@@ -117,7 +116,7 @@ void main() {
117116

118117

119118
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
120-
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
119+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
121120
#if BLOCK_SIZE > 1
122121
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
123122
uint ib = coord / BLOCK_SIZE;
@@ -149,7 +148,7 @@ void main() {
149148
}
150149
}
151150

152-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
151+
if (p.mask != 0) {
153152

154153
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
155154
uint32_t c = (idx + tid) % Bc;
@@ -196,14 +195,14 @@ void main() {
196195
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
197196
}
198197

199-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
198+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
200199
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
201200
Of[r][d] = eMf[r] * Of[r][d];
202201
}
203202
}
204203

205204
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
206-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
205+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
207206
#if BLOCK_SIZE > 1
208207
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
209208
uint ib = coord / BLOCK_SIZE;
@@ -260,7 +259,7 @@ void main() {
260259
Lf[r] = tmpsh[d_tid];
261260
barrier();
262261

263-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
262+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
264263

265264
Of[r][d] = eMf * Of[r][d];
266265
tmpshv4[tid] = Of[r][d];
@@ -282,19 +281,19 @@ void main() {
282281
// If there is split_k, then the split_k resolve shader does the final
283282
// division by L. Store the intermediate O value and per-row m and L values.
284283
if (p.k_num > 1) {
285-
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
284+
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
286285

287286
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
288287
if (r < N) {
289-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
288+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
290289
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
291290
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
292291
}
293292
}
294293
}
295294
}
296295

297-
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
296+
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
298297
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
299298
if (r < N) {
300299
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -310,18 +309,18 @@ void main() {
310309
Lfrcp[r] = 1.0 / Lf[r];
311310
}
312311

313-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
312+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
314313
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
315314
Of[r][d] *= Lfrcp[r];
316315
}
317316
}
318317

319-
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
318+
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
320319

321320
if (p.gqa_ratio > 1) {
322321
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
323322
if (r < N) {
324-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
323+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
325324
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
326325
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
327326
}
@@ -331,9 +330,9 @@ void main() {
331330
} else {
332331
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
333332
if (i * Br + r < N) {
334-
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
333+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
335334
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
336-
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
335+
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
337336
}
338337
}
339338
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
44
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
55
layout (constant_id = 1) const uint32_t Br = 1;
66
layout (constant_id = 2) const uint32_t Bc = 32;
7-
layout (constant_id = 3) const uint32_t HSK = 32;
8-
layout (constant_id = 4) const uint32_t HSV = 32;
9-
layout (constant_id = 5) const uint32_t Clamp = 0;
10-
layout (constant_id = 6) const uint32_t D_split = 16;
7+
layout (constant_id = 3) const uint32_t D = 32;
8+
layout (constant_id = 4) const uint32_t Clamp = 0;
9+
layout (constant_id = 5) const uint32_t D_split = 16;
10+
1111

1212
layout (push_constant) uniform parameter {
1313
uint32_t N;
@@ -25,7 +25,6 @@ layout (push_constant) uniform parameter {
2525
uint32_t nev3;
2626
uint32_t nem1;
2727
uint32_t nem2;
28-
uint32_t nem3;
2928

3029
uint32_t nb01;
3130
uint32_t nb02;
@@ -41,7 +40,8 @@ layout (push_constant) uniform parameter {
4140
float max_bias;
4241
float logit_softcap;
4342

44-
uint32_t mask_n_head_log2;
43+
uint32_t mask;
44+
uint32_t n_head_log2;
4545
float m0;
4646
float m1;
4747

@@ -50,9 +50,6 @@ layout (push_constant) uniform parameter {
5050
uint32_t k_num;
5151
} p;
5252

53-
#define MASK_ENABLE_BIT (1<<16)
54-
#define N_LOG2_MASK 0xFFFF
55-
5653
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
5754

5855
#if defined(A_TYPE_PACKED16)
@@ -103,10 +100,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
103100
{
104101
const uint32_t h = iq2 + (r % p.gqa_ratio);
105102

106-
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
107-
108-
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
109-
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
103+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
110105

111106
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
112107
}

0 commit comments

Comments
 (0)