File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change @@ -4420,6 +4420,7 @@ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_E
44204420
44214421constant int32_t  FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24 )]];
44224422
4423+ //  pad the last chunk of C elements of k and v into a an extra pad buffer
44234424kernel void  kernel_flash_attn_ext_pad (
44244425        constant ggml_metal_kargs_flash_attn_ext_pad & args,
44254426        device const  char  * k,
@@ -4450,6 +4451,7 @@ kernel void kernel_flash_attn_ext_pad(
44504451        device char  * v_dst = v_pad + args.nb21 *i1 + args.nb21 *C*i2 + args.nb21 *C*args.ne_12_2 *i3;
44514452
44524453        if  (i1 >= icp) {
4454+             //  here it is not important the exact value that will be used as we rely on masking out the scores in the attention
44534455            for  (uint64_t  i = tiitg; i < args.nb11 ; i += ntg.x ) {
44544456                k_dst[i] = 0 ;
44554457            }
@@ -4663,6 +4665,7 @@ void kernel_flash_attn_ext_impl(
46634665        for  (int  ic0 = 0 ; ic0 < args.ne11 ; ic0 += C) {
46644666            int  ic = ic0;
46654667
4668+             //  the last partial chunk uses the pad buffer as source
46664669            if  (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11 ) {
46674670                k    = pad;
46684671                v    = k + args.nb11 *C*args.ne_12_2 *args.ne_12_3 ;
@@ -5390,6 +5393,7 @@ void kernel_flash_attn_ext_vec_impl(
53905393                break ;
53915394            }
53925395
5396+             //  the last partial chunk uses the pad buffer as source
53935397            if  (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11 ) {
53945398                k    = pad;
53955399                v    = k + args.nb11 *C*args.ne_12_2 *args.ne_12_3 ;
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments