diff --git a/src/lamm_kernel_q5_0.hpp b/src/lamm_kernel_q5_0.hpp index 735bb68..b79fa86 100644 --- a/src/lamm_kernel_q5_0.hpp +++ b/src/lamm_kernel_q5_0.hpp @@ -25,19 +25,21 @@ LA_INLINE void lamm_naive_kernel(const block_q5_0 *a, const block_q8_0 *b, constexpr int Q = ggml_type_trait::super_block_size; float sum = 0.0; for (int k = 0; k < K; k++) { - uint32_t qh; - memcpy(&qh, a[i].qh, sizeof(qh)); + uint32_t qah; + memcpy(&qah, a[i].qh, sizeof(qah)); const auto *aik = a + (i * lda + k); const auto *bjk = b + (j * ldb + k); int sumi = 0; for (int h = 0; h < Q / 2; h++) { - const uint8_t xh_0 = ((qh & (1u << (h + 0))) >> (h + 0)) << 4; - const uint8_t xh_1 = ((qh & (1u << (h + 16))) >> (h + 12)); - const int32_t x0 = ((a[i].qs[j] & 0x0F) | - (~xh_0 & 0xF0)); // ((a[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((a[i].qs[j] >> 4) | - (~xh_1 & 0xF0)); // ((a[i].qs[j] >> 4) | xh_1) - 16; - sumi += (x0 * b[i].qs[j]) + (x1 * b[i].qs[j + Q / 2]); + uint8_t qah_0 = ((qah & (1u << (h + 0))) >> (h + 0)); + uint8_t qah_1 = ((qah & (1u << (h + 16))) >> (h + 16)); + qah_0 = -qah_0; + qah_1 = -qah_1; + int32_t qa_0 = aik->qs[h] & 0x0F; + int32_t qa_1 = aik->qs[h] >> 4; + qa_0 = qa_0 | ((~qah_0) & (~0x0F)); + qa_1 = qa_1 | ((~qah_1) & (~0x0F)); + sumi += (qa_0 * bjk->qs[h]) + (qa_1 * bjk->qs[h + Q / 2]); } sum += (GGML_FP16_TO_FP32(aik->d) * GGML_FP16_TO_FP32(bjk->d)) * sumi; } @@ -47,22 +49,22 @@ LA_INLINE void lamm_naive_kernel(const block_q5_0 *a, const block_q8_0 *b, LA_INLINE void lamm_simd_kernel(const block_q5_0 *a, const block_q8_0 *b, float *c, int64_t lda, int64_t ldb, int64_t ldc, int i, int j, int K) { - // simd::vreg_t acc = {0}; - // const auto *ai = a + (i * lda); - // const auto *bj = b + (j * ldb); - // for (int k = 0; k < K; k++, ai++, bj++) { - // const simd::vreg_t ad = simd::vset(GGML_FP16_TO_FP32(ai->d)); - // const simd::vreg_t bd = simd::vset(GGML_FP16_TO_FP32(bj->d)); - // const __m256 adbd = simd::mul(ad, bd); - // simd::ivreg_t va_qs = simd::load_quants(ai); - // simd::ivreg_t xh = simd::spread_bits(a[i].qh); - // xh = simd::andnot(xh, simd::ivset((char)0xF0)); - // va_qs = simd::_or(va_qs, xh); - // simd::ivreg_t vb_qs = simd::load_quants(bj); - // const simd::vreg_t xy = simd::mul_sum_us8_pairs_float(va_qs, vb_qs); - // acc = simd::madd(adbd, xy, acc); - // } - // c[j * ldc + i] = simd::reduce_sum(acc); + simd::vreg_t acc = {0}; + const auto *ai = a + (i * lda); + const auto *bj = b + (j * ldb); + for (int k = 0; k < K; k++, ai++, bj++) { + const simd::vreg_t ad = simd::vset(GGML_FP16_TO_FP32(ai->d)); + const simd::vreg_t bd = simd::vset(GGML_FP16_TO_FP32(bj->d)); + const simd::vreg_t adbd = simd::mul(ad, bd); + simd::ivreg_t va_qs = simd::load_quants(ai); + simd::ivreg_t xh = simd::spread_bits(a[i].qh); + xh = simd::andnot(xh, simd::ivset((char)0xF0)); + va_qs = simd::_or(va_qs, xh); + simd::ivreg_t vb_qs = simd::load_quants(bj); + const simd::vreg_t xy = simd::mul_sum_us8_pairs_float(va_qs, vb_qs); + acc = simd::madd(adbd, xy, acc); + } + c[j * ldc + i] = simd::reduce_sum(acc); } template