Skip to content

Commit

Permalink
ggml : introduce ggml_status (ggml/750)
Browse files Browse the repository at this point in the history
* using enum as an exit code instead of macros

* update return type from enum to unsigned int

* indentation fix

* compound update
ggml_compute_exit_code -> ggml_status
changed ggml_status from a bit-field type to simple codes
ggml_status to string cast

* ggml_status to string cast

* GGML_CALL was removed

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
3 people committed Mar 4, 2024
1 parent fe52be1 commit 9fa2627
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 63 deletions.
7 changes: 4 additions & 3 deletions ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ extern "C" {
// (optional) complete all pending operations
void (*GGML_CALL synchronize)(ggml_backend_t backend);

// compute graph with a plan
// create a plan for ggml_cgraph and free it
ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
void (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);

// compute graph with a plan
enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph without a plan (async)
bool (*GGML_CALL graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);

// check if the backend supports an operation
bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
Expand Down
39 changes: 18 additions & 21 deletions ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
backend->iface.graph_plan_free(backend, plan);
}

void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
backend->iface.graph_plan_compute(backend, plan);
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
return backend->iface.graph_plan_compute(backend, plan);
}

bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
return backend->iface.graph_compute(backend, cgraph);
}

Expand Down Expand Up @@ -732,15 +732,15 @@ GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, g
GGML_UNUSED(backend);
}

GGML_CALL static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;

ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);

GGML_UNUSED(backend);
}

GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;

struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
Expand All @@ -755,8 +755,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data;

ggml_graph_compute(cgraph, &cplan);
return true;
return ggml_graph_compute(cgraph, &cplan);
}

GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
Expand Down Expand Up @@ -1437,7 +1436,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
return true;
}

static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
uint64_t compute_us[GGML_MAX_BACKENDS] = {0};

Expand Down Expand Up @@ -1472,8 +1471,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {

uint64_t compute_start_us = ggml_time_us();
if (!sched->callback_eval) {
if (!ggml_backend_graph_compute(split_backend, &split->graph)) {
return false;
enum ggml_status ec = ggml_backend_graph_compute(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
} else {
Expand All @@ -1494,8 +1494,9 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {

struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);

if (!ggml_backend_graph_compute(split_backend, &gv)) {
return false;
enum ggml_status ec = ggml_backend_graph_compute(split_backend, &gv);
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}

if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
Expand All @@ -1519,7 +1520,7 @@ static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
}
#endif

return true;
return GGML_STATUS_SUCCESS;
}

ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
Expand Down Expand Up @@ -1581,7 +1582,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
return true;
}

bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);

if (!sched->is_reset) {
Expand All @@ -1590,14 +1591,10 @@ bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cg

ggml_backend_sched_split_graph(sched, graph);
if (!ggml_backend_sched_alloc_splits(sched)) {
return false;
return GGML_STATUS_ALLOC_FAILED;
}

if (!ggml_backend_sched_compute_splits(sched)) {
return false;
}

return true;
return ggml_backend_sched_compute_splits(sched);
}

void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
Expand Down
31 changes: 16 additions & 15 deletions ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ extern "C" {

GGML_API void ggml_backend_synchronize(ggml_backend_t backend);

GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);

GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
GGML_API enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);

GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);

// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
Expand Down Expand Up @@ -157,26 +158,26 @@ extern "C" {
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);

// Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
// Get the number of splits of the last graph
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);

GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);

GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);

// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);

// Reset all assignments and allocators - must be called before changing the node backends
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);

// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);

//
// Utils
Expand Down
4 changes: 2 additions & 2 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12241,7 +12241,7 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
UNUSED(backend);
}

GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;

ggml_cuda_set_main_device(cuda_ctx->device);
Expand Down Expand Up @@ -12277,7 +12277,7 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg
GGML_ASSERT(ok);
}

return true;
return GGML_STATUS_SUCCESS;
}

GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
Expand Down
4 changes: 2 additions & 2 deletions ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
return ggml_backend_kompute_buffer_type(ctx->device);
}

static bool ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
ggml_vk_graph_compute(ctx, cgraph);
return true;
return GGML_STATUS_SUCCESS;
}

static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
Expand Down
8 changes: 4 additions & 4 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
}
}

static bool ggml_metal_graph_compute(
static enum ggml_status ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {

Expand Down Expand Up @@ -2484,7 +2484,7 @@ static bool ggml_metal_graph_compute(
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
return false;
return GGML_STATUS_FAILED;
}
}

Expand All @@ -2493,7 +2493,7 @@ static bool ggml_metal_graph_compute(
}

}
return true;
return GGML_STATUS_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2795,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
UNUSED(backend);
}

GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;

return ggml_metal_graph_compute(metal_ctx, cgraph);
Expand Down
4 changes: 2 additions & 2 deletions ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,7 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg
GGML_UNUSED(backend);
}

static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i];
switch (node->op) {
Expand All @@ -2246,7 +2246,7 @@ static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgrap
}
}

return true;
return GGML_STATUS_SUCCESS;

GGML_UNUSED(backend);
}
Expand Down
4 changes: 2 additions & 2 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15581,7 +15581,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}

GGML_CALL static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
ggml_sycl_set_main_device(sycl_ctx->device);

Expand Down Expand Up @@ -15613,7 +15613,7 @@ GGML_CALL static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, gg
GGML_ASSERT(ok);
}

return true;
return GGML_STATUS_SUCCESS;
}

GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
Expand Down
4 changes: 2 additions & 2 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5092,7 +5092,7 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
ctx->transfer_ctx = nullptr;
}

GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;

for (int i = 0; i < cgraph->n_nodes; i++) {
Expand Down Expand Up @@ -5135,7 +5135,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml

ggml_vk_graph_cleanup(ctx);

return true;
return GGML_STATUS_SUCCESS;

UNUSED(backend);
}
Expand Down
Loading

0 comments on commit 9fa2627

Please sign in to comment.