@@ -225,16 +225,21 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225225    memcpy (dsti->qh , &qh, sizeof (qh));
226226}
227227
228- 
229- static  __device__  __forceinline__  int  best_index_int8 (int  n, const  int8_t  * val, float  x) {
230-     if  (x <= val[0 ]) return  0 ;
231-     if  (x >= val[n-1 ]) return  n-1 ;
232-     int  ml = 0 , mu = n-1 ;
233-     while  (mu-ml > 1 ) {
234-         int  mav = (ml+mu)/2 ;
235-         if  (x < val[mav]) mu = mav; else  ml = mav;
236-     }
237-     return  x - val[mu-1 ] < val[mu] - x ? mu-1  : mu;
228+ static  __device__  const  int8_t  iq4nl_index[241 ] = {
229+      0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 ,  0 , 16 , 16 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,
230+      1 , 17 , 17 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 ,  2 , 18 ,  3 ,  3 ,  3 ,  3 ,  3 ,  3 ,  3 ,  3 ,  3 ,  3 ,
231+      3 ,  3 ,  3 ,  3 ,  3 ,  3 , 19 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 ,  4 , 20 ,  5 ,  5 ,  5 ,  5 ,  5 ,  5 ,  5 ,  5 ,  5 ,  5 ,
232+      5 ,  5 , 21 , 21 ,  6 ,  6 ,  6 ,  6 ,  6 ,  6 ,  6 ,  6 ,  6 ,  6 ,  6 , 22 ,  7 ,  7 ,  7 ,  7 ,  7 ,  7 ,  7 ,  7 ,  7 ,  7 , 23 , 23 ,  8 ,  8 ,  8 ,  8 ,
233+      8 ,  8 ,  8 ,  8 ,  8 ,  8 , 24 ,  9 ,  9 ,  9 ,  9 ,  9 ,  9 ,  9 ,  9 ,  9 ,  9 ,  9 , 25 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 26 , 26 ,
234+     11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 27 , 27 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 28 , 13 , 13 , 13 ,
235+     13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 29 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 ,
236+     14 , 14 , 14 , 14 , 30 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 
237+ };
238+ static  __device__  __forceinline__  int  best_index_iq4nl (const  int8_t  * values, float  x) {
239+     int  ix = (int )x - values[0 ];
240+     if  (ix < 0  || ix >= 241 ) return  ix < 0  ? 0  : 15 ;
241+     ix = iq4nl_index[ix];
242+     return  ix < 16  ? ix : x - values[ix-16 ] < values[ix-15 ] - x ? ix-16  : ix-15 ;
238243}
239244
240245static  __device__  void  cpy_blck_f32_iq4_nl (const  char  * cxi, char  * cdsti) {
@@ -255,12 +260,14 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
255260    float  d = vmax / kvalues_iq4nl[0 ];
256261    const  float  id = d ? 1 .0f /d : 0 .0f ;
257262
263+     // dsti->d = d;
264+ 
258265    float  sumqx = 0 , sumq2 = 0 ;
259266    for  (int  j = 0 ; j < QK4_NL/2 ; ++j) {
260267        const  float  x0 = xi[0         + j]*id;
261268        const  float  x1 = xi[QK4_NL/2  + j]*id;
262-         const  uint8_t  xi0 = best_index_int8 ( 16 ,  kvalues_iq4nl, x0);
263-         const  uint8_t  xi1 = best_index_int8 ( 16 ,  kvalues_iq4nl, x1);
269+         const  uint8_t  xi0 = best_index_iq4nl ( kvalues_iq4nl, x0);
270+         const  uint8_t  xi1 = best_index_iq4nl ( kvalues_iq4nl, x1);
264271        dsti->qs [j] = xi0 | (xi1 << 4 );
265272        const  float  v0 = kvalues_iq4nl[xi0];
266273        const  float  v1 = kvalues_iq4nl[xi1];
0 commit comments