Skip to content

Commit c3a979e

Browse files
committed
dyn graphs : update ggml_graph_import
1 parent f668734 commit c3a979e

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

include/ggml/ggml.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,8 +1729,8 @@ extern "C" {
17291729

17301730
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
17311731

1732-
GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
1733-
GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
1732+
GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
1733+
GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
17341734

17351735
// print info and performance information for the graph
17361736
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);

src/ggml.c

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18639,8 +18639,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1863918639
case GGML_OP_ADD:
1864018640
case GGML_OP_ADD1:
1864118641
case GGML_OP_ACC:
18642-
n_tasks = n_threads;
18643-
break;
18642+
{
18643+
n_tasks = n_threads;
18644+
} break;
1864418645
case GGML_OP_SUB:
1864518646
case GGML_OP_DIV:
1864618647
case GGML_OP_SQR:
@@ -18652,9 +18653,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1865218653
case GGML_OP_ARGMAX:
1865318654
case GGML_OP_REPEAT:
1865418655
case GGML_OP_REPEAT_BACK:
18655-
n_tasks = 1;
18656-
break;
18657-
18656+
{
18657+
n_tasks = 1;
18658+
} break;
1865818659
case GGML_OP_UNARY:
1865918660
switch (ggml_get_unary_op(node)) {
1866018661
case GGML_UNARY_OP_ABS:
@@ -18681,8 +18682,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1868118682
case GGML_OP_RMS_NORM_BACK:
1868218683
case GGML_OP_GROUP_NORM:
1868318684
case GGML_OP_CONCAT:
18684-
n_tasks = n_threads;
18685-
break;
18685+
{
18686+
n_tasks = n_threads;
18687+
} break;
1868618688
case GGML_OP_MUL_MAT:
1868718689
{
1868818690
n_tasks = n_threads;
@@ -19446,12 +19448,12 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
1944619448
const uint32_t magic = GGML_FILE_MAGIC;
1944719449
const uint32_t version = GGML_FILE_VERSION;
1944819450
const uint32_t n_leafs = cgraph->n_leafs;
19449-
const uint32_t nodes = cgraph->n_nodes;
19451+
const uint32_t n_nodes = cgraph->n_nodes;
1945019452

1945119453
fwrite(&magic, sizeof(uint32_t), 1, fout);
1945219454
fwrite(&version, sizeof(uint32_t), 1, fout);
1945319455
fwrite(&n_leafs, sizeof(uint32_t), 1, fout);
19454-
fwrite(&nodes, sizeof(uint32_t), 1, fout);
19456+
fwrite(&n_nodes, sizeof(uint32_t), 1, fout);
1945519457
fwrite(&size_eval, sizeof(uint64_t), 1, fout);
1945619458
}
1945719459

@@ -19565,12 +19567,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
1956519567
}
1956619568
}
1956719569

19568-
#if 0
19569-
struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
19570+
struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
1957019571
assert(*ctx_data == NULL);
1957119572
assert(*ctx_eval == NULL);
1957219573

19573-
struct ggml_cgraph result = { 0 };
19574+
struct ggml_cgraph * result = NULL;
1957419575

1957519576
struct ggml_tensor * data = NULL;
1957619577

@@ -19642,13 +19643,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
1964219643
const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs);
1964319644
const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes);
1964419645
const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval);
19645-
19646-
result.n_leafs = n_leafs;
19647-
result.n_nodes = n_nodes;
19646+
const int graph_size = MAX(n_leafs, n_nodes);
1964819647

1964919648
// create the data context
1965019649
{
19651-
const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead();
19650+
const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead() + ggml_graph_overhead(graph_size);
1965219651

1965319652
struct ggml_init_params params = {
1965419653
.mem_size = size_eval + overhead,
@@ -19664,6 +19663,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
1966419663
}
1966519664
}
1966619665

19666+
result = ggml_new_graph(*ctx_eval, graph_size);
19667+
19668+
result->n_leafs = n_leafs;
19669+
result->n_nodes = n_nodes;
19670+
19671+
1966719672
// leafs
1966819673
{
1966919674
uint32_t type;
@@ -19702,7 +19707,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
1970219707
tensor->nb[j] = nb[j];
1970319708
}
1970419709

19705-
result.leafs[i] = tensor;
19710+
result->leafs[i] = tensor;
1970619711

1970719712
ptr += ggml_nbytes(tensor);
1970819713

@@ -19754,10 +19759,10 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
1975419759
continue;
1975519760
}
1975619761

19757-
if (arg_idx < result.n_leafs) {
19758-
args[j] = result.leafs[arg_idx];
19762+
if (arg_idx < result->n_leafs) {
19763+
args[j] = result->leafs[arg_idx];
1975919764
} else {
19760-
args[j] = result.nodes[arg_idx - result.n_leafs];
19765+
args[j] = result->nodes[arg_idx - result->n_leafs];
1976119766
}
1976219767
}
1976319768

@@ -19809,7 +19814,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
1980919814
tensor->src[j] = args[j];
1981019815
}
1981119816

19812-
result.nodes[i] = tensor;
19817+
result->nodes[i] = tensor;
1981319818

1981419819
fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor));
1981519820
}
@@ -19818,7 +19823,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
1981819823

1981919824
return result;
1982019825
}
19821-
#endif
1982219826

1982319827
void ggml_graph_print(const struct ggml_cgraph * cgraph) {
1982419828
int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};

0 commit comments

Comments
 (0)