Skip to content

Commit 417df40

Browse files
committed
metal : fix race on mem pool buffers
ggml-ci
1 parent db1e3ce commit 417df40

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ @implementation ggml_metal_heap_ptr
640640
@end
641641

642642
//
643-
// ggml_metal_mem_pool
643+
// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
644644
//
645645

646646
struct ggml_metal_mem_pool {
@@ -4112,6 +4112,14 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41124112
default: break;
41134113
}
41144114

4115+
// TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
4116+
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
4117+
// so we add this extra barrier to prevent the race.
4118+
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4119+
if (ctx_dev->use_concurrency) {
4120+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
4121+
}
4122+
41154123
// tokens per expert
41164124
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
41174125
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -4172,6 +4180,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41724180
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
41734181
}
41744182

4183+
// this barrier is always needed because the next kernel has to wait for the id maps to be computed
41754184
if (ctx_dev->use_concurrency) {
41764185
ggml_metal_encode_mem_ranges_reset(ctx_enc);
41774186
}
@@ -5561,6 +5570,12 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55615570
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
55625571
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
55635572

5573+
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
5574+
// still, we assume that concurrent FA won't happen before we do the refactor
5575+
//if (ctx_dev->use_concurrency) {
5576+
// ggml_metal_encode_mem_ranges_reset(ctx_enc);
5577+
//}
5578+
55645579
const int32_t nrows = ne1*ne2*ne3;
55655580

55665581
// temp buffer for writing the results from each workgroup
@@ -5939,6 +5954,7 @@ static enum ggml_status ggml_metal_graph_compute(
59395954
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
59405955
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
59415956
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
5957+
// [TAG_MEM_POOL_REMOVE]
59425958
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
59435959
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
59445960
[cmd_buf retain];

0 commit comments

Comments
 (0)