Skip to content

Commit 48a16bf

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
support finetune arg -opt SGD (or sgd). llama 3.2-1b-F32 result: observed 11gb gpu ram when using SGD instead of 20gb using adamw easily/quickly reach 99%+ train accuracy on a tiny wikipedia train (~56% token accuracy on held-out eval - reasonable) note: objective loss not directly comparable between adamw, sgd - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-5 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct cache computed per-epoch optimizer opts (formerly were computed twice per) add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. make ggml_opt_init aware of the optimization method since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no arg to set such a policy yet)
1 parent aa59aa3 commit 48a16bf

File tree

16 files changed

+314
-69
lines changed

16 files changed

+314
-69
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,8 +1237,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12371237
}
12381238
sampler_type_names.pop_back();
12391239

1240-
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1241-
params.optimize.adamw.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1240+
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
12421241

12431242
/**
12441243
* filter options by example
@@ -2182,19 +2181,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21822181
params.ppl_output_type = value;
21832182
}
21842183
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2185-
add_opt(common_arg({ "-lr", "--learning-rate" }, "ALPHA",
2186-
string_format("adamw optimizer alpha (default: %.1f)", (double) params.optimize.adamw.alpha),
2187-
[](common_params & params, const std::string & value) {
2188-
params.optimize.adamw.alpha = std::stof(value);
2189-
})
2184+
add_opt(
2185+
common_arg(
2186+
{ "-lr", "--learning-rate" }, "ALPHA",
2187+
string_format("adamw or sgd optimizer alpha (default: %.2g)", (double) params.optimize.adamw.alpha),
2188+
[](common_params & params, const std::string & value) { params.optimize.adamw.alpha = std::stof(value); })
2189+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2190+
add_opt(common_arg(
2191+
{ "-wd", "--weight-decay" }, "WD",
2192+
string_format("adamw or sgd optimizer weight decay (0 is off) (default: %.2g)",
2193+
(double) params.optimize.adamw.wd),
2194+
[](common_params & params, const std::string & value) { params.optimize.adamw.wd = std::stof(value); })
2195+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2196+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
2197+
string_format("optimizer max # of epochs (default: %d)", params.optimize.epochs),
2198+
[](common_params & params, int epochs) { params.optimize.epochs = epochs; })
21902199
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
21912200
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or //TODO:sgd",
21922201
[](common_params & params, const std::string & name) {
21932202
params.optimize.optimizer = named_ggml_opt_optimizer(name.c_str());
21942203
if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_COUNT) {
21952204
throw std::invalid_argument("invalid --optimizer (try adamw)");
2196-
} else if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_SGD) {
2197-
throw std::invalid_argument("TODO: implement SGD");
21982205
}
21992206
})
22002207
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));

examples/training/finetune.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
4342
common_init_result llama_init = common_init_from_params(params);
4443
llama_model_ptr & model = llama_init.model;
@@ -61,7 +60,8 @@ int main(int argc, char ** argv) {
6160
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6261

6362
struct ggml_opt_optimizer_params & optimizer_params = params.optimize;
64-
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double) optimizer_params.adamw.alpha);
63+
LOG_INF("-optimizer %s -lr: %.2g -epochs %d\n", ggml_opt_optimizer_name(optimizer_params.optimizer),
64+
(double) optimizer_params.adamw.alpha, optimizer_params.epochs);
6565

6666
struct llama_opt_params lopt_params {
6767
/*n_ctx_train =*/ 0,
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
7777
ggml_opt_result_t result_train = ggml_opt_result_init();
7878
ggml_opt_result_t result_eval = ggml_opt_result_init();
7979

80-
for (int epoch = 0; epoch < 2; ++epoch) {
80+
for (unsigned epoch = 0; epoch < optimizer_params.epochs; ++epoch) {
8181
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
8282
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8383
fprintf(stderr, "\n");

ggml/include/ggml-opt.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,17 @@ extern "C" {
9090
// AdamW optimizer parameters
9191
struct {
9292
float alpha; // learning rate
93-
float beta1;
94-
float beta2;
93+
float beta1; // adamw
94+
float beta2; // adamw
9595
float eps; // epsilon for numerical stability
96-
float wd; // weight decay for AdamW, use 0.0f to disable
96+
float wd; // weight decay for SGD or AdamW, use 0.0f to disable
9797
} adamw;
98+
99+
// only GGML_OPT_OPTIMIZER_ADMW allocates m, v per parameter
98100
enum ggml_opt_optimizer optimizer;
101+
102+
// affects finetune.cpp only so far:
103+
unsigned epochs; // max # of epochs sampling over training data
99104
};
100105

101106
// callback to calculate optimizer parameters prior to a backward pass
@@ -126,6 +131,8 @@ extern "C" {
126131

127132
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
128133
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
134+
struct ggml_opt_optimizer_params
135+
opt_params; // holds result of get_opt_pars(get_opt_pars_ud) after ggml_opt_init (could call get_opt_pars repeatedly instead)
129136
};
130137

131138
// get parameters for an optimization context with defaults set where possible

ggml/include/ggml.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,11 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2069+
struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad,
2070+
struct ggml_tensor * adamw_params); // parameters: alpha, the learning rate, and wd, weight decay
2071+
20662072
//
20672073
// automatic differentiation
20682074
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,6 +2057,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20572057
ggml_compute_forward_opt_step_adamw(params, tensor);
20582058
}
20592059
break;
2060+
case GGML_OP_OPT_STEP_SGD:
2061+
{
2062+
ggml_compute_forward_opt_step_sgd(params, tensor);
2063+
}
2064+
break;
20602065
case GGML_OP_NONE:
20612066
{
20622067
// nop
@@ -2341,6 +2346,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23412346
case GGML_OP_CROSS_ENTROPY_LOSS:
23422347
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23432348
case GGML_OP_OPT_STEP_ADAMW:
2349+
case GGML_OP_OPT_STEP_SGD:
23442350
{
23452351
n_tasks = n_threads;
23462352
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8831,7 +8831,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88318831
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
88328832
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
88338833
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
8834-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
8834+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
88358835

88368836
const int ith = params->ith;
88378837
const int nth = params->nth;
@@ -8849,14 +8849,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88498849
const int ir1 = MIN(ir0 + dr, nr);
88508850

88518851
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8852+
88528853
const float alpha = adamw_params_ptr[0];
88538854
const float beta1 = adamw_params_ptr[1];
88548855
const float beta2 = adamw_params_ptr[2];
88558856
const float eps = adamw_params_ptr[3];
8856-
const float wd = adamw_params_ptr[4];
88578857
const float beta1h = adamw_params_ptr[5];
88588858
const float beta2h = adamw_params_ptr[6];
8859-
8859+
const float keep = adamw_params_ptr[7];
88608860
for (int ir = ir0; ir < ir1; ++ir) {
88618861
const int64_t i03 = ir/(ne02*ne01);
88628862
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8879,7 +8879,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88798879
// The weight decay is applied independently of the Adam momenta m and v.
88808880
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
88818881
// See: https://arxiv.org/pdf/1711.05101v3.pdf
8882-
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
8882+
w[i00] = w[i00] * keep - alpha * mh / vh;
88838883
}
88848884
}
88858885
}
@@ -8901,3 +8901,63 @@ void ggml_compute_forward_opt_step_adamw(
89018901
}
89028902
}
89038903
}
8904+
8905+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
8906+
const ggml_tensor * src0 = dst->src[0];
8907+
const ggml_tensor * src0_grad = dst->src[1];
8908+
const ggml_tensor * adamw_params = dst->src[2];
8909+
8910+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
8911+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
8912+
8913+
const int ith = params->ith;
8914+
const int nth = params->nth;
8915+
8916+
const int nr = ggml_nrows(src0);
8917+
8918+
GGML_TENSOR_UNARY_OP_LOCALS
8919+
GGML_ASSERT(nb00 == sizeof(float));
8920+
8921+
// rows per thread
8922+
const int dr = (nr + nth - 1) / nth;
8923+
8924+
// row range for this thread
8925+
const int ir0 = dr * ith;
8926+
const int ir1 = MIN(ir0 + dr, nr);
8927+
8928+
// using adamw param subset we care about - alpha, wd - could have a separate struct
8929+
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8930+
const float alpha = adamw_params_ptr[0];
8931+
const float keep = adamw_params_ptr[7];
8932+
8933+
for (int ir = ir0; ir < ir1; ++ir) {
8934+
const int64_t i03 = ir / (ne02 * ne01);
8935+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
8936+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
8937+
8938+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
8939+
8940+
float * w = (float *) ((char *) src0->data + offset); // weight
8941+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
8942+
8943+
for (int i00 = 0; i00 < ne00; ++i00) {
8944+
w[i00] = w[i00] * keep - alpha * g[i00];
8945+
}
8946+
}
8947+
}
8948+
8949+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
8950+
const ggml_tensor * src0 = dst->src[0];
8951+
8952+
switch (src0->type) {
8953+
case GGML_TYPE_F32:
8954+
{
8955+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
8956+
}
8957+
break;
8958+
default:
8959+
{
8960+
GGML_ABORT("fatal error - sgd is F32 only");
8961+
}
8962+
}
8963+
}

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
104104
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106106
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107-
107+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108108
#ifdef __cplusplus
109109
}
110110
#endif

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "ggml-cuda/mmvq.cuh"
2525
#include "ggml-cuda/norm.cuh"
2626
#include "ggml-cuda/opt-step-adamw.cuh"
27+
#include "ggml-cuda/opt-step-sgd.cuh"
2728
#include "ggml-cuda/out-prod.cuh"
2829
#include "ggml-cuda/pad.cuh"
2930
#include "ggml-cuda/pool2d.cuh"
@@ -2352,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23522353
case GGML_OP_OPT_STEP_ADAMW:
23532354
ggml_cuda_opt_step_adamw(ctx, dst);
23542355
break;
2356+
case GGML_OP_OPT_STEP_SGD:
2357+
ggml_cuda_opt_step_sgd(ctx, dst);
2358+
break;
23552359
default:
23562360
return false;
23572361
}
@@ -3256,6 +3260,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32563260
case GGML_OP_CROSS_ENTROPY_LOSS:
32573261
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
32583262
case GGML_OP_OPT_STEP_ADAMW:
3263+
case GGML_OP_OPT_STEP_SGD:
32593264
return true;
32603265
default:
32613266
return false;

ggml/src/ggml-cuda/opt-step-adamw.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ static __global__ void opt_step_adamw_f32(
1717
const float beta1 = pars[1];
1818
const float beta2 = pars[2];
1919
const float eps = pars[3];
20-
const float wd = pars[4];
2120
const float beta1h = pars[5];
2221
const float beta2h = pars[6];
22+
const float keep = pars[7];
2323

2424
const float gi = g[i];
2525
const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
@@ -31,7 +31,7 @@ static __global__ void opt_step_adamw_f32(
3131
const float mh = gmi*beta1h;
3232
const float vh = sqrtf(gvi*beta2h) + eps;
3333

34-
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
34+
x[i] = x[i] * keep - alpha * mh / vh;
3535
}
3636

3737
static void opt_step_adamw_f32_cuda(
@@ -62,14 +62,13 @@ void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst
6262
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
6363
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
6464
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
65-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
65+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
6666

6767
float * src0_d = (float *) src0->data;
6868
const float * src0_grad_d = (const float *) src0_grad->data;
6969
float * src0_grad_m_d = (float *) src0_grad_m->data;
7070
float * src0_grad_v_d = (float *) src0_grad_v->data;
7171
const float * adamw_params_d = (const float *) adamw_params->data;
72-
7372
cudaStream_t stream = ctx.stream();
7473

7574
const int64_t ne = ggml_nelements(src0);

ggml/src/ggml-cuda/opt-step-sgd.cu

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "ggml-impl.h"
2+
#include "opt-step-sgd.cuh"
3+
4+
#include <cstdint>
5+
6+
static __global__ void opt_step_sgd_f32(
7+
float * __restrict__ x, const float * __restrict__ g,
8+
const float * __restrict__ pars, const int64_t k) {
9+
10+
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11+
12+
if (i >= k)
13+
return;
14+
x[i] = x[i] * pars[7] - pars[0] * g[i];
15+
}
16+
17+
static void opt_step_sgd_f32_cuda(
18+
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
19+
20+
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
21+
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
22+
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
23+
}
24+
25+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
26+
const ggml_tensor * src0 = dst->src[0];
27+
const ggml_tensor * src0_grad = dst->src[1];
28+
const ggml_tensor * adamw_params = dst->src[2];
29+
30+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
31+
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
32+
GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
33+
GGML_ASSERT(ggml_is_contiguous(src0));
34+
GGML_ASSERT(ggml_is_contiguous(src0_grad));
35+
GGML_ASSERT(ggml_is_contiguous(adamw_params));
36+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
37+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
38+
39+
float * src0_d = (float *) src0->data;
40+
const float * src0_grad_d = (const float *) src0_grad->data;
41+
const float * adamw_params_d = (const float *) adamw_params->data;
42+
43+
cudaStream_t stream = ctx.stream();
44+
45+
const int64_t ne = ggml_nelements(src0);
46+
47+
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, adamw_params_d, ne, stream);
48+
}

ggml/src/ggml-cuda/opt-step-sgd.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
4+
5+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)