Skip to content

Commit

Permalink
llama.cpp : split llama_context_params into model and context params (g…
Browse files Browse the repository at this point in the history
…gerganov#3301)

* llama.cpp : split llama_context_params into model and context params

ggml-ci

* fix metal build

* fix freq_base/scale default to model value

* llama-bench : keep the same model between tests when possible

* move n_threads to llama_context_params, add n_threads_batch

* fix mpi build

* remove kv_size(), cuda scratch fixes

* remove low-vram option

* add n_threads_batch to system info, refactor to get_system_info()

* add documentation about --threads-batch to the READMEs

* llama-bench fix

* main : fix rope freq/scale warning

* llama.cpp : add llama_get_model
common : add llama_tokenize from model

* remove duplicated ctx/model functions

ggml-ci

* cuda : print total VRAM used
  • Loading branch information
slaren authored Sep 28, 2023
1 parent 0512d66 commit 16bc66d
Show file tree
Hide file tree
Showing 27 changed files with 713 additions and 633 deletions.
112 changes: 72 additions & 40 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency();
}
} else if (arg == "-tb" || arg == "--threads-batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_threads_batch = std::stoi(argv[i]);
if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency();
}
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -451,12 +460,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.mul_mat_q = false;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--low-vram" || arg == "-lv") {
#ifdef GGML_USE_CUBLAS
params.low_vram = true;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
Expand Down Expand Up @@ -630,7 +633,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" (can be specified more than once for multiple prompts).\n");
printf(" --color colorise output to distinguish prompt and user input from generations\n");
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -p PROMPT, --prompt PROMPT\n");
printf(" prompt to start generation with (default: empty)\n");
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
Expand All @@ -645,7 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -f FNAME, --file FNAME\n");
printf(" prompt file to start generation.\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
Expand Down Expand Up @@ -705,7 +710,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n");
#ifdef GGML_USE_CUBLAS
printf(" -nommq, --no-mul-mat-q\n");
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
Expand All @@ -726,6 +730,18 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf("\n");
}

std::string get_system_info(const gpt_params & params) {
std::ostringstream os;

os << "system_info: n_threads = " << params.n_threads;
if (params.n_threads_batch != -1) {
os << " (n_threads_batch = " << params.n_threads_batch << ")";
}
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();

return os.str();
}

std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
Expand All @@ -749,40 +765,50 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
// Model utils
//

struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
auto mparams = llama_model_default_params();

lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
if (params.n_gpu_layers != -1) {
lparams.n_gpu_layers = params.n_gpu_layers;
mparams.n_gpu_layers = params.n_gpu_layers;
}
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
lparams.low_vram = params.low_vram;
lparams.mul_mat_q = params.mul_mat_q;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.logits_all;
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;

return lparams;
mparams.main_gpu = params.main_gpu;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;

return mparams;
}

struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();

cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;

return cparams;
}

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto lparams = llama_context_params_from_gpt_params(params);
auto mparams = llama_model_params_from_gpt_params(params);

llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
}

llama_context * lctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_params_from_gpt_params(params);

llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
Expand Down Expand Up @@ -815,7 +841,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
LOG("warming up the model with an empty run\n");

std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx);
}
Expand All @@ -828,16 +854,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
//

std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const struct llama_context * ctx,
const std::string & text,
bool add_bos) {
return llama_tokenize(llama_get_model(ctx), text, add_bos);
}

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand All @@ -847,10 +880,10 @@ std::vector<llama_token> llama_tokenize(

std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down Expand Up @@ -905,7 +938,7 @@ llama_token llama_sample_token(
std::vector<llama_token_data> & candidates,
int idx) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));

const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
Expand Down Expand Up @@ -1191,7 +1224,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
#endif // NDEBUG

fprintf(stream, "model_desc: %s\n", model_desc);
fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx));
fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));

#ifdef __OPTIMIZE__
fprintf(stream, "optimize: true\n");
Expand Down Expand Up @@ -1258,7 +1291,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat);
Expand Down
12 changes: 10 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
Expand Down Expand Up @@ -95,7 +96,6 @@ struct gpt_params {
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score

bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
Expand Down Expand Up @@ -126,6 +126,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

void gpt_print_usage(int argc, char ** argv, const gpt_params & params);

std::string get_system_info(const gpt_params & params);

std::string gpt_random_prompt(std::mt19937 & rng);

void process_escapes(std::string& input);
Expand All @@ -135,6 +137,7 @@ void process_escapes(std::string& input);
//

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);

//
Expand All @@ -144,7 +147,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
// tokenizes a string into a vector of tokens
// should work similar to Python's `tokenizer.encode`
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const struct llama_context * ctx,
const std::string & text,
bool add_bos);

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos);

Expand Down
10 changes: 5 additions & 5 deletions common/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ size_t tokenize_file(
out_tokens.resize(buf.size() + n_max_tokens_overhead);

int n_tokens = llama_tokenize(
lctx,
llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
Expand All @@ -867,7 +867,7 @@ size_t tokenize_file(
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
lctx,
llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
Expand Down Expand Up @@ -920,7 +920,7 @@ size_t tokenize_file(
size_t found_max_sample_size = 0;

size_t max_token_text_size = 0;
int n_vocab = llama_n_vocab(lctx);
int n_vocab = llama_n_vocab(llama_get_model(lctx));
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
Expand Down Expand Up @@ -961,15 +961,15 @@ size_t tokenize_file(

// tokenize the sample
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx,
int n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(lctx,
n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
Expand Down
39 changes: 24 additions & 15 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,43 @@ int main(int argc, char ** argv) {

llama_backend_init(params.numa);

llama_context_params ctx_params = llama_context_default_params();
// initialize the model

ctx_params.seed = 1234;
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
ctx_params.n_batch = std::max(n_len, n_parallel);
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model_params model_params = llama_model_default_params();

llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
// model_params.n_gpu_layers = 99; // offload all layers to the GPU

llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}

// tokenize the prompt

std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(model, params.prompt, true);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;

// initialize the context

llama_context_params ctx_params = llama_context_default_params();

ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}

// tokenize the prompt

std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);

const int n_ctx = llama_n_ctx(ctx);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;

LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);

Expand Down Expand Up @@ -106,7 +115,7 @@ int main(int argc, char ** argv) {
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;

if (llama_decode(ctx, batch, params.n_threads) != 0) {
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
Expand Down Expand Up @@ -146,7 +155,7 @@ int main(int argc, char ** argv) {
continue;
}

auto n_vocab = llama_n_vocab(ctx);
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);

std::vector<llama_token_data> candidates;
Expand Down Expand Up @@ -210,7 +219,7 @@ int main(int argc, char ** argv) {
n_cur += 1;

// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
Expand Down
Loading

0 comments on commit 16bc66d

Please sign in to comment.