@@ -277,6 +277,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
277277    return  sum;
278278}
279279
280+ template <typename  T, int  D>
281+ static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_q6_0 (
282+     const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
283+ 
284+     const  block_q6_0 * K_q6_0 = (const  block_q6_0 *) K_c;
285+     GGML_UNUSED (Q_v);
286+ 
287+     T sum = 0 .0f ;
288+ 
289+ #pragma  unroll
290+     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE) {
291+         const  int  k_KQ = k_KQ_0 + threadIdx .x ;
292+ 
293+         const  int  ib    = k_KQ /  QI8_1;
294+         const  int  iqs4  = k_KQ %  QI6_0;  //  0...3
295+         const  int  shift = k_KQ & (QI8_1/2 );
296+ 
297+         const  int  vh = (get_int_b2 (K_q6_0[ib].qh , iqs4%2 ) >> (4 *(iqs4/2 ) + shift/2 )) & 0x03030303 ;
298+         const  int  vl = (get_int_b2 (K_q6_0[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
299+         const  int  v  = vl | (vh << 4 );
300+ 
301+         const  int  u = Q_q8[k_KQ_0/WARP_SIZE];
302+ 
303+         const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
304+ 
305+ #ifdef  FP16_AVAILABLE
306+         if  (std::is_same<T, half>::value) {
307+             const  half2  * Q_ds = (const  half2  *) Q_ds_v;
308+ 
309+             const  half2 sum2 = __half2half2 (K_q6_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE];
310+             sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2)*__float2half (4 .0f )) /*  *32/QI8_1 == 4 */ 
311+         } else 
312+ #endif  //  FP16_AVAILABLE
313+         {
314+             const  float2  * Q_ds = (const  float2  *) Q_ds_v;
315+ 
316+             sum += (T) (__half2float (K_q6_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x  - (32 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y ));
317+         }
318+     }
319+ 
320+     return  sum;
321+ }
322+ 
280323template  <typename  T, int  D>
281324static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_q8_0 (
282325    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
@@ -510,6 +553,30 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
510553    return  __low2float (dm)*((float ) q) + __high2float (dm);
511554}
512555
556+ template  <typename  T>
557+ static  __device__  __forceinline__  T dequantize_1_q6_0 (const  void  * __restrict__  vx, const  int64_t  i) {
558+     const  block_q6_0 * x = (const  block_q6_0 *) vx;
559+ 
560+     const  int64_t  ib    =  i  /  QK6_0;
561+     const  int      idq   =  i  %  QK6_0;
562+     const  int      iqs   =  i  % (QK6_0/2 );
563+     const  int      shift = idq / (QK6_0/2 );
564+     // const int     shift = (i % QK6_0) / (QK6_0/2);
565+ 
566+     const  T   d  = x[ib].d ;
567+     const  int  ql = x[ib].qs [iqs] >> 4 *shift;
568+     const  int  qh = x[ib].qh [idq%(QK6_0/4 )] >> (4 *((idq/(QK6_0/4 ))%2 ) + 2 *shift);
569+     const  int  q  = ((ql & 0x0f ) | ((qh & 0x03 ) << 4 )) - 32 ;
570+ 
571+ #ifdef  FP16_AVAILABLE
572+     if  (std::is_same<T, half>::value) {
573+         return  ((half) d)*((half) q);
574+     }
575+ #endif  //  FP16_AVAILABLE
576+ 
577+     return  ((float ) d)*((float ) q);
578+ }
579+ 
513580template  <typename  T>
514581static  __device__  __forceinline__  T dequantize_1_q8_0 (const  void  * __restrict__  vx, const  int64_t  i) {
515582    const  block_q8_0 * x = (const  block_q8_0 *) vx;
@@ -543,6 +610,7 @@ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
543610           type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
544611           type_K == GGML_TYPE_Q5_0   ? vec_dot_fattn_vec_KQ_q5_0<half, D>   :
545612           type_K == GGML_TYPE_Q5_1   ? vec_dot_fattn_vec_KQ_q5_1<half, D>   :
613+            type_K == GGML_TYPE_Q6_0   ? vec_dot_fattn_vec_KQ_q6_0<half, D>   :
546614           type_K == GGML_TYPE_Q8_0   ? vec_dot_fattn_vec_KQ_q8_0<half, D>   :
547615           type_K == GGML_TYPE_F16    ? vec_dot_fattn_vec_KQ_f16<half, D>    :
548616           nullptr ;
@@ -555,6 +623,7 @@ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
555623           type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float , D> :
556624           type_K == GGML_TYPE_Q5_0   ? vec_dot_fattn_vec_KQ_q5_0<float , D>   :
557625           type_K == GGML_TYPE_Q5_1   ? vec_dot_fattn_vec_KQ_q5_1<float , D>   :
626+            type_K == GGML_TYPE_Q6_0   ? vec_dot_fattn_vec_KQ_q6_0<float , D>   :
558627           type_K == GGML_TYPE_Q8_0   ? vec_dot_fattn_vec_KQ_q8_0<float , D>   :
559628           type_K == GGML_TYPE_F16    ? vec_dot_fattn_vec_KQ_f16<float , D>    :
560629           nullptr ;
@@ -565,6 +634,7 @@ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
565634           type_V == GGML_TYPE_Q4_1   ? dequantize_1_q4_1<half> :
566635           type_V == GGML_TYPE_Q5_0   ? dequantize_1_q5_0<half> :
567636           type_V == GGML_TYPE_Q5_1   ? dequantize_1_q5_1<half> :
637+            type_V == GGML_TYPE_Q6_0   ? dequantize_1_q6_0<half> :
568638           type_V == GGML_TYPE_Q8_0   ? dequantize_1_q8_0<half> :
569639           type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
570640           type_V == GGML_TYPE_F16    ? dequantize_1_f16<half> :
@@ -576,6 +646,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
576646           type_V == GGML_TYPE_Q4_1   ? dequantize_1_q4_1<float > :
577647           type_V == GGML_TYPE_Q5_0   ? dequantize_1_q5_0<float > :
578648           type_V == GGML_TYPE_Q5_1   ? dequantize_1_q5_1<float > :
649+            type_V == GGML_TYPE_Q6_0   ? dequantize_1_q6_0<float > :
579650           type_V == GGML_TYPE_Q8_0   ? dequantize_1_q8_0<float > :
580651           type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float > :
581652           type_V == GGML_TYPE_F16    ? dequantize_1_f16<float > :
@@ -635,10 +706,13 @@ static void on_no_fattn_vec_case(const int D) {
635706    } else  if  (D == 128 ) {
636707        fprintf (stderr, " Unsupported KV type combination for head_size 128.\n " 
637708        fprintf (stderr, " Supported combinations:\n " 
638-         fprintf (stderr, "   - K == q4_0,   V == q4_0,   4.50 BPV\n " 
639-         fprintf (stderr, "   - K == iq4_nl, V == iq4_nl, 4.50 BPV\n " 
640-         fprintf (stderr, "   - K == q8_0,   V == q8_0,   8.50 BPV\n " 
641-         fprintf (stderr, "   - K == f16,    V == f16,   16.00 BPV\n " 
709+         fprintf (stderr, "   - K == q4_0,   V == q4_0,   4.5 BPV\n " 
710+         fprintf (stderr, "   - K == iq4_nl, V == iq4_nl, 4.5 BPV\n " 
711+         fprintf (stderr, "   - K == q6_0,   V == q5_0,   6.0 BPV\n " 
712+         fprintf (stderr, "   - K == q8_0,   V == iq4_nl, 6.5 BPV\n " 
713+         fprintf (stderr, "   - K == q8_0,   V == q6_0,   7.5 BPV\n " 
714+         fprintf (stderr, "   - K == q8_0,   V == q8_0,   8.5 BPV\n " 
715+         fprintf (stderr, "   - K == f16,    V == f16,   16.0 BPV\n " 
642716        fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n " 
643717        GGML_ABORT (" fatal error" 
644718    } else  {
0 commit comments