Skip to content

Commit 126ce2c

Browse files
committed
fix compatibility with other q4_k repacking models
1 parent da606bd commit 126ce2c

File tree

3 files changed

+134
-125
lines changed

3 files changed

+134
-125
lines changed

ggml/src/ggml-cpu/arch/arm/repack.cpp

Lines changed: 117 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -210,138 +210,145 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
210210
#endif
211211
}
212212

213-
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
213+
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
214214
assert(QK_K == 256);
215215
assert(k % QK_K == 0);
216+
UNUSED(nc);
216217
const int nb = k / QK_K;
217218

218219
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
219220

220221
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
221-
const int blck_size_interleave = 8;
222-
float32x4_t srcv[4][64]; // 64 = QK_K/4
223-
float iscale[4];
224-
225-
for (int i = 0; i < nb; i++) {
226-
float32x4_t asrcv[64];
227-
float32x4_t amaxv[64];
228-
229-
// d:
230-
for (int row_iter = 0; row_iter < 4; row_iter++) {
231-
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
232-
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
233-
234-
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
235-
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
236-
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
237-
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
238-
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
239-
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
222+
if (nc % 8 == 0) {
223+
UNUSED(nb);
224+
UNUSED(y);
225+
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
226+
} else if (nc % 4 == 0) {
227+
const int blck_size_interleave = 8;
228+
float32x4_t srcv[4][64]; // 64 = QK_K/4
229+
float iscale[4];
230+
231+
for (int i = 0; i < nb; i++) {
232+
float32x4_t asrcv[64];
233+
float32x4_t amaxv[64];
234+
235+
// d:
236+
for (int row_iter = 0; row_iter < 4; row_iter++) {
237+
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
238+
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
239+
240+
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
241+
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
242+
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
243+
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
244+
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
245+
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
246+
247+
const float amax = vmaxvq_f32(amaxv[0]);
248+
249+
// Check if exists: orig == amax
250+
float32x4_t amax_vec = vdupq_n_f32(amax);
251+
uint32x4_t mask_all = vdupq_n_u32(0);
252+
for (int j = 0; j < 64; j++) {
253+
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
254+
mask_all = vorrq_u32(mask_all, mask_curr);
255+
}
240256

241-
const float amax = vmaxvq_f32(amaxv[0]);
257+
// Assume that none == amax, then check mask_all to reverse
258+
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
259+
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
260+
if (vmaxvq_u32(cmp) != 0) {
261+
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
262+
}
242263

243-
// Check if exists: orig == amax
244-
float32x4_t amax_vec = vdupq_n_f32(amax);
245-
uint32x4_t mask_all = vdupq_n_u32(0);
246-
for (int j = 0; j < 64; j++) {
247-
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
248-
mask_all = vorrq_u32(mask_all, mask_curr);
264+
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
249265
}
250266

251-
// Assume that none == amax, then check mask_all to reverse
252-
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
253-
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
254-
if (vmaxvq_u32(cmp) != 0) {
255-
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
267+
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
268+
// bsums: simply generated one by one, row_i is calculated before row_i+1
269+
// loops = 16
270+
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
271+
// Process row0 and row1
272+
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
273+
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
274+
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
275+
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
276+
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
277+
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
278+
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
279+
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
280+
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
281+
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
282+
283+
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
284+
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
285+
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
286+
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
287+
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
288+
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
289+
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
290+
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
291+
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
292+
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
293+
294+
// Calculate and store qs
295+
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
296+
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
297+
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
298+
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
299+
// Calculate and store bsum
300+
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
301+
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
302+
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
303+
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
304+
305+
// Process row2 and row3
306+
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
307+
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
308+
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
309+
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
310+
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
311+
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
312+
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
313+
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
314+
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
315+
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
316+
317+
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
318+
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
319+
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
320+
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
321+
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
322+
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
323+
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
324+
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
325+
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
326+
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
327+
328+
// Calculate and store qs
329+
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
330+
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
331+
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
332+
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
333+
// Calculate and store bsum
334+
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
335+
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
336+
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
337+
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
256338
}
257-
258-
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
259-
}
260-
261-
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
262-
// bsums: simply generated one by one, row_i is calculated before row_i+1
263-
// loops = 16
264-
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
265-
// Process row0 and row1
266-
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
267-
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
268-
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
269-
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
270-
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
271-
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
272-
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
273-
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
274-
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
275-
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
276-
277-
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
278-
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
279-
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
280-
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
281-
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
282-
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
283-
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
284-
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
285-
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
286-
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
287-
288-
// Calculate and store qs
289-
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
290-
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
291-
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
292-
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
293-
// Calculate and store bsum
294-
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
295-
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
296-
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
297-
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
298-
299-
// Process row2 and row3
300-
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
301-
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
302-
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
303-
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
304-
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
305-
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
306-
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
307-
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
308-
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
309-
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
310-
311-
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
312-
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
313-
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
314-
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
315-
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
316-
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
317-
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
318-
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
319-
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
320-
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
321-
322-
// Calculate and store qs
323-
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
324-
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
325-
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
326-
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
327-
// Calculate and store bsum
328-
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
329-
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
330-
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
331-
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
332339
}
333340
}
341+
return;
342+
#endif
334343

335-
#else
336344
// NOTE:
337345
// Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from
338346
// above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in
339347
// the process of their "[bsums] layout". Hoever, we can still reuse the x86 C impl for AArch64, as long as we access the
340348
// "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic().
341349
UNUSED(nb);
342350
UNUSED(y);
343-
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
344-
#endif
351+
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
345352
}
346353

347354
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {

ggml/src/ggml-cpu/repack.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,10 @@ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GG
176176
}
177177
}
178178

179-
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
179+
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
180180
assert(QK_K == 256);
181181
assert(k % QK_K == 0);
182+
UNUSED(nc);
182183
const int nb = k / QK_K;
183184

184185
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
@@ -230,30 +231,33 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
230231
} // extern "C"
231232

232233
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
233-
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
234+
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols);
234235

235-
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
236+
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
236237
assert(nrow == 4);
237238
UNUSED(nrow);
239+
UNUSED(ncols);
238240
ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
239241
}
240242

241-
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
243+
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
242244
assert(nrow == 4);
243245
UNUSED(nrow);
246+
UNUSED(ncols);
244247
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
245248
}
246249

247-
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
250+
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
248251
assert(nrow == 4);
249252
UNUSED(nrow);
253+
UNUSED(ncols);
250254
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
251255
}
252256

253-
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
257+
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
254258
assert(nrow == 4);
255259
UNUSED(nrow);
256-
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
260+
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row, ncols);
257261
}
258262

259263
extern "C" {
@@ -2502,7 +2506,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
25022506

25032507
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
25042508
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
2505-
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
2509+
(void *) (wdata_ptr + i11 * nbw1), 4, ne10, ne01);
25062510
}
25072511

25082512
const int64_t i11_processed = ne11 - ne11 % 4;
@@ -2775,15 +2779,13 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
27752779
return &q4_K_8x8_q8_K;
27762780
}
27772781
}
2778-
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { // new for ARM N2
2779-
if (cur->ne[1] % 4 == 0) {
2780-
return &q4_K_4x8_q8_K;
2781-
}
2782-
}
27832782
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
27842783
if (cur->ne[1] % 8 == 0) {
27852784
return &q4_K_8x8_q8_K;
27862785
}
2786+
if (cur->ne[1] % 4 == 0) {
2787+
return &q4_K_4x8_q8_K;
2788+
}
27872789
}
27882790
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
27892791
if (cur->ne[1] % 8 == 0) {

0 commit comments

Comments
 (0)