@@ -3251,66 +3251,41 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32513251
32523252        //  Step 3: create the PSEShift tensor if needed
32533253        //          this tensor is considered as mask (f16) in the llama.cpp
3254- 
32553254        aclTensor* bcast_pse_tensor = nullptr ;
3256-         int64_t  bcast_pse_ne[GGML_MAX_DIMS];
3257-         size_t  bcast_pse_nb[GGML_MAX_DIMS];
3258-         ggml_cann_pool_alloc bcast_pse_allocator (ctx.pool ());
3259-         void * bcast_pse_buffer = nullptr ;
3260- 
32613255        if (src3 != nullptr ){
3262-             bcast_pse_buffer = bcast_pse_allocator.alloc (
3263-                             ggml_nelements (src3) * src0->ne [2 ] * sizeof (uint16_t ));
3264- 
3265-             if (src0->ne [1 ] > 1 ){
3266-                 //  Case 1: broadcast pse for prefill stage with multiple head
3267-                 aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor (src3);
3268-                 bcast_pse_ne[0 ] = src3->ne [0 ];
3269-                 bcast_pse_ne[1 ] = src3->ne [1 ];
3270-                 bcast_pse_ne[2 ] = src0->ne [2 ];
3271-                 bcast_pse_ne[3 ] = src3->ne [3 ];
3272- 
3273-                 bcast_pse_nb[0 ] = sizeof (uint16_t );
3274-                 for (int  i = 1 ; i < GGML_MAX_DIMS; ++i){
3275-                     bcast_pse_nb[i] = bcast_pse_nb[i - 1 ] * bcast_pse_ne[i - 1 ];
3276-                 }
3277- 
3278-                 bcast_pse_tensor = ggml_cann_create_tensor (
3279-                     bcast_pse_buffer, ACL_FLOAT16, sizeof (uint16_t ),
3280-                     bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3281- 
3282-                 int64_t  repeats[] = {1 , src0->ne [2 ], 1 , 1 };
3283-                 aclnn_repeat (ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
3284- 
3285-                 ggml_cann_release_resources (ctx, acl_mask_f16_tensor);
3286-             }else {
3287-                 //  Case 2: trunc the first row and broadcast pse for decode stage with multiple head
3288-                 int64_t  trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne [0 ], src0->ne [1 ], src3->ne [2 ], src3->ne [3 ]};
3289-                 size_t * trunc_pse_nb = src3->nb ;
3290- 
3291-                 aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor (
3292-                     src3->data , ACL_FLOAT16, sizeof (uint16_t ),
3293-                     trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
3294- 
3295-                 bcast_pse_ne[0 ] = src3->ne [0 ];
3296-                 bcast_pse_ne[1 ] = src0->ne [1 ];
3297-                 bcast_pse_ne[2 ] = src0->ne [2 ];
3298-                 bcast_pse_ne[3 ] = src3->ne [3 ];
3299- 
3300-                 bcast_pse_nb[0 ] = sizeof (uint16_t );
3301-                 for (int  i = 1 ; i < GGML_MAX_DIMS; ++i){
3302-                     bcast_pse_nb[i] = bcast_pse_nb[i - 1 ] * bcast_pse_ne[i - 1 ];
3303-                 }
3304- 
3305-                 bcast_pse_tensor = ggml_cann_create_tensor (
3306-                     bcast_pse_buffer, ACL_FLOAT16, sizeof (uint16_t ),
3307-                     bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3256+             //  Construct the truncated pse tensor (common for prefill/decode)
3257+             int64_t  trunc_pse_ne[GGML_MAX_DIMS] = {
3258+                 src3->ne [0 ],        //  D
3259+                 src0->ne [1 ],        //  S (number of Q tokens)
3260+                 src3->ne [2 ],        //  mask N
3261+                 src3->ne [3 ]         //  B
3262+             };
3263+             size_t * trunc_pse_nb = src3->nb ;
3264+ 
3265+             aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor (
3266+                 src3->data , ACL_FLOAT16, sizeof (uint16_t ),
3267+                 trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
3268+             );
33083269
3309-                 int64_t  repeats[] = {1 , src0->ne [2 ], 1 , 1 };
3310-                 aclnn_repeat (ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
3270+             //  Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
3271+             int64_t  bcast_pse_ne[GGML_MAX_DIMS];
3272+             size_t  bcast_pse_nb[GGML_MAX_DIMS];
3273+             bcast_pse_ne[0 ] = src3->ne [0 ];      //  D
3274+             bcast_pse_ne[1 ] = src0->ne [1 ];      //  S
3275+             bcast_pse_ne[2 ] = src0->ne [2 ];      //  N (num_heads)
3276+             bcast_pse_ne[3 ] = src3->ne [3 ];      //  B
3277+ 
3278+             bcast_pse_nb[0 ] = sizeof (uint16_t );
3279+             bcast_pse_nb[1 ] = bcast_pse_nb[0 ] * bcast_pse_ne[0 ];
3280+             bcast_pse_nb[2 ] = 0 ;                //  <---- the head dimension shares the same data
3281+             bcast_pse_nb[3 ] = bcast_pse_nb[1 ] * bcast_pse_ne[1 ];
3282+ 
3283+             bcast_pse_tensor = ggml_cann_create_tensor (
3284+                 src3->data , ACL_FLOAT16, sizeof (uint16_t ),
3285+                 bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3286+             );
33113287
3312-                 ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
3313-             }
3288+             ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
33143289
33153290            //  Compute the slope if needed. Derived from ggml_cann_softmax().
33163291            if (maxBias != 0 .0f ){
0 commit comments