@@ -972,25 +972,42 @@ template <typename addr_uint16_p,typename addr_block_q_p, typename type4x4>
972
972
class q4_K_driver {
973
973
public:
974
974
uint16_t d_mask1, d_mask2, m_mask1, mask1, mask2;
975
- float coef1, coef2, sumy ;
975
+ float coef1, coef2, sumy1, sumy2 ;
976
976
uint16_t d_loc1, d_loc2, m_loc1, m_loc2, q_offset;
977
977
978
978
void init (int il) {
979
- d_mask1 = il < 8 ? 63 : 0x0F ; d_mask2 = il < 8 ? 0 : 192 ;
980
- d_loc1 = il < 8 ? il/2 : 4 + il/2 ; d_loc2 = il < 8 ? il/2 : il/2 - 4 ;
981
- m_mask1 = il < 8 ? 63 : 0xF0 ;
982
- m_loc1 = il/2 + 4 ; m_loc2 = il/2 ;
983
- mask1 = (il%4 ) < 2 ? 0x000F : 0x00F0 ; mask2 = mask1 << 8 ;
984
- coef1 = (il%4 ) < 2 ? 1 .f : 1 /16 .f ; coef2 = coef1 / 256 .f ;
985
- #if QK_K == 256
986
- q_offset = (il/4 ) * 16 + 8 * (il&1 );
979
+ q_offset = (il/4 ) * 16 + 4 * (il%4 );
980
+ d_mask1 = il < 8 ? 0x3F3F : 0x0F0F ; d_mask2 = il < 8 ? 0x0000 : 0xC0C0 ;
981
+ d_loc1 = il < 8 ? il/4 : il/4 + 2 ; d_loc2 = il < 8 ? il/4 : il/4 - 2 ;
982
+ m_mask1 = il < 8 ? 0x3F3F : 0xF0F0 ;
983
+ m_loc1 = il/4 + 2 ; m_loc2 = il/4 ;
984
+ }
985
+
986
+ void get_scales (addr_block_q_p xb, int il, thread float & dl1, thread float & ml1, thread float & dl2, thread float & ml2) {
987
+ #if QK_K == 256
988
+ const float d = (float )(xb->d );
989
+ const float min = (float )(xb->dmin );
990
+ addr_uint16_p sc = (addr_uint16_p)xb->scales ;
991
+ uint16_t d_int = (sc[d_loc1] & d_mask1) | ((sc[d_loc2] & d_mask2) >> 2 );
992
+ uint16_t m_int = il < 8 ? (sc[m_loc1] & m_mask1) : ((sc[m_loc1] & m_mask1) >> 4 );
993
+ m_int = m_int | ((sc[m_loc2] & d_mask2) >> 2 );
994
+ dl1 = as_type<uchar2>(d_int)[0 ] * d, ml1 = as_type<uchar2>(m_int)[0 ] * min;
995
+ dl2 = as_type<uchar2>(d_int)[1 ] * d, ml2 = as_type<uchar2>(m_int)[1 ] * min;
987
996
#else
988
- q_offset = 8 * (il&1 );
997
+ dl1 = (float )(xb->d [0 ]) * (xb->scales [0 ]&0xF ); dl2 = (float )(xb->d [0 ]) * (xb->scales [1 ]&0xF );
998
+ ml1 = (float )(xb->d [1 ]) * (xb->scales [0 ]>>4 ); ml2 = (float )(xb->d [1 ]) * (xb->scales [1 ]>>4 );
989
999
#endif
990
1000
}
991
1001
992
- void get_scales (addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
1002
+ void get_scales2 (addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
1003
+ q_offset = (il/4 ) * 16 + 8 * (il&1 );
1004
+ mask1 = (il%4 ) < 2 ? 0x000F : 0x00F0 ; mask2 = mask1 << 8 ;
1005
+ coef1 = (il%4 ) < 2 ? 1 .f : 1 /16 .f ; coef2 = coef1 / 256 .f ;
993
1006
#if QK_K == 256
1007
+ d_mask1 = il < 8 ? 63 : 0x0F ; d_mask2 = il < 8 ? 0 : 192 ;
1008
+ d_loc1 = il < 8 ? il/2 : 4 + il/2 ; d_loc2 = il < 8 ? il/2 : il/2 - 4 ;
1009
+ m_mask1 = il < 8 ? 63 : 0xF0 ;
1010
+ m_loc1 = il/2 + 4 ; m_loc2 = il/2 ;
994
1011
const float d = (float )(xb->d );
995
1012
const float min = (float )(xb->dmin );
996
1013
uint16_t d_int = (xb->scales [d_loc1] & d_mask1) | ((xb->scales [d_loc2] & d_mask2) >> 2 );
@@ -1004,23 +1021,34 @@ class q4_K_driver {
1004
1021
}
1005
1022
1006
1023
void inner_product_pre (int il, thread float4x4 & yl){
1007
- fix_y_v2 (coef1, coef2, sumy, yl);
1024
+ sumy1 = 0 .f ; sumy2 = 0 .f ;
1025
+ for (int i = 0 ; i < 8 ; i += 2 ) {
1026
+ sumy1 += yl[i/4 ][i%4 ]; sumy1 += yl[i/4 ][i%4 +1 ];
1027
+ sumy2 += yl[2 +i/4 ][i%4 ]; sumy2 += yl[2 +i/4 ][i%4 +1 ];
1028
+ yl[i/4 ][i%4 ] = yl[i/4 ][i%4 ];
1029
+ yl[i/4 ][i%4 +1 ] = 1 /256 .f * yl[i/4 ][i%4 +1 ];
1030
+ yl[i/4 +2 ][i%4 ] = 1 /16 .f * yl[2 +i/4 ][i%4 ];
1031
+ yl[i/4 +2 ][i%4 +1 ] = 1 /4096 .f * yl[2 +i/4 ][i%4 +1 ];
1032
+ }
1008
1033
}
1009
1034
1010
1035
void inner_product (addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){
1011
- float dl, ml;
1012
- get_scales (xb, il, dl, ml);
1036
+ float dl1, ml1, dl2, ml2;
1037
+ float sum2 = 0 .f ;
1038
+ get_scales (xb, il, dl1, ml1, dl2, ml2);
1013
1039
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
1014
- for (int i = 0 ; i < 16 ; i += 2 ) {
1015
- sum += yl[i/4 ][i%4 ] * (q[i/2 ] & mask1);
1016
- sum += yl[i/4 ][i%4 +1 ] * (q[i/2 ] & mask2);
1040
+ for (int i = 0 ; i < 8 ; i += 2 ) {
1041
+ sum += yl[i/4 ][i%4 ] * ((q[i/2 ]&0x000F ));
1042
+ sum += yl[i/4 ][i%4 +1 ] * ((q[i/2 ]&0x0F00 ));
1043
+ sum2 += yl[i/4 +2 ][i%4 ] * ((q[i/2 ]&0x00F0 ));
1044
+ sum2 += yl[i/4 +2 ][i%4 +1 ] * ((q[i/2 ]&0xF000 ));
1017
1045
}
1018
- sum = dl * sum - ml * sumy ;
1046
+ sum = dl1 * sum - ml1 * sumy1 + dl2 * sum2 - ml2 * sumy2 ;
1019
1047
}
1020
1048
1021
1049
void dequantize (addr_block_q_p xb, int il, thread type4x4 & reg) {
1022
1050
float dl, ml;
1023
- get_scales (xb, il, dl, ml);
1051
+ get_scales2 (xb, il, dl, ml);
1024
1052
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
1025
1053
for (int i = 0 ; i < 16 ; i += 2 ) {
1026
1054
reg[i/4 ][i%4 ] = coef1 * dl * (q[i/2 ] & mask1) - ml;
@@ -1465,7 +1493,7 @@ template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv<b
1465
1493
template [[host_name(" kernel_mul_mv_q8_0_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q8_0, N_DST, N_SIMDGROUP, 2 , 8 , q8_0_driver>;
1466
1494
template [[host_name(" kernel_mul_mv_q2_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q2_K, N_DST, N_SIMDGROUP, QK_NL, 8 , q2_K_driver>;
1467
1495
template [[host_name(" kernel_mul_mv_q3_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q3_K, N_DST, N_SIMDGROUP, QK_NL, 8 , q3_K_driver>;
1468
- template [[host_name(" kernel_mul_mv_q4_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 8 , q4_K_driver>;
1496
+ template [[host_name(" kernel_mul_mv_q4_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 32 , q4_K_driver>;
1469
1497
template [[host_name(" kernel_mul_mv_q5_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q5_K, N_DST, N_SIMDGROUP, QK_NL, 8 , q5_K_driver>;
1470
1498
#if QK_K == 256
1471
1499
template [[host_name(" kernel_mul_mv_q6_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q6_K, N_DST, N_SIMDGROUP, QK_NL, 64 , q6_K_driver>;
0 commit comments