Skip to content

Commit

Permalink
examples : add llama_init_from_gpt_params() common function (ggml-org…
Browse files Browse the repository at this point in the history
…#1290)

Signed-off-by: deadprogram <ron@hybridgroup.com>
  • Loading branch information
deadprogram authored May 2, 2023
1 parent 0e6cbff commit 67c7779
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 76 deletions.
31 changes: 31 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,37 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}

struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();

lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;

llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);

if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return NULL;
}

if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(lctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return NULL;
}
}

return lctx;
}

/* Keep track of current color of output, and emit ANSI code if it changes. */
void set_console_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
Expand Down
6 changes: 6 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);

std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);

//
// Model utils
//

struct llama_context * llama_init_from_gpt_params(const gpt_params & params);

//
// Console utils
//
Expand Down
22 changes: 4 additions & 18 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,10 @@ int main(int argc, char ** argv) {
llama_context * ctx;

// load the model
{
auto lparams = llama_context_default_params();

lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.embedding = params.embedding;

ctx = llama_init_from_file(params.model.c_str(), lparams);

if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}

// print system information
Expand Down
33 changes: 5 additions & 28 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,34 +101,11 @@ int main(int argc, char ** argv) {
llama_context * ctx;
g_ctx = &ctx;

// load the model
{
auto lparams = llama_context_default_params();

lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;

ctx = llama_init_from_file(params.model.c_str(), lparams);

if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
}

if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
}
// load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}

// print system information
Expand Down
35 changes: 5 additions & 30 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,36 +122,11 @@ int main(int argc, char ** argv) {

llama_context * ctx;

// load the model
{
auto lparams = llama_context_default_params();

lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.embedding = params.embedding;

ctx = llama_init_from_file(params.model.c_str(), lparams);

if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
}

if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
}
// load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}

// print system information
Expand Down

0 comments on commit 67c7779

Please sign in to comment.