Skip to content

Commit

Permalink
fix naive q5_0; debugging simd q5_0
Browse files Browse the repository at this point in the history
  • Loading branch information
AyiStar committed Jul 16, 2024
1 parent 0d76687 commit 32507ab
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions src/lamm_kernel_q5_0.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,21 @@ LA_INLINE void lamm_naive_kernel(const block_q5_0 *a, const block_q8_0 *b,
constexpr int Q = ggml_type_trait<GGML_TYPE_Q5_0>::super_block_size;
float sum = 0.0;
for (int k = 0; k < K; k++) {
uint32_t qh;
memcpy(&qh, a[i].qh, sizeof(qh));
uint32_t qah;
memcpy(&qah, a[i].qh, sizeof(qah));
const auto *aik = a + (i * lda + k);
const auto *bjk = b + (j * ldb + k);
int sumi = 0;
for (int h = 0; h < Q / 2; h++) {
const uint8_t xh_0 = ((qh & (1u << (h + 0))) >> (h + 0)) << 4;
const uint8_t xh_1 = ((qh & (1u << (h + 16))) >> (h + 12));
const int32_t x0 = ((a[i].qs[j] & 0x0F) |
(~xh_0 & 0xF0)); // ((a[i].qs[j] & 0x0F) | xh_0) - 16;
const int32_t x1 = ((a[i].qs[j] >> 4) |
(~xh_1 & 0xF0)); // ((a[i].qs[j] >> 4) | xh_1) - 16;
sumi += (x0 * b[i].qs[j]) + (x1 * b[i].qs[j + Q / 2]);
uint8_t qah_0 = ((qah & (1u << (h + 0))) >> (h + 0));
uint8_t qah_1 = ((qah & (1u << (h + 16))) >> (h + 16));
qah_0 = -qah_0;
qah_1 = -qah_1;
int32_t qa_0 = aik->qs[h] & 0x0F;
int32_t qa_1 = aik->qs[h] >> 4;
qa_0 = qa_0 | ((~qah_0) & (~0x0F));
qa_1 = qa_1 | ((~qah_1) & (~0x0F));
sumi += (qa_0 * bjk->qs[h]) + (qa_1 * bjk->qs[h + Q / 2]);
}
sum += (GGML_FP16_TO_FP32(aik->d) * GGML_FP16_TO_FP32(bjk->d)) * sumi;
}
Expand All @@ -47,22 +49,22 @@ LA_INLINE void lamm_naive_kernel(const block_q5_0 *a, const block_q8_0 *b,
LA_INLINE void lamm_simd_kernel(const block_q5_0 *a, const block_q8_0 *b,
float *c, int64_t lda, int64_t ldb, int64_t ldc,
int i, int j, int K) {
// simd::vreg_t acc = {0};
// const auto *ai = a + (i * lda);
// const auto *bj = b + (j * ldb);
// for (int k = 0; k < K; k++, ai++, bj++) {
// const simd::vreg_t ad = simd::vset(GGML_FP16_TO_FP32(ai->d));
// const simd::vreg_t bd = simd::vset(GGML_FP16_TO_FP32(bj->d));
// const __m256 adbd = simd::mul(ad, bd);
// simd::ivreg_t va_qs = simd::load_quants(ai);
// simd::ivreg_t xh = simd::spread_bits(a[i].qh);
// xh = simd::andnot(xh, simd::ivset((char)0xF0));
// va_qs = simd::_or(va_qs, xh);
// simd::ivreg_t vb_qs = simd::load_quants(bj);
// const simd::vreg_t xy = simd::mul_sum_us8_pairs_float(va_qs, vb_qs);
// acc = simd::madd(adbd, xy, acc);
// }
// c[j * ldc + i] = simd::reduce_sum(acc);
simd::vreg_t acc = {0};
const auto *ai = a + (i * lda);
const auto *bj = b + (j * ldb);
for (int k = 0; k < K; k++, ai++, bj++) {
const simd::vreg_t ad = simd::vset(GGML_FP16_TO_FP32(ai->d));
const simd::vreg_t bd = simd::vset(GGML_FP16_TO_FP32(bj->d));
const simd::vreg_t adbd = simd::mul(ad, bd);
simd::ivreg_t va_qs = simd::load_quants(ai);
simd::ivreg_t xh = simd::spread_bits(a[i].qh);
xh = simd::andnot(xh, simd::ivset((char)0xF0));
va_qs = simd::_or(va_qs, xh);
simd::ivreg_t vb_qs = simd::load_quants(bj);
const simd::vreg_t xy = simd::mul_sum_us8_pairs_float(va_qs, vb_qs);
acc = simd::madd(adbd, xy, acc);
}
c[j * ldc + i] = simd::reduce_sum(acc);
}

template <int B0, int B1>
Expand Down

0 comments on commit 32507ab

Please sign in to comment.