Skip to content

Commit

Permalink
metal : add abort callback (#905)
Browse files Browse the repository at this point in the history
  • Loading branch information
conradev authored Aug 7, 2024
1 parent 18703ad commit 1f2b80a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void

GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);

GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);

// helper to check if the device supports a specific family
Expand Down
41 changes: 38 additions & 3 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@
bool support_simdgroup_mm;

bool should_capture_next_compute;

// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
void * abort_callback_data;
};

// MSL code
Expand Down Expand Up @@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
command_buffer_builder[cb_idx] = command_buffer;

// enqueue the command buffers in order to specify their execution order
[command_buffer enqueue];
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer enqueue];
}
}

const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
Expand Down Expand Up @@ -2829,7 +2836,9 @@ static enum ggml_status ggml_metal_graph_compute(

[encoder endEncoding];

[command_buffer commit];
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer commit];
}
});

// Wait for completion and check status of each command buffer
Expand All @@ -2849,6 +2858,23 @@ static enum ggml_status ggml_metal_graph_compute(

return GGML_STATUS_FAILED;
}

id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
if (!next_buffer) {
continue;
}

bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
if (next_queued) {
continue;
}

if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
return GGML_STATUS_ABORTED;
}

[next_buffer commit];
}

if (should_capture) {
Expand Down Expand Up @@ -3244,6 +3270,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
}

void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
GGML_ASSERT(ggml_backend_is_metal(backend));

struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;

ctx->abort_callback = abort_callback;
ctx->abort_callback_data = user_data;
}

bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
GGML_ASSERT(ggml_backend_is_metal(backend));

Expand Down

0 comments on commit 1f2b80a

Please sign in to comment.