@@ -17103,25 +17103,23 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
1710317103inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1710417104 const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
1710517105
17106- // Not sure if this actually helps.
17107- // So, let's reduce compilation time by commenting it out for now.
17108- //if (nk1 >= 256) { //4096) {
17109- // if (nq1 >= 64) {
17110- // FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17111- // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17112- // return;
17113- // }
17114- // if (nq1 >= 32) {
17115- // FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17116- // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17117- // return;
17118- // }
17119- // if (nq1 >= 16) {
17120- // FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17121- // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17122- // return;
17123- // }
17124- //}
17106+ if (nk1 >= 256) { //4096) {
17107+ if (nq1 >= 64) {
17108+ FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17109+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17110+ return;
17111+ }
17112+ if (nq1 >= 32) {
17113+ FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17114+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17115+ return;
17116+ }
17117+ if (nq1 >= 16) {
17118+ FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17119+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17120+ return;
17121+ }
17122+ }
1712517123 if (nq1 >= 8) {
1712617124 FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
1712717125 fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
0 commit comments