Skip to content

Commit 6c758bb

Browse files
committed
Address q2k comments
1 parent 6f99895 commit 6c758bb

File tree

1 file changed

+2
-121
lines changed

1 file changed

+2
-121
lines changed

ggml/src/ggml-cpu/arch/x86/repack.cpp

Lines changed: 2 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,61 +1298,8 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
12981298
}
12991299
#else
13001300

1301-
float sumf[8];
1302-
float sum_minf[8];
1303-
int sumi1,sumi2,sumi3,sumi4;
1304-
int sumi;
1301+
ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
13051302

1306-
const block_q8_K * a_ptr = (const block_q8_K *)vy;
1307-
for(int x = 0; x < nc / ncols_interleaved; x++) {
1308-
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
1309-
for (int j = 0; j < ncols_interleaved; j++) {
1310-
sumf[j] = 0.0;
1311-
sum_minf[j] = 0.0;
1312-
}
1313-
for (int l = 0; l < nb; l++) {
1314-
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
1315-
uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
1316-
uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
1317-
uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
1318-
uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
1319-
for (int j = 0; j < ncols_interleaved; j++) {
1320-
sumi1 = 0;
1321-
sumi2 = 0;
1322-
sumi3 = 0;
1323-
sumi4 = 0;
1324-
sumi = 0;
1325-
int offset = ((k / 2) % 2) + j * 2;
1326-
for (int i = 0; i < blocklen; ++i){
1327-
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
1328-
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
1329-
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
1330-
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
1331-
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
1332-
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
1333-
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
1334-
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
1335-
1336-
sumi1 = sumi1 * (scales_0[offset] & 0xF);
1337-
sumi2 = sumi2 * (scales_1[offset] & 0xF);
1338-
sumi3 = sumi3 * (scales_2[offset] & 0xF);
1339-
sumi4 = sumi4 * (scales_3[offset] & 0xF);
1340-
sumi += sumi1 + sumi2 + sumi3 + sumi4;
1341-
}
1342-
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
1343-
}
1344-
}
1345-
for(int sb = 0; sb < 8; sb++) {
1346-
uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
1347-
for(int j = 0; j < ncols_interleaved; j++){
1348-
sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
1349-
}
1350-
}
1351-
}
1352-
for (int j = 0; j < ncols_interleaved; j++) {
1353-
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
1354-
}
1355-
}
13561303
#endif
13571304
}
13581305

@@ -6527,74 +6474,8 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
65276474
}
65286475
#else
65296476

6530-
float sumf[4][8];
6531-
float sum_minf[4][8];
6532-
int sumi1, sumi2, sumi3, sumi4;
6533-
int sumi;
6477+
ggml_gemm_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
65346478

6535-
for (int y = 0; y < nr / 4; y++) {
6536-
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
6537-
for (int x = 0; x < nc / ncols_interleaved; x++) {
6538-
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
6539-
for (int m = 0; m < 4; m++) {
6540-
for (int j = 0; j < ncols_interleaved; j++) {
6541-
sumf[m][j] = 0.0;
6542-
sum_minf[m][j] = 0.0;
6543-
}
6544-
}
6545-
for (int l = 0; l < nb; l++) {
6546-
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
6547-
6548-
uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
6549-
uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
6550-
uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
6551-
uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
6552-
for (int m = 0; m < 4; m++) {
6553-
for (int j = 0; j < ncols_interleaved; j++) {
6554-
sumi1 = 0;
6555-
sumi2 = 0;
6556-
sumi3 = 0;
6557-
sumi4 = 0;
6558-
sumi = 0;
6559-
int offset = ((k / 2) % 2) + j * 2;
6560-
for (int i = 0; i < blocklen; ++i){
6561-
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x03);
6562-
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 0x03);
6563-
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 0x03);
6564-
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 0x03);
6565-
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
6566-
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
6567-
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
6568-
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
6569-
sumi1 = sumi1 * (scales_0[offset] & 0xF);
6570-
sumi2 = sumi2 * (scales_1[offset] & 0xF);
6571-
sumi3 = sumi3 * (scales_2[offset] & 0xF);
6572-
sumi4 = sumi4 * (scales_3[offset] & 0xF);
6573-
sumi += sumi1 + sumi2 + sumi3 + sumi4;
6574-
}
6575-
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
6576-
}
6577-
}
6578-
}
6579-
for(int sb = 0; sb < 8; sb++) {
6580-
uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
6581-
for(int m = 0; m < 4; m++) {
6582-
const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
6583-
for(int j = 0; j < ncols_interleaved; j++) {
6584-
int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
6585-
sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
6586-
}
6587-
}
6588-
}
6589-
}
6590-
6591-
for (int m = 0; m < 4; m++) {
6592-
for (int j = 0; j < ncols_interleaved; j++) {
6593-
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
6594-
}
6595-
}
6596-
}
6597-
}
65986479

65996480
#endif
66006481
}

0 commit comments

Comments
 (0)