@@ -153,6 +153,23 @@ static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, blo
153153 }
154154}
155155
156+ static __device__ const int8_t iq4nl_index[241 ] = {
157+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 16 , 16 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
158+ 1 , 17 , 17 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 18 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 ,
159+ 3 , 3 , 3 , 3 , 3 , 3 , 19 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 20 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 ,
160+ 5 , 5 , 21 , 21 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 22 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 23 , 23 , 8 , 8 , 8 , 8 ,
161+ 8 , 8 , 8 , 8 , 8 , 8 , 24 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 25 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 26 , 26 ,
162+ 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 27 , 27 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 28 , 13 , 13 , 13 ,
163+ 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 29 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 ,
164+ 14 , 14 , 14 , 14 , 30 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15
165+ };
166+ static __device__ __forceinline__ int best_index_iq4nl (const int8_t * values, float x) {
167+ int ix = (int )x - values[0 ];
168+ if (ix < 0 || ix >= 241 ) return ix < 0 ? 0 : 15 ;
169+ ix = iq4nl_index[ix];
170+ return ix < 16 ? ix : x - values[ix-16 ] < values[ix-15 ] - x ? ix-16 : ix-15 ;
171+ }
172+
156173static __device__ void quantize_f32_iq4_nl_block (const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
157174 float amax = 0 .0f ;
158175 float vmax = 0 .0f ;
@@ -172,8 +189,8 @@ static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, b
172189 for (int j = 0 ; j < QK4_NL/2 ; ++j) {
173190 const float x0 = x[0 + j]*id;
174191 const float x1 = x[QK4_NL/2 + j]*id;
175- const uint8_t xi0 = best_index_int8 ( 16 , kvalues_iq4nl, x0);
176- const uint8_t xi1 = best_index_int8 ( 16 , kvalues_iq4nl, x1);
192+ const uint8_t xi0 = best_index_iq4nl ( kvalues_iq4nl, x0);
193+ const uint8_t xi1 = best_index_iq4nl ( kvalues_iq4nl, x1);
177194 y->qs [j] = xi0 | (xi1 << 4 );
178195 const float v0 = kvalues_iq4nl[xi0];
179196 const float v1 = kvalues_iq4nl[xi1];
0 commit comments