@@ -210,138 +210,145 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
210210#endif
211211}
212212
213- void ggml_quantize_mat_q8_K_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
213+ void ggml_quantize_mat_q8_K_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc ) {
214214 assert (QK_K == 256 );
215215 assert (k % QK_K == 0 );
216+ UNUSED (nc);
216217 const int nb = k / QK_K;
217218
218219 block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
219220
220221#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
221- const int blck_size_interleave = 8 ;
222- float32x4_t srcv[4 ][64 ]; // 64 = QK_K/4
223- float iscale[4 ];
224-
225- for (int i = 0 ; i < nb; i++) {
226- float32x4_t asrcv[64 ];
227- float32x4_t amaxv[64 ];
228-
229- // d:
230- for (int row_iter = 0 ; row_iter < 4 ; row_iter++) {
231- for (int j = 0 ; j < 64 ; j++) srcv[row_iter][j] = vld1q_f32 (x + row_iter * k + i * 256 + 4 * j);
232- for (int j = 0 ; j < 64 ; j++) asrcv[j] = vabsq_f32 (srcv[row_iter][j]);
233-
234- for (int j = 0 ; j < 32 ; j++) amaxv[2 * j] = vmaxq_f32 (asrcv[2 * j], asrcv[2 * j + 1 ]);
235- for (int j = 0 ; j < 16 ; j++) amaxv[4 * j] = vmaxq_f32 (amaxv[4 * j], amaxv[4 * j + 2 ]);
236- for (int j = 0 ; j < 8 ; j++) amaxv[8 * j] = vmaxq_f32 (amaxv[8 * j], amaxv[8 * j + 4 ]);
237- for (int j = 0 ; j < 4 ; j++) amaxv[16 * j] = vmaxq_f32 (amaxv[16 * j], amaxv[16 * j + 8 ]);
238- for (int j = 0 ; j < 2 ; j++) amaxv[32 * j] = vmaxq_f32 (amaxv[32 * j], amaxv[32 * j + 16 ]);
239- for (int j = 0 ; j < 1 ; j++) amaxv[64 * j] = vmaxq_f32 (amaxv[64 * j], amaxv[64 * j + 32 ]);
222+ if (nc % 8 == 0 ) {
223+ UNUSED (nb);
224+ UNUSED (y);
225+ ggml_quantize_mat_q8_K_4x8_generic (x, vy, k, nc);
226+ } else if (nc % 4 == 0 ) {
227+ const int blck_size_interleave = 8 ;
228+ float32x4_t srcv[4 ][64 ]; // 64 = QK_K/4
229+ float iscale[4 ];
230+
231+ for (int i = 0 ; i < nb; i++) {
232+ float32x4_t asrcv[64 ];
233+ float32x4_t amaxv[64 ];
234+
235+ // d:
236+ for (int row_iter = 0 ; row_iter < 4 ; row_iter++) {
237+ for (int j = 0 ; j < 64 ; j++) srcv[row_iter][j] = vld1q_f32 (x + row_iter * k + i * 256 + 4 * j);
238+ for (int j = 0 ; j < 64 ; j++) asrcv[j] = vabsq_f32 (srcv[row_iter][j]);
239+
240+ for (int j = 0 ; j < 32 ; j++) amaxv[2 * j] = vmaxq_f32 (asrcv[2 * j], asrcv[2 * j + 1 ]);
241+ for (int j = 0 ; j < 16 ; j++) amaxv[4 * j] = vmaxq_f32 (amaxv[4 * j], amaxv[4 * j + 2 ]);
242+ for (int j = 0 ; j < 8 ; j++) amaxv[8 * j] = vmaxq_f32 (amaxv[8 * j], amaxv[8 * j + 4 ]);
243+ for (int j = 0 ; j < 4 ; j++) amaxv[16 * j] = vmaxq_f32 (amaxv[16 * j], amaxv[16 * j + 8 ]);
244+ for (int j = 0 ; j < 2 ; j++) amaxv[32 * j] = vmaxq_f32 (amaxv[32 * j], amaxv[32 * j + 16 ]);
245+ for (int j = 0 ; j < 1 ; j++) amaxv[64 * j] = vmaxq_f32 (amaxv[64 * j], amaxv[64 * j + 32 ]);
246+
247+ const float amax = vmaxvq_f32 (amaxv[0 ]);
248+
249+ // Check if exists: orig == amax
250+ float32x4_t amax_vec = vdupq_n_f32 (amax);
251+ uint32x4_t mask_all = vdupq_n_u32 (0 );
252+ for (int j = 0 ; j < 64 ; j++) {
253+ uint32x4_t mask_curr = vceqq_f32 (amax_vec, srcv[row_iter][j]);
254+ mask_all = vorrq_u32 (mask_all, mask_curr);
255+ }
240256
241- const float amax = vmaxvq_f32 (amaxv[0 ]);
257+ // Assume that none == amax, then check mask_all to reverse
258+ iscale[row_iter] = ( amax != 0 .0f ) ? 127 .f / amax : 0 .0f ;
259+ uint32x4_t cmp = vceqq_u32 (mask_all, vdupq_n_u32 (0xFFFFFFFFu ));
260+ if (vmaxvq_u32 (cmp) != 0 ) {
261+ iscale[row_iter] = ( amax != 0 .0f ) ? -127 .f / amax : 0 .0f ;
262+ }
242263
243- // Check if exists: orig == amax
244- float32x4_t amax_vec = vdupq_n_f32 (amax);
245- uint32x4_t mask_all = vdupq_n_u32 (0 );
246- for (int j = 0 ; j < 64 ; j++) {
247- uint32x4_t mask_curr = vceqq_f32 (amax_vec, srcv[row_iter][j]);
248- mask_all = vorrq_u32 (mask_all, mask_curr);
264+ y[i].d [row_iter] = amax ? 1 /iscale[row_iter] : 0 ;
249265 }
250266
251- // Assume that none == amax, then check mask_all to reverse
252- iscale[row_iter] = ( amax != 0 .0f ) ? 127 .f / amax : 0 .0f ;
253- uint32x4_t cmp = vceqq_u32 (mask_all, vdupq_n_u32 (0xFFFFFFFFu ));
254- if (vmaxvq_u32 (cmp) != 0 ) {
255- iscale[row_iter] = ( amax != 0 .0f ) ? -127 .f / amax : 0 .0f ;
267+ // qs: 8 byte interleave over 4 rows, loop = QK_K/8
268+ // bsums: simply generated one by one, row_i is calculated before row_i+1
269+ // loops = 16
270+ for (int j = 0 ; j < QK_K / blck_size_interleave / 2 ; j++) {
271+ // Process row0 and row1
272+ float32x4_t f0_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j], iscale[0 ]));
273+ float32x4_t f0_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j + 1 ], iscale[0 ]));
274+ float32x4_t f0_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j + 2 ], iscale[0 ]));
275+ float32x4_t f0_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j + 3 ], iscale[0 ]));
276+ int32x4_t i0_0_3 = vcvtnq_s32_f32 (f0_0_3);
277+ int32x4_t i0_4_7 = vcvtnq_s32_f32 (f0_4_7);
278+ int16x8_t i0_0_7 = vcombine_s16 (vqmovn_s32 (i0_0_3), vqmovn_s32 (i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
279+ int32x4_t i0_8_11 = vcvtnq_s32_f32 (f0_8_11);
280+ int32x4_t i0_12_15 = vcvtnq_s32_f32 (f0_12_15);
281+ int16x8_t i0_8_15 = vcombine_s16 (vqmovn_s32 (i0_8_11), vqmovn_s32 (i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
282+
283+ float32x4_t f1_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j], iscale[1 ]));
284+ float32x4_t f1_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j + 1 ], iscale[1 ]));
285+ float32x4_t f1_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j + 2 ], iscale[1 ]));
286+ float32x4_t f1_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j + 3 ], iscale[1 ]));
287+ int32x4_t i1_0_3 = vcvtnq_s32_f32 (f1_0_3);
288+ int32x4_t i1_4_7 = vcvtnq_s32_f32 (f1_4_7);
289+ int16x8_t i1_0_7 = vcombine_s16 (vqmovn_s32 (i1_0_3), vqmovn_s32 (i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
290+ int32x4_t i1_8_11 = vcvtnq_s32_f32 (f1_8_11);
291+ int32x4_t i1_12_15 = vcvtnq_s32_f32 (f1_12_15);
292+ int16x8_t i1_8_15 = vcombine_s16 (vqmovn_s32 (i1_8_11), vqmovn_s32 (i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
293+
294+ // Calculate and store qs
295+ int8x16_t i0_i1_0_7 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
296+ int8x16_t i0_i1_8_15 = vcombine_s8 (vqmovn_s16 (i0_8_15), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
297+ vst1q_s8 (y[i].qs + 64 * j, i0_i1_0_7);
298+ vst1q_s8 (y[i].qs + 64 * j + 32 , i0_i1_8_15);
299+ // Calculate and store bsum
300+ int8x16_t i0_0_15 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
301+ int8x16_t i1_0_15 = vcombine_s8 (vqmovn_s16 (i1_0_7), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
302+ y[i].bsums [j] = vaddlvq_s8 (i0_0_15);
303+ y[i].bsums [j + 16 ] = vaddlvq_s8 (i1_0_15);
304+
305+ // Process row2 and row3
306+ f0_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j], iscale[2 ]));
307+ f0_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j + 1 ], iscale[2 ]));
308+ f0_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j + 2 ], iscale[2 ]));
309+ f0_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j + 3 ], iscale[2 ]));
310+ i0_0_3 = vcvtnq_s32_f32 (f0_0_3);
311+ i0_4_7 = vcvtnq_s32_f32 (f0_4_7);
312+ i0_0_7 = vcombine_s16 (vqmovn_s32 (i0_0_3), vqmovn_s32 (i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
313+ i0_8_11 = vcvtnq_s32_f32 (f0_8_11);
314+ i0_12_15 = vcvtnq_s32_f32 (f0_12_15);
315+ i0_8_15 = vcombine_s16 (vqmovn_s32 (i0_8_11), vqmovn_s32 (i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
316+
317+ f1_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j], iscale[3 ]));
318+ f1_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j + 1 ], iscale[3 ]));
319+ f1_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j + 2 ], iscale[3 ]));
320+ f1_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j + 3 ], iscale[3 ]));
321+ i1_0_3 = vcvtnq_s32_f32 (f1_0_3);
322+ i1_4_7 = vcvtnq_s32_f32 (f1_4_7);
323+ i1_0_7 = vcombine_s16 (vqmovn_s32 (i1_0_3), vqmovn_s32 (i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
324+ i1_8_11 = vcvtnq_s32_f32 (f1_8_11);
325+ i1_12_15 = vcvtnq_s32_f32 (f1_12_15);
326+ i1_8_15 = vcombine_s16 (vqmovn_s32 (i1_8_11), vqmovn_s32 (i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
327+
328+ // Calculate and store qs
329+ i0_i1_0_7 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
330+ i0_i1_8_15 = vcombine_s8 (vqmovn_s16 (i0_8_15), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
331+ vst1q_s8 (y[i].qs + 64 * j + 16 , i0_i1_0_7);
332+ vst1q_s8 (y[i].qs + 64 * j + 48 , i0_i1_8_15);
333+ // Calculate and store bsum
334+ i0_0_15 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
335+ i1_0_15 = vcombine_s8 (vqmovn_s16 (i1_0_7), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
336+ y[i].bsums [j + 32 ] = vaddlvq_s8 (i0_0_15);
337+ y[i].bsums [j + 48 ] = vaddlvq_s8 (i1_0_15);
256338 }
257-
258- y[i].d [row_iter] = amax ? 1 /iscale[row_iter] : 0 ;
259- }
260-
261- // qs: 8 byte interleave over 4 rows, loop = QK_K/8
262- // bsums: simply generated one by one, row_i is calculated before row_i+1
263- // loops = 16
264- for (int j = 0 ; j < QK_K / blck_size_interleave / 2 ; j++) {
265- // Process row0 and row1
266- float32x4_t f0_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j], iscale[0 ]));
267- float32x4_t f0_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j + 1 ], iscale[0 ]));
268- float32x4_t f0_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j + 2 ], iscale[0 ]));
269- float32x4_t f0_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[0 ][4 * j + 3 ], iscale[0 ]));
270- int32x4_t i0_0_3 = vcvtnq_s32_f32 (f0_0_3);
271- int32x4_t i0_4_7 = vcvtnq_s32_f32 (f0_4_7);
272- int16x8_t i0_0_7 = vcombine_s16 (vqmovn_s32 (i0_0_3), vqmovn_s32 (i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
273- int32x4_t i0_8_11 = vcvtnq_s32_f32 (f0_8_11);
274- int32x4_t i0_12_15 = vcvtnq_s32_f32 (f0_12_15);
275- int16x8_t i0_8_15 = vcombine_s16 (vqmovn_s32 (i0_8_11), vqmovn_s32 (i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
276-
277- float32x4_t f1_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j], iscale[1 ]));
278- float32x4_t f1_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j + 1 ], iscale[1 ]));
279- float32x4_t f1_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j + 2 ], iscale[1 ]));
280- float32x4_t f1_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[1 ][4 * j + 3 ], iscale[1 ]));
281- int32x4_t i1_0_3 = vcvtnq_s32_f32 (f1_0_3);
282- int32x4_t i1_4_7 = vcvtnq_s32_f32 (f1_4_7);
283- int16x8_t i1_0_7 = vcombine_s16 (vqmovn_s32 (i1_0_3), vqmovn_s32 (i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
284- int32x4_t i1_8_11 = vcvtnq_s32_f32 (f1_8_11);
285- int32x4_t i1_12_15 = vcvtnq_s32_f32 (f1_12_15);
286- int16x8_t i1_8_15 = vcombine_s16 (vqmovn_s32 (i1_8_11), vqmovn_s32 (i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
287-
288- // Calculate and store qs
289- int8x16_t i0_i1_0_7 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
290- int8x16_t i0_i1_8_15 = vcombine_s8 (vqmovn_s16 (i0_8_15), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
291- vst1q_s8 (y[i].qs + 64 * j, i0_i1_0_7);
292- vst1q_s8 (y[i].qs + 64 * j + 32 , i0_i1_8_15);
293- // Calculate and store bsum
294- int8x16_t i0_0_15 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
295- int8x16_t i1_0_15 = vcombine_s8 (vqmovn_s16 (i1_0_7), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
296- y[i].bsums [j] = vaddlvq_s8 (i0_0_15);
297- y[i].bsums [j + 16 ] = vaddlvq_s8 (i1_0_15);
298-
299- // Process row2 and row3
300- f0_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j], iscale[2 ]));
301- f0_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j + 1 ], iscale[2 ]));
302- f0_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j + 2 ], iscale[2 ]));
303- f0_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[2 ][4 * j + 3 ], iscale[2 ]));
304- i0_0_3 = vcvtnq_s32_f32 (f0_0_3);
305- i0_4_7 = vcvtnq_s32_f32 (f0_4_7);
306- i0_0_7 = vcombine_s16 (vqmovn_s32 (i0_0_3), vqmovn_s32 (i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
307- i0_8_11 = vcvtnq_s32_f32 (f0_8_11);
308- i0_12_15 = vcvtnq_s32_f32 (f0_12_15);
309- i0_8_15 = vcombine_s16 (vqmovn_s32 (i0_8_11), vqmovn_s32 (i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
310-
311- f1_0_3 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j], iscale[3 ]));
312- f1_4_7 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j + 1 ], iscale[3 ]));
313- f1_8_11 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j + 2 ], iscale[3 ]));
314- f1_12_15 = vrndnq_f32 (vmulq_n_f32 (srcv[3 ][4 * j + 3 ], iscale[3 ]));
315- i1_0_3 = vcvtnq_s32_f32 (f1_0_3);
316- i1_4_7 = vcvtnq_s32_f32 (f1_4_7);
317- i1_0_7 = vcombine_s16 (vqmovn_s32 (i1_0_3), vqmovn_s32 (i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
318- i1_8_11 = vcvtnq_s32_f32 (f1_8_11);
319- i1_12_15 = vcvtnq_s32_f32 (f1_12_15);
320- i1_8_15 = vcombine_s16 (vqmovn_s32 (i1_8_11), vqmovn_s32 (i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
321-
322- // Calculate and store qs
323- i0_i1_0_7 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
324- i0_i1_8_15 = vcombine_s8 (vqmovn_s16 (i0_8_15), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
325- vst1q_s8 (y[i].qs + 64 * j + 16 , i0_i1_0_7);
326- vst1q_s8 (y[i].qs + 64 * j + 48 , i0_i1_8_15);
327- // Calculate and store bsum
328- i0_0_15 = vcombine_s8 (vqmovn_s16 (i0_0_7), vqmovn_s16 (i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
329- i1_0_15 = vcombine_s8 (vqmovn_s16 (i1_0_7), vqmovn_s16 (i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
330- y[i].bsums [j + 32 ] = vaddlvq_s8 (i0_0_15);
331- y[i].bsums [j + 48 ] = vaddlvq_s8 (i1_0_15);
332339 }
333340 }
341+ return ;
342+ #endif
334343
335- #else
336344 // NOTE:
337345 // Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from
338346 // above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in
339347 // the process of their "[bsums] layout". Hoever, we can still reuse the x86 C impl for AArch64, as long as we access the
340348 // "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic().
341349 UNUSED (nb);
342350 UNUSED (y);
343- ggml_quantize_mat_q8_K_4x8_generic (x, vy, k);
344- #endif
351+ ggml_quantize_mat_q8_K_4x8_generic (x, vy, k, nc);
345352}
346353
347354void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
0 commit comments