Skip to content

Commit

Permalink
Metal implementation for all k_quants (ggerganov#1807)
Browse files Browse the repository at this point in the history
* metal : improve q4_K

28.3 -> 26.0 ms/token by avoiding a branch in the
calculation of the scales.

* metal : small improvement for Q4_K

* metal : still optimizing Q4_K

This commit pushes it down to 25.3 ms / token.

The crazy idea of using 6 bits for the scales is really costly on
Metal: if I remove the bit fiddling necessary to make the block
scales, time goes almost to the Q4_0 23 ms/token.

Before pushing the k-quants upstream I had a Q4_K variant that
had used 8-bit scales. It wasn't more accurate, used 0.125 bits more per weight,
was running slightly slower on the CPU (due to the larger model size
and being memory bound there), and the difference was entirely
negligible under CUDA. So, I decided to publish the version with 6-bit
scales. Perhaps I should re-consider and change to 8-bit scales?

* metal : some more optimizations

Q2_K: 25.4 ms/token
Q6_K: 27.3 ms/token
Q4_0: 22.8 ms/token
Q4_1: 23.1 ms/token

* metal : Q3_K support

Something is not quite right yet.

* metal : Q5_K support

Initial version achieves 31.2 ms/token, 210 GB/s

* metal : still not able to figure out why q3_K does not work

* Minor

* metal : yet another failed attempt to make q3_K work

* metal : optimize Q5_K

31.2 ms -> 27.8 ms.
250 GB/s.

* metal : q3_K still not working

Adding a heavily commented q3_K metal kernel to explain
my obviously faulty logic. Perhaps someone could spot the issue?

* metal : q3_K finally working

Not optimized at all.

What was the issue? The scales are not 4-bytes aligned,
and I was accessing them with a uint32_t pointer.
When I tried that on CUDA, I got an error (illegal memory access)
and added a memcpy to a local array of 3 uint32_t's.
But on Metal it told me there is no memcpy, so I tried
accessing directly. There is no error, just garbage results.
At some point I did try accessing the scales with an uint16_t
pointer (the scales are for sure 2-byte aligned), but was
still getting garbage. I guess, there must have been another bug.

No access to scales is via a uint16_t pointer and, after starting
from scratch from the C dequantize function, it finally works.

* metal : Q3_K 1st optimization pass

* metal : Q3_K second optimization pass - 29.6 ms/token

* metal : Q3_K cleanup

* metal : fixed accidentally broken Q2_K

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
  • Loading branch information
ikawrakow and Kawrakow authored Jun 12, 2023
1 parent e4caa8d commit 74a6d92
Show file tree
Hide file tree
Showing 3 changed files with 463 additions and 135 deletions.
41 changes: 34 additions & 7 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
GGML_METAL_DECL_KERNEL(get_rows_q3_k);
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
GGML_METAL_DECL_KERNEL(get_rows_q5_k);
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
Expand Down Expand Up @@ -153,14 +157,18 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
GGML_METAL_ADD_KERNEL(get_rows_q3_k);
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
Expand Down Expand Up @@ -575,6 +583,15 @@ void ggml_metal_graph_compute(
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(ne02 == 1);
Expand All @@ -584,6 +601,15 @@ void ggml_metal_graph_compute(
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
} break;
case GGML_TYPE_Q5_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
Expand Down Expand Up @@ -620,15 +646,14 @@ void ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else if (src0t == GGML_TYPE_Q2_K) {
}
else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else if (src0t == GGML_TYPE_Q4_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else if (src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
Expand All @@ -646,7 +671,9 @@ void ggml_metal_graph_compute(
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
default: GGML_ASSERT(false && "not implemented");
}
Expand Down
Loading

0 comments on commit 74a6d92

Please sign in to comment.