Skip to content

Commit faffbec

Browse files
committed
cont : refactor ggml-metal.m
ggml-ci
1 parent 907616d commit faffbec

File tree

1 file changed

+71
-81
lines changed

1 file changed

+71
-81
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 71 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,9 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
811811

812812
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
813813
struct ggml_metal_mem_pool * mem_pool;
814+
815+
// used to enable concurrent execution of ops in the command buffers
816+
struct ggml_mem_ranges * mem_ranges;
814817
};
815818

816819
struct ggml_backend_metal_context {
@@ -1127,6 +1130,10 @@ @implementation GGMLMetalClass
11271130

11281131
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
11291132
ctx->cmd_bufs[i].mem_pool->device = device;
1133+
1134+
if (ctx_dev->use_concurrency) {
1135+
ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
1136+
}
11301137
}
11311138

11321139
ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
@@ -1737,6 +1744,10 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
17371744
}
17381745

17391746
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
1747+
1748+
if (ctx->cmd_bufs[i].mem_ranges) {
1749+
ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
1750+
}
17401751
}
17411752

17421753
[ctx->cmd_bufs_ext removeAllObjects];
@@ -2103,22 +2114,34 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
21032114
struct ggml_mem_ranges * mem_ranges;
21042115
};
21052116

2106-
static bool ggml_metal_encode_mem_ranges_reset(struct ggml_metal_encode_context * ctx) {
2117+
static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) {
2118+
if (!ctx->mem_ranges) {
2119+
return true;
2120+
}
2121+
21072122
[ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
21082123

21092124
ggml_mem_ranges_reset(ctx->mem_ranges);
21102125

21112126
return true;
21122127
}
21132128

2114-
static bool ggml_metal_encode_mem_ranges_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2115-
return ggml_mem_ranges_add(ctx->mem_ranges, node);
2116-
}
2129+
static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2130+
if (!ctx->mem_ranges) {
2131+
return false;
2132+
}
21172133

2118-
static bool ggml_metal_encode_mem_ranges_check(const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21192134
return ggml_mem_ranges_check(ctx->mem_ranges, node);
21202135
}
21212136

2137+
static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2138+
if (!ctx->mem_ranges) {
2139+
return true;
2140+
}
2141+
2142+
return ggml_mem_ranges_add(ctx->mem_ranges, node);
2143+
}
2144+
21222145
static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
21232146
ggml_backend_t backend = ctx_enc->backend;
21242147

@@ -2240,22 +2263,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22402263

22412264
int n_fuse = 1;
22422265

2243-
if (ctx_dev->debug_graph > 0) {
2244-
GGML_LOG_DEBUG("%s: op - %s\n", __func__, ggml_op_name(dst->op));
2245-
if (src0) {
2246-
GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2247-
ggml_is_contiguous(src0), src0->name);
2248-
}
2249-
if (src1) {
2250-
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2251-
ggml_is_contiguous(src1), src1->name);
2252-
}
2253-
if (dst) {
2254-
GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2255-
dst->name);
2256-
}
2257-
}
2258-
22592266
// check if the current node can run concurrently with other nodes before it
22602267
// the condition is that:
22612268
// - the current node cannot write to any previous src or dst ranges
@@ -2264,17 +2271,29 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22642271
// if the condition is not satisfied, we put a memory barrier and clear all ranges
22652272
// otherwise, we add the new ranges to the encoding context and process the node concurrently
22662273
//
2267-
if (ctx_dev->use_concurrency) {
2268-
bool is_concurrent = true;
2269-
2270-
if (!ggml_metal_encode_mem_ranges_check(ctx_enc, node)) {
2271-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
2274+
{
2275+
const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node);
22722276

2273-
is_concurrent = false;
2277+
if (!is_concurrent) {
2278+
ggml_metal_encode_concurrency_reset(ctx_enc);
22742279
}
22752280

22762281
if (ctx_dev->debug_graph > 0) {
2277-
GGML_LOG_DEBUG("%s: concurrent = %d\n", __func__, is_concurrent);
2282+
GGML_LOG_DEBUG("%s: op - %-12s %s\n", __func__, ggml_op_name(dst->op), is_concurrent ? "(concurrent)" : "");
2283+
}
2284+
if (ctx_dev->debug_graph > 1) {
2285+
if (src0) {
2286+
GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2287+
ggml_is_contiguous(src0), src0->name);
2288+
}
2289+
if (src1) {
2290+
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2291+
ggml_is_contiguous(src1), src1->name);
2292+
}
2293+
if (dst) {
2294+
GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2295+
dst->name);
2296+
}
22782297
}
22792298
}
22802299

@@ -2475,17 +2494,13 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
24752494

24762495
if (n_fuse > 1) {
24772496
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
2478-
}
2479-
2480-
if (ctx_dev->use_concurrency && n_fuse > 1) {
2481-
bool is_concurrent = true;
24822497

24832498
for (int i = 1; i < n_fuse; ++i) {
2484-
is_concurrent = is_concurrent && ggml_metal_encode_mem_ranges_check(ctx_enc, nodes[i]);
2485-
}
2499+
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
2500+
ggml_metal_encode_concurrency_reset(ctx_enc);
24862501

2487-
if (!is_concurrent) {
2488-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
2502+
break;
2503+
}
24892504
}
24902505
}
24912506

@@ -2632,9 +2647,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26322647

26332648
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
26342649

2635-
if (ctx_dev->use_concurrency) {
2636-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
2637-
}
2650+
ggml_metal_encode_concurrency_reset(ctx_enc);
26382651
}
26392652

26402653
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
@@ -4103,9 +4116,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41034116
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
41044117
// so we add this extra barrier to prevent the race.
41054118
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4106-
if (ctx_dev->use_concurrency) {
4107-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
4108-
}
4119+
ggml_metal_encode_concurrency_reset(ctx_enc);
41094120

41104121
// tokens per expert
41114122
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
@@ -4168,9 +4179,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41684179
}
41694180

41704181
// this barrier is always needed because the next kernel has to wait for the id maps to be computed
4171-
if (ctx_dev->use_concurrency) {
4172-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
4173-
}
4182+
ggml_metal_encode_concurrency_reset(ctx_enc);
41744183

41754184
{
41764185
id<MTLComputePipelineState> pipeline = nil;
@@ -4640,17 +4649,13 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46404649

46414650
if (n_fuse > 1) {
46424651
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
4643-
}
4644-
4645-
if (ctx_dev->use_concurrency) {
4646-
bool is_concurrent = true;
46474652

46484653
for (int i = 1; i < n_fuse; ++i) {
4649-
is_concurrent = is_concurrent && ggml_metal_encode_mem_ranges_check(ctx_enc, nodes[i]);
4650-
}
4654+
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
4655+
ggml_metal_encode_concurrency_reset(ctx_enc);
46514656

4652-
if (!is_concurrent) {
4653-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
4657+
break;
4658+
}
46544659
}
46554660
}
46564661

@@ -5555,9 +5560,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55555560

55565561
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
55575562
// still, we assume that concurrent FA won't happen before we do the refactor
5558-
//if (ctx_dev->use_concurrency) {
5559-
// ggml_metal_encode_mem_ranges_reset(ctx_enc);
5560-
//}
5563+
//ggml_metal_encode_concurrency_reset(ctx_enc);
55615564

55625565
const int32_t nrows = ne1*ne2*ne3;
55635566

@@ -5579,9 +5582,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55795582
[encoder setThreadgroupMemoryLength:smem atIndex:0];
55805583
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
55815584

5582-
if (ctx_dev->use_concurrency) {
5583-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
5584-
}
5585+
ggml_metal_encode_concurrency_reset(ctx_enc);
55855586

55865587
// reduce the results from the workgroups
55875588
{
@@ -5852,19 +5853,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
58525853
}
58535854

58545855
// update the mem ranges in the encoding context
5855-
if (ctx_dev->use_concurrency) {
5856-
bool ok = true;
5857-
5858-
for (int i = 0; i < n_fuse; ++i) {
5859-
ok = ok && ggml_metal_encode_mem_ranges_add(ctx_enc, nodes[i]);
5860-
}
5861-
5862-
if (!ok) {
5863-
if (ctx_dev->debug_graph > 2) {
5864-
GGML_LOG_DEBUG("%s: the range cache is full -> reset and put a barrier\n", __func__);
5865-
}
5866-
5867-
ggml_metal_encode_mem_ranges_reset(ctx_enc);
5856+
for (int i = 0; i < n_fuse; ++i) {
5857+
if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) {
5858+
ggml_metal_encode_concurrency_reset(ctx_enc);
58685859
}
58695860
}
58705861

@@ -6745,11 +6736,16 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67456736

67466737
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
67476738

6748-
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
6749-
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
6739+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
6740+
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
6741+
struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
67506742

67516743
ggml_metal_mem_pool_reset(mem_pool);
67526744

6745+
if (mem_ranges) {
6746+
ggml_mem_ranges_reset(mem_ranges);
6747+
}
6748+
67536749
id<MTLComputeCommandEncoder> encoder;
67546750

67556751
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
@@ -6774,13 +6770,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67746770
/*.backend =*/ backend,
67756771
/*.encoder =*/ encoder,
67766772
/*.mem_pool =*/ mem_pool,
6777-
/*.mem_ranges =*/ NULL,
6773+
/*.mem_ranges =*/ mem_ranges,
67786774
};
67796775

6780-
if (ctx_dev->use_concurrency) {
6781-
ctx_enc.mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
6782-
}
6783-
67846776
for (int idx = node_start; idx < node_end;) {
67856777
if (should_capture) {
67866778
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
@@ -6805,8 +6797,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
68056797

68066798
[encoder endEncoding];
68076799

6808-
ggml_mem_ranges_free(ctx_enc.mem_ranges);
6809-
68106800
if (cb_idx < 2 || ctx->abort_callback == NULL) {
68116801
[cmd_buf commit];
68126802
}

0 commit comments

Comments
 (0)