76
76
GGML_METAL_DECL_KERNEL (rms_norm);
77
77
GGML_METAL_DECL_KERNEL (norm);
78
78
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32);
79
+ GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_1row);
79
80
GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
80
81
GGML_METAL_DECL_KERNEL (mul_mat_q4_1_f32);
81
82
GGML_METAL_DECL_KERNEL (mul_mat_q8_0_f32);
@@ -219,6 +220,7 @@ @implementation GGMLMetalClass
219
220
GGML_METAL_ADD_KERNEL (rms_norm);
220
221
GGML_METAL_ADD_KERNEL (norm);
221
222
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
223
+ GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_1row);
222
224
GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
223
225
GGML_METAL_ADD_KERNEL (mul_mat_q4_1_f32);
224
226
GGML_METAL_ADD_KERNEL (mul_mat_q8_0_f32);
@@ -284,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
284
286
GGML_METAL_DEL_KERNEL (rms_norm);
285
287
GGML_METAL_DEL_KERNEL (norm);
286
288
GGML_METAL_DEL_KERNEL (mul_mat_f16_f32);
289
+ GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_1row);
287
290
GGML_METAL_DEL_KERNEL (mul_mat_q4_0_f32);
288
291
GGML_METAL_DEL_KERNEL (mul_mat_q4_1_f32);
289
292
GGML_METAL_DEL_KERNEL (mul_mat_q8_0_f32);
@@ -868,7 +871,11 @@ void ggml_metal_graph_compute(
868
871
{
869
872
nth0 = 32 ;
870
873
nth1 = 1 ;
871
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
874
+ if (ne11 * ne12 < 4 ) {
875
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_1row];
876
+ } else {
877
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
878
+ }
872
879
} break ;
873
880
case GGML_TYPE_Q4_0:
874
881
{
@@ -920,8 +927,8 @@ void ggml_metal_graph_compute(
920
927
GGML_ASSERT (ne02 == 1 );
921
928
GGML_ASSERT (ne12 == 1 );
922
929
923
- nth0 = 2 ;
924
- nth1 = 32 ;
930
+ nth0 = 4 ; // 1 ;
931
+ nth1 = 8 ; // 32;
925
932
[encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_K_f32];
926
933
} break ;
927
934
case GGML_TYPE_Q5_K:
@@ -969,9 +976,12 @@ void ggml_metal_graph_compute(
969
976
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
970
977
971
978
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
972
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
979
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
973
980
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
974
981
}
982
+ else if (src0t == GGML_TYPE_Q4_K) {
983
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 )/4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
984
+ }
975
985
else if (src0t == GGML_TYPE_Q3_K) {
976
986
#ifdef GGML_QKK_64
977
987
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
@@ -985,8 +995,8 @@ void ggml_metal_graph_compute(
985
995
else if (src0t == GGML_TYPE_Q6_K) {
986
996
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
987
997
} else {
988
- [encoder setThreadgroupMemoryLength: nth0* sizeof ( float ) atIndex: 0 ] ;
989
- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11 , ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
998
+ int64_t ny = (ne11 + 3 )/ 4 ;
999
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ny , ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
990
1000
}
991
1001
}
992
1002
} break ;
0 commit comments