@@ -152,7 +152,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
152152 } \
153153
154154static void ggml_cuda_flash_attn_ext_vec_f16 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155- ggml_tensor * Q = dst->src [0 ];
155+ ggml_tensor * Q = dst->src [1 ];
156156 ggml_tensor * K = dst->src [1 ];
157157 ggml_tensor * V = dst->src [2 ];
158158
@@ -162,57 +162,69 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
162162 FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q5_0)
163163 FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
164164 FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
165- FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16 )
165+ FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
166+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
167+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
166168
167169 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
168170 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
169171 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
170172 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
171173 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
172174 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_0)
175+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
176+
177+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL)
178+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL)
179+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
180+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL)
181+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
182+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
183+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
173184
174185 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
175186 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
176187 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
177188 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
178189 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
179190 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_1)
191+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1)
180192
181193 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
182194 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
183195 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
184196 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
185197 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
186198 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_0)
199+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0)
187200
188201 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
189202 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
190203 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
191204 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
192205 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
193206 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
207+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1)
194208
195209 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
196210 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
197211 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
198212 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
199213 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
200214 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
215+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0)
201216
202217 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_F16)
203218 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_F16)
204219 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_F16)
205220 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_F16)
206221 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_F16)
207222 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
223+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
208224
209225 FATTN_VEC_F16_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
210226
211- // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
212- // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
213227 // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
214- // FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
215- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
216228 // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
217229#else
218230 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -241,7 +253,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
241253 } \
242254
243255static void ggml_cuda_flash_attn_ext_vec_f32 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
244- ggml_tensor * Q = dst->src [0 ];
256+ ggml_tensor * Q = dst->src [1 ];
245257 ggml_tensor * K = dst->src [1 ];
246258 ggml_tensor * V = dst->src [2 ];
247259
@@ -252,50 +264,69 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
252264 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
253265 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
254266 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
267+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
268+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
255269
256270 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
257271 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
258272 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
259273 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
260274 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
261275 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_0)
276+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_0)
277+
278+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL)
279+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL)
280+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
281+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL)
282+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
283+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
284+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
262285
263286 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
264287 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
265288 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
266289 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
267290 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
268291 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q4_1)
292+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1)
269293
270294 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
271295 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
272296 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
273297 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
274298 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
275299 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_0)
300+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0)
276301
277302 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
278303 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
279304 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
280305 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
281306 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
282307 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q5_1)
308+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1)
283309
284310 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
285311 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
286312 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
287313 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
288314 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
289315 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_Q8_0)
316+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0)
290317
291318 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_F16)
292319 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_1, GGML_TYPE_F16)
293320 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_F16)
294321 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_F16)
295322 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_F16)
296323 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
324+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
297325
298326 FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
327+
328+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
329+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
299330#else
300331 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
301332
@@ -304,6 +335,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
304335 FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
305336 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
306337 FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
338+
339+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
340+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
341+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
342+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
343+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
344+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
307345#endif // GGML_CUDA_FA_ALL_QUANTS
308346
309347 on_no_fattn_vec_case (Q->ne [0 ]);
0 commit comments