@@ -498,6 +498,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
498498 GGML_METAL_KERNEL_TYPE_COS,
499499 GGML_METAL_KERNEL_TYPE_NEG,
500500 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501+ GGML_METAL_KERNEL_TYPE_MEAN,
501502 GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502503 GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503504 GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1454,6 +1455,7 @@ @implementation GGMLMetalClass
14541455 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
14551456 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
14561457 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
1458+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, true );
14571459 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
14581460 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
14591461 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16531655 case GGML_OP_LOG:
16541656 return false ; // TODO: implement
16551657 case GGML_OP_SUM_ROWS:
1658+ case GGML_OP_MEAN:
16561659 case GGML_OP_SOFT_MAX:
16571660 case GGML_OP_GROUP_NORM:
16581661 return has_simdgroup_reduction && ggml_is_contiguous (op->src [0 ]);
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
24002403 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
24012404 } break ;
24022405 case GGML_OP_SUM_ROWS:
2406+ case GGML_OP_MEAN:
24032407 {
24042408 GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
24052409
2406- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
2410+ id <MTLComputePipelineState > pipeline = nil ;
2411+
2412+ switch (dst->op ) {
2413+ case GGML_OP_SUM_ROWS:
2414+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
2415+ break ;
2416+ case GGML_OP_MEAN:
2417+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MEAN].pipeline ;
2418+ break ;
2419+ default :
2420+ GGML_ABORT (" fatal error" );
2421+ }
2422+
2423+ int nth = 32 ; // SIMD width
2424+
2425+ while (nth < ne00 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
2426+ nth *= 2 ;
2427+ }
24072428
2429+ nth = MIN (nth, ne00);
24082430
24092431 ggml_metal_kargs_sum_rows args = {
24102432 /* .ne00 =*/ ne00,
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
24342456 };
24352457
24362458 [encoder setComputePipelineState: pipeline];
2437- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2438- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2439- [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
2459+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2460+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2461+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
2462+ [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
24402463
2441- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2464+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth , 1 , 1 )];
24422465 } break ;
24432466 case GGML_OP_SOFT_MAX:
24442467 {
0 commit comments