Skip to content

Commit ca82cf7

Browse files
ikawrakowKawrakowggerganov
authored
metal : more optimizations (#2959)
* Very minor speedup via simd-group synchronization in f16 x f32 * Another very minor speedup on metal * Quite significant PP speedup on metal * Another attempt * Minor * Massive improvement for TG for fp16 * ~4-5% improvement for Q8_0 TG on metal --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 6a31a3b commit ca82cf7

File tree

2 files changed

+160
-82
lines changed

2 files changed

+160
-82
lines changed

ggml-metal.m

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
GGML_METAL_DECL_KERNEL(rms_norm);
7777
GGML_METAL_DECL_KERNEL(norm);
7878
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
7980
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
8081
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
8182
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -219,6 +220,7 @@ @implementation GGMLMetalClass
219220
GGML_METAL_ADD_KERNEL(rms_norm);
220221
GGML_METAL_ADD_KERNEL(norm);
221222
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
223+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
222224
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
223225
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
224226
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -284,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
284286
GGML_METAL_DEL_KERNEL(rms_norm);
285287
GGML_METAL_DEL_KERNEL(norm);
286288
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
289+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
287290
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
288291
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
289292
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -868,7 +871,11 @@ void ggml_metal_graph_compute(
868871
{
869872
nth0 = 32;
870873
nth1 = 1;
871-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
874+
if (ne11 * ne12 < 4) {
875+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
876+
} else {
877+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
878+
}
872879
} break;
873880
case GGML_TYPE_Q4_0:
874881
{
@@ -920,8 +927,8 @@ void ggml_metal_graph_compute(
920927
GGML_ASSERT(ne02 == 1);
921928
GGML_ASSERT(ne12 == 1);
922929

923-
nth0 = 2;
924-
nth1 = 32;
930+
nth0 = 4; //1;
931+
nth1 = 8; //32;
925932
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
926933
} break;
927934
case GGML_TYPE_Q5_K:
@@ -969,9 +976,12 @@ void ggml_metal_graph_compute(
969976
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
970977

971978
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
972-
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
979+
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
973980
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
974981
}
982+
else if (src0t == GGML_TYPE_Q4_K) {
983+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
984+
}
975985
else if (src0t == GGML_TYPE_Q3_K) {
976986
#ifdef GGML_QKK_64
977987
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -985,8 +995,8 @@ void ggml_metal_graph_compute(
985995
else if (src0t == GGML_TYPE_Q6_K) {
986996
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
987997
} else {
988-
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
989-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
998+
int64_t ny = (ne11 + 3)/4;
999+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
9901000
}
9911001
}
9921002
} break;

0 commit comments

Comments
 (0)