Skip to content

Commit

Permalink
Fix q2_k simd block bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AyiStar committed Jul 27, 2024
1 parent 77f1467 commit e92ad41
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 29 deletions.
8 changes: 4 additions & 4 deletions src/la-benchmark-matmult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ int main(int argc, char **argv) {
#ifdef LAMM_DEBUG
printf("Debugging the correctness\n");
// check the correctness
const int sizey = 32;
const int sizex = 512;
const int sizez = 32;
const int sizey = 33;
const int sizex = 4096;
const int sizez = 18;
#else
const int sizey = 4096;
const int sizex = 11008;
Expand Down Expand Up @@ -371,7 +371,7 @@ void do_benchmark(int sizex, int sizey, int sizez, struct ggml_cgraph *g1,
// quantizuation will be slightly different
float sum_of_result = tensor_sum_elements(g1->nodes[0]);
float delta = std::abs(sum_of_result - correct) / std::abs(correct);
float allowed_delta = 1e-2; // Let's accept an epsilon of 10^-3
float allowed_delta = 1e-2;

if (delta > allowed_delta) {
printf("\nABORT - ERROR in Matrix Multiplication result - expected "
Expand Down
15 changes: 11 additions & 4 deletions src/lamm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ template <ggml_type GGMLType> class LAMMImpl {
}
int M = C.row, N = C.col, K = A.col;
assert(M == A.row && N == B.col && K == B.row);
assert(K % simd::kF32PerVec == 0);
assert(A.type != GGML_TYPE_F32 || K % simd::kF32PerVec == 0);
assert(nth > 0);
// split thread-local job by M
int job_size = M / nth;
Expand Down Expand Up @@ -97,8 +97,12 @@ template <ggml_type GGMLType> class LAMMImpl {
}
}
int M = C.row, N = C.col, K = A.col;
assert(M == A.row && N == B.col && K == B.row);
assert(nth > 0);
if (!(M == A.row && N == B.col && K == B.row)) {
std::cout << "Assertion error" << std::endl;
std::abort();
}
// assert(M == A.row && N == B.col && K == B.row);
// assert(nth > 0);
// split thread-local job by M
int job_size = M / nth;
int job_start = ith * job_size;
Expand All @@ -116,7 +120,10 @@ template <ggml_type GGMLType> class LAMMImpl {
int jj = (L1 / kBlockSize * kBlockSize);
int64_t lda{A.ld}, ldb{B.ld}, ldc{C.ld};

assert((K % simd::kF32PerVec) == 0);
if (A.type == GGML_TYPE_F32 && (K % simd::kF32PerVec) != 0) {
std::cout << "K= " << K << std::endl;
std::abort();
}
dtype *a = (dtype *)(A.data);
vec_dot_dtype *b = (vec_dot_dtype *)(B.data);
float *c = (float *)(C.data);
Expand Down
41 changes: 21 additions & 20 deletions src/lamm_kernel_q2_k.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ LA_INLINE void lamm_naive_kernel(const block_q2_K *a, const block_q8_K *b,
c[j * ldc + i] = sumf;
}

static LA_INLINE __m256i get_scale_shuffle_q3k(int i) {
static LA_INLINE simd::ivreg_t get_scale_shuffle_q3k(int i) {
static const uint8_t k_shuffle[128] = {
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
Expand Down Expand Up @@ -126,7 +126,7 @@ LA_INLINE void lamm_simd_kernel(const block_q2_K *a, const block_q8_K *b,
ivreg_t sumi = {0};
const float d = bjk->d * GGML_FP16_TO_FP32(aik->d);

for (int j = 0; j < 2; ++j) {
for (int l = 0; l < 2; ++l) {

// for a
const ivreg_t q2bits = load(q2);
Expand All @@ -137,15 +137,15 @@ LA_INLINE void lamm_simd_kernel(const block_q2_K *a, const block_q8_K *b,
const ivreg_t q2_3 = _and(logic_shift_right(q2bits, 6), m3);

// for a and b
ivreg_t p0 = mul_ubs(q2_0, q8_0[j]);
ivreg_t p1 = mul_ubs(q2_1, q8_1[j]);
ivreg_t p2 = mul_ubs(q2_2, q8_2[j]);
ivreg_t p3 = mul_ubs(q2_3, q8_3[j]);
ivreg_t p0 = mul_ubs(q2_0, q8_0[l]);
ivreg_t p1 = mul_ubs(q2_1, q8_1[l]);
ivreg_t p2 = mul_ubs(q2_2, q8_2[l]);
ivreg_t p3 = mul_ubs(q2_3, q8_3[l]);

p0 = mul(shuffle(scales[j], get_scale_shuffle_q3k(0)), p0);
p1 = mul(shuffle(scales[j], get_scale_shuffle_q3k(1)), p1);
p2 = mul(shuffle(scales[j], get_scale_shuffle_q3k(2)), p2);
p3 = mul(shuffle(scales[j], get_scale_shuffle_q3k(3)), p3);
p0 = mul(shuffle(scales[l], get_scale_shuffle_q3k(0)), p0);
p1 = mul(shuffle(scales[l], get_scale_shuffle_q3k(1)), p1);
p2 = mul(shuffle(scales[l], get_scale_shuffle_q3k(2)), p2);
p3 = mul(shuffle(scales[l], get_scale_shuffle_q3k(3)), p3);

p0 = add(p0, p1);
p2 = add(p2, p3);
Expand All @@ -167,7 +167,7 @@ LA_INLINE void lamm_simd_block_kernel(const block_q2_K *a, const block_q8_K *b,

static_assert(B0 > 0 && B0 <= 4);
static_assert(B1 > 0 && B1 <= 4);

// std::cout << "Q2_K SIMD block called with B0=" << B0 << ", B1=" << B1 << std::endl;
using namespace simd;

const ivreg_t m3 = ivset(3);
Expand All @@ -185,15 +185,6 @@ LA_INLINE void lamm_simd_block_kernel(const block_q2_K *a, const block_q8_K *b,
[[maybe_unused]] ivreg_t bsum0 = {0}, bsum1 = {0}, bsum2 = {0}, bsum3 = {0};
[[maybe_unused]] vreg_t bd0 = {0}, bd1 = {0}, bd2 = {0}, bd3 = {0};

[[maybe_unused]] ivreg_t sumi00 = {0}, sumi01 = {0}, sumi02 = {0},
sumi03 = {0};
[[maybe_unused]] ivreg_t sumi10 = {0}, sumi11 = {0}, sumi12 = {0},
sumi13 = {0};
[[maybe_unused]] ivreg_t sumi20 = {0}, sumi21 = {0}, sumi22 = {0},
sumi23 = {0};
[[maybe_unused]] ivreg_t sumi30 = {0}, sumi31 = {0}, sumi32 = {0},
sumi33 = {0};

[[maybe_unused]] vreg_t acc00 = {0}, acc01 = {0}, acc02 = {0}, acc03 = {0};
[[maybe_unused]] vreg_t acc10 = {0}, acc11 = {0}, acc12 = {0}, acc13 = {0};
[[maybe_unused]] vreg_t acc20 = {0}, acc21 = {0}, acc22 = {0}, acc23 = {0};
Expand Down Expand Up @@ -289,6 +280,16 @@ LA_INLINE void lamm_simd_block_kernel(const block_q2_K *a, const block_q8_K *b,
LOOP(OUTER_FN, 4)
#undef INNER_FN
#undef OUTER_FN

#define FN(N) \
if constexpr (B0 > N) { \
ai##N++; \
} \
if constexpr (B1 > N) { \
bj##N++; \
}
LOOP(FN, 4)
#undef FN
} // loop `k`

#define INNER_FN(N0, N1) \
Expand Down
3 changes: 2 additions & 1 deletion src/loongarch_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ void lamm_mul_mat(const struct ggml_compute_params *params,
C.col = ne11;
C.ld = nb1 / ggml_type_size(dst->type);

decltype(LAMMImpl<GGML_TYPE_F32>::matmul) *mm_func = nullptr;
using MatMulFuncPtr = void(*)(const Matrix &A, const Matrix &B, const Matrix &C, int ith, int nth);
MatMulFuncPtr mm_func = nullptr;
switch (A.type) {
case GGML_TYPE_F32:
mm_func = LAMMImpl<GGML_TYPE_F32>::matmul;
Expand Down

0 comments on commit e92ad41

Please sign in to comment.