@@ -112,7 +112,31 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
112112 }
113113
114114#endif
115- ggml_gemv_q4_0_8x8_q8_0_generic (n, s, bs, vx, vy, nr, nc);
115+ {
116+ float sumf[8 ];
117+ int sumi;
118+
119+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
120+ for (int x = 0 ; x < nc / ncols_interleaved; x++) {
121+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
122+
123+ for (int j = 0 ; j < ncols_interleaved; j++) sumf[j] = 0.0 ;
124+ for (int l = 0 ; l < nb; l++) {
125+ for (int k = 0 ; k < (qk / (2 * blocklen)); k++) {
126+ for (int j = 0 ; j < ncols_interleaved; j++) {
127+ sumi = 0 ;
128+ for (int i = 0 ; i < blocklen; ++i) {
129+ const int v0 = (int8_t ) (b_ptr[l].qs [k * ncols_interleaved * blocklen + j * blocklen + i] << 4 );
130+ const int v1 = (int8_t ) (b_ptr[l].qs [k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0 );
131+ sumi += ((v0 * a_ptr[l].qs [k * blocklen + i]) + (v1 * a_ptr[l].qs [k * blocklen + i + qk / 2 ])) >> 4 ;
132+ }
133+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32 (b_ptr[l].d [j]) * GGML_CPU_FP16_TO_FP32 (a_ptr[l].d );
134+ }
135+ }
136+ }
137+ for (int j = 0 ; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
138+ }
139+ }
116140}
117141
118142void ggml_gemm_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -337,6 +361,37 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
337361 return ;
338362 }
339363
340- #endif
341- ggml_gemm_q4_0_8x8_q8_0_generic (n, s, bs, vx, vy, nr, nc);
364+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
365+ float sumf[4 ][8 ];
366+ int sumi;
367+
368+ for (int y = 0 ; y < nr / 4 ; y++) {
369+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
370+ for (int x = 0 ; x < nc / ncols_interleaved; x++) {
371+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
372+ for (int m = 0 ; m < 4 ; m++) {
373+ for (int j = 0 ; j < ncols_interleaved; j++) sumf[m][j] = 0.0 ;
374+ }
375+ for (int l = 0 ; l < nb; l++) {
376+ for (int k = 0 ; k < (qk / (2 * blocklen)); k++) {
377+ for (int m = 0 ; m < 4 ; m++) {
378+ for (int j = 0 ; j < ncols_interleaved; j++) {
379+ sumi = 0 ;
380+ for (int i = 0 ; i < blocklen; ++i) {
381+ const int v0 = (int8_t ) (b_ptr[l].qs [k * ncols_interleaved * blocklen + j * blocklen + i] << 4 );
382+ const int v1 = (int8_t ) (b_ptr[l].qs [k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0 );
383+ sumi += ((v0 * a_ptr[l].qs [k * 4 * blocklen + m * blocklen + i]) +
384+ (v1 * a_ptr[l].qs [k * 4 * blocklen + m * blocklen + i + qk / 2 * 4 ])) >> 4 ;
385+ }
386+ sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32 (b_ptr[l].d [j]) * GGML_CPU_FP16_TO_FP32 (a_ptr[l].d [m]);
387+ }
388+ }
389+ }
390+ }
391+ for (int m = 0 ; m < 4 ; m++) {
392+ for (int j = 0 ; j < ncols_interleaved; j++)
393+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
394+ }
395+ }
396+ }
342397}
0 commit comments