Skip to content

ggml : fix FA mask dim 2 and 3 #14505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1980,15 +1980,16 @@ extern "C" {

#define GGML_KQ_MASK_PAD 64

// q: [n_embd_k, n_batch, n_head, ne3]
// k: [n_embd_k, n_kv, n_head_kv, ne3]
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
// q: [n_embd_k, n_batch, n_head, ne3 ]
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
//
// broadcast:
// n_head % n_head_kv == 0
// ne3 % ne32 == 0
// n_head % ne32 == 0
// ne3 % ne33 == 0
//
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memset(VKQ32, 0, DV*sizeof(float));
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;

// k indices
const int ik3 = iq3 / rk3;
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3377,7 +3377,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) {
return false;
}
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,10 @@ typedef struct {
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
float scale;
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -4989,8 +4989,10 @@ static bool ggml_metal_encode_node(
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3784,7 +3784,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);

const float m = pm[ic + tiisg];

Expand Down Expand Up @@ -4270,7 +4270,7 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q;

// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);

float slope = 1.0f;

Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10265,6 +10265,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
return false;
}
Comment on lines +10268 to +10273
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler Just checking if you have started implementing this? I am curious to get #14363 running with CUDA and this is the missing piece. No rush though, just checking the status.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on it right now and unless something else comes up I'll make a PR this week. In fact, I have a specific question: is the mask being broadcast across dimension 2 needed in practice? I wrote the FA kernel for Turing or newer with a GQA-specific optimization that reduces I/O by using data from K/V and the mask for multiple Q values. However, this only works if all attention heads use the same mask.

Copy link
Member Author

@ggerganov ggerganov Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not needed. I was looking at this the other day and it does seem like unnecessary complication.

We will simply remove the test-backend-ops tests that currently use mask->ne[2] > 1 of the mask and update the ggml.h comment to state that mask->ne[2] == 1.

// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
Expand Down
5 changes: 2 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3666,7 +3666,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_3d(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
Expand Down Expand Up @@ -4696,12 +4695,12 @@ struct ggml_tensor * ggml_flash_attn_ext(

if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));

GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
}

if (max_bias > 0.0f) {
Expand Down
4 changes: 2 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3607,7 +3607,7 @@ struct test_flash_attn_ext : public test_case {

ggml_tensor * m = nullptr;
if (mask) {
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
ggml_set_name(m, "m");
}

Expand Down Expand Up @@ -4720,7 +4720,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));

if (ne0 <= 32 && ne1 <= 32) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
}
}
Expand Down
Loading