Skip to content

Commit

Permalink
llama : per-layer KV cache + quantum K cache (#4309)
Browse files Browse the repository at this point in the history
* per-layer KV

* remove unnecessary copies

* less code duplication, offload k and v separately

* llama : offload KV cache per-layer

* llama : offload K shift tensors

* llama : offload for rest of the model arches

* llama : enable offload debug temporarily

* llama : keep the KV related layers on the device

* llama : remove mirrors, perform Device -> Host when partial offload

* common : add command-line arg to disable KV cache offloading

* llama : update session save/load

* llama : support quantum K cache (#4312)

* llama : support quantum K cache (wip)

* metal : add F32 -> Q8_0 copy kernel

* cuda : add F32 -> Q8_0 copy kernel

ggml-ci

* cuda : use mmv kernel for quantum cache ops

* llama : pass KV cache type through API

* llama : fix build

ggml-ci

* metal : add F32 -> Q4_0 copy kernel

* metal : add F32 -> Q4_1 copy kernel

* cuda : wip

* cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels

* llama-bench : support type_k/type_v

* metal : use mm kernel only for quantum KV cache

* cuda : add comment

* llama : remove memory_f16 and kv_f16 flags

---------

Co-authored-by: slaren <slarengh@gmail.com>

* readme : add API change notice

---------

Co-authored-by: slaren <slarengh@gmail.com>
  • Loading branch information
ggerganov and slaren committed Dec 7, 2023
1 parent 81bc921 commit bcc0eb4
Show file tree
Hide file tree
Showing 11 changed files with 747 additions and 287 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++

### Hot topics

- **llama.h API change for handling KV cache offloading and data type: https://github.com/ggerganov/llama.cpp/pull/4309**
- Using `llama.cpp` with AWS instances: https://github.com/ggerganov/llama.cpp/discussions/4225
- Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216
- Collecting Apple Silicon performance stats: https://github.com/ggerganov/llama.cpp/discussions/4167
Expand Down
45 changes: 39 additions & 6 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--samplers") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -510,6 +508,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.infill = true;
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
params.dump_kv_cache = true;
} else if (arg == "-nkvo" || arg == "--no-kv-offload") {
params.no_kv_offload = true;
} else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") {
params.cache_type_v = argv[++i];
} else if (arg == "--multiline-input") {
params.multiline_input = true;
} else if (arg == "--simple-io") {
Expand Down Expand Up @@ -858,8 +862,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
Expand Down Expand Up @@ -900,6 +902,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --verbose-prompt print prompt before generation\n");
printf(" -dkvc, --dump-kv-cache\n");
printf(" verbose print of the KV cache\n");
printf(" -nkvo, --no-kv-offload\n");
printf(" disable KV offload\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
printf(" -ctv TYPE, --cache-type-v TYPE\n");
printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
Expand Down Expand Up @@ -1015,6 +1023,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
return mparams;
}

static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
if (s == "q4_0") {
return GGML_TYPE_Q4_0;
}
if (s == "q4_1") {
return GGML_TYPE_Q4_1;
}
if (s == "q5_0") {
return GGML_TYPE_Q5_0;
}
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}

throw std::runtime_error("Invalid cache type: " + s);
}

struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();

Expand All @@ -1024,7 +1055,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
Expand All @@ -1035,6 +1065,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.offload_kqv = !params.no_kv_offload;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);

return cparams;
}
Expand Down Expand Up @@ -1447,7 +1481,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
Expand Down
7 changes: 5 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ struct gpt_params {
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score

bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
Expand All @@ -125,10 +124,14 @@ struct gpt_params {
bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
std::string image = ""; // path to an image file
std::string image = ""; // path to an image file
};

bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
Expand Down
111 changes: 91 additions & 20 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ static std::vector<T> split(const std::string & str, char delim) {
return values;
}

template<typename T, typename F>
static std::vector<std::string> transform_to_str(const std::vector<T> & values, F f) {
std::vector<std::string> str_values;
std::transform(values.begin(), values.end(), std::back_inserter(str_values), f);
return str_values;
}

template<typename T>
static T avg(const std::vector<T> & v) {
if (v.empty()) {
Expand Down Expand Up @@ -126,7 +133,8 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<int> n_batch;
std::vector<bool> f32_kv;
std::vector<ggml_type> type_k;
std::vector<ggml_type> type_v;
std::vector<int> n_threads;
std::vector<int> n_gpu_layers;
std::vector<int> main_gpu;
Expand All @@ -142,7 +150,8 @@ static const cmd_params cmd_params_defaults = {
/* n_prompt */ {512},
/* n_gen */ {128},
/* n_batch */ {512},
/* f32_kv */ {false},
/* type_k */ {GGML_TYPE_F16},
/* type_v */ {GGML_TYPE_F16},
/* n_threads */ {get_num_physical_cores()},
/* n_gpu_layers */ {99},
/* main_gpu */ {0},
Expand All @@ -162,7 +171,8 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
Expand All @@ -173,9 +183,32 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
}

static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
if (s == "q4_0") {
return GGML_TYPE_Q4_0;
}
if (s == "q4_1") {
return GGML_TYPE_Q4_1;
}
if (s == "q5_0") {
return GGML_TYPE_Q5_0;
}
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}

return GGML_TYPE_COUNT;
}


static cmd_params parse_cmd_params(int argc, char ** argv) {
cmd_params params;
std::string arg;
Expand Down Expand Up @@ -224,13 +257,38 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = split<int>(argv[i], split_delim);
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
} else if (arg == "--memory-f32") {
} else if (arg == "-ctk" || arg == "--cache-type-k") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<int>(argv[i], split_delim);
params.f32_kv.insert(params.f32_kv.end(), p.begin(), p.end());
auto p = split<std::string>(argv[i], split_delim);
std::vector<ggml_type> types;
for (const auto & t : p) {
ggml_type gt = ggml_type_from_name(t);
if (gt == GGML_TYPE_COUNT) {
invalid_param = true;
break;
}
types.push_back(gt);
}
params.type_k.insert(params.type_k.end(), types.begin(), types.end());
} else if (arg == "-ctv" || arg == "--cache-type-v") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<std::string>(argv[i], split_delim);
std::vector<ggml_type> types;
for (const auto & t : p) {
ggml_type gt = ggml_type_from_name(t);
if (gt == GGML_TYPE_COUNT) {
invalid_param = true;
break;
}
types.push_back(gt);
}
params.type_v.insert(params.type_v.end(), types.begin(), types.end());
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -321,7 +379,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
if (params.f32_kv.empty()) { params.f32_kv = cmd_params_defaults.f32_kv; }
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
Expand All @@ -336,7 +395,8 @@ struct cmd_params_instance {
int n_prompt;
int n_gen;
int n_batch;
bool f32_kv;
ggml_type type_k;
ggml_type type_v;
int n_threads;
int n_gpu_layers;
int main_gpu;
Expand Down Expand Up @@ -365,7 +425,8 @@ struct cmd_params_instance {

cparams.n_ctx = n_prompt + n_gen;
cparams.n_batch = n_batch;
cparams.f16_kv = !f32_kv;
cparams.type_k = type_k;
cparams.type_v = type_v;
cparams.mul_mat_q = mul_mat_q;

return cparams;
Expand All @@ -380,15 +441,17 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
for (const auto & mg : params.main_gpu)
for (const auto & ts : params.tensor_split)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & tk : params.type_k)
for (const auto & tv : params.type_v)
for (const auto & mmq : params.mul_mat_q)
for (const auto & nt : params.n_threads) {
cmd_params_instance instance = {
/* .model = */ m,
/* .n_prompt = */ n_prompt,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
Expand All @@ -410,7 +473,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & mg : params.main_gpu)
for (const auto & ts : params.tensor_split)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & tk : params.type_k)
for (const auto & tv : params.type_v)
for (const auto & mmq : params.mul_mat_q)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
Expand All @@ -422,7 +486,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
Expand All @@ -441,7 +506,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
Expand Down Expand Up @@ -489,7 +555,8 @@ struct test {
uint64_t model_n_params;
int n_batch;
int n_threads;
bool f32_kv;
ggml_type type_k;
ggml_type type_v;
int n_gpu_layers;
int main_gpu;
bool mul_mat_q;
Expand All @@ -508,7 +575,8 @@ struct test {
model_n_params = llama_model_n_params(lmodel);
n_batch = inst.n_batch;
n_threads = inst.n_threads;
f32_kv = inst.f32_kv;
type_k = inst.type_k;
type_v = inst.type_v;
n_gpu_layers = inst.n_gpu_layers;
main_gpu = inst.main_gpu;
mul_mat_q = inst.mul_mat_q;
Expand Down Expand Up @@ -571,7 +639,7 @@ struct test {
"cuda", "opencl", "metal", "gpu_blas", "blas",
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "f16_kv",
"n_batch", "n_threads", "type_k", "type_v",
"n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
Expand Down Expand Up @@ -621,7 +689,7 @@ struct test {
std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
Expand Down Expand Up @@ -805,8 +873,11 @@ struct markdown_printer : public printer {
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
fields.push_back("n_batch");
}
if (params.f32_kv.size() > 1 || params.f32_kv != cmd_params_defaults.f32_kv) {
fields.push_back("f16_kv");
if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
fields.push_back("type_k");
}
if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
fields.push_back("type_v");
}
if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
fields.push_back("main_gpu");
Expand Down
1 change: 0 additions & 1 deletion examples/quantize-stats/quantize-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ int main(int argc, char ** argv) {
auto cparams = llama_context_default_params();
cparams.n_ctx = 256;
cparams.seed = 1;
cparams.f16_kv = false;

ctx = llama_new_context_with_model(model, cparams);

Expand Down
Loading

0 comments on commit bcc0eb4

Please sign in to comment.