@@ -811,6 +811,9 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
811811
812812 // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
813813 struct ggml_metal_mem_pool * mem_pool;
814+
815+ // used to enable concurrent execution of ops in the command buffers
816+ struct ggml_mem_ranges * mem_ranges;
814817};
815818
816819struct ggml_backend_metal_context {
@@ -1127,6 +1130,10 @@ @implementation GGMLMetalClass
11271130
11281131 ctx->cmd_bufs [i].mem_pool = ggml_metal_mem_pool_init ();
11291132 ctx->cmd_bufs [i].mem_pool ->device = device;
1133+
1134+ if (ctx_dev->use_concurrency ) {
1135+ ctx->cmd_bufs [i].mem_ranges = ggml_mem_ranges_init (ctx_dev->debug_graph );
1136+ }
11301137 }
11311138
11321139 ctx->cmd_bufs_ext = [[NSMutableArray alloc ] init ];
@@ -1737,6 +1744,10 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
17371744 }
17381745
17391746 ggml_metal_mem_pool_free (ctx->cmd_bufs [i].mem_pool );
1747+
1748+ if (ctx->cmd_bufs [i].mem_ranges ) {
1749+ ggml_mem_ranges_free (ctx->cmd_bufs [i].mem_ranges );
1750+ }
17401751 }
17411752
17421753 [ctx->cmd_bufs_ext removeAllObjects ];
@@ -2103,22 +2114,34 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
21032114 struct ggml_mem_ranges * mem_ranges;
21042115};
21052116
2106- static bool ggml_metal_encode_mem_ranges_reset (struct ggml_metal_encode_context * ctx) {
2117+ static bool ggml_metal_encode_concurrency_reset (struct ggml_metal_encode_context * ctx) {
2118+ if (!ctx->mem_ranges ) {
2119+ return true ;
2120+ }
2121+
21072122 [ctx->encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
21082123
21092124 ggml_mem_ranges_reset (ctx->mem_ranges );
21102125
21112126 return true ;
21122127}
21132128
2114- static bool ggml_metal_encode_mem_ranges_add (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2115- return ggml_mem_ranges_add (ctx->mem_ranges , node);
2116- }
2129+ static bool ggml_metal_encode_concurrency_check (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2130+ if (!ctx->mem_ranges ) {
2131+ return false ;
2132+ }
21172133
2118- static bool ggml_metal_encode_mem_ranges_check (const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21192134 return ggml_mem_ranges_check (ctx->mem_ranges , node);
21202135}
21212136
2137+ static bool ggml_metal_encode_concurrency_add (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2138+ if (!ctx->mem_ranges ) {
2139+ return true ;
2140+ }
2141+
2142+ return ggml_mem_ranges_add (ctx->mem_ranges , node);
2143+ }
2144+
21222145static int ggml_metal_encode_node (struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
21232146 ggml_backend_t backend = ctx_enc->backend ;
21242147
@@ -2240,22 +2263,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22402263
22412264 int n_fuse = 1 ;
22422265
2243- if (ctx_dev->debug_graph > 0 ) {
2244- GGML_LOG_DEBUG (" %s : op - %s \n " , __func__, ggml_op_name (dst->op ));
2245- if (src0) {
2246- GGML_LOG_DEBUG (" %s : src0 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2247- ggml_is_contiguous (src0), src0->name );
2248- }
2249- if (src1) {
2250- GGML_LOG_DEBUG (" %s : src1 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2251- ggml_is_contiguous (src1), src1->name );
2252- }
2253- if (dst) {
2254- GGML_LOG_DEBUG (" %s : dst - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], 1, %s \n " , __func__, ggml_type_name (dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2255- dst->name );
2256- }
2257- }
2258-
22592266 // check if the current node can run concurrently with other nodes before it
22602267 // the condition is that:
22612268 // - the current node cannot write to any previous src or dst ranges
@@ -2264,17 +2271,29 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22642271 // if the condition is not satisfied, we put a memory barrier and clear all ranges
22652272 // otherwise, we add the new ranges to the encoding context and process the node concurrently
22662273 //
2267- if (ctx_dev->use_concurrency ) {
2268- bool is_concurrent = true ;
2269-
2270- if (!ggml_metal_encode_mem_ranges_check (ctx_enc, node)) {
2271- ggml_metal_encode_mem_ranges_reset (ctx_enc);
2274+ {
2275+ const bool is_concurrent = ggml_metal_encode_concurrency_check (ctx_enc, node);
22722276
2273- is_concurrent = false ;
2277+ if (!is_concurrent) {
2278+ ggml_metal_encode_concurrency_reset (ctx_enc);
22742279 }
22752280
22762281 if (ctx_dev->debug_graph > 0 ) {
2277- GGML_LOG_DEBUG (" %s : concurrent = %d \n " , __func__, is_concurrent);
2282+ GGML_LOG_DEBUG (" %s : op - %-12s %s \n " , __func__, ggml_op_name (dst->op ), is_concurrent ? " (concurrent)" : " " );
2283+ }
2284+ if (ctx_dev->debug_graph > 1 ) {
2285+ if (src0) {
2286+ GGML_LOG_DEBUG (" %s : src0 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2287+ ggml_is_contiguous (src0), src0->name );
2288+ }
2289+ if (src1) {
2290+ GGML_LOG_DEBUG (" %s : src1 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2291+ ggml_is_contiguous (src1), src1->name );
2292+ }
2293+ if (dst) {
2294+ GGML_LOG_DEBUG (" %s : dst - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], 1, %s \n " , __func__, ggml_type_name (dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2295+ dst->name );
2296+ }
22782297 }
22792298 }
22802299
@@ -2475,17 +2494,13 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
24752494
24762495 if (n_fuse > 1 ) {
24772496 id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
2478- }
2479-
2480- if (ctx_dev->use_concurrency && n_fuse > 1 ) {
2481- bool is_concurrent = true ;
24822497
24832498 for (int i = 1 ; i < n_fuse; ++i) {
2484- is_concurrent = is_concurrent && ggml_metal_encode_mem_ranges_check ( ctx_enc, nodes[i]);
2485- }
2499+ if (! ggml_metal_encode_concurrency_check ( ctx_enc, nodes[i])) {
2500+ ggml_metal_encode_concurrency_reset (ctx_enc);
24862501
2487- if (!is_concurrent) {
2488- ggml_metal_encode_mem_ranges_reset (ctx_enc);
2502+ break ;
2503+ }
24892504 }
24902505 }
24912506
@@ -2632,9 +2647,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26322647
26332648 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
26342649
2635- if (ctx_dev->use_concurrency ) {
2636- ggml_metal_encode_mem_ranges_reset (ctx_enc);
2637- }
2650+ ggml_metal_encode_concurrency_reset (ctx_enc);
26382651 }
26392652
26402653 const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ;
@@ -4103,9 +4116,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41034116 // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
41044117 // so we add this extra barrier to prevent the race.
41054118 // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4106- if (ctx_dev->use_concurrency ) {
4107- ggml_metal_encode_mem_ranges_reset (ctx_enc);
4108- }
4119+ ggml_metal_encode_concurrency_reset (ctx_enc);
41094120
41104121 // tokens per expert
41114122 const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
@@ -4168,9 +4179,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41684179 }
41694180
41704181 // this barrier is always needed because the next kernel has to wait for the id maps to be computed
4171- if (ctx_dev->use_concurrency ) {
4172- ggml_metal_encode_mem_ranges_reset (ctx_enc);
4173- }
4182+ ggml_metal_encode_concurrency_reset (ctx_enc);
41744183
41754184 {
41764185 id <MTLComputePipelineState > pipeline = nil ;
@@ -4640,17 +4649,13 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46404649
46414650 if (n_fuse > 1 ) {
46424651 id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
4643- }
4644-
4645- if (ctx_dev->use_concurrency ) {
4646- bool is_concurrent = true ;
46474652
46484653 for (int i = 1 ; i < n_fuse; ++i) {
4649- is_concurrent = is_concurrent && ggml_metal_encode_mem_ranges_check ( ctx_enc, nodes[i]);
4650- }
4654+ if (! ggml_metal_encode_concurrency_check ( ctx_enc, nodes[i])) {
4655+ ggml_metal_encode_concurrency_reset (ctx_enc);
46514656
4652- if (!is_concurrent) {
4653- ggml_metal_encode_mem_ranges_reset (ctx_enc);
4657+ break ;
4658+ }
46544659 }
46554660 }
46564661
@@ -5555,9 +5560,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55555560
55565561 // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
55575562 // still, we assume that concurrent FA won't happen before we do the refactor
5558- // if (ctx_dev->use_concurrency) {
5559- // ggml_metal_encode_mem_ranges_reset(ctx_enc);
5560- // }
5563+ // ggml_metal_encode_concurrency_reset(ctx_enc);
55615564
55625565 const int32_t nrows = ne1*ne2*ne3;
55635566
@@ -5579,9 +5582,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55795582 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
55805583 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
55815584
5582- if (ctx_dev->use_concurrency ) {
5583- ggml_metal_encode_mem_ranges_reset (ctx_enc);
5584- }
5585+ ggml_metal_encode_concurrency_reset (ctx_enc);
55855586
55865587 // reduce the results from the workgroups
55875588 {
@@ -5852,19 +5853,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
58525853 }
58535854
58545855 // update the mem ranges in the encoding context
5855- if (ctx_dev->use_concurrency ) {
5856- bool ok = true ;
5857-
5858- for (int i = 0 ; i < n_fuse; ++i) {
5859- ok = ok && ggml_metal_encode_mem_ranges_add (ctx_enc, nodes[i]);
5860- }
5861-
5862- if (!ok) {
5863- if (ctx_dev->debug_graph > 2 ) {
5864- GGML_LOG_DEBUG (" %s : the range cache is full -> reset and put a barrier\n " , __func__);
5865- }
5866-
5867- ggml_metal_encode_mem_ranges_reset (ctx_enc);
5856+ for (int i = 0 ; i < n_fuse; ++i) {
5857+ if (!ggml_metal_encode_concurrency_add (ctx_enc, nodes[i])) {
5858+ ggml_metal_encode_concurrency_reset (ctx_enc);
58685859 }
58695860 }
58705861
@@ -6745,11 +6736,16 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67456736
67466737 const int n_nodes_per_cb = ctx->n_nodes_per_cb ;
67476738
6748- id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
6749- struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs [cb_idx].mem_pool ;
6739+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
6740+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs [cb_idx].mem_pool ;
6741+ struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs [cb_idx].mem_ranges ;
67506742
67516743 ggml_metal_mem_pool_reset (mem_pool);
67526744
6745+ if (mem_ranges) {
6746+ ggml_mem_ranges_reset (mem_ranges);
6747+ }
6748+
67536749 id <MTLComputeCommandEncoder > encoder;
67546750
67556751 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
@@ -6774,13 +6770,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67746770 /* .backend =*/ backend,
67756771 /* .encoder =*/ encoder,
67766772 /* .mem_pool =*/ mem_pool,
6777- /* .mem_ranges =*/ NULL ,
6773+ /* .mem_ranges =*/ mem_ranges ,
67786774 };
67796775
6780- if (ctx_dev->use_concurrency ) {
6781- ctx_enc.mem_ranges = ggml_mem_ranges_init (ctx_dev->debug_graph );
6782- }
6783-
67846776 for (int idx = node_start; idx < node_end;) {
67856777 if (should_capture) {
67866778 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
@@ -6805,8 +6797,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
68056797
68066798 [encoder endEncoding ];
68076799
6808- ggml_mem_ranges_free (ctx_enc.mem_ranges );
6809-
68106800 if (cb_idx < 2 || ctx->abort_callback == NULL ) {
68116801 [cmd_buf commit ];
68126802 }
0 commit comments