Skip to content

Commit ec4bc75

Browse files
author
Iwan Kawrakow
committed
Revert the commented out section in iqk_mul_mat.cpp
It does have some benefit at long contexts.
1 parent d12f4a1 commit ec4bc75

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17103,25 +17103,23 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
1710317103
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1710417104
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
1710517105

17106-
// Not sure if this actually helps.
17107-
// So, let's reduce compilation time by commenting it out for now.
17108-
//if (nk1 >= 256) { //4096) {
17109-
// if (nq1 >= 64) {
17110-
// FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17111-
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17112-
// return;
17113-
// }
17114-
// if (nq1 >= 32) {
17115-
// FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17116-
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17117-
// return;
17118-
// }
17119-
// if (nq1 >= 16) {
17120-
// FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17121-
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17122-
// return;
17123-
// }
17124-
//}
17106+
if (nk1 >= 256) { //4096) {
17107+
if (nq1 >= 64) {
17108+
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17109+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17110+
return;
17111+
}
17112+
if (nq1 >= 32) {
17113+
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17114+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17115+
return;
17116+
}
17117+
if (nq1 >= 16) {
17118+
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17119+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17120+
return;
17121+
}
17122+
}
1712517123
if (nq1 >= 8) {
1712617124
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
1712717125
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);

0 commit comments

Comments
 (0)