@@ -751,6 +751,7 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n,
751751 const int ncols_interleaved = 4 ;
752752 const int blocklen = 8 ;
753753
754+ assert (nr == 1 );
754755 assert (n % qk == 0 );
755756 assert (nc % ncols_interleaved == 0 );
756757
@@ -1328,38 +1329,36 @@ void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
13281329 assert (nr % 4 == 0 );
13291330 assert (nc % ncols_interleaved == 0 );
13301331
1331- {
1332- float sumf[4 ][4 ];
1333- int sumi;
1332+ float sumf[4 ][4 ];
1333+ int sumi;
13341334
1335- for (int y = 0 ; y < nr / 4 ; y++) {
1336- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1337- for (int x = 0 ; x < nc / ncols_interleaved; x++) {
1338- const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1339- for (int m = 0 ; m < 4 ; m++) {
1340- for (int j = 0 ; j < ncols_interleaved; j++) {
1341- sumf[m][j] = 0.0 ;
1342- }
1335+ for (int y = 0 ; y < nr / 4 ; y++) {
1336+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1337+ for (int x = 0 ; x < nc / ncols_interleaved; x++) {
1338+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1339+ for (int m = 0 ; m < 4 ; m++) {
1340+ for (int j = 0 ; j < ncols_interleaved; j++) {
1341+ sumf[m][j] = 0.0 ;
13431342 }
1344- for (int l = 0 ; l < nb; l++) {
1345- for (int k = 0 ; k < (qk / blocklen); k++) {
1346- for (int m = 0 ; m < 4 ; m++) {
1347- for (int j = 0 ; j < ncols_interleaved; j++) {
1348- sumi = 0 ;
1349- for (int i = 0 ; i < blocklen; ++i) {
1350- const int v0 = b_ptr[l].qs [k * ncols_interleaved * blocklen + j * blocklen + i];
1351- sumi += v0 * a_ptr[l].qs [k * 4 * blocklen + m * blocklen + i];
1352- }
1353- sumf[m][j] +=
1354- sumi * GGML_CPU_FP16_TO_FP32 (b_ptr[l].d [j]) * GGML_CPU_FP16_TO_FP32 (a_ptr[l].d [m]);
1343+ }
1344+ for (int l = 0 ; l < nb; l++) {
1345+ for (int k = 0 ; k < (qk / blocklen); k++) {
1346+ for (int m = 0 ; m < 4 ; m++) {
1347+ for (int j = 0 ; j < ncols_interleaved; j++) {
1348+ sumi = 0 ;
1349+ for (int i = 0 ; i < blocklen; ++i) {
1350+ const int v0 = b_ptr[l].qs [k * ncols_interleaved * blocklen + j * blocklen + i];
1351+ sumi += v0 * a_ptr[l].qs [k * 4 * blocklen + m * blocklen + i];
13551352 }
1353+ sumf[m][j] +=
1354+ sumi * GGML_CPU_FP16_TO_FP32 (b_ptr[l].d [j]) * GGML_CPU_FP16_TO_FP32 (a_ptr[l].d [m]);
13561355 }
13571356 }
13581357 }
1359- for ( int m = 0 ; m < 4 ; m++) {
1360- for (int j = 0 ; j < ncols_interleaved; j ++) {
1361- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1362- }
1358+ }
1359+ for (int m = 0 ; m < 4 ; m ++) {
1360+ for ( int j = 0 ; j < ncols_interleaved; j++) {
1361+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
13631362 }
13641363 }
13651364 }
0 commit comments