Skip to content

Commit aa59aa3

Browse files
committed
finetune.cpp command-line arg
add to ggml-opt learning rate (adamw alpha) cmdline arg, and an optimizer enum defaulting to adamw, including string->id mapping, preparatory to work to support SGD these are in common args a set of optimizer options active only for the new FINETUNE example (but we drop all the previous finetune.cpp PERPLEXITY options which we're told are unused/accidental) perhaps breaking with precedent, the ggml_opt_optimizer_params struct is included directly as args - if desired, we can instead just add learning rate and optimizer type to a struct independent of ggml-opt.h as proposed in #13835
1 parent e0e3aa2 commit aa59aa3

File tree

5 files changed

+103
-48
lines changed

5 files changed

+103
-48
lines changed

common/arg.cpp

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,48 +1085,47 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
10851085
printf(" esac\n");
10861086
printf("}\n\n");
10871087

1088-
std::set<std::string> executables = {
1089-
"llama-batched",
1090-
"llama-batched-bench",
1091-
"llama-bench",
1092-
"llama-cli",
1093-
"llama-convert-llama2c-to-ggml",
1094-
"llama-cvector-generator",
1095-
"llama-embedding",
1096-
"llama-eval-callback",
1097-
"llama-export-lora",
1098-
"llama-gen-docs",
1099-
"llama-gguf",
1100-
"llama-gguf-hash",
1101-
"llama-gguf-split",
1102-
"llama-gritlm",
1103-
"llama-imatrix",
1104-
"llama-infill",
1105-
"llama-mtmd-cli",
1106-
"llama-llava-clip-quantize-cli",
1107-
"llama-lookahead",
1108-
"llama-lookup",
1109-
"llama-lookup-create",
1110-
"llama-lookup-merge",
1111-
"llama-lookup-stats",
1112-
"llama-parallel",
1113-
"llama-passkey",
1114-
"llama-perplexity",
1115-
"llama-q8dot",
1116-
"llama-quantize",
1117-
"llama-qwen2vl-cli",
1118-
"llama-retrieval",
1119-
"llama-run",
1120-
"llama-save-load-state",
1121-
"llama-server",
1122-
"llama-simple",
1123-
"llama-simple-chat",
1124-
"llama-speculative",
1125-
"llama-speculative-simple",
1126-
"llama-tokenize",
1127-
"llama-tts",
1128-
"llama-vdot"
1129-
};
1088+
std::set<std::string> executables = { "llama-batched",
1089+
"llama-batched-bench",
1090+
"llama-bench",
1091+
"llama-cli",
1092+
"llama-convert-llama2c-to-ggml",
1093+
"llama-cvector-generator",
1094+
"llama-embedding",
1095+
"llama-eval-callback",
1096+
"llama-export-lora",
1097+
"llama-finetune",
1098+
"llama-gen-docs",
1099+
"llama-gguf",
1100+
"llama-gguf-hash",
1101+
"llama-gguf-split",
1102+
"llama-gritlm",
1103+
"llama-imatrix",
1104+
"llama-infill",
1105+
"llama-mtmd-cli",
1106+
"llama-llava-clip-quantize-cli",
1107+
"llama-lookahead",
1108+
"llama-lookup",
1109+
"llama-lookup-create",
1110+
"llama-lookup-merge",
1111+
"llama-lookup-stats",
1112+
"llama-parallel",
1113+
"llama-passkey",
1114+
"llama-perplexity",
1115+
"llama-q8dot",
1116+
"llama-quantize",
1117+
"llama-qwen2vl-cli",
1118+
"llama-retrieval",
1119+
"llama-run",
1120+
"llama-save-load-state",
1121+
"llama-server",
1122+
"llama-simple",
1123+
"llama-simple-chat",
1124+
"llama-speculative",
1125+
"llama-speculative-simple",
1126+
"llama-tokenize",
1127+
"llama-tts",
1128+
"llama-vdot" };
11301129

11311130
for (const auto& exe : executables) {
11321131
printf("complete -F _llama_completions %s\n", exe.c_str());
@@ -1238,6 +1237,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12381237
}
12391238
sampler_type_names.pop_back();
12401239

1240+
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1241+
params.optimize.adamw.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
12411242

12421243
/**
12431244
* filter options by example
@@ -2181,6 +2182,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21812182
params.ppl_output_type = value;
21822183
}
21832184
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2185+
add_opt(common_arg({ "-lr", "--learning-rate" }, "ALPHA",
2186+
string_format("adamw optimizer alpha (default: %.1f)", (double) params.optimize.adamw.alpha),
2187+
[](common_params & params, const std::string & value) {
2188+
params.optimize.adamw.alpha = std::stof(value);
2189+
})
2190+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2191+
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or //TODO:sgd",
2192+
[](common_params & params, const std::string & name) {
2193+
params.optimize.optimizer = named_ggml_opt_optimizer(name.c_str());
2194+
if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_COUNT) {
2195+
throw std::invalid_argument("invalid --optimizer (try adamw)");
2196+
} else if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_SGD) {
2197+
throw std::invalid_argument("TODO: implement SGD");
2198+
}
2199+
})
2200+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
21842201
add_opt(common_arg(
21852202
{"-dt", "--defrag-thold"}, "N",
21862203
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),

common/common.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
#pragma once
44

5-
#include "llama-cpp.h"
6-
75
#include <set>
6+
#include <sstream>
87
#include <string>
98
#include <string_view>
109
#include <vector>
11-
#include <sstream>
10+
11+
#include "ggml-opt.h"
12+
#include "llama-cpp.h"
1213

1314
#ifdef _WIN32
1415
#define DIRECTORY_SEPARATOR '\\'
@@ -80,6 +81,7 @@ enum llama_example {
8081
LLAMA_EXAMPLE_LOOKUP,
8182
LLAMA_EXAMPLE_PARALLEL,
8283
LLAMA_EXAMPLE_TTS,
84+
LLAMA_EXAMPLE_FINETUNE,
8385

8486
LLAMA_EXAMPLE_COUNT,
8587
};
@@ -349,6 +351,8 @@ struct common_params {
349351
bool no_mmproj = false; // explicitly disable multimodal model
350352
std::vector<std::string> image; // path to image file(s)
351353

354+
// finetune
355+
struct ggml_opt_optimizer_params optimize;
352356
// embedding
353357
bool embedding = false; // get only sentence embedding
354358
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)

examples/training/finetune.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ int main(int argc, char ** argv) {
1818

1919
params.escape = false;
2020

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
21+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2222
return 1;
2323
}
2424

@@ -60,8 +60,8 @@ int main(int argc, char ** argv) {
6060
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
6161
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6262

63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
63+
struct ggml_opt_optimizer_params & optimizer_params = params.optimize;
64+
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double) optimizer_params.adamw.alpha);
6565

6666
struct llama_opt_params lopt_params {
6767
/*n_ctx_train =*/ 0,

ggml/include/ggml-opt.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer {
78+
GGML_OPT_OPTIMIZER_ADAMW,
79+
GGML_OPT_OPTIMIZER_SGD,
80+
81+
GGML_OPT_OPTIMIZER_COUNT
82+
};
83+
84+
// "adamw" or "sgd" (case insensitive)
85+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer);
86+
GGML_API enum ggml_opt_optimizer named_ggml_opt_optimizer(const char *);
87+
7788
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7889
struct ggml_opt_optimizer_params {
7990
// AdamW optimizer parameters
@@ -84,6 +95,7 @@ extern "C" {
8495
float eps; // epsilon for numerical stability
8596
float wd; // weight decay for AdamW, use 0.0f to disable
8697
} adamw;
98+
enum ggml_opt_optimizer optimizer;
8799
};
88100

89101
// callback to calculate optimizer parameters prior to a backward pass

ggml/src/ggml-opt.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,32 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
228228
result.adamw.beta2 = 0.999f;
229229
result.adamw.eps = 1e-8f;
230230
result.adamw.wd = 0.0f;
231+
result.optimizer = GGML_OPT_OPTIMIZER_ADAMW;
231232

232233
return result;
233234
}
234235

236+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer o) {
237+
switch (o) {
238+
case GGML_OPT_OPTIMIZER_ADAMW:
239+
return "adamw";
240+
case GGML_OPT_OPTIMIZER_SGD:
241+
return "sgd";
242+
default:
243+
return "undefined";
244+
};
245+
}
246+
247+
GGML_API enum ggml_opt_optimizer named_ggml_opt_optimizer(const char * n) {
248+
if (!strcasecmp("adamw", n)) {
249+
return GGML_OPT_OPTIMIZER_ADAMW;
250+
} else if (!strcasecmp("sgd", n)) {
251+
return GGML_OPT_OPTIMIZER_SGD;
252+
} else {
253+
return GGML_OPT_OPTIMIZER_COUNT;
254+
}
255+
}
256+
235257
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
236258
return *((struct ggml_opt_optimizer_params *) userdata);
237259
}

0 commit comments

Comments
 (0)