Skip to content

Commit 09f0444

Browse files
committed
CANN(flash-attn): refactor mask handling and improve performance
1. Refactored the mask computation in Flash Attention, unified the logic without separating prefill and decode. 2. Optimized performance in non-alibi scenarios by reducing one repeat operation. 3. Updated operator management to explicitly mark unsupported cases on 310P devices and when dim is not divisible by 16. Signed-off-by: noemotiovon <757486878@qq.com>
1 parent 74f52f7 commit 09f0444

File tree

2 files changed

+53
-52
lines changed

2 files changed

+53
-52
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
14271427
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
14281428
float m, int64_t size, float start, float stop, float step){
14291429
int64_t ne[] = {size};
1430-
size_t nb[] = {sizeof(float)};
1430+
size_t nb[] = {sizeof(uint16_t)};
14311431

1432-
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
1432+
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
14331433
void* arange_buffer = arange_allocator.get();
14341434

14351435
aclTensor* arange_tensor = ggml_cann_create_tensor(
1436-
arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
1436+
arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
14371437
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
14381438

14391439
aclTensor* slope_tensor = ggml_cann_create_tensor(
1440-
slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
1440+
slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
14411441

14421442
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
14431443

@@ -3251,88 +3251,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32513251

32523252
// Step 3: create the PSEShift tensor if needed
32533253
// this tensor is considered as mask (f16) in the llama.cpp
3254-
32553254
aclTensor* bcast_pse_tensor = nullptr;
3256-
int64_t bcast_pse_ne[GGML_MAX_DIMS];
3257-
size_t bcast_pse_nb[GGML_MAX_DIMS];
32583255
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
3259-
void* bcast_pse_buffer = nullptr;
3260-
32613256
if(src3 != nullptr){
3262-
bcast_pse_buffer = bcast_pse_allocator.alloc(
3263-
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
3264-
3265-
if(src0->ne[1] > 1){
3266-
// Case 1: broadcast pse for prefill stage with multiple head
3267-
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
3268-
bcast_pse_ne[0] = src3->ne[0];
3269-
bcast_pse_ne[1] = src3->ne[1];
3270-
bcast_pse_ne[2] = src0->ne[2];
3271-
bcast_pse_ne[3] = src3->ne[3];
3257+
// Construct the truncated pse tensor (common for prefill/decode)
3258+
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
3259+
src3->ne[0], // D
3260+
src0->ne[1], // S (number of Q tokens)
3261+
src3->ne[2], // mask N
3262+
src3->ne[3] // B
3263+
};
3264+
size_t* trunc_pse_nb = src3->nb;
3265+
3266+
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
3267+
src3->data, ACL_FLOAT16, sizeof(uint16_t),
3268+
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
3269+
);
32723270

3271+
int64_t bcast_pse_ne[GGML_MAX_DIMS];
3272+
size_t bcast_pse_nb[GGML_MAX_DIMS];
3273+
bcast_pse_ne[0] = src3->ne[0]; // D
3274+
bcast_pse_ne[1] = src0->ne[1]; // S
3275+
bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
3276+
bcast_pse_ne[3] = src3->ne[3]; // B
3277+
if (maxBias == 0.0f) {
3278+
// When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
3279+
// Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
32733280
bcast_pse_nb[0] = sizeof(uint16_t);
3274-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3275-
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
3276-
}
3281+
bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
3282+
bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
3283+
bcast_pse_nb[3] = src3->nb[3];
32773284

32783285
bcast_pse_tensor = ggml_cann_create_tensor(
3279-
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
3280-
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3281-
3282-
int64_t repeats[] = {1, src0->ne[2], 1, 1};
3283-
aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
3284-
3285-
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
3286-
}else{
3287-
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
3288-
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
3289-
size_t* trunc_pse_nb = src3->nb;
3290-
3291-
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
32923286
src3->data, ACL_FLOAT16, sizeof(uint16_t),
3293-
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
3294-
3295-
bcast_pse_ne[0] = src3->ne[0];
3296-
bcast_pse_ne[1] = src0->ne[1];
3297-
bcast_pse_ne[2] = src0->ne[2];
3298-
bcast_pse_ne[3] = src3->ne[3];
3287+
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3288+
);
32993289

3290+
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3291+
} else {
33003292
bcast_pse_nb[0] = sizeof(uint16_t);
3301-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3293+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
33023294
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
33033295
}
33043296

3297+
void* bcast_pse_buffer = bcast_pse_allocator.alloc(
3298+
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
3299+
);
3300+
33053301
bcast_pse_tensor = ggml_cann_create_tensor(
33063302
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
3307-
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3303+
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3304+
);
33083305

33093306
int64_t repeats[] = {1, src0->ne[2], 1, 1};
33103307
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
33113308

3312-
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3313-
}
3314-
3315-
// Compute the slope if needed. Derived from ggml_cann_softmax().
3316-
if(maxBias != 0.0f){
33173309
// alibi
3310+
// Compute the slope if needed. Derived from ggml_cann_softmax().
33183311
const int64_t n_heads = src0->ne[2];
3319-
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
3312+
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
33203313
void* slope_buffer = slope_allocator.get();
33213314
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
33223315

33233316
int64_t slope_ne[] = {1, 1, n_heads, 1};
33243317
size_t slope_nb[GGML_MAX_DIMS];
3325-
slope_nb[0] = sizeof(float);
3318+
slope_nb[0] = sizeof(uint16_t);
33263319
for(int i = 1;i<GGML_MAX_DIMS;i++) {
33273320
slope_nb[i] = slope_nb[i-1] * slope_ne[0];
33283321
}
33293322

33303323
aclTensor* slope_tensor = ggml_cann_create_tensor(
3331-
slope_buffer, ACL_FLOAT, sizeof(float),
3324+
slope_buffer, ACL_FLOAT16, sizeof(uint16_t),
33323325
slope_ne, slope_nb, GGML_MAX_DIMS);
33333326
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
33343327

3335-
ggml_cann_release_resources(ctx, slope_tensor);
3328+
ggml_cann_release_resources(ctx, slope_tensor, acl_mask_f16_trunc_tensor);
33363329
}
33373330
}
33383331

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,6 +2505,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
25052505
}
25062506
return true;
25072507
case GGML_OP_FLASH_ATTN_EXT:{
2508+
#ifdef ASCEND_310P
2509+
// FA not suppor on 310p device
2510+
return false;
2511+
#endif
25082512
// derived from [ggml-cuda.cu]
25092513
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
25102514
return false;
@@ -2530,6 +2534,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
25302534
// DeepSeek MLA
25312535
return false;
25322536
}
2537+
if (op->src[0]->ne[0] % 16 != 0) {
2538+
// TODO: padding to support
2539+
return false;
2540+
}
25332541
float logitSoftcap = 0.0f;
25342542
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
25352543
if(logitSoftcap != 0.0f) {

0 commit comments

Comments
 (0)