@@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
22612261                    }
22622262
22632263                    simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
2264- 
2265-                     const  short  tx = tiisg%4 ;
2266-                     const  short  ty = tiisg/4 ;
2267- 
2268-                     //  mqk = mqk*scale
2269-                     ss[8 *cc + ty*TF + 2 *tx + 0 ] *= scale;
2270-                     ss[8 *cc + ty*TF + 2 *tx + 1 ] *= scale;
2271- 
2272-                     if  (logit_softcap != 0 .0f ) {
2273-                         ss[8 *cc + ty*TF + 2 *tx + 0 ] = logit_softcap*precise::tanh (ss[8 *cc + ty*TF + 2 *tx + 0 ]);
2274-                         ss[8 *cc + ty*TF + 2 *tx + 1 ] = logit_softcap*precise::tanh (ss[8 *cc + ty*TF + 2 *tx + 1 ]);
2275-                     }
2276- 
2277-                     if  (mask != q) {
2278-                         //  mqk = mqk + mask*slope
2279-                         ss[8 *cc + ty*TF + 2 *tx + 0 ] += slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 0 ];
2280-                         ss[8 *cc + ty*TF + 2 *tx + 1 ] += slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 1 ];
2281-                     }
22822264                }
22832265            }
22842266
@@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
22902272                float  ms[Q];
22912273
22922274                for  (short  j = 0 ; j < Q; ++j) {
2293-                     const  short  p = tiisg;
2294- 
22952275                    const  float  m = M[j];
2296-                     const  float  s = ss[j*TF + p];
2276+ 
2277+                     //  scale and apply the logitcap / mask
2278+                     float  s = ss[j*TF + tiisg]*scale;
2279+ 
2280+                     if  (logit_softcap != 0 .0f ) {
2281+                         s = logit_softcap*precise::tanh (s);
2282+                     }
2283+ 
2284+                     if  (mask != q) {
2285+                         //  mqk = mqk + mask*slope
2286+                         s += slope*mp[ic + j*nb31/sizeof (half) + tiisg];
2287+                     }
22972288
22982289                    smax = simd_max (max (smax, s));
22992290                    M[j] = simd_max (max (M[j], s));
@@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
23042295                    S[j] = S[j]*ms[j] + simd_sum (vs);
23052296
23062297                    //  the P matrix from the paper (Q rows, C columns)
2307-                     ss[j*TF + p ] = vs;
2298+                     ss[j*TF + tiisg ] = vs;
23082299                }
23092300
23102301                //  create a QxQ diagonal matrix for rescaling the output
0 commit comments