@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
543543 return _mm256_cvtepi32_ps (summed_pairs );
544544}
545545
546- // multiply int8_t, add results pairwise twice and return as float vector
547- static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
548- // Get absolute values of x vectors
549- const __m256i ax = _mm256_sign_epi8 (x , x );
550- // Sign the values of the y vectors
551- const __m256i sy = _mm256_sign_epi8 (y , x );
546+ static inline __m256 mul_sum_us8_pairs_float (const __m256i ax , const __m256i sy ) {
552547#if __AVXVNNI__
553548 const __m256i zero = _mm256_setzero_si256 ();
554549 const __m256i summed_pairs = _mm256_dpbusd_epi32 (zero , ax , sy );
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
560555#endif
561556}
562557
558+ // multiply int8_t, add results pairwise twice and return as float vector
559+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
560+ #if __AVXVNNIINT8__
561+ const __m256i zero = _mm256_setzero_si256 ();
562+ const __m256i summed_pairs = _mm256_dpbssd_epi32 (zero , x , y );
563+ return _mm256_cvtepi32_ps (summed_pairs );
564+ #else
565+ // Get absolute values of x vectors
566+ const __m256i ax = _mm256_sign_epi8 (x , x );
567+ // Sign the values of the y vectors
568+ const __m256i sy = _mm256_sign_epi8 (y , x );
569+ return mul_sum_us8_pairs_float (ax , sy );
570+ #endif
571+ }
572+
563573static inline __m128i packNibbles ( __m256i bytes )
564574{
565575 // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
619629 return _mm256_cvtepi32_ps (summed_pairs );
620630}
621631
632+ static inline __m256 mul_sum_us8_pairs_float (const __m256i ax , const __m256i sy ) {
633+ const __m128i axl = _mm256_castsi256_si128 (ax );
634+ const __m128i axh = _mm256_extractf128_si256 (ax , 1 );
635+ const __m128i syl = _mm256_castsi256_si128 (sy );
636+ const __m128i syh = _mm256_extractf128_si256 (sy , 1 );
637+ // Perform multiplication and create 16-bit values
638+ const __m128i dotl = _mm_maddubs_epi16 (axl , syl );
639+ const __m128i doth = _mm_maddubs_epi16 (axh , syh );
640+ return sum_i16_pairs_float (doth , dotl );
641+ }
642+
622643// multiply int8_t, add results pairwise twice and return as float vector
623644static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
624645 const __m128i xl = _mm256_castsi256_si128 (x );
@@ -2434,7 +2455,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
24342455 const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
24352456 const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
24362457
2437- const __m256 xy = mul_sum_i8_pairs_float (bx , by );
2458+ const __m256 xy = mul_sum_us8_pairs_float (bx , by );
24382459
24392460 // Accumulate d0*d1*x*y
24402461#if defined(__AVX2__ )
@@ -2906,7 +2927,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
29062927 const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
29072928 const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
29082929
2909- const __m256 q = mul_sum_i8_pairs_float (bx , by );
2930+ const __m256 q = mul_sum_us8_pairs_float (bx , by );
29102931
29112932 acc = _mm256_fmadd_ps (q , _mm256_mul_ps (dx , dy ), acc );
29122933 }
@@ -2940,7 +2961,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
29402961 const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
29412962 const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
29422963
2943- const __m256 q = mul_sum_i8_pairs_float (bx , by );
2964+ const __m256 q = mul_sum_us8_pairs_float (bx , by );
29442965
29452966 acc = _mm256_add_ps (_mm256_mul_ps (q , _mm256_mul_ps (dx , dy )), acc );
29462967 }
0 commit comments