Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal : the Q3_K and Q4_K kernels with LLAMA_QKK_64=1 are broken #3276

Closed
ggerganov opened this issue Sep 20, 2023 · 0 comments · Fixed by #4705
Closed

metal : the Q3_K and Q4_K kernels with LLAMA_QKK_64=1 are broken #3276

ggerganov opened this issue Sep 20, 2023 · 0 comments · Fixed by #4705
Labels
bug Something isn't working

Comments

@ggerganov
Copy link
Owner

The following commands fail to generate coherent text:

LLAMA_QKK_64=1 make -j && ./main -m tmp/mnt/models/open-llama/3B-v2/ggml-model-q4_k.gguf -p "I believe the meaning of life is" -t 8 -ngl 1

LLAMA_QKK_64=1 make -j && ./main -m tmp/mnt/models/open-llama/3B-v2/ggml-model-q3_k.gguf -p "I believe the meaning of life is" -t 8 -ngl 1

It works on the CPU (Arm and x86).
It also works with the following patch:

diff --git a/ggml-metal.m b/ggml-metal.m
index 1139ee3..ed9857f 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -889,7 +889,7 @@ void ggml_metal_graph_compute(
                                 src1t == GGML_TYPE_F32 &&
                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne00%32 == 0 &&
-                                ne11 > 1) {
+                                ne11 >= 1) {
                                 switch (src0->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;

So it seems the issue is in the kernel_mul_mat_q4_K_f32 kernel in the QK_K == 64 branch:

llama.cpp/ggml-metal.metal

Lines 1576 to 1663 in a40f2b6

kernel void kernel_mul_mat_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]],
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int ix = tiisg/4; // 0...7
const int it = tiisg%4; // 0...3
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[8];
float yh[8];
float sumf[N_DST]={0.f}, all_sum;
const int step = sizeof(block_q4_K) * nb / 2;
device const float * y4 = y + ix * QK_K + 8 * it;
uint16_t sc16[4];
for (int ib = ix; ib < nb; ib += 8) {
float2 sumy = {0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
yh[i] = y4[i+32]; sumy[1] += yh[i];
}
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
device const half * dh = x[ib].d;
for (int row = 0; row < N_DST; row++) {
sc16[0] = sc[0] & 0x000f;
sc16[1] = sc[0] & 0x0f00;
sc16[2] = sc[0] & 0x00f0;
sc16[3] = sc[0] & 0xf000;
float2 acc1 = {0.f, 0.f};
float2 acc2 = {0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
qs += step;
sc += step;
dh += step;
}
y4 += 8 * QK_K;
}
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
#endif

Might have been broken with #2615 , but I haven't tested this yet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

1 participant