Skip to content

Commit df36aa8

Browse files
committed
[CANN]Optimization of unnecessary repeat in the FA operator
Signed-off-by: noemotiovon <757486878@qq.com>
1 parent 74f52f7 commit df36aa8

File tree

1 file changed

+31
-56
lines changed

1 file changed

+31
-56
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,66 +3251,41 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32513251

32523252
// Step 3: create the PSEShift tensor if needed
32533253
// this tensor is considered as mask (f16) in the llama.cpp
3254-
32553254
aclTensor* bcast_pse_tensor = nullptr;
3256-
int64_t bcast_pse_ne[GGML_MAX_DIMS];
3257-
size_t bcast_pse_nb[GGML_MAX_DIMS];
3258-
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
3259-
void* bcast_pse_buffer = nullptr;
3260-
32613255
if(src3 != nullptr){
3262-
bcast_pse_buffer = bcast_pse_allocator.alloc(
3263-
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
3264-
3265-
if(src0->ne[1] > 1){
3266-
// Case 1: broadcast pse for prefill stage with multiple head
3267-
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
3268-
bcast_pse_ne[0] = src3->ne[0];
3269-
bcast_pse_ne[1] = src3->ne[1];
3270-
bcast_pse_ne[2] = src0->ne[2];
3271-
bcast_pse_ne[3] = src3->ne[3];
3272-
3273-
bcast_pse_nb[0] = sizeof(uint16_t);
3274-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3275-
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
3276-
}
3277-
3278-
bcast_pse_tensor = ggml_cann_create_tensor(
3279-
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
3280-
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3281-
3282-
int64_t repeats[] = {1, src0->ne[2], 1, 1};
3283-
aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
3284-
3285-
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
3286-
}else{
3287-
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
3288-
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
3289-
size_t* trunc_pse_nb = src3->nb;
3290-
3291-
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
3292-
src3->data, ACL_FLOAT16, sizeof(uint16_t),
3293-
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
3294-
3295-
bcast_pse_ne[0] = src3->ne[0];
3296-
bcast_pse_ne[1] = src0->ne[1];
3297-
bcast_pse_ne[2] = src0->ne[2];
3298-
bcast_pse_ne[3] = src3->ne[3];
3299-
3300-
bcast_pse_nb[0] = sizeof(uint16_t);
3301-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3302-
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
3303-
}
3304-
3305-
bcast_pse_tensor = ggml_cann_create_tensor(
3306-
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
3307-
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3256+
// Construct the truncated pse tensor (common for prefill/decode)
3257+
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
3258+
src3->ne[0], // D
3259+
src0->ne[1], // S (number of Q tokens)
3260+
src3->ne[2], // mask N
3261+
src3->ne[3] // B
3262+
};
3263+
size_t* trunc_pse_nb = src3->nb;
3264+
3265+
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
3266+
src3->data, ACL_FLOAT16, sizeof(uint16_t),
3267+
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
3268+
);
33083269

3309-
int64_t repeats[] = {1, src0->ne[2], 1, 1};
3310-
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
3270+
// Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
3271+
int64_t bcast_pse_ne[GGML_MAX_DIMS];
3272+
size_t bcast_pse_nb[GGML_MAX_DIMS];
3273+
bcast_pse_ne[0] = src3->ne[0]; // D
3274+
bcast_pse_ne[1] = src0->ne[1]; // S
3275+
bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
3276+
bcast_pse_ne[3] = src3->ne[3]; // B
3277+
3278+
bcast_pse_nb[0] = sizeof(uint16_t);
3279+
bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
3280+
bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
3281+
bcast_pse_nb[3] = bcast_pse_nb[1] * bcast_pse_ne[1];
3282+
3283+
bcast_pse_tensor = ggml_cann_create_tensor(
3284+
src3->data, ACL_FLOAT16, sizeof(uint16_t),
3285+
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3286+
);
33113287

3312-
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3313-
}
3288+
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
33143289

33153290
// Compute the slope if needed. Derived from ggml_cann_softmax().
33163291
if(maxBias != 0.0f){

0 commit comments

Comments
 (0)