@@ -134,21 +134,33 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
134134 FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q5_0)
135135 FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
136136 FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
137- FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16 )
137+ FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
138+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
139+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
138140
139141 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
140142 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
141143 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
142144 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
143145 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
144146 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_0)
147+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
148+
149+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL)
150+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL)
151+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
152+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL)
153+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
154+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
155+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
145156
146157 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
147158 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
148159 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
149160 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
150161 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
151162 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_1)
163+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1)
152164
153165 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
154166 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
@@ -159,40 +171,39 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
159171
160172 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
161173 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
162-
163174 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
164175 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
165176 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
166177 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_0)
178+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0)
167179
168180 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
169181 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
170182 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
171183 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
172- // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
173- // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
184+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
185+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
186+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1)
174187
175188 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
176189 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
177190 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
178191 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
179192 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
180193 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
194+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0)
181195
182196 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_F16)
183197 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_F16)
184198 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_F16)
185199 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_F16)
186200 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_F16)
187201 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
202+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
188203
189204 FATTN_VEC_F16_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
190205
191- // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
192- // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
193206 // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
194- // FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
195- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
196207 // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
197208#else
198209 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -232,20 +243,32 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
232243 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
233244 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
234245 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
246+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
247+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
235248
236249 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
237250 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
238251 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
239252 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
240253 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
241254 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_0)
255+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_0)
256+
257+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL)
258+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL)
259+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
260+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL)
261+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
262+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
263+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
242264
243265 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
244266 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
245267 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
246268 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
247269 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
248270 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_1)
271+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1)
249272
250273 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
251274 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
@@ -256,34 +279,40 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
256279
257280 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
258281 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
259-
260282 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
261283 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
262284 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
263285 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_0)
286+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0)
264287
265288 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
266289 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
267290 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
268291 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
269- // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
270- // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
292+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
293+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
294+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1)
271295
272296 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
273297 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
274298 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
275299 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
276300 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
277301 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
302+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0)
278303
279304 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_F16)
280305 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_F16)
281306 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_F16)
282307 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_F16)
283308 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_F16)
284309 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
310+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
285311
286312 FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
313+
314+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
315+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
287316#else
288317 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
289318
@@ -292,6 +321,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
292321 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
293322 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
294323 FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
324+
325+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
326+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
327+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
328+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
329+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
330+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
295331#endif // GGML_CUDA_FA_ALL_QUANTS
296332
297333 GGML_ABORT (" fatal error" );
0 commit comments