@@ -225,16 +225,21 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225225 memcpy (dsti->qh , &qh, sizeof (qh));
226226}
227227
228-
229- static __device__ __forceinline__ int best_index_int8 (int n, const int8_t * val, float x) {
230- if (x <= val[0 ]) return 0 ;
231- if (x >= val[n-1 ]) return n-1 ;
232- int ml = 0 , mu = n-1 ;
233- while (mu-ml > 1 ) {
234- int mav = (ml+mu)/2 ;
235- if (x < val[mav]) mu = mav; else ml = mav;
236- }
237- return x - val[mu-1 ] < val[mu] - x ? mu-1 : mu;
228+ static __device__ const int8_t iq4nl_index[241 ] = {
229+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 16 , 16 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
230+ 1 , 17 , 17 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 18 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 ,
231+ 3 , 3 , 3 , 3 , 3 , 3 , 19 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 20 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 ,
232+ 5 , 5 , 21 , 21 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 22 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 23 , 23 , 8 , 8 , 8 , 8 ,
233+ 8 , 8 , 8 , 8 , 8 , 8 , 24 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 25 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 26 , 26 ,
234+ 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 27 , 27 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 28 , 13 , 13 , 13 ,
235+ 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 29 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 ,
236+ 14 , 14 , 14 , 14 , 30 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15
237+ };
238+ static __device__ __forceinline__ int best_index_iq4nl (const int8_t * values, float x) {
239+ int ix = (int )x - values[0 ];
240+ if (ix < 0 || ix >= 241 ) return ix < 0 ? 0 : 15 ;
241+ ix = iq4nl_index[ix];
242+ return ix < 16 ? ix : x - values[ix-16 ] < values[ix-15 ] - x ? ix-16 : ix-15 ;
238243}
239244
240245static __device__ void cpy_blck_f32_iq4_nl (const char * cxi, char * cdsti) {
@@ -255,12 +260,14 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
255260 float d = vmax / kvalues_iq4nl[0 ];
256261 const float id = d ? 1 .0f /d : 0 .0f ;
257262
263+ // dsti->d = d;
264+
258265 float sumqx = 0 , sumq2 = 0 ;
259266 for (int j = 0 ; j < QK4_NL/2 ; ++j) {
260267 const float x0 = xi[0 + j]*id;
261268 const float x1 = xi[QK4_NL/2 + j]*id;
262- const uint8_t xi0 = best_index_int8 ( 16 , kvalues_iq4nl, x0);
263- const uint8_t xi1 = best_index_int8 ( 16 , kvalues_iq4nl, x1);
269+ const uint8_t xi0 = best_index_iq4nl ( kvalues_iq4nl, x0);
270+ const uint8_t xi1 = best_index_iq4nl ( kvalues_iq4nl, x1);
264271 dsti->qs [j] = xi0 | (xi1 << 4 );
265272 const float v0 = kvalues_iq4nl[xi0];
266273 const float v1 = kvalues_iq4nl[xi1];
0 commit comments