Skip to content

Commit aa4b7d2

Browse files
committed
metal: improvement for Q4_K driver
1 parent 804c78d commit aa4b7d2

File tree

1 file changed

+48
-20
lines changed

1 file changed

+48
-20
lines changed

ggml-metal.metal

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -972,25 +972,42 @@ template <typename addr_uint16_p,typename addr_block_q_p, typename type4x4>
972972
class q4_K_driver {
973973
public:
974974
uint16_t d_mask1, d_mask2, m_mask1, mask1, mask2;
975-
float coef1, coef2, sumy;
975+
float coef1, coef2, sumy1, sumy2;
976976
uint16_t d_loc1, d_loc2, m_loc1, m_loc2, q_offset;
977977

978978
void init(int il) {
979-
d_mask1 = il < 8 ? 63 : 0x0F; d_mask2 = il < 8 ? 0 : 192;
980-
d_loc1 = il < 8 ? il/2 : 4 + il/2; d_loc2 = il < 8 ? il/2 : il/2 - 4;
981-
m_mask1 = il < 8 ? 63 : 0xF0;
982-
m_loc1 = il/2 + 4; m_loc2 = il/2;
983-
mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8;
984-
coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f;
985-
#if QK_K == 256
986-
q_offset = (il/4) * 16 + 8 * (il&1);
979+
q_offset = (il/4) * 16 + 4 * (il%4);
980+
d_mask1 = il < 8 ? 0x3F3F : 0x0F0F; d_mask2 = il < 8 ? 0x0000 : 0xC0C0;
981+
d_loc1 = il < 8 ? il/4 : il/4 + 2; d_loc2 = il < 8 ? il/4 : il/4 - 2;
982+
m_mask1 = il < 8 ? 0x3F3F : 0xF0F0;
983+
m_loc1 = il/4 + 2; m_loc2 = il/4;
984+
}
985+
986+
void get_scales(addr_block_q_p xb, int il, thread float & dl1, thread float & ml1, thread float & dl2, thread float & ml2) {
987+
#if QK_K == 256
988+
const float d = (float)(xb->d);
989+
const float min = (float)(xb->dmin);
990+
addr_uint16_p sc = (addr_uint16_p)xb->scales;
991+
uint16_t d_int = (sc[d_loc1] & d_mask1) | ((sc[d_loc2] & d_mask2) >> 2);
992+
uint16_t m_int = il < 8 ? (sc[m_loc1] & m_mask1) : ((sc[m_loc1] & m_mask1) >> 4);
993+
m_int = m_int | ((sc[m_loc2] & d_mask2) >> 2);
994+
dl1 = as_type<uchar2>(d_int)[0] * d, ml1 = as_type<uchar2>(m_int)[0] * min;
995+
dl2 = as_type<uchar2>(d_int)[1] * d, ml2 = as_type<uchar2>(m_int)[1] * min;
987996
#else
988-
q_offset = 8 * (il&1);
997+
dl1 = (float)(xb->d[0]) * (xb->scales[0]&0xF); dl2 = (float)(xb->d[0]) * (xb->scales[1]&0xF);
998+
ml1 = (float)(xb->d[1]) * (xb->scales[0]>>4); ml2 = (float)(xb->d[1]) * (xb->scales[1]>>4);
989999
#endif
9901000
}
9911001

992-
void get_scales(addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
1002+
void get_scales2(addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
1003+
q_offset = (il/4) * 16 + 8 * (il&1);
1004+
mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8;
1005+
coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f;
9931006
#if QK_K == 256
1007+
d_mask1 = il < 8 ? 63 : 0x0F; d_mask2 = il < 8 ? 0 : 192;
1008+
d_loc1 = il < 8 ? il/2 : 4 + il/2; d_loc2 = il < 8 ? il/2 : il/2 - 4;
1009+
m_mask1 = il < 8 ? 63 : 0xF0;
1010+
m_loc1 = il/2 + 4; m_loc2 = il/2;
9941011
const float d = (float)(xb->d);
9951012
const float min = (float)(xb->dmin);
9961013
uint16_t d_int = (xb->scales[d_loc1] & d_mask1) | ((xb->scales[d_loc2] & d_mask2) >> 2);
@@ -1004,23 +1021,34 @@ class q4_K_driver {
10041021
}
10051022

10061023
void inner_product_pre(int il, thread float4x4 & yl){
1007-
fix_y_v2(coef1, coef2, sumy, yl);
1024+
sumy1 = 0.f; sumy2 = 0.f;
1025+
for (int i = 0; i < 8; i += 2) {
1026+
sumy1 += yl[i/4 ][i%4]; sumy1 += yl[i/4 ][i%4+1];
1027+
sumy2 += yl[2+i/4][i%4]; sumy2 += yl[2+i/4][i%4+1];
1028+
yl[i/4 ][i%4 ] = yl[i/4][i%4];
1029+
yl[i/4 ][i%4+1] = 1/256.f * yl[i/4][i%4+1];
1030+
yl[i/4+2][i%4 ] = 1/16.f * yl[2+i/4][i%4];
1031+
yl[i/4+2][i%4+1] = 1/4096.f * yl[2+i/4][i%4+1];
1032+
}
10081033
}
10091034

10101035
void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){
1011-
float dl, ml;
1012-
get_scales(xb, il, dl, ml);
1036+
float dl1, ml1, dl2, ml2;
1037+
float sum2 = 0.f;
1038+
get_scales(xb, il, dl1, ml1, dl2, ml2);
10131039
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
1014-
for (int i = 0; i < 16; i += 2) {
1015-
sum += yl[i/4][i%4] * (q[i/2] & mask1);
1016-
sum += yl[i/4][i%4+1] * (q[i/2] & mask2);
1040+
for (int i = 0; i < 8; i += 2) {
1041+
sum += yl[i/4 ][i%4 ] * ((q[i/2]&0x000F));
1042+
sum += yl[i/4 ][i%4+1] * ((q[i/2]&0x0F00));
1043+
sum2 += yl[i/4+2][i%4 ] * ((q[i/2]&0x00F0));
1044+
sum2 += yl[i/4+2][i%4+1] * ((q[i/2]&0xF000));
10171045
}
1018-
sum = dl * sum - ml * sumy;
1046+
sum = dl1 * sum - ml1 * sumy1 + dl2 * sum2 - ml2 * sumy2;
10191047
}
10201048

10211049
void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) {
10221050
float dl, ml;
1023-
get_scales(xb, il, dl, ml);
1051+
get_scales2(xb, il, dl, ml);
10241052
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
10251053
for (int i = 0; i < 16; i += 2) {
10261054
reg[i/4][i%4] = coef1 * dl * (q[i/2] & mask1) - ml;
@@ -1465,7 +1493,7 @@ template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv<b
14651493
template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv<block_q8_0, N_DST, N_SIMDGROUP, 2, 8, q8_0_driver>;
14661494
template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q2_K, N_DST, N_SIMDGROUP, QK_NL, 8, q2_K_driver>;
14671495
template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q3_K, N_DST, N_SIMDGROUP, QK_NL, 8, q3_K_driver>;
1468-
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 8, q4_K_driver>;
1496+
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 32, q4_K_driver>;
14691497
template [[host_name("kernel_mul_mv_q5_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q5_K, N_DST, N_SIMDGROUP, QK_NL, 8, q5_K_driver>;
14701498
#if QK_K == 256
14711499
template [[host_name("kernel_mul_mv_q6_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q6_K, N_DST, N_SIMDGROUP, QK_NL, 64, q6_K_driver>;

0 commit comments

Comments
 (0)