Skip to content

Commit 0fc3efe

Browse files
stochastic gradient descent op
1 parent d9316cc commit 0fc3efe

File tree

4 files changed

+112
-6
lines changed

4 files changed

+112
-6
lines changed

examples/mnist/mnist-common.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
514514
opt_pars.print_backward_graph = false;
515515
opt_pars.n_threads = std::thread::hardware_concurrency();
516516
opt_pars.adam.n_iter = 1; // per call of ggml_opt_resume_g
517-
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 397510);
517+
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 0);
518518

519519
model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
520520

@@ -530,8 +530,10 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
530530
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
531531
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
532532

533-
enum ggml_opt_result opt_result = ggml_opt_resume_g(model.ctx_compute, &opt_ctx, model.loss, gf, gb, NULL, NULL);
534-
GGML_ASSERT(opt_result == GGML_OPT_RESULT_OK || opt_result == GGML_OPT_RESULT_DID_NOT_CONVERGE);
533+
const float onef = 1.0f;
534+
ggml_backend_graph_compute(model.backend, gf);
535+
ggml_backend_tensor_set(model.loss->grad, &onef, 0, sizeof(float));
536+
ggml_backend_graph_compute(model.backend, gb);
535537

536538
ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss));
537539
ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));

examples/mnist/mnist-common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct mnist_model {
5959
mnist_model() {
6060
// backend = ggml_backend_cuda_init(0);
6161
backend = ggml_backend_cpu_init();
62-
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency());
62+
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency()/2);
6363

6464
buf_weight = malloc(size_weight);
6565
{

include/ggml.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ extern "C" {
528528

529529
GGML_OP_CROSS_ENTROPY_LOSS,
530530
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
531+
GGML_OP_OPT_STEP_ADAM,
531532

532533
GGML_OP_COUNT,
533534
};
@@ -2033,6 +2034,11 @@ extern "C" {
20332034
struct ggml_tensor * b,
20342035
struct ggml_tensor * c);
20352036

2037+
GGML_API struct ggml_tensor * ggml_opt_step_adam(
2038+
struct ggml_context * ctx,
2039+
struct ggml_tensor * a,
2040+
float alpha);
2041+
20362042
//
20372043
// automatic differentiation
20382044
//

src/ggml.c

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,9 +2850,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
28502850

28512851
"CROSS_ENTROPY_LOSS",
28522852
"CROSS_ENTROPY_LOSS_BACK",
2853+
"OPT_STEP_ADAM",
28532854
};
28542855

2855-
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
2856+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
28562857

28572858
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
28582859
"none",
@@ -2942,9 +2943,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29422943

29432944
"cross_entropy_loss(x,y)",
29442945
"cross_entropy_loss_back(x,y)",
2946+
"adam(x)",
29452947
};
29462948

2947-
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
2949+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
29482950

29492951
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
29502952

@@ -8104,6 +8106,26 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
81048106
return result;
81058107
}
81068108

8109+
// opt_step_adam
8110+
8111+
struct ggml_tensor * ggml_opt_step_adam(
8112+
struct ggml_context * ctx,
8113+
struct ggml_tensor * a,
8114+
float alpha) {
8115+
GGML_ASSERT(a->grad);
8116+
8117+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
8118+
8119+
result->op = GGML_OP_OPT_STEP_ADAM;
8120+
result->grad = NULL;
8121+
result->src[0] = a;
8122+
result->src[1] = a->grad;
8123+
8124+
ggml_set_op_params(result, &alpha, sizeof(alpha));
8125+
8126+
return result;
8127+
}
8128+
81078129
////////////////////////////////////////////////////////////////////////////////
81088130

81098131
void ggml_set_param(
@@ -17093,6 +17115,62 @@ static void ggml_compute_forward_cross_entropy_loss_back(
1709317115
}
1709417116
}
1709517117

17118+
static void ggml_compute_forward_opt_step_adam_f32(
17119+
const struct ggml_compute_params * params,
17120+
struct ggml_tensor * dst) {
17121+
17122+
const struct ggml_tensor * src0 = dst->src[0];
17123+
const struct ggml_tensor * src0_grad = dst->src[1];
17124+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
17125+
17126+
const int ith = params->ith;
17127+
const int nth = params->nth;
17128+
17129+
const int nr = ggml_nrows(src0);
17130+
17131+
GGML_TENSOR_UNARY_OP_LOCALS
17132+
GGML_ASSERT(nb00 == sizeof(float));
17133+
17134+
// rows per thread
17135+
const int dr = (nr + nth - 1)/nth;
17136+
17137+
// row range for this thread
17138+
const int ir0 = dr*ith;
17139+
const int ir1 = MIN(ir0 + dr, nr);
17140+
17141+
const float alpha = ggml_get_op_params_f32(dst, 0);
17142+
17143+
for (int ir = ir0; ir < ir1; ++ir) {
17144+
const int64_t i03 = ir/(ne02*ne01);
17145+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
17146+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
17147+
17148+
const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
17149+
17150+
float * weight_ptr = (float *) ((char *) src0->data + offset);
17151+
const float * grad_ptr = (const float *) ((const char *) src0_grad->data + offset);
17152+
17153+
ggml_vec_mad_f32(ne00, weight_ptr, grad_ptr, -alpha);
17154+
}
17155+
}
17156+
17157+
static void ggml_compute_forward_opt_step_adam(
17158+
const struct ggml_compute_params * params,
17159+
struct ggml_tensor * dst) {
17160+
17161+
const struct ggml_tensor * src0 = dst->src[0];
17162+
17163+
switch (src0->type) {
17164+
case GGML_TYPE_F32:
17165+
{
17166+
ggml_compute_forward_opt_step_adam_f32(params, dst);
17167+
} break;
17168+
default:
17169+
{
17170+
GGML_ABORT("fatal error");
17171+
}
17172+
}
17173+
}
1709617174
/////////////////////////////////
1709717175

1709817176
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -17434,6 +17512,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1743417512
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
1743517513
}
1743617514
break;
17515+
case GGML_OP_OPT_STEP_ADAM:
17516+
{
17517+
ggml_compute_forward_opt_step_adam(params, tensor);
17518+
}
17519+
break;
1743717520
case GGML_OP_NONE:
1743817521
{
1743917522
// nop
@@ -18520,6 +18603,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1852018603
{
1852118604
GGML_ABORT("fatal error"); // not supported
1852218605
}
18606+
case GGML_OP_OPT_STEP_ADAM:
18607+
{
18608+
GGML_ABORT("fatal error"); // not supported
18609+
}
1852318610
case GGML_OP_NONE:
1852418611
{
1852518612
// nop
@@ -18652,6 +18739,16 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
1865218739
}
1865318740
}
1865418741

18742+
for (int i = 0; i < gf->n_nodes; i++) {
18743+
struct ggml_tensor * node = gf->nodes[i];
18744+
18745+
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
18746+
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18747+
struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, 0.001f);
18748+
ggml_build_forward_expand(gb, opt_step);
18749+
}
18750+
}
18751+
1865518752
ggml_hash_set_free(&zero_table);
1865618753
}
1865718754

@@ -19107,6 +19204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1910719204
} break;
1910819205
case GGML_OP_CROSS_ENTROPY_LOSS:
1910919206
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
19207+
case GGML_OP_OPT_STEP_ADAM:
1911019208
{
1911119209
n_tasks = n_threads;
1911219210
} break;

0 commit comments

Comments
 (0)